1#pragma once
  2
  3// GGML CPU internal header
  4
  5#include "ggml.h"
  6#include "ggml-impl.h"
  7
  8#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
  9//#include <stddef.h>
 10#include <stdbool.h>
 11#include <string.h> // memcpy
 12#include <math.h>   // fabsf
 13
 14#ifdef __cplusplus
 15extern "C" {
 16#endif
 17
 18struct ggml_compute_params {
 19    // ith = thread index, nth = number of threads
 20    int ith, nth;
 21
 22    // work buffer for all threads
 23    size_t wsize;
 24    void * wdata;
 25
 26    struct ggml_threadpool * threadpool;
 27
 28    // use reference implementation
 29    bool use_ref;
 30};
 31
 32
 33#if defined(_MSC_VER)
 34
 35#define m512bh(p) p
 36#define m512i(p) p
 37
 38#else
 39
 40#define m512bh(p) (__m512bh)(p)
 41#define m512i(p) (__m512i)(p)
 42
 43#endif
 44
 45// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
 46#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
 47#ifndef __FMA__
 48#define __FMA__
 49#endif
 50#ifndef __F16C__
 51#define __F16C__
 52#endif
 53#endif
 54
 55// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
 56#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
 57#ifndef __SSE3__
 58#define __SSE3__
 59#endif
 60#ifndef __SSSE3__
 61#define __SSSE3__
 62#endif
 63#endif
 64
 65#if defined(__s390x__) && defined(__VEC__)
 66#ifndef __VXE__
 67#define __VXE__
 68#endif  // __VXE__
 69#ifndef __VXE2__
 70#define __VXE2__
 71#endif  // __VXE2__
 72#endif  // __s390x__ && __VEC__
 73
 74#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
 75#include <sys/prctl.h>
 76#endif
 77
 78#if defined(__ARM_NEON)
 79
 80// ref: https://github.com/ggml-org/llama.cpp/pull/5404
 81#ifdef _MSC_VER
 82#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
 83#else
 84#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
 85#endif // _MSC_VER
 86
 87#if !defined(__aarch64__)
 88
 89// 32-bit ARM compatibility
 90
 91// vaddlvq_s16
 92// vpaddq_s16
 93// vpaddq_s32
 94// vaddvq_s32
 95// vaddvq_f32
 96// vmaxvq_f32
 97// vcvtnq_s32_f32
 98// vzip1_u8
 99// vzip2_u8
100
101inline static int32_t vaddlvq_s16(int16x8_t v) {
102    int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
103    return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
104}
105
106inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
107    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
108    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
109    return vcombine_s16(a0, b0);
110}
111
112inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
113    int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
114    int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
115    return vcombine_s32(a0, b0);
116}
117
118inline static int32_t vaddvq_s32(int32x4_t v) {
119    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
120}
121
122inline static float vaddvq_f32(float32x4_t v) {
123    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
124}
125
126inline static float vmaxvq_f32(float32x4_t v) {
127    return
128        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
129            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
130}
131
132inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
133    int32x4_t res;
134
135    res[0] = roundf(vgetq_lane_f32(v, 0));
136    res[1] = roundf(vgetq_lane_f32(v, 1));
137    res[2] = roundf(vgetq_lane_f32(v, 2));
138    res[3] = roundf(vgetq_lane_f32(v, 3));
139
140    return res;
141}
142
143inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
144    uint8x8_t res;
145
146    res[0] = a[0]; res[1] = b[0];
147    res[2] = a[1]; res[3] = b[1];
148    res[4] = a[2]; res[5] = b[2];
149    res[6] = a[3]; res[7] = b[3];
150
151    return res;
152}
153
154inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
155    uint8x8_t res;
156
157    res[0] = a[4]; res[1] = b[4];
158    res[2] = a[5]; res[3] = b[5];
159    res[4] = a[6]; res[5] = b[6];
160    res[6] = a[7]; res[7] = b[7];
161
162    return res;
163}
164
165// vld1q_s16_x2
166// vld1q_u8_x2
167// vld1q_u8_x4
168// vld1q_s8_x2
169// vld1q_s8_x4
170// TODO: double-check these work correctly
171
172typedef struct ggml_int16x8x2_t {
173    int16x8_t val[2];
174} ggml_int16x8x2_t;
175
176inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
177    ggml_int16x8x2_t res;
178
179    res.val[0] = vld1q_s16(ptr + 0);
180    res.val[1] = vld1q_s16(ptr + 8);
181
182    return res;
183}
184
185typedef struct ggml_uint8x16x2_t {
186    uint8x16_t val[2];
187} ggml_uint8x16x2_t;
188
189inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
190    ggml_uint8x16x2_t res;
191
192    res.val[0] = vld1q_u8(ptr + 0);
193    res.val[1] = vld1q_u8(ptr + 16);
194
195    return res;
196}
197
198typedef struct ggml_uint8x16x4_t {
199    uint8x16_t val[4];
200} ggml_uint8x16x4_t;
201
202inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
203    ggml_uint8x16x4_t res;
204
205    res.val[0] = vld1q_u8(ptr + 0);
206    res.val[1] = vld1q_u8(ptr + 16);
207    res.val[2] = vld1q_u8(ptr + 32);
208    res.val[3] = vld1q_u8(ptr + 48);
209
210    return res;
211}
212
213typedef struct ggml_int8x16x2_t {
214    int8x16_t val[2];
215} ggml_int8x16x2_t;
216
217inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
218    ggml_int8x16x2_t res;
219
220    res.val[0] = vld1q_s8(ptr + 0);
221    res.val[1] = vld1q_s8(ptr + 16);
222
223    return res;
224}
225
226typedef struct ggml_int8x16x4_t {
227    int8x16_t val[4];
228} ggml_int8x16x4_t;
229
230inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
231    ggml_int8x16x4_t res;
232
233    res.val[0] = vld1q_s8(ptr + 0);
234    res.val[1] = vld1q_s8(ptr + 16);
235    res.val[2] = vld1q_s8(ptr + 32);
236    res.val[3] = vld1q_s8(ptr + 48);
237
238    return res;
239}
240
241// NOTE: not tested
242inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
243    int8x16_t res;
244
245    res[ 0] = a[b[ 0]];
246    res[ 1] = a[b[ 1]];
247    res[ 2] = a[b[ 2]];
248    res[ 3] = a[b[ 3]];
249    res[ 4] = a[b[ 4]];
250    res[ 5] = a[b[ 5]];
251    res[ 6] = a[b[ 6]];
252    res[ 7] = a[b[ 7]];
253    res[ 8] = a[b[ 8]];
254    res[ 9] = a[b[ 9]];
255    res[10] = a[b[10]];
256    res[11] = a[b[11]];
257    res[12] = a[b[12]];
258    res[13] = a[b[13]];
259    res[14] = a[b[14]];
260    res[15] = a[b[15]];
261
262    return res;
263}
264
265// NOTE: not tested
266inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
267    uint8x16_t res;
268
269    res[ 0] = a[b[ 0]];
270    res[ 1] = a[b[ 1]];
271    res[ 2] = a[b[ 2]];
272    res[ 3] = a[b[ 3]];
273    res[ 4] = a[b[ 4]];
274    res[ 5] = a[b[ 5]];
275    res[ 6] = a[b[ 6]];
276    res[ 7] = a[b[ 7]];
277    res[ 8] = a[b[ 8]];
278    res[ 9] = a[b[ 9]];
279    res[10] = a[b[10]];
280    res[11] = a[b[11]];
281    res[12] = a[b[12]];
282    res[13] = a[b[13]];
283    res[14] = a[b[14]];
284    res[15] = a[b[15]];
285
286    return res;
287}
288
289#else
290
291#define ggml_int16x8x2_t  int16x8x2_t
292#define ggml_uint8x16x2_t uint8x16x2_t
293#define ggml_uint8x16x4_t uint8x16x4_t
294#define ggml_int8x16x2_t  int8x16x2_t
295#define ggml_int8x16x4_t  int8x16x4_t
296
297#define ggml_vld1q_s16_x2 vld1q_s16_x2
298#define ggml_vld1q_u8_x2  vld1q_u8_x2
299#define ggml_vld1q_u8_x4  vld1q_u8_x4
300#define ggml_vld1q_s8_x2  vld1q_s8_x2
301#define ggml_vld1q_s8_x4  vld1q_s8_x4
302#define ggml_vqtbl1q_s8   vqtbl1q_s8
303#define ggml_vqtbl1q_u8   vqtbl1q_u8
304
305#endif // !defined(__aarch64__)
306
307#if !defined(__ARM_FEATURE_DOTPROD)
308
309inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
310    const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
311    const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
312
313    return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
314}
315
316#else
317
318#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
319
320#endif // !defined(__ARM_FEATURE_DOTPROD)
321
322#endif // defined(__ARM_NEON)
323
324#ifdef __wasm_simd128__
325#include <wasm_simd128.h>
326#endif
327
328#ifdef __POWER9_VECTOR__
329#include <altivec.h>
330#endif
331
332#if defined(_MSC_VER) || defined(__MINGW32__)
333#include <intrin.h>
334#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__)
335#include <immintrin.h>
336#endif
337
338#ifdef __riscv_v_intrinsic
339#include <riscv_vector.h>
340#endif
341
342#if defined(__loongarch64)
343#if defined(__loongarch_asx)
344#include <lasxintrin.h>
345#endif
346#if defined(__loongarch_sx)
347#include <lsxintrin.h>
348#endif
349#endif
350
351#if defined(__VXE__) || defined(__VXE2__)
352#include <vecintrin.h>
353
354#define vec_neg(a)    (-(a))                // Vector Negate
355#define vec_add(a, b) ((a) + (b))           // Vector Add
356#define vec_sub(a, b) ((a) - (b))           // Vector Subtract
357#define vec_mul(a, b) ((a) * (b))           // Vector Multiply
358#define vec_div(a, b) ((a) / (b))           // Vector Divide
359#define vec_sl(a, b)  ((a) << (b))          // Vector Shift Left
360#define vec_sra(a, b) ((a) >> (b))          // Vector Shift Right
361#define vec_sr(a, b)  ((a) >> (b))          // Vector Shift Right Algebraic
362#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
363#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
364
365#ifndef vec_and
366#define vec_and(a, b) ((a) & (b)) // Vector AND
367#endif
368
369#ifndef vec_or
370#define vec_or(a, b)  ((a) | (b)) // Vector OR
371#endif
372
373#ifndef vec_xor
374#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
375#endif
376
377typedef signed   char char8x16_t  __attribute__((vector_size(16)));
378typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
379
380typedef int8_t  int8x16_t __attribute__((vector_size(16)));
381typedef int16_t int16x8_t __attribute__((vector_size(16)));
382typedef int32_t int32x4_t __attribute__((vector_size(16)));
383
384typedef uint8_t  uint8x16_t __attribute__((vector_size(16)));
385typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
386typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
387
388typedef float  float32x4_t  __attribute__((vector_size(16)));
389typedef double double64x2_t __attribute__((vector_size(16)));
390
391typedef signed   long long long64x2_t  __attribute__((vector_size(16)));
392typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
393
394typedef struct ggml_uint8x16x2_t {
395    uint8x16_t val[2];
396} ggml_uint8x16x2_t;
397
398inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {
399    ggml_uint8x16x2_t res;
400
401    res.val[0] = vec_xl( 0, ptr);
402    res.val[1] = vec_xl(16, ptr);
403
404    return res;
405}
406
407typedef struct ggml_uint8x16x4_t {
408    uint8x16_t val[4];
409} ggml_uint8x16x4_t;
410
411inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {
412    ggml_uint8x16x4_t res;
413
414    res.val[0] = vec_xl( 0, ptr);
415    res.val[1] = vec_xl(16, ptr);
416    res.val[2] = vec_xl(32, ptr);
417    res.val[3] = vec_xl(48, ptr);
418
419    return res;
420}
421
422typedef struct ggml_int8x16x4_t {
423    int8x16_t val[4];
424} ggml_int8x16x4_t;
425
426inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {
427    ggml_int8x16x4_t res;
428
429    res.val[0] = vec_xl( 0, ptr);
430    res.val[1] = vec_xl(16, ptr);
431    res.val[2] = vec_xl(32, ptr);
432    res.val[3] = vec_xl(48, ptr);
433
434    return res;
435}
436
437typedef struct ggml_int16x8x2_t {
438    int16x8_t val[2];
439} ggml_int16x8x2_t;
440
441inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {
442    ggml_int16x8x2_t res;
443
444    res.val[0] = vec_xl( 0, ptr);
445    res.val[1] = vec_xl(16, ptr);
446
447    return res;
448}
449
450/*
451    ! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs
452    !          or iq4_nl for example implementation.
453*/
454inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {
455    int8x16_t res;
456
457    res[ 0] = a[b[ 0]];
458    res[ 1] = a[b[ 1]];
459    res[ 2] = a[b[ 2]];
460    res[ 3] = a[b[ 3]];
461    res[ 4] = a[b[ 4]];
462    res[ 5] = a[b[ 5]];
463    res[ 6] = a[b[ 6]];
464    res[ 7] = a[b[ 7]];
465    res[ 8] = a[b[ 8]];
466    res[ 9] = a[b[ 9]];
467    res[10] = a[b[10]];
468    res[11] = a[b[11]];
469    res[12] = a[b[12]];
470    res[13] = a[b[13]];
471    res[14] = a[b[14]];
472    res[15] = a[b[15]];
473
474    return res;
475}
476
477inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
478    const uchar8x16_t v_maske = {  0,  1,  4,  5,  8,  9, 12, 13,
479                                  16, 17, 20, 21, 24, 25, 28, 29 };
480
481    const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);
482    const int16x8_t v_abe = vec_perm(a, b, v_maske);
483    return v_abo + v_abe;
484}
485
486/**
487 * @see https://github.com/ggml-org/llama.cpp/pull/14037
488 */
489inline static float vec_hsum_f32x4(float32x4_t v) {
490    float32x4_t v_temp = v + vec_reve(v);
491    return v_temp[0] + v_temp[1];
492}
493
494inline static int32_t vec_hsum_i32x4(int32x4_t v) {
495    int32x4_t v_temp = v + vec_reve(v);
496    return v_temp[0] + v_temp[1];
497}
498
499inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
500    const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
501    return acc + (vec_unpackh(p) + vec_unpackl(p));
502}
503
504#endif
505
506#if defined(__loongarch_sx)
507/* float type data load instructions */
508static __m128 __lsx_vreplfr2vr_s(const float val) {
509    v4f32 res = {val, val, val, val};
510    return (__m128)res;
511}
512#endif
513
514#if defined(__loongarch_asx)
515static __m256 __lasx_xvreplfr2vr_s(const float val) {
516    v8f32 res = {val, val, val, val, val, val, val, val};
517    return (__m256)res;
518}
519#endif
520
521// TODO: move to ggml-threading
522void ggml_barrier(struct ggml_threadpool * tp);
523
524void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value);
525int  ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value);
526
527#ifdef __cplusplus
528}
529#endif