1// Copyright 2024 Mozilla Foundation
2//
3// Permission is hereby granted, free of charge, to any person obtaining
4// a copy of this software and associated documentation files (the
5// "Software"), to deal in the Software without restriction, including
6// without limitation the rights to use, copy, modify, merge, publish,
7// distribute, sublicense, and/or sell copies of the Software, and to
8// permit persons to whom the Software is furnished to do so, subject to
9// the following conditions:
10//
11// The above copyright notice and this permission notice shall be
12// included in all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
18// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
19// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23//
24// _ _ ___ _ _ ___
25// | |_(_)_ _ _ _| _ ) | /_\ / __|
26// | _| | ' \ || | _ \ |__ / _ \\__ \.
27// \__|_|_||_\_, |___/____/_/ \_\___/
28// |__/
29//
30// BASIC LINEAR ALGEBRA SUBPROGRAMS
31//
32//
33// This file implements multithreaded CPU matrix multiplication for the
34// common contiguous use case C = Aᵀ * B. These kernels are designed to
35// have excellent performance[1] for matrices that fit in the CPU cache
36// without imposing any overhead such as cache filling or malloc calls.
37//
38// This implementation does not guarantee any upper bound with rounding
39// errors, which grow along with k. Our goal's to maximally exploit the
40// hardware for performance, and then use whatever resources remain for
41// improving numerical accuracy.
42//
43// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
44// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
45
46#if defined(__GNUC__)
47#pragma GCC diagnostic ignored "-Wpedantic"
48#pragma GCC diagnostic ignored "-Wignored-attributes"
49#endif
50
51#include "sgemm.h"
52#include "ggml-impl.h"
53#include "ggml-cpu-impl.h"
54#include "ggml-quants.h"
55#include "simd-mappings.h"
56
57#include <array>
58#include <type_traits>
59
60#ifdef _MSC_VER
61#define NOINLINE __declspec(noinline)
62#else
63#define NOINLINE __attribute__((__noinline__))
64#endif
65
66#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
67#define VECTOR_REGISTERS 32
68#else
69#define VECTOR_REGISTERS 16
70#endif
71
72#if defined(__riscv_v_intrinsic)
73#define LMUL 4
74#endif
75
76#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
77
78namespace {
79
80inline float unhalf(ggml_fp16_t d) {
81 return GGML_CPU_FP16_TO_FP32(d);
82}
83
84////////////////////////////////////////////////////////////////////////////////////////////////////
85// VECTORIZED ARITHMETIC OPERATIONS
86
87#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
88inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
89inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
90inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
91#endif // __SSE__
92
93#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
94inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
95inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
96inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
97#endif // __AVX__
98
99#if defined(__AVX512F__)
100inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
101inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
102inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
103#endif // __AVX512F__
104
105#if defined(__ARM_NEON)
106inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
107inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
108inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
109#endif // __ARM_NEON
110
111#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
112inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
113inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
114inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
115#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
116
117#if defined(__VXE__) || defined(__VXE2__)
118inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
119inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
120inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
121#endif
122
123#if defined(__MMA__)
124#include "sgemm-ppc.h"
125#endif
126////////////////////////////////////////////////////////////////////////////////////////////////////
127// VECTORIZED FUSED MULTIPLY ADD
128
129/**
130 * Computes a * b + c.
131 */
132template <typename T, typename U>
133inline U madd(T a, T b, U c) {
134 return add(mul(a, b), c);
135}
136
137#if defined(__FMA__)
138#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
139template <>
140inline __m256 madd(__m256 a, __m256 b, __m256 c) {
141 return _mm256_fmadd_ps(a, b, c);
142}
143#endif
144#if defined(__AVX512F__)
145template <>
146inline __m512 madd(__m512 a, __m512 b, __m512 c) {
147 return _mm512_fmadd_ps(a, b, c);
148}
149#endif
150#if defined(__AVX512BF16__)
151template <>
152inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
153 return _mm512_dpbf16_ps(c, a, b);
154}
155template <>
156inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
157 return _mm256_dpbf16_ps(c, a, b);
158}
159#endif
160#endif
161
162#if defined(__ARM_FEATURE_FMA)
163template <>
164inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
165 return vfmaq_f32(c, b, a);
166}
167#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
168template <>
169inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
170 return vfmaq_f16(c, b, a);
171}
172#endif
173#endif
174
175#if defined(__VXE__) || defined(__VXE2__)
176template <>
177inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
178 return vec_madd(a, b, c);
179}
180#endif
181
182#if defined(__riscv_zvfh)
183template <>
184inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
185 return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
186}
187inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
188 return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
189}
190inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
191 return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
192}
193inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
194 return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
195}
196inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
197 return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
198}
199inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
200 return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
201}
202inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
203 return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
204}
205inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
206 return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
207}
208#endif
209
210#if defined(__riscv_zvfbfwma)
211inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
212 return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
213}
214inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
215 return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
216}
217inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
218 return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
219}
220#endif
221
222////////////////////////////////////////////////////////////////////////////////////////////////////
223// VECTORIZED HORIZONTAL SUM
224
225#if defined(__ARM_NEON)
226inline float hsum(float32x4_t x) {
227 return vaddvq_f32(x);
228}
229#endif // __ARM_NEON
230
231#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
232inline float hsum(float16x8_t x) {
233 return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
234 vcvt_f32_f16(vget_high_f16(x))));
235}
236#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
237
238#if defined(__VXE__) || defined(__VXE2__)
239inline float hsum(float32x4_t x) {
240 float32x4_t tmp = x + vec_reve(x);
241 return tmp[0] + tmp[1];
242}
243#endif
244
245#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
246inline float hsum(__m128 x) {
247#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
248 x = _mm_add_ps(x, _mm_movehl_ps(x, x));
249 x = _mm_add_ss(x, _mm_movehdup_ps(x));
250#else
251 __m128 t;
252 t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
253 x = _mm_add_ps(x, t);
254 t = _mm_movehl_ps(t, x);
255 x = _mm_add_ss(x, t);
256#endif
257 return _mm_cvtss_f32(x);
258}
259#endif
260
261#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
262inline float hsum(__m256 x) {
263 return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
264 _mm256_castps256_ps128(x)));
265}
266#endif // __AVX__
267
268#if defined(__AVX512F__)
269inline float hsum(__m512 x) {
270 return _mm512_reduce_add_ps(x);
271}
272#endif // __AVX512F__
273
274#if defined(__riscv_zvfh)
275inline float hsum(vfloat32m1_t x) {
276 return __riscv_vfmv_f_s_f32m1_f32(
277 __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
278}
279inline float hsum(vfloat32m2_t x) {
280 return __riscv_vfmv_f_s_f32m1_f32(
281 __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
282}
283inline float hsum(vfloat32m4_t x) {
284 return __riscv_vfmv_f_s_f32m1_f32(
285 __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
286}
287inline float hsum(vfloat32m8_t x) {
288 return __riscv_vfmv_f_s_f32m1_f32(
289 __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
290}
291#endif
292
293////////////////////////////////////////////////////////////////////////////////////////////////////
294// VECTORIZED MEMORY LOADING
295
296template <typename T, typename U> T load(const U *);
297
298#if defined(__ARM_NEON)
299template <> inline float32x4_t load(const float *p) {
300 return vld1q_f32(p);
301}
302#if !defined(_MSC_VER)
303// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
304template <> inline float16x8_t load(const ggml_fp16_t *p) {
305 return vld1q_f16((const float16_t *)p);
306}
307template <> inline float32x4_t load(const ggml_fp16_t *p) {
308 return vcvt_f32_f16(vld1_f16((const float16_t *)p));
309}
310#endif // _MSC_VER
311#endif // __ARM_NEON
312
313#if defined(__VXE__) || defined(__VXE2__)
314template <> inline float32x4_t load(const ggml_fp16_t * p) {
315 float tmp[4];
316
317 for (int i = 0; i < 4; i++) {
318 tmp[i] = GGML_CPU_FP16_TO_FP32(p[i]);
319 }
320
321 return vec_xl(0, (const float *)(tmp));
322}
323template <> inline float32x4_t load(const float * p) {
324 return vec_xl(0, p);
325}
326#endif
327
328#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
329template <> inline __m128 load(const float *p) {
330 return _mm_loadu_ps(p);
331}
332#endif // __SSE__
333
334#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
335template <> inline __m256 load(const float *p) {
336 return _mm256_loadu_ps(p);
337}
338#endif // __AVX__
339
340#if defined(__AVX2__) || defined(__AVX512F__)
341template <> inline __m256 load(const ggml_bf16_t *p) {
342 return _mm256_castsi256_ps(
343 _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
344}
345#endif // __AVX2__
346
347#if defined(__F16C__)
348template <> inline __m256 load(const ggml_fp16_t *p) {
349 return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
350}
351#endif // __F16C__
352
353#if defined(__AVX512F__)
354template <> inline __m512 load(const float *p) {
355 return _mm512_loadu_ps(p);
356}
357template <> inline __m512 load(const ggml_fp16_t *p) {
358 return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
359}
360template <> inline __m512 load(const ggml_bf16_t *p) {
361 return _mm512_castsi512_ps(
362 _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
363}
364#endif // __AVX512F__
365
366#if defined(__AVX512BF16__)
367template <> inline __m512bh load(const ggml_bf16_t *p) {
368 return (__m512bh)_mm512_loadu_ps((const float *)p);
369}
370template <> inline __m256bh load(const ggml_bf16_t *p) {
371 return (__m256bh)_mm256_loadu_ps((const float *)p);
372}
373template <> inline __m512bh load(const float *p) {
374 return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
375}
376template <> inline __m256bh load(const float *p) {
377 return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
378}
379#endif
380
381#if defined(__riscv_zvfh)
382template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
383 return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
384}
385template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
386 return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
387}
388template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
389 return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
390}
391template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
392 return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
393}
394template <> inline vfloat32m1_t load(const float *p) {
395 return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
396}
397template <> inline vfloat32m2_t load(const float *p) {
398 return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
399}
400template <> inline vfloat32m4_t load(const float *p) {
401 return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
402}
403template <> inline vfloat32m8_t load(const float *p) {
404 return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
405}
406#endif
407
408#if defined(__riscv_zvfbfwma)
409template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
410 return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
411}
412template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
413 return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
414}
415template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
416 return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
417}
418#endif
419
420#if defined(__riscv_zvfh)
421template <typename T> T set_zero();
422
423template <> inline vfloat16mf2_t set_zero() {
424 return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
425}
426template <> inline vfloat16m1_t set_zero() {
427 return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
428}
429template <> inline vfloat16m2_t set_zero() {
430 return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
431}
432template <> inline vfloat16m4_t set_zero() {
433 return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
434}
435template <> inline vfloat32m1_t set_zero() {
436 return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
437}
438template <> inline vfloat32m2_t set_zero() {
439 return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
440}
441template <> inline vfloat32m4_t set_zero() {
442 return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
443}
444template <> inline vfloat32m8_t set_zero() {
445 return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
446}
447#endif
448
449#if defined(__riscv_v_intrinsic)
450template <typename T> size_t vlmax() {
451 if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
452 else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
453 else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
454 else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
455 else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
456 else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
457 else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
458 else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
459 return 0;
460}
461#endif
462
463////////////////////////////////////////////////////////////////////////////////////////////////////
464// FLOATING POINT MATRIX MULTIPLICATION
465
466template <int M>
467static inline int64_t BLOCK_SIZE(size_t m) {
468 const int64_t NB_BLOC_M = (m + M - 1) / M;
469 return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
470}
471
472static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
473 return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
474}
475
476template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
477class tinyBLAS {
478 public:
479 tinyBLAS(const ggml_compute_params * params, int64_t k,
480 const TA *A, int64_t lda,
481 const TB *B, int64_t ldb,
482 TC *C, int64_t ldc)
483 : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
484 }
485
486 bool matmul(int64_t m, int64_t n) {
487 if (k % KN != 0)
488 return false;
489 // compute RM for only need tile with size RM&RM-1
490#if VECTOR_REGISTERS == 32
491 if (m % 16 == 0 && (m/16 >= params->nth)) {
492 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
493 mnpack<4, 6, 4>(m, n, SIZE_N, 12);
494 return true;
495 }
496 if (m % 8 == 0 ) {
497 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
498 mnpack<4, 6, 2>(m, n, SIZE_N, 12);
499 return true;
500 }
501 if (m % 4 == 0) {
502 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
503 mnpack<4, 6, 1>(m, n, SIZE_N, 12);
504 return true;
505 }
506#else // VECTOR_REGISTERS == 16
507 if (m % 16 == 0 && (m/16 >= params->nth)) {
508 const int64_t SIZE_N = BLOCK_SIZE<3>(n);
509 mnpack<4, 3, 4>(m, n, SIZE_N, 24);
510 return true;
511 }
512 if (m % 8 == 0 ) {
513 const int64_t SIZE_N = BLOCK_SIZE<3>(n);
514 mnpack<4, 3, 2>(m, n, SIZE_N, 24);
515 return true;
516 }
517 if (m % 4 == 0) {
518 const int64_t SIZE_N = BLOCK_SIZE<3>(n);
519 mnpack<4, 3, 1>(m, n, SIZE_N, 24);
520 return true;
521 }
522#endif
523 return false;
524 }
525
526 private:
527 template <int RM, int RN, int BM>
528 inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
529 if (SIZE_N == RN) {
530 return gemm<RM, RN, BM>(m, n, BN);
531 }
532 if constexpr (RN > 1) {
533 return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
534 } else {
535 GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
536 GGML_ASSERT(false); // we have miss something.
537 }
538 }
539
540 template <int RM, int RN>
541 inline void gemm_bloc(int64_t ii, int64_t jj) {
542 D Cv[RN][RM] = {};
543 for (int64_t l = 0; l < k; l += KN) {
544 // help compiler for op order.
545 if constexpr (RM <= RN) {
546 V Av[RM];
547 for (int64_t i = 0; i < RM; ++i) {
548 Av[i] = load<V>(A + lda * (ii + i) + l);
549 }
550 for (int64_t j = 0; j < RN; ++j) {
551 V Bv = load<V>(B + ldb * (jj + j) + l);
552 for (int64_t i = 0; i < RM; ++i) {
553 Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
554 }
555 }
556 } else {
557 V Bv[RN];
558 for (int64_t j = 0; j < RN; ++j) {
559 Bv[j] = load<V>(B + ldb * (jj + j) + l);
560 }
561 for (int64_t i = 0; i < RM; ++i) {
562 V Av = load<V>(A + lda * (ii + i) + l);
563 for (int64_t j = 0; j < RN; ++j) {
564 Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
565 }
566 }
567 }
568 }
569 for (int64_t j = 0; j < RN; ++j)
570 for (int64_t i = 0; i < RM; ++i)
571 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
572 }
573
574 template <int RM, int RN, int BM>
575 NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
576 GGML_ASSERT(m % (RM * BM) == 0);
577 const int64_t ytiles = m / (RM * BM);
578 const int64_t xtiles = (n + RN -1) / RN;
579 const int64_t jj_RN = (xtiles - (xtiles * RN - n));
580
581 // "round" bloc_size to "nearest" BN
582 const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
583 const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
584 const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
585 const int64_t nb_job = ytiles * NB_BN;
586
587 if (params->ith == 0) {
588 GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
589 // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
590 ggml_threadpool_chunk_set(params->threadpool, params->nth);
591 }
592
593 ggml_barrier(params->threadpool);
594
595 int64_t job = params->ith;
596 while (job < nb_job) {
597 const int64_t ii = (job % ytiles) * RM * BM;
598 const int64_t jb = job / ytiles;
599 const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
600 const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
601
602 const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
603 const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
604 const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
605
606 for (int64_t bi = 0; bi < BM * RM; bi += RM) {
607 int64_t jj = jj0;
608 for (; jj < jj1; jj += RN) {
609 gemm_bloc<RM, RN>(ii + bi, jj);
610 }
611 if constexpr (RN > 1) {
612 for (; jj < jj2; jj += RN - 1) {
613 gemm_bloc<RM, RN-1>(ii + bi, jj);
614 }
615 }
616 GGML_ASSERT(jj == jj2);
617 }
618
619 job = ggml_threadpool_chunk_add(params->threadpool, 1);
620 }
621
622 ggml_barrier(params->threadpool);
623 return;
624 }
625
626 const ggml_compute_params * params;
627 const TA *const A;
628 const TB *const B;
629 TC *const C;
630 const int64_t k;
631 const int64_t lda;
632 const int64_t ldb;
633 const int64_t ldc;
634};
635
636#if defined(__riscv_v_intrinsic)
637template <typename D, typename V, typename TA, typename TB, typename TC>
638class tinyBLAS_RVV {
639 public:
640 tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
641 const TA *A, int64_t lda,
642 const TB *B, int64_t ldb,
643 TC *C, int64_t ldc)
644 : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
645 }
646
647 bool matmul(int64_t m, int64_t n) {
648 if (k % vlmax<V>() != 0) {
649 return false;
650 }
651
652#if LMUL == 1
653 if (m % 16 == 0 && (m/16 >= params->nth)) {
654 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
655 mnpack<4, 6, 4>(m, n, SIZE_N, 12);
656 return true;
657 }
658 if (m % 8 == 0 ) {
659 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
660 mnpack<4, 6, 2>(m, n, SIZE_N, 12);
661 return true;
662 }
663 if (m % 4 == 0) {
664 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
665 mnpack<4, 6, 1>(m, n, SIZE_N, 12);
666 return true;
667 }
668#elif LMUL == 2
669 if (m % 16 == 0 && (m/16 >= params->nth)) {
670 const int64_t SIZE_N = BLOCK_SIZE<3>(n);
671 mnpack<4, 3, 4>(m, n, SIZE_N, 24);
672 return true;
673 }
674 if (m % 8 == 0 ) {
675 const int64_t SIZE_N = BLOCK_SIZE<3>(n);
676 mnpack<4, 3, 2>(m, n, SIZE_N, 24);
677 return true;
678 }
679 if (m % 4 == 0) {
680 const int64_t SIZE_N = BLOCK_SIZE<3>(n);
681 mnpack<4, 3, 1>(m, n, SIZE_N, 24);
682 return true;
683 }
684#else // LMUL = 4
685 if (m % 16 == 0 && (m/16 >= params->nth)) {
686 const int64_t SIZE_N = BLOCK_SIZE<2>(n);
687 mnpack<2, 2, 8>(m, n, SIZE_N, 36);
688 return true;
689 }
690 if (m % 8 == 0 ) {
691 const int64_t SIZE_N = BLOCK_SIZE<2>(n);
692 mnpack<2, 2, 4>(m, n, SIZE_N, 36);
693 return true;
694 }
695 if (m % 4 == 0) {
696 const int64_t SIZE_N = BLOCK_SIZE<2>(n);
697 mnpack<2, 2, 2>(m, n, SIZE_N, 36);
698 return true;
699 }
700#endif
701 return false;
702 }
703
704 private:
705 template<int RM, int RN, int BM>
706 inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
707 if (SIZE_N == RN) {
708 return gemm<RM, RN, BM>(m, n, BN);
709 }
710 if constexpr (RN > 1) {
711 return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
712 } else {
713 GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
714 GGML_ASSERT(false); // we have miss something.
715 }
716 }
717
718 inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
719 size_t vl = vlmax<V>();
720 D Cv00 = set_zero<D>();
721 D Cv01 = set_zero<D>();
722 D Cv02 = set_zero<D>();
723 D Cv03 = set_zero<D>();
724 D Cv10 = set_zero<D>();
725 D Cv11 = set_zero<D>();
726 D Cv12 = set_zero<D>();
727 D Cv13 = set_zero<D>();
728 D Cv20 = set_zero<D>();
729 D Cv21 = set_zero<D>();
730 D Cv22 = set_zero<D>();
731 D Cv23 = set_zero<D>();
732 D Cv30 = set_zero<D>();
733 D Cv31 = set_zero<D>();
734 D Cv32 = set_zero<D>();
735 D Cv33 = set_zero<D>();
736 D Cv40 = set_zero<D>();
737 D Cv41 = set_zero<D>();
738 D Cv42 = set_zero<D>();
739 D Cv43 = set_zero<D>();
740 D Cv50 = set_zero<D>();
741 D Cv51 = set_zero<D>();
742 D Cv52 = set_zero<D>();
743 D Cv53 = set_zero<D>();
744
745 for (int64_t l = 0; l < k; l += vl) {
746 V Bv0 = load<V>(B + ldb * (jj + 0) + l);
747 V Bv1 = load<V>(B + ldb * (jj + 1) + l);
748 V Bv2 = load<V>(B + ldb * (jj + 2) + l);
749 V Bv3 = load<V>(B + ldb * (jj + 3) + l);
750 V Bv4 = load<V>(B + ldb * (jj + 4) + l);
751 V Bv5 = load<V>(B + ldb * (jj + 5) + l);
752
753 V Av0 = load<V>(A + lda * (ii + 0) + l);
754 Cv00 = madd(Av0, Bv0, Cv00);
755 Cv10 = madd(Av0, Bv1, Cv10);
756 Cv20 = madd(Av0, Bv2, Cv20);
757 Cv30 = madd(Av0, Bv3, Cv30);
758 Cv40 = madd(Av0, Bv4, Cv40);
759 Cv50 = madd(Av0, Bv5, Cv50);
760
761 V Av1 = load<V>(A + lda * (ii + 1) + l);
762 Cv01 = madd(Av1, Bv0, Cv01);
763 Cv11 = madd(Av1, Bv1, Cv11);
764 Cv21 = madd(Av1, Bv2, Cv21);
765 Cv31 = madd(Av1, Bv3, Cv31);
766 Cv41 = madd(Av1, Bv4, Cv41);
767 Cv51 = madd(Av1, Bv5, Cv51);
768
769 V Av2 = load<V>(A + lda * (ii + 2) + l);
770 Cv02 = madd(Av2, Bv0, Cv02);
771 Cv12 = madd(Av2, Bv1, Cv12);
772 Cv22 = madd(Av2, Bv2, Cv22);
773 Cv32 = madd(Av2, Bv3, Cv32);
774 Cv42 = madd(Av2, Bv4, Cv42);
775 Cv52 = madd(Av2, Bv5, Cv52);
776
777 V Av3 = load<V>(A + lda * (ii + 3) + l);
778 Cv03 = madd(Av3, Bv0, Cv03);
779 Cv13 = madd(Av3, Bv1, Cv13);
780 Cv23 = madd(Av3, Bv2, Cv23);
781 Cv33 = madd(Av3, Bv3, Cv33);
782 Cv43 = madd(Av3, Bv4, Cv43);
783 Cv53 = madd(Av3, Bv5, Cv53);
784 }
785
786 C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
787 C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
788 C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
789 C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
790 C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
791 C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
792 C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
793 C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
794 C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
795 C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
796 C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
797 C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
798 C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
799 C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
800 C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
801 C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
802 C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
803 C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
804 C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
805 C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
806 C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
807 C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
808 C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
809 C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
810 }
811
812 inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
813 size_t vl = vlmax<V>();
814 D Cv00 = set_zero<D>();
815 D Cv01 = set_zero<D>();
816 D Cv02 = set_zero<D>();
817 D Cv03 = set_zero<D>();
818 D Cv10 = set_zero<D>();
819 D Cv11 = set_zero<D>();
820 D Cv12 = set_zero<D>();
821 D Cv13 = set_zero<D>();
822 D Cv20 = set_zero<D>();
823 D Cv21 = set_zero<D>();
824 D Cv22 = set_zero<D>();
825 D Cv23 = set_zero<D>();
826 D Cv30 = set_zero<D>();
827 D Cv31 = set_zero<D>();
828 D Cv32 = set_zero<D>();
829 D Cv33 = set_zero<D>();
830 D Cv40 = set_zero<D>();
831 D Cv41 = set_zero<D>();
832 D Cv42 = set_zero<D>();
833 D Cv43 = set_zero<D>();
834
835 for (int64_t l = 0; l < k; l += vl) {
836 V Bv0 = load<V>(B + ldb * (jj + 0) + l);
837 V Bv1 = load<V>(B + ldb * (jj + 1) + l);
838 V Bv2 = load<V>(B + ldb * (jj + 2) + l);
839 V Bv3 = load<V>(B + ldb * (jj + 3) + l);
840 V Bv4 = load<V>(B + ldb * (jj + 4) + l);
841
842 V Av0 = load<V>(A + lda * (ii + 0) + l);
843 Cv00 = madd(Av0, Bv0, Cv00);
844 Cv10 = madd(Av0, Bv1, Cv10);
845 Cv20 = madd(Av0, Bv2, Cv20);
846 Cv30 = madd(Av0, Bv3, Cv30);
847 Cv40 = madd(Av0, Bv4, Cv40);
848
849 V Av1 = load<V>(A + lda * (ii + 1) + l);
850 Cv01 = madd(Av1, Bv0, Cv01);
851 Cv11 = madd(Av1, Bv1, Cv11);
852 Cv21 = madd(Av1, Bv2, Cv21);
853 Cv31 = madd(Av1, Bv3, Cv31);
854 Cv41 = madd(Av1, Bv4, Cv41);
855
856 V Av2 = load<V>(A + lda * (ii + 2) + l);
857 Cv02 = madd(Av2, Bv0, Cv02);
858 Cv12 = madd(Av2, Bv1, Cv12);
859 Cv22 = madd(Av2, Bv2, Cv22);
860 Cv32 = madd(Av2, Bv3, Cv32);
861 Cv42 = madd(Av2, Bv4, Cv42);
862
863 V Av3 = load<V>(A + lda * (ii + 3) + l);
864 Cv03 = madd(Av3, Bv0, Cv03);
865 Cv13 = madd(Av3, Bv1, Cv13);
866 Cv23 = madd(Av3, Bv2, Cv23);
867 Cv33 = madd(Av3, Bv3, Cv33);
868 Cv43 = madd(Av3, Bv4, Cv43);
869 }
870
871 C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
872 C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
873 C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
874 C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
875 C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
876 C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
877 C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
878 C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
879 C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
880 C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
881 C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
882 C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
883 C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
884 C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
885 C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
886 C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
887 C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
888 C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
889 C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
890 C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
891 }
892
893 inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
894 size_t vl = vlmax<V>();
895 D Cv00 = set_zero<D>();
896 D Cv01 = set_zero<D>();
897 D Cv02 = set_zero<D>();
898 D Cv03 = set_zero<D>();
899 D Cv10 = set_zero<D>();
900 D Cv11 = set_zero<D>();
901 D Cv12 = set_zero<D>();
902 D Cv13 = set_zero<D>();
903 D Cv20 = set_zero<D>();
904 D Cv21 = set_zero<D>();
905 D Cv22 = set_zero<D>();
906 D Cv23 = set_zero<D>();
907 D Cv30 = set_zero<D>();
908 D Cv31 = set_zero<D>();
909 D Cv32 = set_zero<D>();
910 D Cv33 = set_zero<D>();
911
912 for (int64_t l = 0; l < k; l += vl) {
913 V Av0 = load<V>(A + lda * (ii + 0) + l);
914 V Av1 = load<V>(A + lda * (ii + 1) + l);
915 V Av2 = load<V>(A + lda * (ii + 2) + l);
916 V Av3 = load<V>(A + lda * (ii + 3) + l);
917
918 V Bv0 = load<V>(B + ldb * (jj + 0) + l);
919 Cv00 = madd(Av0, Bv0, Cv00);
920 Cv01 = madd(Av1, Bv0, Cv01);
921 Cv02 = madd(Av2, Bv0, Cv02);
922 Cv03 = madd(Av3, Bv0, Cv03);
923
924 V Bv1 = load<V>(B + ldb * (jj + 1) + l);
925 Cv10 = madd(Av0, Bv1, Cv10);
926 Cv11 = madd(Av1, Bv1, Cv11);
927 Cv12 = madd(Av2, Bv1, Cv12);
928 Cv13 = madd(Av3, Bv1, Cv13);
929
930 V Bv2 = load<V>(B + ldb * (jj + 2) + l);
931 Cv20 = madd(Av0, Bv2, Cv20);
932 Cv21 = madd(Av1, Bv2, Cv21);
933 Cv22 = madd(Av2, Bv2, Cv22);
934 Cv23 = madd(Av3, Bv2, Cv23);
935
936 V Bv3 = load<V>(B + ldb * (jj + 3) + l);
937 Cv30 = madd(Av0, Bv3, Cv30);
938 Cv31 = madd(Av1, Bv3, Cv31);
939 Cv32 = madd(Av2, Bv3, Cv32);
940 Cv33 = madd(Av3, Bv3, Cv33);
941 }
942
943 C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
944 C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
945 C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
946 C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
947 C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
948 C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
949 C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
950 C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
951 C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
952 C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
953 C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
954 C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
955 C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
956 C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
957 C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
958 C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
959 }
960
961 inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
962 size_t vl = vlmax<V>();
963 D Cv00 = set_zero<D>();
964 D Cv01 = set_zero<D>();
965 D Cv02 = set_zero<D>();
966 D Cv03 = set_zero<D>();
967 D Cv10 = set_zero<D>();
968 D Cv11 = set_zero<D>();
969 D Cv12 = set_zero<D>();
970 D Cv13 = set_zero<D>();
971 D Cv20 = set_zero<D>();
972 D Cv21 = set_zero<D>();
973 D Cv22 = set_zero<D>();
974 D Cv23 = set_zero<D>();
975
976 for (int64_t l = 0; l < k; l += vl) {
977 V Av0 = load<V>(A + lda * (ii + 0) + l);
978 V Av1 = load<V>(A + lda * (ii + 1) + l);
979 V Av2 = load<V>(A + lda * (ii + 2) + l);
980 V Av3 = load<V>(A + lda * (ii + 3) + l);
981
982 V Bv0 = load<V>(B + ldb * (jj + 0) + l);
983 Cv00 = madd(Av0, Bv0, Cv00);
984 Cv01 = madd(Av1, Bv0, Cv01);
985 Cv02 = madd(Av2, Bv0, Cv02);
986 Cv03 = madd(Av3, Bv0, Cv03);
987
988 V Bv1 = load<V>(B + ldb * (jj + 1) + l);
989 Cv10 = madd(Av0, Bv1, Cv10);
990 Cv11 = madd(Av1, Bv1, Cv11);
991 Cv12 = madd(Av2, Bv1, Cv12);
992 Cv13 = madd(Av3, Bv1, Cv13);
993
994 V Bv2 = load<V>(B + ldb * (jj + 2) + l);
995 Cv20 = madd(Av0, Bv2, Cv20);
996 Cv21 = madd(Av1, Bv2, Cv21);
997 Cv22 = madd(Av2, Bv2, Cv22);
998 Cv23 = madd(Av3, Bv2, Cv23);
999 }
1000
1001 C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1002 C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1003 C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1004 C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1005 C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1006 C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1007 C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1008 C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1009 C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
1010 C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
1011 C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
1012 C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
1013 }
1014
1015 inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
1016 size_t vl = vlmax<V>();
1017 D Cv00 = set_zero<D>();
1018 D Cv01 = set_zero<D>();
1019 D Cv02 = set_zero<D>();
1020 D Cv03 = set_zero<D>();
1021 D Cv10 = set_zero<D>();
1022 D Cv11 = set_zero<D>();
1023 D Cv12 = set_zero<D>();
1024 D Cv13 = set_zero<D>();
1025
1026 for (int64_t l = 0; l < k; l += vl) {
1027 V Av0 = load<V>(A + lda * (ii + 0) + l);
1028 V Av1 = load<V>(A + lda * (ii + 1) + l);
1029 V Av2 = load<V>(A + lda * (ii + 2) + l);
1030 V Av3 = load<V>(A + lda * (ii + 3) + l);
1031
1032 V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1033 Cv00 = madd(Av0, Bv0, Cv00);
1034 Cv01 = madd(Av1, Bv0, Cv01);
1035 Cv02 = madd(Av2, Bv0, Cv02);
1036 Cv03 = madd(Av3, Bv0, Cv03);
1037
1038 V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1039 Cv10 = madd(Av0, Bv1, Cv10);
1040 Cv11 = madd(Av1, Bv1, Cv11);
1041 Cv12 = madd(Av2, Bv1, Cv12);
1042 Cv13 = madd(Av3, Bv1, Cv13);
1043 }
1044
1045 C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1046 C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1047 C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1048 C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1049 C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1050 C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1051 C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1052 C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1053 }
1054
1055 inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
1056 size_t vl = vlmax<V>();
1057 D Cv00 = set_zero<D>();
1058 D Cv01 = set_zero<D>();
1059 D Cv02 = set_zero<D>();
1060 D Cv03 = set_zero<D>();
1061
1062 for (int64_t l = 0; l < k; l += vl) {
1063 V Av0 = load<V>(A + lda * (ii + 0) + l);
1064 V Av1 = load<V>(A + lda * (ii + 1) + l);
1065 V Av2 = load<V>(A + lda * (ii + 2) + l);
1066 V Av3 = load<V>(A + lda * (ii + 3) + l);
1067
1068 V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1069 Cv00 = madd(Av0, Bv0, Cv00);
1070 Cv01 = madd(Av1, Bv0, Cv01);
1071 Cv02 = madd(Av2, Bv0, Cv02);
1072 Cv03 = madd(Av3, Bv0, Cv03);
1073 }
1074
1075 C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1076 C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1077 C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1078 C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1079 }
1080
1081 inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
1082 size_t vl = vlmax<V>();
1083 D Cv00 = set_zero<D>();
1084 D Cv01 = set_zero<D>();
1085 D Cv10 = set_zero<D>();
1086 D Cv11 = set_zero<D>();
1087
1088 for (int64_t l = 0; l < k; l += vl) {
1089 V Av0 = load<V>(A + lda * (ii + 0) + l);
1090 V Av1 = load<V>(A + lda * (ii + 1) + l);
1091
1092 V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1093 Cv00 = madd(Av0, Bv0, Cv00);
1094 Cv01 = madd(Av1, Bv0, Cv01);
1095
1096 V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1097 Cv10 = madd(Av0, Bv1, Cv10);
1098 Cv11 = madd(Av1, Bv1, Cv11);
1099 }
1100
1101 C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1102 C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1103 C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1104 C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1105 }
1106
1107 inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
1108 size_t vl = vlmax<V>();
1109 D Cv00 = set_zero<D>();
1110 D Cv01 = set_zero<D>();
1111
1112 for (int64_t l = 0; l < k; l += vl) {
1113 V Av0 = load<V>(A + lda * (ii + 0) + l);
1114 V Av1 = load<V>(A + lda * (ii + 1) + l);
1115
1116 V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1117 Cv00 = madd(Av0, Bv0, Cv00);
1118 Cv01 = madd(Av1, Bv0, Cv01);
1119 }
1120
1121 C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1122 C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1123 }
1124
1125 template <int RM, int RN>
1126 inline void gemm_bloc(int64_t ii, int64_t jj) {
1127 if constexpr (RM == 4) {
1128 if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
1129 if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
1130 if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
1131 if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
1132 if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
1133 if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
1134 } else if constexpr (RM == 2) {
1135 if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
1136 if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
1137 }
1138 }
1139
1140 template <int RM, int RN, int BM>
1141 NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
1142 GGML_ASSERT(m % (RM * BM) == 0);
1143 const int64_t ytiles = m / (RM * BM);
1144 const int64_t xtiles = (n + RN -1) / RN;
1145 const int64_t jj_RN = (xtiles - (xtiles * RN - n));
1146
1147 // "round" bloc_size to "nearest" BN
1148 const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
1149 const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
1150 const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
1151 const int64_t nb_job = ytiles * NB_BN;
1152
1153 if (params->ith == 0) {
1154 GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
1155 // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1156 ggml_threadpool_chunk_set(params->threadpool, params->nth);
1157 }
1158
1159 ggml_barrier(params->threadpool);
1160
1161 int64_t job = params->ith;
1162 while (job < nb_job) {
1163 const int64_t ii = (job % ytiles) * RM * BM;
1164 const int64_t jb = job / ytiles;
1165 const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
1166 const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
1167
1168 const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
1169 const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
1170 const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
1171
1172 for (int64_t bi = 0; bi < BM * RM; bi += RM) {
1173 int64_t jj = jj0;
1174 for (; jj < jj1; jj += RN) {
1175 gemm_bloc<RM, RN>(ii + bi, jj);
1176 }
1177 if constexpr (RN > 1) {
1178 for (; jj < jj2; jj += RN - 1) {
1179 gemm_bloc<RM, RN-1>(ii + bi, jj);
1180 }
1181 }
1182 GGML_ASSERT(jj == jj2);
1183 }
1184
1185 job = ggml_threadpool_chunk_add(params->threadpool, 1);
1186 }
1187
1188 ggml_barrier(params->threadpool);
1189 return;
1190 }
1191
1192 const ggml_compute_params * params;
1193 const TA *const A;
1194 const TB *const B;
1195 TC *const C;
1196 const int64_t k;
1197 const int64_t lda;
1198 const int64_t ldb;
1199 const int64_t ldc;
1200};
1201#endif
1202
1203//////////////////////////////////////////////////////////////////////////////////////////
1204// QUANT ZERO MATRIX MULTIPLICATION
1205
1206#if defined(__ARM_FEATURE_DOTPROD)
1207template <typename TA>
1208class tinyBLAS_Q0_ARM {
1209 public:
1210 tinyBLAS_Q0_ARM(int64_t k,
1211 const TA *A, int64_t lda,
1212 const block_q8_0 *B, int64_t ldb,
1213 float *C, int64_t ldc,
1214 int ith, int nth)
1215 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1216 }
1217
1218 void matmul(int64_t m, int64_t n) {
1219 mnpack(0, m, 0, n);
1220 }
1221
1222 private:
1223 NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1224 int64_t mc, nc, mp, np;
1225 switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
1226 case 0x33:
1227 mc = 3;
1228 nc = 3;
1229 gemm<3, 3>(m0, m, n0, n);
1230 break;
1231 case 0x32:
1232 mc = 3;
1233 nc = 2;
1234 gemm<3, 2>(m0, m, n0, n);
1235 break;
1236 case 0x23:
1237 mc = 2;
1238 nc = 3;
1239 gemm<2, 3>(m0, m, n0, n);
1240 break;
1241 case 0x22:
1242 mc = 2;
1243 nc = 2;
1244 gemm<2, 2>(m0, m, n0, n);
1245 break;
1246 case 0x31:
1247 mc = 3;
1248 nc = 1;
1249 gemm<3, 1>(m0, m, n0, n);
1250 break;
1251 case 0x13:
1252 mc = 1;
1253 nc = 3;
1254 gemm<1, 3>(m0, m, n0, n);
1255 break;
1256 case 0x21:
1257 mc = 2;
1258 nc = 1;
1259 gemm<2, 1>(m0, m, n0, n);
1260 break;
1261 case 0x12:
1262 mc = 1;
1263 nc = 2;
1264 gemm<1, 2>(m0, m, n0, n);
1265 break;
1266 case 0x11:
1267 mc = 1;
1268 nc = 1;
1269 gemm<1, 1>(m0, m, n0, n);
1270 break;
1271 default:
1272 return;
1273 }
1274 mp = m0 + (m - m0) / mc * mc;
1275 np = n0 + (n - n0) / nc * nc;
1276 mnpack(mp, m, n0, np);
1277 mnpack(m0, m, np, n);
1278 }
1279
1280 template <int RM, int RN>
1281 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1282 int64_t ytiles = (m - m0) / RM;
1283 int64_t xtiles = (n - n0) / RN;
1284 int64_t tiles = xtiles * ytiles;
1285 int64_t duty = (tiles + nth - 1) / nth;
1286 int64_t start = duty * ith;
1287 int64_t end = start + duty;
1288 if (end > tiles)
1289 end = tiles;
1290 for (int64_t job = start; job < end; ++job) {
1291 int64_t ii = m0 + job / xtiles * RM;
1292 int64_t jj = n0 + job % xtiles * RN;
1293 float32x4_t Cv[RN][RM] = {};
1294 for (int64_t l = 0; l < k; ++l)
1295 for (int64_t j = 0; j < RN; ++j)
1296 for (int64_t i = 0; i < RM; ++i)
1297 Cv[j][i] = vmlaq_n_f32(Cv[j][i],
1298 vcvtq_f32_s32(vdotq_s32(
1299 vdotq_s32(vdupq_n_s32(0),
1300 load_lo(A + lda * (ii + i) + l),
1301 load_lo(B + ldb * (jj + j) + l)),
1302 load_hi(A + lda * (ii + i) + l),
1303 load_hi(B + ldb * (jj + j) + l))),
1304 unhalf(A[lda * (ii + i) + l].d) *
1305 unhalf(B[ldb * (jj + j) + l].d));
1306 for (int64_t j = 0; j < RN; ++j)
1307 for (int64_t i = 0; i < RM; ++i)
1308 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1309 }
1310 }
1311
1312 inline int8x16_t load_lo(const block_q8_0 *b) {
1313 return vld1q_s8(b->qs);
1314 }
1315
1316 inline int8x16_t load_hi(const block_q8_0 *b) {
1317 return vld1q_s8(b->qs + 16);
1318 }
1319
1320 inline int8x16_t load_lo(const block_q4_0 *b) {
1321 return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
1322 vdupq_n_u8(0x0f))),
1323 vdupq_n_s8(0x8));
1324 }
1325
1326 inline int8x16_t load_hi(const block_q4_0 *b) {
1327 return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
1328 vdupq_n_s8(0x8));
1329 }
1330
1331 const TA *const A;
1332 const block_q8_0 *const B;
1333 float *const C;
1334 const int64_t k;
1335 const int64_t lda;
1336 const int64_t ldb;
1337 const int64_t ldc;
1338 const int ith;
1339 const int nth;
1340};
1341#endif // __ARM_FEATURE_DOTPROD
1342
1343#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1344template <typename TA, typename TB, typename TC>
1345class tinyBLAS_Q0_AVX {
1346 public:
1347 tinyBLAS_Q0_AVX(int64_t k,
1348 const TA *A, int64_t lda,
1349 const TB *B, int64_t ldb,
1350 TC *C, int64_t ldc,
1351 int ith, int nth)
1352 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1353 const int8_t kvalues_iq4nl[16] = {
1354 -127, -104, -83, -65,
1355 -49, -35, -22, -10,
1356 1, 13, 25, 38,
1357 53, 69, 89, 113
1358 };
1359
1360 iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
1361 }
1362
1363 void matmul(int64_t m, int64_t n) {
1364 mnpack(0, m, 0, n);
1365 }
1366
1367 private:
1368 void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1369 int64_t mc, nc, mp, np;
1370 switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
1371#if VECTOR_REGISTERS == 32
1372 case 0x44:
1373 mc = 4;
1374 nc = 4;
1375#if defined(__AVX2__) && defined(__F16C__)
1376 gemm4xN<4>(m0, m, n0, n);
1377#else
1378 gemm<4, 4>(m0, m, n0, n);
1379#endif
1380 break;
1381 case 0x43:
1382 mc = 4;
1383 nc = 3;
1384#if defined(__AVX2__) && defined(__F16C__)
1385 gemm4xN<3>(m0, m, n0, n);
1386#else
1387 gemm<4, 3>(m0, m, n0, n);
1388#endif
1389 break;
1390 case 0x34:
1391 mc = 3;
1392 nc = 4;
1393#if defined(__AVX2__) && defined(__F16C__)
1394 gemmMx4<3>(m0, m, n0, n);
1395#else
1396 gemm<3, 4>(m0, m, n0, n);
1397#endif
1398 break;
1399 case 0x33:
1400 mc = 3;
1401 nc = 3;
1402 gemm<3, 3>(m0, m, n0, n);
1403 break;
1404 case 0x42:
1405 mc = 4;
1406 nc = 2;
1407#if defined(__AVX2__) && defined(__F16C__)
1408 gemm4xN<2>(m0, m, n0, n);
1409#else
1410 gemm<4, 2>(m0, m, n0, n);
1411#endif
1412 break;
1413 case 0x24:
1414 mc = 2;
1415 nc = 4;
1416#if defined(__AVX2__) && defined(__F16C__)
1417 gemmMx4<2>(m0, m, n0, n);
1418#else
1419 gemm<2, 4>(m0, m, n0, n);
1420#endif
1421 break;
1422#else
1423 case 0x44:
1424 case 0x43:
1425 case 0x42:
1426 mc = 4;
1427 nc = 2;
1428#if defined(__AVX2__) && defined(__F16C__)
1429 gemm4xN<2>(m0, m, n0, n);
1430#else
1431 gemm<4, 2>(m0, m, n0, n);
1432#endif
1433 break;
1434 case 0x34:
1435 case 0x24:
1436 mc = 2;
1437 nc = 4;
1438#if defined(__AVX2__) && defined(__F16C__)
1439 gemmMx4<2>(m0, m, n0, n);
1440#else
1441 gemm<2, 4>(m0, m, n0, n);
1442#endif
1443 break;
1444 case 0x33:
1445#endif
1446 case 0x32:
1447 mc = 3;
1448 nc = 2;
1449 gemm<3, 2>(m0, m, n0, n);
1450 break;
1451 case 0x23:
1452 mc = 2;
1453 nc = 3;
1454 gemm<2, 3>(m0, m, n0, n);
1455 break;
1456 case 0x41:
1457 mc = 4;
1458 nc = 1;
1459#if defined(__AVX2__) && defined(__F16C__)
1460 gemm4xN<1>(m0, m, n0, n);
1461#else
1462 gemm<4, 1>(m0, m, n0, n);
1463#endif
1464 break;
1465 case 0x22:
1466 mc = 2;
1467 nc = 2;
1468 gemm<2, 2>(m0, m, n0, n);
1469 break;
1470 case 0x14:
1471 mc = 1;
1472 nc = 4;
1473#if defined(__AVX2__) && defined(__F16C__)
1474 gemmMx4<1>(m0, m, n0, n);
1475#else
1476 gemm<1, 4>(m0, m, n0, n);
1477#endif
1478 break;
1479 case 0x31:
1480 mc = 3;
1481 nc = 1;
1482 gemm<3, 1>(m0, m, n0, n);
1483 break;
1484 case 0x13:
1485 mc = 1;
1486 nc = 3;
1487 gemm<1, 3>(m0, m, n0, n);
1488 break;
1489 case 0x21:
1490 mc = 2;
1491 nc = 1;
1492 gemm<2, 1>(m0, m, n0, n);
1493 break;
1494 case 0x12:
1495 mc = 1;
1496 nc = 2;
1497 gemm<1, 2>(m0, m, n0, n);
1498 break;
1499 case 0x11:
1500 mc = 1;
1501 nc = 1;
1502 gemm<1, 1>(m0, m, n0, n);
1503 break;
1504 default:
1505 return;
1506 }
1507 mp = m0 + (m - m0) / mc * mc;
1508 np = n0 + (n - n0) / nc * nc;
1509 mnpack(mp, m, n0, np);
1510 mnpack(m0, m, np, n);
1511 }
1512
1513#if defined(__AVX2__) && defined(__F16C__)
1514// Templated functions for gemm of dimensions 4xN
1515 template <int RN>
1516 NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1517 int64_t ytiles = (m - m0) / 4;
1518 int64_t xtiles = (n - n0) / RN;
1519 int64_t tiles = xtiles * ytiles;
1520 int64_t duty = (tiles + nth - 1) / nth;
1521 int64_t start = duty * ith;
1522 int64_t end = start + duty;
1523 if (end > tiles)
1524 end = tiles;
1525 for (int64_t job = start; job < end; ++job) {
1526 int64_t ii = m0 + job / xtiles * 4;
1527 int64_t jj = n0 + job % xtiles * RN;
1528 __m256 Cv[RN][4] = {};
1529 for (int64_t l = 0; l < k; ++l) {
1530 uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
1531 // Convert delta values for four blocks to float values
1532 __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
1533 __m256i avec0 = load(A + lda * (ii + 0) + l);
1534 __m256i avec1 = load(A + lda * (ii + 1) + l);
1535 __m256i avec2 = load(A + lda * (ii + 2) + l);
1536 __m256i avec3 = load(A + lda * (ii + 3) + l);
1537 for (int64_t j = 0; j < RN; ++j) {
1538 __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
1539 // Computation of product of delta values for four blocks and replicate it across 256 bit lane
1540 __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
1541 dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
1542 // Computation of dot product and multiplication with appropriate delta value products
1543 Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
1544 updot(_mm256_sign_epi8(avec0, avec0),
1545 _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
1546 Cv[j][0]);
1547 Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
1548 updot(_mm256_sign_epi8(avec1, avec1),
1549 _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
1550 Cv[j][1]);
1551 Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
1552 updot(_mm256_sign_epi8(avec2, avec2),
1553 _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
1554 Cv[j][2]);
1555 Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
1556 updot(_mm256_sign_epi8(avec3, avec3),
1557 _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
1558 Cv[j][3]);
1559 }
1560 }
1561
1562 for (int64_t j = 0; j < RN; ++j)
1563 for (int64_t i = 0; i < 4; ++i)
1564 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1565 }
1566 }
1567
1568 // Templated functions for gemm of dimensions Mx4
1569 template <int RM>
1570 NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1571 int64_t ytiles = (m - m0) / RM;
1572 int64_t xtiles = (n - n0) / 4;
1573 int64_t tiles = xtiles * ytiles;
1574 int64_t duty = (tiles + nth - 1) / nth;
1575 int64_t start = duty * ith;
1576 int64_t end = start + duty;
1577 if (end > tiles)
1578 end = tiles;
1579 for (int64_t job = start; job < end; ++job) {
1580 int64_t ii = m0 + job / xtiles * RM;
1581 int64_t jj = n0 + job % xtiles * 4;
1582 __m256 Cv[4][RM] = {};
1583 for (int64_t l = 0; l < k; ++l) {
1584 uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
1585 // Convert delta values for four blocks to float values
1586 __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
1587 __m256i bvec0 = load(B + ldb * (jj + 0) + l);
1588 __m256i bvec1 = load(B + ldb * (jj + 1) + l);
1589 __m256i bvec2 = load(B + ldb * (jj + 2) + l);
1590 __m256i bvec3 = load(B + ldb * (jj + 3) + l);
1591 for (int64_t i = 0; i < RM; ++i) {
1592 __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
1593 // Computation of product of delta values for four blocks and replicate it across 256 bit lane
1594 __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
1595 dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
1596 // Computation of dot product and multiplication with appropriate delta value products
1597 Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
1598 updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1599 load(A + lda * (ii + i) + l)),
1600 _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
1601 Cv[0][i]);
1602 Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
1603 updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1604 load(A + lda * (ii + i) + l)),
1605 _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
1606 Cv[1][i]);
1607 Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
1608 updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1609 load(A + lda * (ii + i) + l)),
1610 _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
1611 Cv[2][i]);
1612 Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
1613 updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1614 load(A + lda * (ii + i) + l)),
1615 _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
1616 Cv[3][i]);
1617 }
1618 }
1619 for (int64_t j = 0; j < 4; ++j)
1620 for (int64_t i = 0; i < RM; ++i)
1621 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1622 }
1623 }
1624#endif
1625
1626 template <int RM, int RN>
1627 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1628 int64_t ytiles = (m - m0) / RM;
1629 int64_t xtiles = (n - n0) / RN;
1630 int64_t tiles = xtiles * ytiles;
1631 int64_t duty = (tiles + nth - 1) / nth;
1632 int64_t start = duty * ith;
1633 int64_t end = start + duty;
1634 if (end > tiles)
1635 end = tiles;
1636 for (int64_t job = start; job < end; ++job) {
1637 int64_t ii = m0 + job / xtiles * RM;
1638 int64_t jj = n0 + job % xtiles * RN;
1639 __m256 Cv[RN][RM] = {};
1640 for (int64_t l = 0; l < k; ++l)
1641 for (int64_t j = 0; j < RN; ++j)
1642 for (int64_t i = 0; i < RM; ++i) {
1643#if defined(__AVX2__)
1644 __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1645 load(A + lda * (ii + i) + l)),
1646 _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
1647 load(A + lda * (ii + i) + l)));
1648#else
1649 __m128i ali0 = load0(A + lda * (ii + i) + l);
1650 __m128i ali1 = load1(A + lda * (ii + i) + l);
1651 __m128i blj0 = load0(B + ldb * (jj + j) + l);
1652 __m128i blj1 = load1(B + ldb * (jj + j) + l);
1653
1654 __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
1655 __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
1656 __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
1657 __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
1658
1659 // updot
1660 const __m128i oneFill = _mm_set1_epi16(1);
1661 __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
1662 __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
1663 __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
1664#endif
1665 Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
1666 unhalf(B[ldb * (jj + j) + l].d)),
1667 udTmp,
1668 Cv[j][i]);
1669 }
1670 for (int64_t j = 0; j < RN; ++j)
1671 for (int64_t i = 0; i < RM; ++i)
1672 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1673 }
1674 }
1675
1676 inline __m256i load(const block_q8_0 *b) {
1677 return _mm256_loadu_si256((const __m256i *)b->qs);
1678 }
1679
1680 inline __m128i load0(const block_q8_0 *b) {
1681 return _mm_loadu_si128((const __m128i *)b->qs);
1682 }
1683
1684 inline __m128i load1(const block_q8_0 *b) {
1685 return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
1686 }
1687
1688 inline __m256i load(const block_q4_0 *b) {
1689 return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
1690 }
1691
1692 inline __m128i load0(const block_q4_0 *b) {
1693 const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1694 return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
1695 }
1696
1697 inline __m128i load1(const block_q4_0 *b) {
1698 const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1699 return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
1700 }
1701
1702 inline __m256i load(const block_q5_0 *b) {
1703 return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
1704 }
1705
1706 inline __m128i load0(const block_q5_0* b) {
1707 const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1708 uint32_t x32;
1709 memcpy(&x32, b->qh, sizeof(uint32_t));
1710 __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
1711 __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
1712 _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
1713 _mm_shuffle_epi8(_mm_set1_epi32(x32),
1714 _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
1715 bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
1716 return _mm_or_si128(qxl, bytesl);
1717 }
1718
1719 inline __m128i load1(const block_q5_0* b) {
1720 const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1721 uint32_t x32;
1722 memcpy(&x32, b->qh, sizeof(uint32_t));
1723 __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
1724 __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
1725 _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
1726 _mm_shuffle_epi8(_mm_set1_epi32(x32),
1727 _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
1728 bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
1729 return _mm_or_si128(qxh, bytesh);
1730 }
1731
1732 inline __m256i load(const block_iq4_nl *b) {
1733 return MM256_SET_M128I(load1(b), load0(b));
1734 }
1735
1736 inline __m128i load0(const block_iq4_nl *b) {
1737 const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1738 return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
1739 }
1740
1741 inline __m128i load1(const block_iq4_nl *b) {
1742 const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1743 return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
1744 }
1745
1746 inline __m256 updot(__m256i u, __m256i s) {
1747 __m256i res;
1748#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
1749 res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
1750#elif defined(__AVXVNNI__)
1751 res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
1752#else
1753 res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
1754#endif
1755 return _mm256_cvtepi32_ps(res);
1756 }
1757
1758 static inline __m256i denibble(const uint8_t *p) {
1759 __m128i x = _mm_loadu_si128((const __m128i *)p);
1760 return _mm256_and_si256(_mm256_set1_epi8(15),
1761 _mm256_insertf128_si256(_mm256_castsi128_si256(x),
1762 _mm_srli_epi16(x, 4), 1));
1763 }
1764
1765 static inline __m256i bittobyte(const uint8_t *p) {
1766 uint32_t x32;
1767 memcpy(&x32, p, sizeof(uint32_t));
1768 __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
1769 _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
1770 _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
1771 _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
1772 0x0101010101010101, 0x0000000000000000))));
1773 return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
1774 }
1775
1776 const TA *const A;
1777 const TB *const B;
1778 TC *const C;
1779 const int64_t k;
1780 const int64_t lda;
1781 const int64_t ldb;
1782 const int64_t ldc;
1783 const int ith;
1784 const int nth;
1785 __m128i iq4nlt;
1786};
1787#endif // __AVX__
1788
1789//PPC Implementation
1790#if defined(__MMA__)
1791
1792#define SAVE_ACC(ACC, ii, jj) \
1793 __builtin_mma_disassemble_acc(vec_C, ACC); \
1794 for (int I = 0; I < 4; I++) { \
1795 for (int J = 0; J < 4; J++) { \
1796 *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1797 } \
1798 } \
1799
1800template<typename T>
1801struct mma_instr;
1802
1803template<>
1804struct mma_instr<ggml_bf16_t> {
1805 static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1806 __builtin_mma_xvbf16ger2pp(acc, a, b);
1807 }
1808};
1809
1810template<>
1811struct mma_instr<ggml_fp16_t> {
1812 static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1813 __builtin_mma_xvf16ger2pp(acc, a, b);
1814 }
1815};
1816
1817template <typename TA, typename TB, typename TC>
1818class tinyBLAS_HP16_PPC {
1819 public:
1820 tinyBLAS_HP16_PPC(int64_t k,
1821 const TA *A, int64_t lda,
1822 const TB *B, int64_t ldb,
1823 TC *C, int64_t ldc,
1824 int ith, int nth)
1825 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1826 }
1827
1828 void matmul(int64_t m, int64_t n) {
1829 mnpack(0, m, 0, n);
1830 }
1831
1832 private:
1833 void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1834 vec_t t[8], s[8];
1835 vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1836 vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1837 vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1838 vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1839
1840 if (numVec == 2) {
1841 t[0] = vec_perm(c[0], c[1], swiz1);
1842 t[1] = vec_perm(c[2], c[3], swiz1);
1843 s[0] = vec_perm(t[0], t[1], swiz3);
1844 s[1] = vec_perm(t[0], t[1], swiz4);
1845 vec_xst(s[0], 0, (vec_t*)vecOffset);
1846 vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1847 } else if (numVec == 4) {
1848 t[0] = vec_perm(c[0], c[1], swiz1);
1849 t[1] = vec_perm(c[0], c[1], swiz2);
1850 t[2] = vec_perm(c[2], c[3], swiz1);
1851 t[3] = vec_perm(c[2], c[3], swiz2);
1852 s[0] = vec_perm(t[0], t[2], swiz3);
1853 s[1] = vec_perm(t[0], t[2], swiz4);
1854 s[2] = vec_perm(t[1], t[3], swiz3);
1855 s[3] = vec_perm(t[1], t[3], swiz4);
1856 for (int i = 0; i < 4; ++i)
1857 vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1858 } else if (numVec == 8) {
1859 for (int i = 0; i < 4; i += 2) {
1860 t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1861 t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1862 }
1863 for (int i = 4; i < 8; i += 2) {
1864 t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1865 t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1866 }
1867 s[0] = vec_perm(t[0], t[2], swiz3);
1868 s[1] = vec_perm(t[0], t[2], swiz4);
1869 s[2] = vec_perm(t[1], t[3], swiz3);
1870 s[3] = vec_perm(t[1], t[3], swiz4);
1871 s[4] = vec_perm(t[4], t[6], swiz3);
1872 s[5] = vec_perm(t[4], t[6], swiz4);
1873 s[6] = vec_perm(t[5], t[7], swiz3);
1874 s[7] = vec_perm(t[5], t[7], swiz4);
1875 for (int i = 0; i < 8; ++i)
1876 vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1877 }
1878 }
1879
1880 void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1881 int64_t i, j;
1882 TA *aoffset = NULL;
1883 unsigned char *vecOffset = NULL;
1884 TA * aoffsets[8];
1885 vector unsigned char c_arr[8];
1886 aoffset = const_cast<TA*>(a);
1887 vecOffset = vec;
1888 j = (rows >> 3);
1889 if (j > 0) {
1890 do {
1891 if (cols == 4) {
1892 aoffsets[0] = aoffset;
1893 for (int it = 1; it < 4; ++it)
1894 aoffsets[it] = aoffsets[it-1] + lda;
1895 aoffset += 4 * lda;
1896 for (int i = 0; i < 4; ++i)
1897 c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1898 vector_permute_store(c_arr, 4, vecOffset);
1899 for (int i = 0; i<4; i++)
1900 aoffsets[i] = aoffsets[i]+lda;
1901 vecOffset +=64;
1902 }
1903 i = (cols >> 3);
1904 if (i > 0) {
1905 aoffsets[0] = aoffset;
1906 for (int it = 1; it < 8; ++it) {
1907 aoffsets[it] = aoffsets[it-1] + lda;
1908 }
1909 aoffset += 8 * lda;
1910 do {
1911 for (int it = 0; it < 8; ++it)
1912 c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1913 vector_permute_store(c_arr, 8, vecOffset);
1914 for (int it = 0; it < 8; ++it)
1915 aoffsets[it] = aoffsets[it] + 8*lda;
1916 vecOffset += 128;
1917 i--;
1918 } while(i > 0);
1919 }
1920 j--;
1921 } while(j > 0);
1922 }
1923 if (rows & 4) {
1924 aoffsets[0] = aoffset;
1925 for (int it = 1; it < 4; ++it)
1926 aoffsets[it] = aoffsets[it-1] + lda;
1927 aoffset += 4 * lda;
1928 if (cols == 4) {
1929 for (int it = 0; it < 4; ++it)
1930 c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1931 vector_permute_store(c_arr, 2, vecOffset);
1932 for (int it = 0; it< 4; it++)
1933 aoffsets[it] = aoffsets[it] + lda;
1934 vecOffset += 32;
1935 }
1936 i = (cols >> 3);
1937 if (i > 0) {
1938 do {
1939 for (int it = 0; it < 4; ++it)
1940 c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1941 vector_permute_store(c_arr, 4, vecOffset);
1942 for (int it = 0; it< 4; it++)
1943 aoffsets[it] = aoffsets[it] + 8*lda;
1944 vecOffset += 64;
1945 i--;
1946 } while(i > 0);
1947 }
1948 }
1949 if (rows & 3) {
1950 aoffsets[0] = aoffset;
1951 for (int it = 1; it < 4; ++it)
1952 aoffsets[it] = aoffsets[it-1] + lda;
1953 if (cols == 4) {
1954 switch(rows) {
1955 case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1956 case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1957 case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1958 break;
1959 }
1960 vector_permute_store(c_arr, 2, vecOffset);
1961 for (int it = 0; it< 4; it++)
1962 aoffsets[it] = aoffsets[it] + lda;
1963 vecOffset += 32;
1964 }
1965 i = (cols >> 3);
1966 if (i > 0) {
1967 do {
1968 switch(rows) {
1969 case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1970 case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1971 case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1972 break;
1973 }
1974 vector_permute_store(c_arr, 4, vecOffset);
1975 for (int it = 0; it <4; it++)
1976 aoffsets[it] = aoffsets[it] + 8* lda;
1977 vecOffset += 64;
1978 i--;
1979 } while(i > 0);
1980 }
1981 }
1982 }
1983
1984 void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1985 int64_t mc, nc, mp, np;
1986 int m_rem = MIN(m - m0, 8);
1987 int n_rem = MIN(n - n0, 8);
1988
1989 if (m_rem >= 8 && n_rem >= 8) {
1990 mc = 8;
1991 nc = 8;
1992 gemm<8,8>(m0, m, n0, n);
1993 } else if (m_rem >= 4 && n_rem >= 8) {
1994 mc = 4;
1995 nc = 8;
1996 gemm<4,8>(m0, m, n0, n);
1997 } else if (m_rem >=8 && n_rem >=4){
1998 mc = 8;
1999 nc = 4;
2000 gemm<8,4>(m0, m, n0, n);
2001 } else if ((m_rem < 4) && (n_rem >= 8)) {
2002 nc = 8;
2003 switch(m_rem) {
2004 case 1:
2005 mc = 1;
2006 gemm_Mx8<1>(m0, m, n0, n);
2007 break;
2008 case 2:
2009 mc = 2;
2010 gemm_Mx8<2>(m0, m, n0, n);
2011 break;
2012 case 3:
2013 mc = 3;
2014 gemm_Mx8<3>(m0, m, n0, n);
2015 break;
2016 default:
2017 return;
2018 }
2019 } else if (m_rem >= 4 && n_rem >= 4) {
2020 mc = 4;
2021 nc = 4;
2022 gemm_small<4, 4>(m0, m, n0, n);
2023 } else if ((m_rem > 4) && (n_rem < 4)) {
2024 mc = 4;
2025 switch(n_rem) {
2026 case 1:
2027 nc = 1;
2028 gemm_small<4, 1>(m0, m, n0, n);
2029 break;
2030 case 2:
2031 nc = 2;
2032 gemm_small<4, 2>(m0, m, n0, n);
2033 break;
2034 case 3:
2035 nc = 3;
2036 gemm_small<4, 3>(m0, m, n0, n);
2037 break;
2038
2039 default:
2040 return;
2041 }
2042 } else {
2043 switch((m_rem << 4) | n_rem) {
2044 case 0x43:
2045 mc = 4;
2046 nc = 3;
2047 gemm_small<4, 3>(m0, m, n0, n);
2048 break;
2049 case 0x42:
2050 mc = 4;
2051 nc = 2;
2052 gemm_small<4, 2>(m0, m, n0, n);
2053 break;
2054 case 0x41:
2055 mc = 4;
2056 nc = 1;
2057 gemm_small<4, 1>(m0, m, n0, n);
2058 break;
2059 case 0x34:
2060 mc = 3;
2061 nc = 4;
2062 gemm_small<3, 4>(m0, m, n0, n);
2063 break;
2064 case 0x33:
2065 mc = 3;
2066 nc = 3;
2067 gemm_small<3, 3>(m0, m, n0, n);
2068 break;
2069 case 0x32:
2070 mc = 3;
2071 nc = 2;
2072 gemm_small<3, 2>(m0, m, n0, n);
2073 break;
2074 case 0x31:
2075 mc = 3;
2076 nc = 1;
2077 gemm_small<3, 1>(m0, m, n0, n);
2078 break;
2079 case 0x24:
2080 mc = 2;
2081 nc = 4;
2082 gemm_small<2,4>(m0, m, n0, n);
2083 break;
2084 case 0x23:
2085 mc = 2;
2086 nc = 3;
2087 gemm_small<2, 3>(m0, m, n0, n);
2088 break;
2089 case 0x22:
2090 mc = 2;
2091 nc = 2;
2092 gemm_small<2, 2>(m0, m, n0, n);
2093 break;
2094 case 0x21:
2095 mc = 2;
2096 nc = 1;
2097 gemm_small<2, 1>(m0, m, n0, n);
2098 break;
2099 case 0x14:
2100 mc = 1;
2101 nc = 4;
2102 gemm_small<1, 4>(m0, m, n0, n);
2103 break;
2104 case 0x13:
2105 mc = 1;
2106 nc = 3;
2107 gemm_small<1, 3>(m0, m, n0, n);
2108 break;
2109 case 0x12:
2110 mc = 1;
2111 nc = 2;
2112 gemm_small<1, 2>(m0, m, n0, n);
2113 break;
2114 case 0x11:
2115 mc = 1;
2116 nc = 1;
2117 gemm_small<1, 1>(m0, m, n0, n);
2118 break;
2119 default:
2120 return;
2121 }
2122 }
2123 mp = m0 + (m - m0) / mc * mc;
2124 np = n0 + (n - n0) / nc * nc;
2125 mnpack(mp, m, n0, np);
2126 mnpack(m0, m, np, n);
2127 }
2128
2129 void KERNEL_4x8(int64_t ii, int64_t jj) {
2130 vec_t vec_A[4], vec_B[8] , vec_C[4];
2131 acc_t acc_0, acc_1;
2132 __builtin_mma_xxsetaccz(&acc_0);
2133 __builtin_mma_xxsetaccz(&acc_1);
2134 for (int l = 0; l < k; l+=8) {
2135 packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
2136 packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
2137 for (int x = 0; x < 4; x++) {
2138 mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2139 mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2140 }
2141 }
2142 SAVE_ACC(&acc_0, ii, jj);
2143 SAVE_ACC(&acc_1, ii, jj+4);
2144 }
2145
2146 void KERNEL_8x4(int64_t ii, int64_t jj) {
2147 vec_t vec_A[8], vec_B[4] , vec_C[4];
2148 acc_t acc_0, acc_1;
2149 __builtin_mma_xxsetaccz(&acc_0);
2150 __builtin_mma_xxsetaccz(&acc_1);
2151 for (int l = 0; l < k; l+=8) {
2152 packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
2153 packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
2154 for (int x = 0; x < 4; x++) {
2155 mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2156 mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2157 }
2158 }
2159 SAVE_ACC(&acc_0, ii, jj);
2160 SAVE_ACC(&acc_1, ii+4, jj);
2161 }
2162
2163
2164 void KERNEL_8x8(int64_t ii, int64_t jj) {
2165 vec_t vec_A[8], vec_B[8], vec_C[4];
2166 acc_t acc_0, acc_1, acc_2, acc_3;
2167 __builtin_mma_xxsetaccz(&acc_0);
2168 __builtin_mma_xxsetaccz(&acc_1);
2169 __builtin_mma_xxsetaccz(&acc_2);
2170 __builtin_mma_xxsetaccz(&acc_3);
2171 for (int l = 0; l < k; l+=8) {
2172 packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
2173 packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
2174 for (int x = 0; x < 4; x++) {
2175 mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2176 mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2177 mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
2178 mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
2179 }
2180 }
2181
2182 SAVE_ACC(&acc_0, ii, jj);
2183 SAVE_ACC(&acc_1, ii, jj+4);
2184 SAVE_ACC(&acc_2, ii+4, jj);
2185 SAVE_ACC(&acc_3, ii+4, jj+4);
2186 }
2187
2188 template<int RM, int RN>
2189 void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2190 int64_t ytiles = (m - m0) / RM;
2191 int64_t xtiles = (n - n0) / RN;
2192 int64_t tiles = xtiles * ytiles;
2193 int64_t duty = (tiles + nth - 1) / nth;
2194 int64_t start = duty * ith;
2195 int64_t end = start + duty;
2196 if (end > tiles)
2197 end = tiles;
2198 for (int64_t job = start; job < end; ++job) {
2199 int64_t ii = m0 + job / xtiles * RM;
2200 int64_t jj = n0 + job % xtiles * RN;
2201 vec_t vec_C[4];
2202 acc_t acc_0;
2203 __builtin_mma_xxsetaccz(&acc_0);
2204 vec_t vec_A[2], vec_B[2];
2205 for (int l=0; l<k; l+=4) {
2206 packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
2207 packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
2208 for (int x = 0; x<2; x++) {
2209 mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2210 }
2211 }
2212 __builtin_mma_disassemble_acc(vec_C, &acc_0);
2213 for (int I = 0; I < RM; I++) {
2214 for (int J = 0; J < RN; J++) {
2215 *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2216 }
2217 }
2218 }
2219 }
2220
2221 template<int RM>
2222 void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2223 int RN = 8;
2224 int64_t ytiles = (m - m0) / RM;
2225 int64_t xtiles = (n - n0) / RN;
2226 int64_t tiles = xtiles * ytiles;
2227 int64_t duty = (tiles + nth - 1) / nth;
2228 int64_t start = duty * ith;
2229 int64_t end = start + duty;
2230 if (end > tiles)
2231 end = tiles;
2232 for (int64_t job = start; job < end; ++job) {
2233 int64_t ii = m0 + job / xtiles * RM;
2234 int64_t jj = n0 + job % xtiles * RN;
2235 vec_t vec_C[4];
2236 acc_t acc_0, acc_1;
2237 __builtin_mma_xxsetaccz(&acc_0);
2238 __builtin_mma_xxsetaccz(&acc_1);
2239 vec_t vec_A[4], vec_B[8];
2240 for (int l=0; l<k; l+=8) {
2241 packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
2242 packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
2243 for (int x = 0; x<4; x++) {
2244 mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2245 mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2246 }
2247 }
2248 __builtin_mma_disassemble_acc(vec_C, &acc_0);
2249 for (int I = 0; I < RM; I++) {
2250 for (int J = 0; J < 4; J++) {
2251 *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2252 }
2253 }
2254 __builtin_mma_disassemble_acc(vec_C, &acc_1);
2255 for (int I = 0; I < RM; I++) {
2256 for (int J = 0; J < 4; J++) {
2257 *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2258 }
2259 }
2260 }
2261 }
2262
2263 template<int RM, int RN>
2264 inline void kernel(int64_t ii, int64_t jj) {
2265 if constexpr(RM == 4 && RN == 8) {
2266 KERNEL_4x8(ii,jj);
2267 } else if constexpr(RM == 8 && RN == 8) {
2268 KERNEL_8x8(ii,jj);
2269 } else if constexpr(RM == 8 && RN == 4) {
2270 KERNEL_8x4(ii,jj);
2271 } else {
2272 assert(false && "RN/RM values not supported");
2273 }
2274 }
2275
2276 template <int RM, int RN>
2277 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2278 int64_t ytiles = (m - m0) / RM;
2279 int64_t xtiles = (n - n0) / RN;
2280 int64_t tiles = xtiles * ytiles;
2281 int64_t duty = (tiles + nth - 1) / nth;
2282 int64_t start = duty * ith;
2283 int64_t end = start + duty;
2284 if (end > tiles)
2285 end = tiles;
2286 for (int64_t job = start; job < end; ++job) {
2287 int64_t ii = m0 + job / xtiles * RM;
2288 int64_t jj = n0 + job % xtiles * RN;
2289 kernel<RM, RN>(ii, jj);
2290 }
2291 }
2292
2293 const TA *const A;
2294 const TB *const B;
2295 TC *C;
2296 const int64_t k;
2297 const int64_t lda;
2298 const int64_t ldb;
2299 const int64_t ldc;
2300 const int ith;
2301 const int nth;
2302};
2303
2304 template <typename TA>
2305 tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
2306 const TA *A, int64_t lda,
2307 const block_q8_0 *B, int64_t ldb,
2308 float *C, int64_t ldc,
2309 int ith, int nth)
2310 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2311 kc = 64;
2312 }
2313
2314 template<typename TA>
2315 void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
2316 int mc = 64; int nc = 64;
2317 if (n % 8 == 0 && n < nc) {
2318 nc = n;
2319 mc = 32 ;
2320 kc = 32;
2321 }
2322 const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
2323 if (is_aligned) {
2324 this->matmul_tiled_q0(m, n, mc, nc, kc);
2325 } else {
2326 mnpack(0, m, 0, n);
2327 }
2328 }
2329
2330 template<typename TA>
2331 template<int size>
2332 void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
2333 int64_t i, j;
2334 TA *aoffset = NULL;
2335 int8_t *vecOffset = NULL;
2336 TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2337 TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2338 vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2339 vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2340 aoffset = const_cast<TA*>(a);
2341 vecOffset = vec;
2342 j = (rows >> 3);
2343 if (j > 0) {
2344 do {
2345 aoffset1 = aoffset;
2346 aoffset2 = aoffset1 + lda;
2347 aoffset3 = aoffset2 + lda;
2348 aoffset4 = aoffset3 + lda;
2349 aoffset5 = aoffset4 + lda;
2350 aoffset6 = aoffset5 + lda;
2351 aoffset7 = aoffset6 + lda;
2352 aoffset8 = aoffset7 + lda;
2353 aoffset += 8 * lda;
2354 i = (cols >> 2);
2355 if (i > 0) {
2356 do {
2357 c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2358 c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2359 c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2360 c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
2361 c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
2362 c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
2363 c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
2364 c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
2365
2366 process_q4_elements(c1, &comparray[0]);
2367 process_q4_elements(c2, &comparray[1]);
2368 process_q4_elements(c3, &comparray[2]);
2369 process_q4_elements(c4, &comparray[3]);
2370 process_q4_elements(c5, &comparray[4]);
2371 process_q4_elements(c6, &comparray[5]);
2372 process_q4_elements(c7, &comparray[6]);
2373 process_q4_elements(c8, &comparray[7]);
2374 vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2375 vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2376 vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
2377 vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
2378 aoffset1 += lda;
2379 aoffset2 += lda;
2380 aoffset3 += lda;
2381 aoffset4 += lda;
2382 aoffset5 += lda;
2383 aoffset6 += lda;
2384 aoffset7 += lda;
2385 aoffset8 += lda;
2386 vecOffset += 256;
2387 i--;
2388 } while (i > 0);
2389 }
2390 j--;
2391 } while (j > 0);
2392 }
2393
2394 if (rows & 4) {
2395 aoffset1 = aoffset;
2396 aoffset2 = aoffset1 + lda;
2397 aoffset3 = aoffset2 + lda;
2398 aoffset4 = aoffset3 + lda;
2399 aoffset += 4 * lda;
2400 i = (cols >> 2);
2401 if (i > 0) {
2402 do {
2403 c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2404 c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2405 c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2406 c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
2407
2408 process_q4_elements(c1, &comparray[0]);
2409 process_q4_elements(c2, &comparray[1]);
2410 process_q4_elements(c3, &comparray[2]);
2411 process_q4_elements(c4, &comparray[3]);
2412 vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2413 vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2414 aoffset1 += lda;
2415 aoffset2 += lda;
2416 aoffset3 += lda;
2417 aoffset4 += lda;
2418 vecOffset += 128;
2419 i--;
2420 } while (i > 0);
2421 }
2422 }
2423
2424 if (rows & 3) {
2425 aoffset1 = aoffset;
2426 aoffset2 = aoffset1 + lda;
2427 aoffset3 = aoffset2 + lda;
2428 i = (cols >> 2);
2429 if (i > 0) {
2430 do {
2431 switch(rows) {
2432 case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2433 case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2434 case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2435 break;
2436 }
2437 process_q4_elements(c1, &comparray[0]);
2438 process_q4_elements(c2, &comparray[1]);
2439 process_q4_elements(c3, &comparray[2]);
2440 process_q4_elements(c4, &comparray[3]);
2441 vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2442 vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2443 aoffset1 += lda;
2444 aoffset2 += lda;
2445 aoffset3 += lda;
2446 vecOffset += 128;
2447 i--;
2448 } while(i > 0);
2449 }
2450 }
2451 }
2452
2453 template<typename TA>
2454 template<typename VA, typename VB>
2455 void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2456 int64_t i, j;
2457 block_q8_0 *aoffset = NULL;
2458 VA *vecOffset = NULL;
2459 block_q8_0* aoffsets[8];
2460 __vector_pair arr[8];
2461 VB c[8][2] = {0};
2462 VB c1[8] = {0}; VB c2[8] = {0};
2463 aoffset = const_cast<block_q8_0*>(a);
2464 vecOffset = vec;
2465 j = (rows >> 3);
2466 if (j > 0) {
2467 do {
2468 aoffsets[0] = aoffset;
2469 for (int it = 1; it < 8; it++)
2470 aoffsets[it] = aoffsets[it-1] + lda;
2471 aoffset += 8 * lda;
2472
2473 i = (cols >> 3);
2474 if (i > 0) {
2475 do {
2476 for (int it = 0; it < 8; it++) {
2477 arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
2478 __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2479 c1[it] = c[it][0];
2480 c2[it] = c[it][1];
2481 }
2482 vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2483 vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2484 vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
2485 vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
2486 for (int it = 0; it < 8; it++)
2487 aoffsets[it] += lda;
2488 vecOffset += 256;
2489 i--;
2490 } while(i > 0);
2491 }
2492 j--;
2493 } while(j > 0);
2494 }
2495 if (rows & 4) {
2496 aoffsets[0] = aoffset;
2497 for (int it = 1; it < 4; it++ )
2498 aoffsets[it] = aoffsets[it-1] + lda;
2499 aoffset += 4 * lda;
2500 i = (cols >> 3);
2501 if (i > 0) {
2502 do {
2503 for (int it = 0; it < 4; it++) {
2504 arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
2505 __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2506 c1[it] = c[it][0];
2507 c2[it] = c[it][1];
2508 }
2509 vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2510 vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2511 for (int it = 0; it < 4; it++) {
2512 aoffsets[it] += lda;
2513 }
2514 vecOffset += 128;
2515 i--;
2516 } while(i > 0);
2517 }
2518 }
2519
2520 if (rows & 3) {
2521 aoffsets[0] = aoffset;
2522 for (int it = 1; it < 3; it++ )
2523 aoffsets[it] = aoffsets[it-1] + lda;
2524 i = (cols >> 3);
2525 if (i > 0) {
2526 do {
2527 switch(rows) {
2528 case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
2529 __builtin_vsx_disassemble_pair(c[2], &arr[2]);
2530 c1[2] = c[2][0]; c2[2] = c[2][1];
2531 case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
2532 __builtin_vsx_disassemble_pair(c[1], &arr[1]);
2533 c1[1] = c[1][0]; c2[1] = c[1][1];
2534 case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
2535 __builtin_vsx_disassemble_pair(c[0], &arr[0]);
2536 c1[0] = c[0][0]; c2[0] = c[0][1];
2537 break;
2538 }
2539 vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2540 vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2541 for (int it = 0; it < 3; it++)
2542 aoffsets[it] += lda;
2543 vecOffset += 128;
2544 i--;
2545 } while(i > 0);
2546 }
2547 }
2548 }
2549
2550 template<typename TA>
2551 void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2552 int m_rem = MIN(m - m0, 16);
2553 int n_rem = MIN(n - n0, 16);
2554
2555 int mc = 0, nc = 0;
2556
2557 if (m_rem >= 8 && n_rem >= 8) {
2558 mc = 8;
2559 nc = 8;
2560 gemm<8, 8>(m0, m, n0, n);
2561 } else if (m_rem >= 4 && n_rem >= 8) {
2562 mc = 4;
2563 nc = 8;
2564 gemm<4, 8>(m0, m, n0, n);
2565 } else if (m_rem >= 8 && n_rem >= 4) {
2566 mc = 8;
2567 nc = 4;
2568 gemm<8, 4>(m0, m, n0, n);
2569 } else if (m_rem >= 4 && n_rem >= 4) {
2570 mc = 4;
2571 nc = 4;
2572 gemm_small(m0, m, n0, n, mc, nc);
2573 } else {
2574 mc = (m_rem >= 4) ? 4 : m_rem;
2575 nc = (n_rem >= 4) ? 4 : n_rem;
2576 if (mc == 0 || nc == 0)
2577 return;
2578 gemm_small(m0, m, n0, n, mc, nc);
2579 }
2580
2581 int64_t mp = m0 + ((m - m0) / mc) * mc;
2582 int64_t np = n0 + ((n - n0) / nc) * nc;
2583 mnpack(mp, m, n0, np);
2584 mnpack(m0, m, np, n);
2585 }
2586
2587
2588 template<typename TA>
2589 void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
2590 vec_t vec_A[8], vec_B[16] = {0};
2591 acc_t acc_0, acc_1;
2592 std::array<int, 4> comparray {};
2593 vector float fin_res[8] = {0};
2594 vector float vs[8] = {0};
2595 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2596 for (int l = 0; l < k; l++) {
2597 __builtin_mma_xxsetaccz(&acc_0);
2598 __builtin_mma_xxsetaccz(&acc_1);
2599 if (std::is_same_v<TA, block_q4_0>) {
2600 packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2601 } else {
2602 packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2603 }
2604 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2605 for(int x = 0; x < 8; x++) {
2606 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2607 __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
2608 }
2609 for (int I = 0; I<4; I++) {
2610 for (int J = 0; J<4; J++) {
2611 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2612 *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2613 }
2614 }
2615 if (!isAblock_q4) {
2616 auto aoffset = A+(ii*lda)+l;
2617 for (int i = 0; i < 4; i++) {
2618 comparray[i] = 0;
2619 int ca = 0;
2620 auto *at = aoffset->qs;
2621 for (int j = 0; j < 32; j++)
2622 ca += (int)*at++;
2623 comparray[i] = ca;
2624 aoffset += lda;
2625 }
2626 }
2627 compute(&acc_0, 0, 0, comparray, vs, fin_res);
2628 compute(&acc_1, 0, 4, comparray, vs, fin_res);
2629 }
2630 save_res(ii, jj, 0, fin_res);
2631 save_res(ii, jj+4, 4, fin_res);
2632 }
2633
2634 template<typename TA>
2635 void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
2636 vec_t vec_A[16], vec_B[8] = {0};
2637 acc_t acc_0, acc_1;
2638 std::array<int, 8> comparray {};
2639 vector float fin_res[8] = {0};
2640 vector float vs[8] = {0};
2641 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2642 for (int l = 0; l < k; l++) {
2643 __builtin_mma_xxsetaccz(&acc_0);
2644 __builtin_mma_xxsetaccz(&acc_1);
2645 if (std::is_same_v<TA, block_q4_0>) {
2646 packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2647 } else {
2648 packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2649 }
2650 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2651 for(int x = 0; x < 8; x++) {
2652 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2653 __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2654 }
2655 for (int I = 0; I<8; I++) {
2656 for (int J = 0; J<4; J++) {
2657 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2658 }
2659 }
2660 if (!isAblock_q4) {
2661 auto aoffset = A+(ii*lda)+l;
2662 for (int i = 0; i < 8; i++) {
2663 comparray[i] = 0;
2664 int ca = 0;
2665 auto *at = aoffset->qs;
2666 for (int j = 0; j < 32; j++)
2667 ca += (int)*at++;
2668 comparray[i] = ca;
2669 aoffset += lda;
2670 }
2671 }
2672 compute(&acc_0, 0, 0, comparray, vs, fin_res);
2673 compute(&acc_1, 4, 4, comparray, vs, fin_res);
2674 }
2675 save_res(ii, jj, 0, fin_res);
2676 save_res(ii+4, jj, 4, fin_res);
2677 }
2678
2679 template<typename TA>
2680 void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
2681 vec_t vec_A[16], vec_B[16] = {0};
2682 acc_t acc_0, acc_1, acc_2, acc_3;
2683 acc_t acc_4, acc_5, acc_6, acc_7;
2684 std::array<int, 8> comparray {};
2685 vector float fin_res[16] = {0};
2686 vector float vs[16] = {0};
2687 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2688 for (int l = 0; l < k; l++) {
2689 __builtin_mma_xxsetaccz(&acc_0);
2690 __builtin_mma_xxsetaccz(&acc_1);
2691 __builtin_mma_xxsetaccz(&acc_2);
2692 __builtin_mma_xxsetaccz(&acc_3);
2693 if (std::is_same_v<TA, block_q4_0>) {
2694 packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2695 } else {
2696 packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2697 }
2698 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2699 for(int x = 0; x < 8; x++) {
2700 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2701 __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2702 __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
2703 __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
2704 }
2705 for (int I = 0; I<8; I++) {
2706 for (int J = 0; J<4; J++) {
2707 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2708 *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2709 }
2710 }
2711 if (!isAblock_q4) {
2712 auto aoffset = A+(ii*lda)+l;
2713 for (int i = 0; i < 8; i++) {
2714 comparray[i] = 0;
2715 int ca = 0;
2716 auto *at = aoffset->qs;
2717 for (int j = 0; j < 32; j++)
2718 ca += (int)*at++;
2719 comparray[i] = ca;
2720 aoffset += lda;
2721 }
2722 }
2723 compute(&acc_0, 0, 0, comparray, vs, fin_res);
2724 compute(&acc_1, 4, 4, comparray, vs, fin_res);
2725 compute(&acc_2, 0, 8, comparray, vs, fin_res);
2726 compute(&acc_3, 4, 12, comparray, vs, fin_res);
2727 }
2728 save_res(ii, jj, 0, fin_res);
2729 save_res(ii+4, jj, 4, fin_res);
2730 save_res(ii, jj+4, 8, fin_res);
2731 save_res(ii+4, jj+4, 12, fin_res);
2732 }
2733
2734 template<typename TA>
2735 void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2736 int64_t ytiles = (m - m0) / RM;
2737 int64_t xtiles = (n - n0) / RN;
2738 int64_t tiles = xtiles * ytiles;
2739 int64_t duty = (tiles + nth - 1) / nth;
2740 int64_t start = duty * ith;
2741 int64_t end = start + duty;
2742 vec_t vec_A[8] = {0}, vec_B[8] = {0};
2743 vector signed int vec_C[4];
2744 acc_t acc_0;
2745 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2746
2747 if (end > tiles)
2748 end = tiles;
2749 for (int64_t job = start; job < end; ++job) {
2750 int64_t ii = m0 + job / xtiles * RM;
2751 int64_t jj = n0 + job % xtiles * RN;
2752 std::array<int, 4> comparray{};
2753 vector float res[4] = {0};
2754 vector float fin_res[4] = {0};
2755 vector float vs[4] = {0};
2756 vector float CA[4] = {0};
2757 __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
2758 __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
2759 for (int l = 0; l < k; l++) {
2760 __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2761 __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2762 __builtin_mma_xxsetaccz(&acc_0);
2763 if (isAblock_q4) {
2764 packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2765 } else {
2766 packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2767 }
2768 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2769 for(int x = 0; x < 8; x+=4) {
2770 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2771 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
2772 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
2773 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
2774 }
2775 for (int I = 0; I<RM; I++) {
2776 for (int J = 0; J<RN; J++) {
2777 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2778 }
2779 }
2780 __builtin_mma_disassemble_acc(vec_C, &acc_0);
2781 if (!isAblock_q4) {
2782 auto aoffset = A+(ii*lda)+l;
2783 for (int i = 0; i < RM; i++) {
2784 comparray[i] = 0;
2785 int ca = 0;
2786 auto *at = aoffset->qs;
2787 for (int j = 0; j < 32; j++)
2788 ca += (int)*at++;
2789 comparray[i] = ca;
2790 aoffset += lda;
2791 }
2792 }
2793 for (int i = 0; i < RM; i++) {
2794 CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
2795 res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
2796 fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2797 }
2798 }
2799 save_res(ii, jj, 0, fin_res, RM, RN);
2800 }
2801 }
2802
2803 template<typename TA>
2804 template <int RM, int RN>
2805 NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2806 int64_t ytiles = (m - m0) / RM;
2807 int64_t xtiles = (n - n0) / RN;
2808 int64_t tiles = xtiles * ytiles;
2809 int64_t duty = (tiles + nth - 1) / nth;
2810 int64_t start = duty * ith;
2811 int64_t end = start + duty;
2812 if (end > tiles)
2813 end = tiles;
2814 for (int64_t job = start; job < end; ++job) {
2815 int64_t ii = m0 + job / xtiles * RM;
2816 int64_t jj = n0 + job % xtiles * RN;
2817 this->kernel<RM, RN>(ii, jj);
2818 }
2819 }
2820
2821template class tinyBLAS_Q0_PPC<block_q4_0>;
2822template class tinyBLAS_Q0_PPC<block_q8_0>;
2823
2824class tinyBLAS_PPC {
2825 public:
2826 tinyBLAS_PPC(int64_t k,
2827 const float * A, int64_t lda,
2828 const float * B, int64_t ldb,
2829 float * C, int64_t ldc,
2830 int ith, int nth)
2831 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2832 }
2833
2834 void matmul(int64_t m, int64_t n) {
2835 int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2836 if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2837 matmul_tiled(m, n, mc, nc, kc);
2838 } else {
2839 mnpack(0, m, 0, n);
2840 }
2841 }
2842
2843 private:
2844
2845 inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2846 vec_t vec_C[4];
2847 __builtin_mma_disassemble_acc(vec_C, ACC);
2848 for (int I = 0; I < 4; I++) {
2849 for (int J = 0; J < 4; J++) {
2850 *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2851 }
2852 }
2853 }
2854
2855 inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2856 vec_t vec_C[4];
2857 __builtin_mma_disassemble_acc(vec_C, ACC);
2858 for (int I = 0; I < 4; I++) {
2859 for (int J = 0; J < 4; J++) {
2860 float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2861 *c_ptr += *((float *)&vec_C[I]+J);
2862 }
2863 }
2864 }
2865
2866 inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2867 vector float t1, t2, t3, t4, t5, t6, t7, t8;
2868 t1 = vec_mergeh(src[0], src[1]);
2869 t2 = vec_mergeh(src[2], src[3]);
2870 t3 = vec_mergel(src[0], src[1]);
2871 t4 = vec_mergel(src[2], src[3]);
2872
2873 t5 = vec_xxpermdi(t1, t2, 0);
2874 t6 = vec_xxpermdi(t1, t2, 3);
2875 t7 = vec_xxpermdi(t3, t4, 0);
2876 t8 = vec_xxpermdi(t3, t4, 3);
2877
2878 vec_xst(t5, 0, vecOffset);
2879 vec_xst(t6, 0, vecOffset + 4);
2880 vec_xst(t7, 0, vecOffset + 8);
2881 vec_xst(t8, 0, vecOffset + 12);
2882 }
2883
2884 inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2885 vector float t1, t2, t3, t4, t5, t6, t7, t8;
2886 t1 = vec_mergeh(src[0], src[1]);
2887 t2 = vec_mergeh(src[2], src[3]);
2888 t3 = vec_mergeh(src[4], src[5]);
2889 t4 = vec_mergeh(src[6], src[7]);
2890
2891 t5 = vec_xxpermdi(t1, t2, 0);
2892 t6 = vec_xxpermdi(t3, t4, 0);
2893 t7 = vec_xxpermdi(t1, t2, 3);
2894 t8 = vec_xxpermdi(t3, t4, 3);
2895
2896 vec_xst(t5, 0, vecOffset);
2897 vec_xst(t6, 0, vecOffset + 4);
2898 vec_xst(t7, 0, vecOffset + 8);
2899 vec_xst(t8, 0, vecOffset + 12);
2900
2901 t1 = vec_mergel(src[0], src[1]);
2902 t2 = vec_mergel(src[2], src[3]);
2903 t3 = vec_mergel(src[4], src[5]);
2904 t4 = vec_mergel(src[6], src[7]);
2905
2906 t5 = vec_xxpermdi(t1, t2, 0);
2907 t6 = vec_xxpermdi(t3, t4, 0);
2908 t7 = vec_xxpermdi(t1, t2, 3);
2909 t8 = vec_xxpermdi(t3, t4, 3);
2910
2911 vec_xst(t5, 0, vecOffset + 16);
2912 vec_xst(t6, 0, vecOffset + 20);
2913 vec_xst(t7, 0, vecOffset + 24);
2914 vec_xst(t8, 0, vecOffset + 28);
2915 }
2916
2917 void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2918 int64_t i, j;
2919 float * aoffsets[8];
2920 float * aoffset = NULL, * boffset = NULL;
2921 __vector_pair arr[8];
2922 vector float c[8][2] = {0};
2923 vector float c1[8] = {0};
2924 vector float c2[8] = {0};
2925 aoffset = const_cast<float *>(a);
2926 boffset = vec;
2927 j = (rows >> 3);
2928 if (j > 0) {
2929 do {
2930 aoffsets[0] = aoffset;
2931 for (int it = 1; it < 8; it++)
2932 aoffsets[it] = aoffsets[it-1] + lda;
2933 aoffset += 8 * lda;
2934 i = (cols >> 3);
2935 if (i > 0) {
2936 do {
2937 for (int it = 0; it < 8; it++) {
2938 arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2939 __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2940 c1[it] = c[it][0];
2941 c2[it] = c[it][1];
2942 }
2943
2944 vector_permute_store_8(c1, boffset);
2945 vector_permute_store_8(c2, boffset + 32);
2946 boffset += 64;
2947 i--;
2948 if (i > 0) {
2949 for (int it = 0; it < 8; it++) {
2950 aoffsets[it] = aoffsets[it] + 8;
2951 }
2952 }
2953 } while(i > 0);
2954 }
2955 if (cols & 4) {
2956 for (int it = 0; it < 8 ; it++)
2957 c1[it] = vec_xl(0, aoffsets[it]);
2958 vector_permute_store_8(c1, boffset);
2959 }
2960 j--;
2961 } while(j > 0);
2962 }
2963
2964 if (rows & 4) {
2965 aoffsets[0] = aoffset;
2966 for (int it = 1; it < 4; it++)
2967 aoffsets[it] = aoffsets[it-1] + lda;
2968 aoffset += 4 * lda;
2969 i = (cols >> 3);
2970 if (i > 0) {
2971 do {
2972 for (int it = 0; it < 4; it++) {
2973 arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2974 __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2975 c1[it] = c[it][0];
2976 c2[it] = c[it][1];
2977 }
2978 vector_permute_store_4(c1, boffset);
2979 vector_permute_store_4(c2, boffset + 16);
2980 for (int it = 0; it < 4; it++)
2981 aoffsets[it] += 8 * lda;
2982 boffset += 32;
2983 i--;
2984 } while(i > 0);
2985 }
2986
2987 if (cols & 4) {
2988 for (int it = 0; it < 4; it++)
2989 c1[it] = vec_xl(0, aoffsets[it]);
2990 vector_permute_store_4(c1, boffset);
2991 }
2992 }
2993 if (rows & 3) {
2994 aoffsets[0] = aoffset;
2995 for (int it = 1; it < 3; it++)
2996 aoffsets[it] = aoffsets[it-1] + lda;
2997 if (cols & 4) {
2998 for (int it = 0; it < 3; it++)
2999 c1[it] = vec_xl(0, aoffsets[it]);
3000 vector_permute_store_4(c1, boffset);
3001 }
3002 }
3003 }
3004
3005 void KERNEL_4x4(int64_t ii, int64_t jj) {
3006 vec_t vec_A[4], vec_B[4], vec_C[4];
3007 acc_t acc_0;
3008 __builtin_mma_xxsetaccz(&acc_0);
3009 for (int l = 0; l < k; l += 4) {
3010 packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
3011 packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
3012 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3013 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
3014 __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
3015 __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
3016 }
3017 save_acc(&acc_0, ii, jj);
3018 }
3019
3020 void KERNEL_4x8(int64_t ii, int64_t jj) {
3021 vec_t vec_A[4], vec_B[8], vec_C[4];
3022 acc_t acc_0, acc_1;
3023 __builtin_mma_xxsetaccz(&acc_0);
3024 __builtin_mma_xxsetaccz(&acc_1);
3025 for (int64_t l = 0; l < k; l += 4) {
3026 packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
3027 packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
3028 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
3029 __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
3030 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
3031 __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
3032 __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
3033 __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
3034 __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
3035 __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
3036 }
3037 save_acc(&acc_0, ii, jj);
3038 save_acc(&acc_1, ii, jj + 4);
3039 }
3040
3041 void KERNEL_8x4(int64_t ii, int64_t jj) {
3042 vec_t vec_A[8], vec_B[4], vec_C[4];
3043 acc_t acc_0, acc_1;
3044 __builtin_mma_xxsetaccz(&acc_0);
3045 __builtin_mma_xxsetaccz(&acc_1);
3046 for (int64_t l = 0; l < k; l += 4) {
3047 packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
3048 packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
3049 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
3050 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
3051 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
3052 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
3053 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
3054 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
3055 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
3056 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
3057 }
3058 save_acc(&acc_0, ii, jj);
3059 save_acc(&acc_1, ii + 4, jj);
3060 }
3061
3062 void KERNEL_8x8(int64_t ii, int64_t jj) {
3063 vec_t vec_A[16], vec_B[16], vec_C[4];
3064 acc_t acc_0, acc_1, acc_2, acc_3;
3065 __builtin_mma_xxsetaccz(&acc_0);
3066 __builtin_mma_xxsetaccz(&acc_1);
3067 __builtin_mma_xxsetaccz(&acc_2);
3068 __builtin_mma_xxsetaccz(&acc_3);
3069 for (int l = 0; l < k; l+=8) {
3070 packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
3071 packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
3072 for(int x = 0; x < 16; x+=2) {
3073 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
3074 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
3075 __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
3076 __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
3077 }
3078 }
3079 save_acc(&acc_0, ii, jj);
3080 save_acc(&acc_1, ii, jj + 4);
3081 save_acc(&acc_2, ii + 4, jj);
3082 save_acc(&acc_3, ii + 4, jj + 4);
3083 }
3084
3085 inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
3086 for (int x = 0; x < 16; x += 2) {
3087 __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
3088 __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
3089 __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
3090 __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
3091 __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
3092 __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
3093 __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
3094 __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
3095 }
3096 }
3097
3098 void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
3099 for (int64_t i = 0; i < mc; i += 16) {
3100 int A_base_addr = (mc / 8) * (i / 8) * 16;
3101 for (int64_t j = 0; j < nc; j += 8) {
3102 int B_base_addr = (nc / 8) * (j / 8) * 16;
3103 acc_t acc[8];
3104 vec_t A0_block[16]; vec_t A1_block[16];
3105 for (int x = 0; x < 8; x++)
3106 __builtin_mma_xxsetaccz(&acc[x]);
3107 for (int64_t l = 0; l < kc; l += 8) {
3108 int A0_block_idx = A_base_addr + (l / 8) * 16;
3109 int A1_block_idx = A0_block_idx + (mc / 8) * 16;
3110 int B_block_idx = B_base_addr + (l / 8) * 16;
3111 vec_t* A0_block = &vec_A[A0_block_idx];
3112 vec_t* A1_block = &vec_A[A1_block_idx];
3113 vec_t* B_block = &vec_B[B_block_idx];
3114 MMA_16x8(A0_block, A1_block, B_block, acc);
3115 }
3116 if (kk == 0) {
3117 save_acc(&acc[0], ii + i, jj + j);
3118 save_acc(&acc[1], ii + i, jj + j + 4);
3119 save_acc(&acc[2], ii + i + 4, jj + j);
3120 save_acc(&acc[3], ii + i + 4, jj + j + 4);
3121 save_acc(&acc[4], ii + i + 8, jj + j);
3122 save_acc(&acc[5], ii + i + 8, jj + j + 4);
3123 save_acc(&acc[6], ii + i + 12, jj + j);
3124 save_acc(&acc[7], ii + i + 12, jj + j + 4);
3125 } else {
3126 add_save_acc(&acc[0], ii + i, jj + j);
3127 add_save_acc(&acc[1], ii + i, jj + j + 4);
3128 add_save_acc(&acc[2], ii + i + 4, jj + j);
3129 add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
3130 add_save_acc(&acc[4], ii + i + 8, jj + j);
3131 add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
3132 add_save_acc(&acc[6], ii + i + 12, jj + j);
3133 add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
3134 }
3135 }
3136 }
3137 }
3138
3139 void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
3140 int64_t ytiles = m / mc;
3141 int64_t xtiles = n / nc;
3142 int64_t tiles = xtiles * ytiles;
3143 int64_t duty = (tiles + nth - 1) / nth;
3144 int64_t start = duty * ith;
3145 int64_t end = start + duty;
3146 if (end > tiles) {
3147 end = tiles;
3148 }
3149 for (int64_t job = start; job < end; ++job) {
3150 int64_t ii = (job / xtiles) * mc;
3151 int64_t jj = (job % xtiles) * nc;
3152 for (int64_t kk = 0; kk < k; kk += kc) {
3153 vec_t A_pack[kc * mc / 4];
3154 vec_t B_pack[kc * nc / 4];
3155 packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
3156 packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
3157 KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
3158 }
3159 }
3160 }
3161
3162 void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3163 int m_rem = MIN(m - m0, 8);
3164 int n_rem = MIN(n - n0, 8);
3165 int mc = 0, nc = 0;
3166 if (m_rem >= 8 && n_rem >= 8) {
3167 mc = 8;
3168 nc = 8;
3169 gemm<8, 8>(m0, m, n0, n);
3170 } else if (m_rem >= 4 && n_rem >= 8) {
3171 mc = 4;
3172 nc = 8;
3173 gemm<4, 8>(m0, m, n0, n);
3174 } else if (m_rem >= 8 && n_rem >= 4) {
3175 mc = 8;
3176 nc = 4;
3177 gemm<8, 4>(m0, m, n0, n);
3178 } else if (m_rem >= 4 && n_rem >= 4) {
3179 mc = 4;
3180 nc = 4;
3181 gemm<4, 4>(m0, m, n0, n);
3182 } else {
3183 mc = (m_rem >= 4) ? 4 : m_rem;
3184 nc = (n_rem >= 4) ? 4 : n_rem;
3185 if (mc == 0 || nc == 0)
3186 return;
3187 gemm_small(m0, m, n0, n, mc, nc);
3188 }
3189 int64_t mp = m0 + ((m - m0) / mc) * mc;
3190 int64_t np = n0 + ((n - n0) / nc) * nc;
3191 mnpack(mp, m, n0, np);
3192 mnpack(m0, m, np, n);
3193 }
3194
3195 void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3196 int64_t ytiles = (m - m0) / RM;
3197 int64_t xtiles = (n - n0) / RN;
3198 int64_t tiles = xtiles * ytiles;
3199 int64_t duty = (tiles + nth - 1) / nth;
3200 int64_t start = duty * ith;
3201 int64_t end = start + duty;
3202 if (end > tiles)
3203 end = tiles;
3204 for (int64_t job = start; job < end; ++job) {
3205 int64_t ii = m0 + job / xtiles * RM;
3206 int64_t jj = n0 + job % xtiles * RN;
3207 vec_t vec_C[4];
3208 acc_t acc_0;
3209 __builtin_mma_xxsetaccz(&acc_0);
3210 vec_t vec_A[4] = {0}, vec_B[4] = {0};
3211 for (int l = 0; l < k; l += 4) {
3212 /* 'GEMV Forwarding' concept is used in first two conditional loops.
3213 * when one of the matrix has a single row/column, the elements are
3214 * broadcasted, instead of using packing routine to prepack the
3215 * matrix elements.
3216 */
3217 if (RM == 1) {
3218 float * a = const_cast<float *>(A + (ii) * lda + l);
3219 packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3220 vec_A[0] = (vec_t)vec_xl(0,a);
3221 vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
3222 vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
3223 vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
3224 } else if (RN == 1) {
3225 packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
3226 float * b = const_cast<float *>(B + (jj) * ldb + l);
3227 vec_B[0] = (vec_t)vec_xl(0,b);
3228 vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
3229 vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
3230 vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
3231 } else {
3232 packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
3233 packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3234 }
3235 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3236 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
3237 __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
3238 __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
3239 }
3240 __builtin_mma_disassemble_acc(vec_C, &acc_0);
3241 for (int I = 0; I < RM; I++) {
3242 for (int J = 0; J < RN; J++) {
3243 *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
3244 }
3245 }
3246 }
3247 }
3248
3249 template<int RM, int RN>
3250 inline void kernel(int64_t ii, int64_t jj) {
3251 if constexpr(RM == 4 && RN == 4) {
3252 KERNEL_4x4(ii, jj);
3253 } else if constexpr(RM == 4 && RN == 8) {
3254 KERNEL_4x8(ii, jj);
3255 } else if constexpr(RM == 8 && RN == 4) {
3256 KERNEL_8x4(ii, jj);
3257 } else if constexpr(RM == 8 && RN == 8) {
3258 KERNEL_8x8(ii, jj);
3259 } else {
3260 static_assert(false, "RN/RM values not supported");
3261 }
3262 }
3263
3264 template <int RM, int RN>
3265 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3266 int64_t ytiles = (m - m0) / RM;
3267 int64_t xtiles = (n - n0) / RN;
3268 int64_t tiles = xtiles * ytiles;
3269 int64_t duty = (tiles + nth - 1) / nth;
3270 int64_t start = duty * ith;
3271 int64_t end = start + duty;
3272 if (end > tiles)
3273 end = tiles;
3274 for (int64_t job = start; job < end; ++job) {
3275 int64_t ii = m0 + job / xtiles * RM;
3276 int64_t jj = n0 + job % xtiles * RN;
3277 kernel<RM, RN>(ii, jj);
3278 }
3279 }
3280
3281 const float * const A;
3282 const float * const B;
3283 float * C;
3284 const int64_t k;
3285 const int64_t lda;
3286 const int64_t ldb;
3287 const int64_t ldc;
3288 const int ith;
3289 const int nth;
3290};
3291#endif
3292} // namespace
3293
3294/**
3295 * Performs optimized matrix multiplication on CPU.
3296 *
3297 * This subroutine may compute C = Aᵀ * B with column major ordering.
3298 * Despite its name, this isn't a generalized implementation. Work is
3299 * only performed when a handwritten kernel is written and available.
3300 * Otherwise the caller should fall back to a general matmul routine.
3301 *
3302 * For example, for single-threaded single-precision GEMM you can say
3303 *
3304 * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
3305 * 0, 1,
3306 * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
3307 *
3308 * @param m is rows in `A` and `C`
3309 * @param n is cols in `B` and `C`
3310 * @param k is cols in `A` and rows in `B`
3311 * @param A is first input matrix (always transposed)
3312 * @param lda is row stride of `A`
3313 * @param B is second input matrix (never transposed)
3314 * @param ldb is row stride of `B`
3315 * @param C is input/output array of output matrices
3316 * @param ldc is row stride of `C`
3317 * @param ith is thread id (must be less than `nth`)
3318 * @param nth is number of threads (must be greater than zero)
3319 * @param Atype is GGML data type of `A`
3320 * @param Btype is GGML data type of `B`
3321 * @param Ctype is GGML data type of `C`
3322 * @return true if this function was able to service the matmul request
3323 */
3324bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
3325 const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
3326 int64_t ldc, int Atype, int Btype, int Ctype) {
3327
3328 assert(m >= 0);
3329 assert(n >= 0);
3330 assert(k >= 0);
3331 assert(lda >= k);
3332 assert(ldb >= k);
3333 assert(ldc >= m);
3334 assert(params->nth > 0);
3335 assert(params->ith < params->nth);
3336
3337 // only enable sgemm for prompt processing
3338#if !defined(__MMA__)
3339 if (n < 2)
3340 return false;
3341#endif
3342
3343 if (Ctype != GGML_TYPE_F32)
3344 return false;
3345
3346 switch (Atype) {
3347
3348 case GGML_TYPE_F32: {
3349 if (Btype != GGML_TYPE_F32)
3350 return false;
3351#if defined(__AVX512F__)
3352 tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
3353 k, (const float *)A, lda,
3354 (const float *)B, ldb,
3355 (float *)C, ldc};
3356 return tb.matmul(m, n);
3357#elif defined(__AVX__) || defined(__AVX2__)
3358 tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
3359 k, (const float *)A, lda,
3360 (const float *)B, ldb,
3361 (float *)C, ldc};
3362 return tb.matmul(m, n);
3363#elif defined(__ARM_NEON)
3364 if (n < 4)
3365 return false;
3366 tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
3367 k, (const float *)A, lda,
3368 (const float *)B, ldb,
3369 (float *)C, ldc};
3370 return tb.matmul(m, n);
3371#elif defined(__VXE__) || defined(__VXE2__)
3372 if (n < 4)
3373 return false;
3374 tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
3375 k, (const float *)A, lda,
3376 (const float *)B, ldb,
3377 (float *)C, ldc};
3378 return tb.matmul(m, n);
3379#elif defined(__MMA__)
3380 if (k % 8)
3381 return false;
3382 tinyBLAS_PPC tb{
3383 k, (const float *)A, lda,
3384 (const float *)B, ldb,
3385 (float *)C, ldc,
3386 params->ith, params->nth};
3387 tb.matmul(m, n);
3388 return true;
3389#elif defined(__riscv_zvfh)
3390 #if LMUL == 1
3391 tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
3392 k, (const float *)A, lda,
3393 (const float *)B, ldb,
3394 (float *)C, ldc};
3395 #elif LMUL == 2
3396 tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
3397 k, (const float *)A, lda,
3398 (const float *)B, ldb,
3399 (float *)C, ldc};
3400 #else // LMUL = 4
3401 tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
3402 k, (const float *)A, lda,
3403 (const float *)B, ldb,
3404 (float *)C, ldc};
3405 #endif
3406 return tb.matmul(m, n);
3407#else
3408 return false;
3409#endif
3410 }
3411
3412 case GGML_TYPE_BF16: {
3413#if defined(__AVX512BF16__)
3414 if (Btype == GGML_TYPE_BF16) {
3415 tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3416 (const ggml_bf16_t *)A, lda,
3417 (const ggml_bf16_t *)B, ldb,
3418 (float *)C, ldc};
3419 return tb.matmul(m, n);
3420 }
3421#elif defined(__AVX512F__)
3422 if (Btype == GGML_TYPE_BF16) {
3423 tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3424 (const ggml_bf16_t *)A, lda,
3425 (const ggml_bf16_t *)B, ldb,
3426 (float *)C, ldc};
3427 return tb.matmul(m, n);
3428 }
3429#elif defined(__AVX2__)
3430 if (Btype == GGML_TYPE_BF16) {
3431 tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3432 (const ggml_bf16_t *)A, lda,
3433 (const ggml_bf16_t *)B, ldb,
3434 (float *)C, ldc};
3435 return tb.matmul(m, n);
3436 }
3437#elif defined(__MMA__)
3438 if (k % 8) {
3439 return false;
3440 }
3441
3442 if (Btype == GGML_TYPE_BF16) {
3443 tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3444 (const ggml_bf16_t *)A, lda,
3445 (const ggml_bf16_t *)B, ldb,
3446 (float *)C, ldc,
3447 params->ith, params->nth };
3448
3449 tb.matmul(m, n);
3450 return true;
3451 }
3452#elif defined(__riscv_zvfbfwma)
3453 #if LMUL == 1
3454 tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3455 k, (const ggml_bf16_t *)A, lda,
3456 (const ggml_bf16_t *)B, ldb,
3457 (float *)C, ldc};
3458 #elif LMUL == 2
3459 tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3460 k, (const ggml_bf16_t *)A, lda,
3461 (const ggml_bf16_t *)B, ldb,
3462 (float *)C, ldc};
3463 #else // LMUL = 4
3464 tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3465 k, (const ggml_bf16_t *)A, lda,
3466 (const ggml_bf16_t *)B, ldb,
3467 (float *)C, ldc};
3468 #endif
3469 return tb.matmul(m, n);
3470#endif
3471 return false;
3472 }
3473
3474 case GGML_TYPE_F16: {
3475#if defined(__AVX512F__)
3476 if (Btype == GGML_TYPE_F16) {
3477 tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
3478 (const ggml_fp16_t *)A, lda,
3479 (const ggml_fp16_t *)B, ldb,
3480 (float *)C, ldc};
3481 return tb.matmul(m, n);
3482 }
3483#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
3484 if (Btype == GGML_TYPE_F16) {
3485 tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
3486 (const ggml_fp16_t *)A, lda,
3487 (const ggml_fp16_t *)B, ldb,
3488 (float *)C, ldc};
3489 return tb.matmul(m, n);
3490 }
3491#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
3492 if (n < 8)
3493 return false;
3494 if (Btype == GGML_TYPE_F16) {
3495 tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3496 k, (const ggml_fp16_t *)A, lda,
3497 (const ggml_fp16_t *)B, ldb,
3498 (float *)C, ldc};
3499 return tb.matmul(m, n);
3500 }
3501#elif defined(__ARM_NEON) && !defined(_MSC_VER)
3502 if (Btype == GGML_TYPE_F32) {
3503 tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
3504 k, (const ggml_fp16_t *)A, lda,
3505 (const float *)B, ldb,
3506 (float *)C, ldc};
3507 return tb.matmul(m, n);
3508 }
3509#elif defined(__VXE__) || defined(__VXE2__)
3510 if (n < 4)
3511 return false;
3512 if (Btype == GGML_TYPE_F16) {
3513 tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3514 k, (const ggml_fp16_t *)A, lda,
3515 (const ggml_fp16_t *)B, ldb,
3516 (float *)C, ldc};
3517 return tb.matmul(m, n);
3518 }
3519#elif defined(__riscv_zvfh)
3520 if (Btype == GGML_TYPE_F16) {
3521 #if LMUL == 1
3522 tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3523 k, (const ggml_fp16_t *)A, lda,
3524 (const ggml_fp16_t *)B, ldb,
3525 (float *)C, ldc};
3526 #elif LMUL == 2
3527 tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3528 k, (const ggml_fp16_t *)A, lda,
3529 (const ggml_fp16_t *)B, ldb,
3530 (float *)C, ldc};
3531 #else // LMUL = 4
3532 tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3533 k, (const ggml_fp16_t *)A, lda,
3534 (const ggml_fp16_t *)B, ldb,
3535 (float *)C, ldc};
3536 #endif
3537 return tb.matmul(m, n);
3538 }
3539#elif defined(__MMA__)
3540 if (k % 8) {
3541 return false;
3542 }
3543
3544 if (Btype == GGML_TYPE_F16) {
3545 tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
3546 (const ggml_fp16_t *)A, lda,
3547 (const ggml_fp16_t *)B, ldb,
3548 (float *)C, ldc,
3549 params->ith, params->nth };
3550
3551 tb.matmul(m, n);
3552 return true;
3553 }
3554#endif
3555 return false;
3556 }
3557
3558 case GGML_TYPE_Q8_0: {
3559 if (Btype != GGML_TYPE_Q8_0)
3560 return false;
3561#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3562 tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
3563 k, (const block_q8_0 *)A, lda,
3564 (const block_q8_0 *)B, ldb,
3565 (float *)C, ldc,
3566 params->ith, params->nth};
3567 tb.matmul(m, n);
3568 return true;
3569#elif defined(__ARM_FEATURE_DOTPROD)
3570 tinyBLAS_Q0_ARM<block_q8_0> tb{
3571 k, (const block_q8_0 *)A, lda,
3572 (const block_q8_0 *)B, ldb,
3573 (float *)C, ldc,
3574 params->ith, params->nth};
3575 tb.matmul(m, n);
3576 return true;
3577#elif defined(__MMA__)
3578 //TO-DO: Remove this condition once gemv forwarding is enabled.
3579 if (n < 8 && n != 4)
3580 return false;
3581 if (m < 8 && m != 4)
3582 return false;
3583 tinyBLAS_Q0_PPC<block_q8_0> tb{
3584 k, (const block_q8_0 *)A, lda,
3585 (const block_q8_0 *)B, ldb,
3586 (float *)C, ldc,
3587 params->ith, params->nth};
3588 tb.matmul(m, n);
3589 return true;
3590#else
3591 return false;
3592#endif
3593 }
3594
3595 case GGML_TYPE_Q4_0: {
3596 if (Btype != GGML_TYPE_Q8_0)
3597 return false;
3598#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3599 tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
3600 k, (const block_q4_0 *)A, lda,
3601 (const block_q8_0 *)B, ldb,
3602 (float *)C, ldc,
3603 params->ith, params->nth};
3604 tb.matmul(m, n);
3605 return true;
3606#elif defined(__ARM_FEATURE_DOTPROD)
3607 tinyBLAS_Q0_ARM<block_q4_0> tb{
3608 k, (const block_q4_0 *)A, lda,
3609 (const block_q8_0 *)B, ldb,
3610 (float *)C, ldc,
3611 params->ith, params->nth};
3612 tb.matmul(m, n);
3613 return true;
3614#elif defined(__MMA__)
3615 //TO-DO: Remove this condition once gemv forwarding is enabled.
3616 if (n < 8 && n != 4)
3617 return false;
3618 if (m < 8 && m != 4)
3619 return false;
3620 tinyBLAS_Q0_PPC<block_q4_0> tb{
3621 k, (const block_q4_0 *)A, lda,
3622 (const block_q8_0 *)B, ldb,
3623 (float *)C, ldc,
3624 params->ith, params->nth};
3625 tb.matmul(m, n);
3626 return true;
3627#else
3628 return false;
3629#endif
3630 }
3631
3632 case GGML_TYPE_Q5_0: {
3633 if (Btype != GGML_TYPE_Q8_0)
3634 return false;
3635#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3636 tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
3637 k, (const block_q5_0 *)A, lda,
3638 (const block_q8_0 *)B, ldb,
3639 (float *)C, ldc,
3640 params->ith, params->nth};
3641 tb.matmul(m, n);
3642 return true;
3643#else
3644 return false;
3645#endif
3646 }
3647
3648 case GGML_TYPE_IQ4_NL: {
3649 if (Btype != GGML_TYPE_Q8_0)
3650 return false;
3651#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3652 tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
3653 k, (const block_iq4_nl *)A, lda,
3654 (const block_q8_0 *)B, ldb,
3655 (float *)C, ldc,
3656 params->ith, params->nth};
3657 tb.matmul(m, n);
3658 return true;
3659#else
3660 return false;
3661#endif
3662 }
3663
3664 default:
3665 return false;
3666 }
3667
3668 (void)params;
3669 (void)m;
3670 (void)n;
3671 (void)k;
3672 (void)A;
3673 (void)lda;
3674 (void)B;
3675 (void)ldb;
3676 (void)C;
3677 (void)ldc;
3678 (void)Atype;
3679 (void)Btype;
3680 (void)Ctype;
3681}