1#pragma once
2
3// GGML CPU internal header
4
5#include "ggml.h"
6#include "ggml-impl.h"
7
8#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
9//#include <stddef.h>
10#include <stdbool.h>
11#include <string.h> // memcpy
12#include <math.h> // fabsf
13
14#ifdef __cplusplus
15extern "C" {
16#endif
17
18struct ggml_compute_params {
19 // ith = thread index, nth = number of threads
20 int ith, nth;
21
22 // work buffer for all threads
23 size_t wsize;
24 void * wdata;
25
26 struct ggml_threadpool * threadpool;
27
28 // use reference implementation
29 bool use_ref;
30};
31
32
33#if defined(_MSC_VER)
34
35#define m512bh(p) p
36#define m512i(p) p
37
38#else
39
40#define m512bh(p) (__m512bh)(p)
41#define m512i(p) (__m512i)(p)
42
43#endif
44
45// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
46#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
47#ifndef __FMA__
48#define __FMA__
49#endif
50#ifndef __F16C__
51#define __F16C__
52#endif
53#endif
54
55// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
56#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
57#ifndef __SSE3__
58#define __SSE3__
59#endif
60#ifndef __SSSE3__
61#define __SSSE3__
62#endif
63#endif
64
65#if defined(__s390x__) && defined(__VEC__)
66#ifndef __VXE__
67#define __VXE__
68#endif // __VXE__
69#ifndef __VXE2__
70#define __VXE2__
71#endif // __VXE2__
72#endif // __s390x__ && __VEC__
73
74#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
75#include <sys/prctl.h>
76#endif
77
78#if defined(__ARM_NEON)
79
80// ref: https://github.com/ggml-org/llama.cpp/pull/5404
81#ifdef _MSC_VER
82#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
83#else
84#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
85#endif // _MSC_VER
86
87#if !defined(__aarch64__)
88
89// 32-bit ARM compatibility
90
91// vaddlvq_s16
92// vpaddq_s16
93// vpaddq_s32
94// vaddvq_s32
95// vaddvq_f32
96// vmaxvq_f32
97// vcvtnq_s32_f32
98// vzip1_u8
99// vzip2_u8
100
101inline static int32_t vaddlvq_s16(int16x8_t v) {
102 int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
103 return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
104}
105
106inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
107 int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
108 int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
109 return vcombine_s16(a0, b0);
110}
111
112inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
113 int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
114 int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
115 return vcombine_s32(a0, b0);
116}
117
118inline static int32_t vaddvq_s32(int32x4_t v) {
119 return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
120}
121
122inline static float vaddvq_f32(float32x4_t v) {
123 return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
124}
125
126inline static float vmaxvq_f32(float32x4_t v) {
127 return
128 MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
129 MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
130}
131
132inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
133 int32x4_t res;
134
135 res[0] = roundf(vgetq_lane_f32(v, 0));
136 res[1] = roundf(vgetq_lane_f32(v, 1));
137 res[2] = roundf(vgetq_lane_f32(v, 2));
138 res[3] = roundf(vgetq_lane_f32(v, 3));
139
140 return res;
141}
142
143inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
144 uint8x8_t res;
145
146 res[0] = a[0]; res[1] = b[0];
147 res[2] = a[1]; res[3] = b[1];
148 res[4] = a[2]; res[5] = b[2];
149 res[6] = a[3]; res[7] = b[3];
150
151 return res;
152}
153
154inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
155 uint8x8_t res;
156
157 res[0] = a[4]; res[1] = b[4];
158 res[2] = a[5]; res[3] = b[5];
159 res[4] = a[6]; res[5] = b[6];
160 res[6] = a[7]; res[7] = b[7];
161
162 return res;
163}
164
165// vld1q_s16_x2
166// vld1q_u8_x2
167// vld1q_u8_x4
168// vld1q_s8_x2
169// vld1q_s8_x4
170// TODO: double-check these work correctly
171
172typedef struct ggml_int16x8x2_t {
173 int16x8_t val[2];
174} ggml_int16x8x2_t;
175
176inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
177 ggml_int16x8x2_t res;
178
179 res.val[0] = vld1q_s16(ptr + 0);
180 res.val[1] = vld1q_s16(ptr + 8);
181
182 return res;
183}
184
185typedef struct ggml_uint8x16x2_t {
186 uint8x16_t val[2];
187} ggml_uint8x16x2_t;
188
189inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
190 ggml_uint8x16x2_t res;
191
192 res.val[0] = vld1q_u8(ptr + 0);
193 res.val[1] = vld1q_u8(ptr + 16);
194
195 return res;
196}
197
198typedef struct ggml_uint8x16x4_t {
199 uint8x16_t val[4];
200} ggml_uint8x16x4_t;
201
202inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
203 ggml_uint8x16x4_t res;
204
205 res.val[0] = vld1q_u8(ptr + 0);
206 res.val[1] = vld1q_u8(ptr + 16);
207 res.val[2] = vld1q_u8(ptr + 32);
208 res.val[3] = vld1q_u8(ptr + 48);
209
210 return res;
211}
212
213typedef struct ggml_int8x16x2_t {
214 int8x16_t val[2];
215} ggml_int8x16x2_t;
216
217inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
218 ggml_int8x16x2_t res;
219
220 res.val[0] = vld1q_s8(ptr + 0);
221 res.val[1] = vld1q_s8(ptr + 16);
222
223 return res;
224}
225
226typedef struct ggml_int8x16x4_t {
227 int8x16_t val[4];
228} ggml_int8x16x4_t;
229
230inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
231 ggml_int8x16x4_t res;
232
233 res.val[0] = vld1q_s8(ptr + 0);
234 res.val[1] = vld1q_s8(ptr + 16);
235 res.val[2] = vld1q_s8(ptr + 32);
236 res.val[3] = vld1q_s8(ptr + 48);
237
238 return res;
239}
240
241// NOTE: not tested
242inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
243 int8x16_t res;
244
245 res[ 0] = a[b[ 0]];
246 res[ 1] = a[b[ 1]];
247 res[ 2] = a[b[ 2]];
248 res[ 3] = a[b[ 3]];
249 res[ 4] = a[b[ 4]];
250 res[ 5] = a[b[ 5]];
251 res[ 6] = a[b[ 6]];
252 res[ 7] = a[b[ 7]];
253 res[ 8] = a[b[ 8]];
254 res[ 9] = a[b[ 9]];
255 res[10] = a[b[10]];
256 res[11] = a[b[11]];
257 res[12] = a[b[12]];
258 res[13] = a[b[13]];
259 res[14] = a[b[14]];
260 res[15] = a[b[15]];
261
262 return res;
263}
264
265// NOTE: not tested
266inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
267 uint8x16_t res;
268
269 res[ 0] = a[b[ 0]];
270 res[ 1] = a[b[ 1]];
271 res[ 2] = a[b[ 2]];
272 res[ 3] = a[b[ 3]];
273 res[ 4] = a[b[ 4]];
274 res[ 5] = a[b[ 5]];
275 res[ 6] = a[b[ 6]];
276 res[ 7] = a[b[ 7]];
277 res[ 8] = a[b[ 8]];
278 res[ 9] = a[b[ 9]];
279 res[10] = a[b[10]];
280 res[11] = a[b[11]];
281 res[12] = a[b[12]];
282 res[13] = a[b[13]];
283 res[14] = a[b[14]];
284 res[15] = a[b[15]];
285
286 return res;
287}
288
289#else
290
291#define ggml_int16x8x2_t int16x8x2_t
292#define ggml_uint8x16x2_t uint8x16x2_t
293#define ggml_uint8x16x4_t uint8x16x4_t
294#define ggml_int8x16x2_t int8x16x2_t
295#define ggml_int8x16x4_t int8x16x4_t
296
297#define ggml_vld1q_s16_x2 vld1q_s16_x2
298#define ggml_vld1q_u8_x2 vld1q_u8_x2
299#define ggml_vld1q_u8_x4 vld1q_u8_x4
300#define ggml_vld1q_s8_x2 vld1q_s8_x2
301#define ggml_vld1q_s8_x4 vld1q_s8_x4
302#define ggml_vqtbl1q_s8 vqtbl1q_s8
303#define ggml_vqtbl1q_u8 vqtbl1q_u8
304
305#endif // !defined(__aarch64__)
306
307#if !defined(__ARM_FEATURE_DOTPROD)
308
309inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
310 const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
311 const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
312
313 return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
314}
315
316#else
317
318#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
319
320#endif // !defined(__ARM_FEATURE_DOTPROD)
321
322#endif // defined(__ARM_NEON)
323
324#ifdef __wasm_simd128__
325#include <wasm_simd128.h>
326#endif
327
328#ifdef __POWER9_VECTOR__
329#include <altivec.h>
330#endif
331
332#if defined(_MSC_VER) || defined(__MINGW32__)
333#include <intrin.h>
334#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__)
335#include <immintrin.h>
336#endif
337
338#ifdef __riscv_v_intrinsic
339#include <riscv_vector.h>
340#endif
341
342#if defined(__loongarch64)
343#if defined(__loongarch_asx)
344#include <lasxintrin.h>
345#endif
346#if defined(__loongarch_sx)
347#include <lsxintrin.h>
348#endif
349#endif
350
351#if defined(__VXE__) || defined(__VXE2__)
352#include <vecintrin.h>
353
354#define vec_neg(a) (-(a)) // Vector Negate
355#define vec_add(a, b) ((a) + (b)) // Vector Add
356#define vec_sub(a, b) ((a) - (b)) // Vector Subtract
357#define vec_mul(a, b) ((a) * (b)) // Vector Multiply
358#define vec_div(a, b) ((a) / (b)) // Vector Divide
359#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
360#define vec_sra(a, b) ((a) >> (b)) // Vector Shift Right
361#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
362#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
363#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
364
365#ifndef vec_and
366#define vec_and(a, b) ((a) & (b)) // Vector AND
367#endif
368
369#ifndef vec_or
370#define vec_or(a, b) ((a) | (b)) // Vector OR
371#endif
372
373#ifndef vec_xor
374#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
375#endif
376
377typedef signed char char8x16_t __attribute__((vector_size(16)));
378typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
379
380typedef int8_t int8x16_t __attribute__((vector_size(16)));
381typedef int16_t int16x8_t __attribute__((vector_size(16)));
382typedef int32_t int32x4_t __attribute__((vector_size(16)));
383
384typedef uint8_t uint8x16_t __attribute__((vector_size(16)));
385typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
386typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
387
388typedef float float32x4_t __attribute__((vector_size(16)));
389typedef double double64x2_t __attribute__((vector_size(16)));
390
391typedef signed long long long64x2_t __attribute__((vector_size(16)));
392typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
393
394typedef struct ggml_uint8x16x2_t {
395 uint8x16_t val[2];
396} ggml_uint8x16x2_t;
397
398inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {
399 ggml_uint8x16x2_t res;
400
401 res.val[0] = vec_xl( 0, ptr);
402 res.val[1] = vec_xl(16, ptr);
403
404 return res;
405}
406
407typedef struct ggml_uint8x16x4_t {
408 uint8x16_t val[4];
409} ggml_uint8x16x4_t;
410
411inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {
412 ggml_uint8x16x4_t res;
413
414 res.val[0] = vec_xl( 0, ptr);
415 res.val[1] = vec_xl(16, ptr);
416 res.val[2] = vec_xl(32, ptr);
417 res.val[3] = vec_xl(48, ptr);
418
419 return res;
420}
421
422typedef struct ggml_int8x16x4_t {
423 int8x16_t val[4];
424} ggml_int8x16x4_t;
425
426inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {
427 ggml_int8x16x4_t res;
428
429 res.val[0] = vec_xl( 0, ptr);
430 res.val[1] = vec_xl(16, ptr);
431 res.val[2] = vec_xl(32, ptr);
432 res.val[3] = vec_xl(48, ptr);
433
434 return res;
435}
436
437typedef struct ggml_int16x8x2_t {
438 int16x8_t val[2];
439} ggml_int16x8x2_t;
440
441inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {
442 ggml_int16x8x2_t res;
443
444 res.val[0] = vec_xl( 0, ptr);
445 res.val[1] = vec_xl(16, ptr);
446
447 return res;
448}
449
450/*
451 ! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs
452 ! or iq4_nl for example implementation.
453*/
454inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {
455 int8x16_t res;
456
457 res[ 0] = a[b[ 0]];
458 res[ 1] = a[b[ 1]];
459 res[ 2] = a[b[ 2]];
460 res[ 3] = a[b[ 3]];
461 res[ 4] = a[b[ 4]];
462 res[ 5] = a[b[ 5]];
463 res[ 6] = a[b[ 6]];
464 res[ 7] = a[b[ 7]];
465 res[ 8] = a[b[ 8]];
466 res[ 9] = a[b[ 9]];
467 res[10] = a[b[10]];
468 res[11] = a[b[11]];
469 res[12] = a[b[12]];
470 res[13] = a[b[13]];
471 res[14] = a[b[14]];
472 res[15] = a[b[15]];
473
474 return res;
475}
476
477inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
478 const uchar8x16_t v_maske = { 0, 1, 4, 5, 8, 9, 12, 13,
479 16, 17, 20, 21, 24, 25, 28, 29 };
480
481 const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);
482 const int16x8_t v_abe = vec_perm(a, b, v_maske);
483 return v_abo + v_abe;
484}
485
486/**
487 * @see https://github.com/ggml-org/llama.cpp/pull/14037
488 */
489inline static float vec_hsum_f32x4(float32x4_t v) {
490 float32x4_t v_temp = v + vec_reve(v);
491 return v_temp[0] + v_temp[1];
492}
493
494inline static int32_t vec_hsum_i32x4(int32x4_t v) {
495 int32x4_t v_temp = v + vec_reve(v);
496 return v_temp[0] + v_temp[1];
497}
498
499inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
500 const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
501 return acc + (vec_unpackh(p) + vec_unpackl(p));
502}
503
504#endif
505
506#if defined(__loongarch_sx)
507/* float type data load instructions */
508static __m128 __lsx_vreplfr2vr_s(const float val) {
509 v4f32 res = {val, val, val, val};
510 return (__m128)res;
511}
512#endif
513
514#if defined(__loongarch_asx)
515static __m256 __lasx_xvreplfr2vr_s(const float val) {
516 v8f32 res = {val, val, val, val, val, val, val, val};
517 return (__m256)res;
518}
519#endif
520
521// TODO: move to ggml-threading
522void ggml_barrier(struct ggml_threadpool * tp);
523
524void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value);
525int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value);
526
527#ifdef __cplusplus
528}
529#endif