1#include "vec.h"
  2
  3#include <cassert>
  4
  5// precomputed gelu table for f16 (128 KB)
  6ggml_fp16_t ggml_table_gelu_f16[1 << 16];
  7
  8// precomputed quick gelu table for f16 (128 KB)
  9ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
 10
 11void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
 12   assert(nrc == 1);
 13   GGML_UNUSED(nrc);
 14   GGML_UNUSED(bx);
 15   GGML_UNUSED(by);
 16   GGML_UNUSED(bs);
 17
 18#if defined(GGML_SIMD)
 19    float sumf = 0.0f;
 20
 21    #if defined(__ARM_FEATURE_SVE)
 22        const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
 23        const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
 24        const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
 25
 26        const int np = (n & ~(ggml_f32_step - 1));
 27        svfloat32_t sum1 = svdup_n_f32(0.0f);
 28        svfloat32_t sum2 = svdup_n_f32(0.0f);
 29        svfloat32_t sum3 = svdup_n_f32(0.0f);
 30        svfloat32_t sum4 = svdup_n_f32(0.0f);
 31        svfloat32_t sum5 = svdup_n_f32(0.0f);
 32        svfloat32_t sum6 = svdup_n_f32(0.0f);
 33        svfloat32_t sum7 = svdup_n_f32(0.0f);
 34        svfloat32_t sum8 = svdup_n_f32(0.0f);
 35        svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8;
 36        svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8;
 37        for (int i = 0; i < np; i += ggml_f32_step) {
 38            ax1 = GGML_F32_VEC_LOAD(x + i);
 39            ay1 = GGML_F32_VEC_LOAD(y + i);
 40            sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
 41
 42            ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
 43            ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
 44            sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
 45
 46            ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
 47            ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
 48            sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
 49
 50            ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
 51            ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
 52            sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
 53
 54            ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
 55            ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
 56            sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
 57
 58            ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
 59            ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
 60            sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
 61
 62            ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
 63            ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
 64            sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
 65
 66            ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
 67            ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
 68            sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
 69        }
 70        // leftovers
 71        // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
 72        const int np2 = (n & ~(ggml_f32_epr - 1));
 73        for (int i = np; i < np2; i += ggml_f32_epr) {
 74            ax1 = GGML_F32_VEC_LOAD(x + i);
 75            ay1 = GGML_F32_VEC_LOAD(y + i);
 76            sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
 77        }
 78        // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
 79        if (np2 < n) {
 80            svbool_t pg = svwhilelt_b32(np2, n);
 81            ax1 = svld1_f32(pg, x + np2);
 82            ay1 = svld1_f32(pg, y + np2);
 83            sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
 84        }
 85        // reduce sum1,sum2 to sum1
 86        GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
 87    #elif defined(__riscv_v_intrinsic)
 88        int vl = __riscv_vsetvlmax_e32m8();
 89        vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
 90        vfloat32m8_t vsum;
 91        vfloat32m8_t ax;
 92        vfloat32m8_t ay;
 93        vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl);
 94        for (int i = 0; i < n; i += vl) {
 95            vl = __riscv_vsetvl_e32m8(n - i);
 96            ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl);
 97            ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl);
 98            vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl);
 99        }
100        vl = __riscv_vsetvlmax_e32m8();
101        vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl);
102        sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
103    #else
104        const int np = (n & ~(GGML_F32_STEP - 1));
105
106        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
107
108        GGML_F32_VEC ax[GGML_F32_ARR];
109        GGML_F32_VEC ay[GGML_F32_ARR];
110
111        for (int i = 0; i < np; i += GGML_F32_STEP) {
112            for (int j = 0; j < GGML_F32_ARR; j++) {
113                ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
114                ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
115
116                sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
117            }
118        }
119
120        // reduce sum0..sum3 to sum0
121        GGML_F32_VEC_REDUCE(sumf, sum);
122
123        // leftovers
124        for (int i = np; i < n; ++i) {
125            sumf += x[i]*y[i];
126        }
127    #endif
128#else
129    // scalar
130    ggml_float sumf = 0.0;
131    for (int i = 0; i < n; ++i) {
132        sumf += (ggml_float)(x[i]*y[i]);
133    }
134#endif
135
136    *s = sumf;
137}
138
139void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc) {
140    assert(nrc == 1);
141    GGML_UNUSED(nrc);
142    GGML_UNUSED(bx);
143    GGML_UNUSED(by);
144    GGML_UNUSED(bs);
145    int i = 0;
146    ggml_float sumf = 0;
147
148#if defined(__AVX512BF16__)
149    __m512 c1 = _mm512_setzero_ps();
150    __m512 c2 = _mm512_setzero_ps();
151    for (; i + 64 <= n; i += 64) {
152        c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
153                             m512bh(_mm512_loadu_si512((y + i))));
154        c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
155                             m512bh(_mm512_loadu_si512((y + i + 32))));
156    }
157    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
158    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
159
160#elif defined(__AVX512F__)
161#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
162    __m512 c1 = _mm512_setzero_ps();
163    __m512 c2 = _mm512_setzero_ps();
164    for (; i + 32 <= n; i += 32) {
165        c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
166        c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
167    }
168    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
169    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
170
171#undef LOAD
172#elif defined(__AVX2__) || defined(__AVX__)
173#if defined(__AVX2__)
174#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
175#else
176#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
177#endif
178    __m256 c1 = _mm256_setzero_ps();
179    __m256 c2 = _mm256_setzero_ps();
180    __m256 c3 = _mm256_setzero_ps();
181    __m256 c4 = _mm256_setzero_ps();
182    for (; i + 32 <= n; i += 32) {
183        c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
184        c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
185        c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
186        c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
187    }
188    __m128 g;
189    c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
190                       _mm256_add_ps(c2, c4));
191    g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
192                   _mm256_castps256_ps128(c1));
193    g = _mm_add_ps(g, _mm_movehl_ps(g, g));
194    g = _mm_add_ss(g, _mm_movehdup_ps(g));
195    sumf += (ggml_float)_mm_cvtss_f32(g);
196
197#undef LOAD
198#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma)
199    size_t vl = __riscv_vsetvlmax_e32m4();
200
201    // initialize accumulators to all zeroes
202    vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
203    vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
204
205    // calculate step size
206    const size_t epr = __riscv_vsetvlmax_e16m2();
207    const size_t step = epr * 2;
208    const int np = (n & ~(step - 1));
209
210    // unroll by 2
211    for (; i < np; i += step) {
212        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr);
213        vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr);
214        vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr);
215        __asm__ __volatile__ ("" ::: "memory");
216
217        vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr);
218        vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr);
219        vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr);
220        __asm__ __volatile__ ("" ::: "memory");
221    }
222
223    // accumulate in 1 register
224    vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl);
225
226    // leftovers
227    for (i = np; i < n; i += vl) {
228        vl = __riscv_vsetvl_e16m2(n - i);
229        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl);
230        vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl);
231        vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl);
232    }
233
234    // reduce
235    vl = __riscv_vsetvlmax_e32m4();
236    vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
237    sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
238
239#endif
240#if defined(__POWER9_VECTOR__)
241    const int np = (n & ~(GGML_BF16_STEP - 1));
242    if (np > 0) {
243        GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
244        for (; i < np; i += GGML_BF16_STEP) {
245            GGML_BF16_VEC vx0 = GGML_BF16_VEC_LOAD(x + i);
246            GGML_BF16_VEC vx1 = GGML_BF16_VEC_LOAD(x + i + 8);
247            GGML_BF16_VEC vy0 = GGML_BF16_VEC_LOAD(y + i);
248            GGML_BF16_VEC vy1 = GGML_BF16_VEC_LOAD(y + i + 8);
249            GGML_BF16_FMA_LO(sum[0], vx0, vy0);
250            GGML_BF16_FMA_HI(sum[1], vx0, vy0);
251            GGML_BF16_FMA_LO(sum[2], vx1, vy1);
252            GGML_BF16_FMA_HI(sum[3], vx1, vy1);
253        }
254        GGML_F32x4_REDUCE_4(sumf, sum[0], sum[1], sum[2], sum[3]);
255    }
256#endif
257
258    for (; i < n; ++i) {
259        sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
260                             GGML_BF16_TO_FP32(y[i]));
261    }
262    *s = sumf;
263}
264
265void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc) {
266    assert(nrc == 1);
267    GGML_UNUSED(nrc);
268    GGML_UNUSED(bx);
269    GGML_UNUSED(by);
270    GGML_UNUSED(bs);
271
272    ggml_float sumf = 0.0;
273
274
275#if defined(GGML_SIMD)
276    #if defined(__ARM_FEATURE_SVE)
277        const int sve_register_length = svcntb() * 8; //get vector length
278        const int ggml_f16_epr = sve_register_length / 16; // running when 16
279        const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
280
281        const int np= (n & ~(ggml_f16_step - 1));
282        svfloat16_t sum1 = svdup_n_f16(0.0f);
283        svfloat16_t sum2 = svdup_n_f16(0.0f);
284        svfloat16_t sum3 = svdup_n_f16(0.0f);
285        svfloat16_t sum4 = svdup_n_f16(0.0f);
286
287        svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
288        svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
289        for (int i = 0; i < np; i += ggml_f16_step) {
290            ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
291            ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
292            sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
293
294            ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
295            ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
296            sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
297
298            ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
299            ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
300            sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
301
302            ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
303            ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
304            sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
305
306            ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
307            ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
308            sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
309
310            ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
311            ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
312            sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
313
314            ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
315            ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
316            sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
317
318            ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
319            ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
320            sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
321        }
322
323        const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
324        for (int k = np; k < np2; k += ggml_f16_epr) {
325            svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
326            svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
327            sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
328        }
329
330        if (np2 < n) {
331            svbool_t pg = svwhilelt_b16(np2, n);
332            svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
333            svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
334
335            sum1 = svmad_f16_x(pg, hx, hy, sum1);
336        }
337        GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
338    #elif defined(__riscv_v_intrinsic)
339        #if defined(__riscv_zvfh)
340            int vl = __riscv_vsetvlmax_e32m2();
341            vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
342            vfloat32m2_t vsum;
343            vfloat16m1_t ax;
344            vfloat16m1_t ay;
345            vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl));
346            for (int i = 0; i < n; i += vl) {
347                vl = __riscv_vsetvl_e16m1(n - i);
348                ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl);
349                ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl);
350                vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl);
351            }
352            vl = __riscv_vsetvlmax_e32m1();
353            vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl);
354            vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl);
355            sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
356        #else
357            for (int i = 0; i < n; ++i) {
358                sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
359            }
360        #endif // __riscv_zvfh
361    #else
362        const int np = (n & ~(GGML_F16_STEP - 1));
363
364        GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
365
366        GGML_F16_VEC ax[GGML_F16_ARR];
367        GGML_F16_VEC ay[GGML_F16_ARR];
368
369        for (int i = 0; i < np; i += GGML_F16_STEP) {
370            for (int j = 0; j < GGML_F16_ARR; j++) {
371                ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
372                ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
373
374                sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
375            }
376        }
377
378        // reduce sum0..sum3 to sum0
379        GGML_F16_VEC_REDUCE(sumf, sum);
380
381        // leftovers
382        for (int i = np; i < n; ++i) {
383            sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
384        }
385        // if you hit this, you are likely running outside the FP range
386        assert(!isnan(sumf) && !isinf(sumf));
387    #endif
388#else
389    for (int i = 0; i < n; ++i) {
390        sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
391    }
392#endif // GGML_SIMD
393
394    *s = sumf;
395}
396
397void ggml_vec_silu_f32(const int n, float * y, const float * x) {
398    int i = 0;
399#if defined(__AVX512F__) && defined(__AVX512DQ__)
400    for (; i + 15 < n; i += 16) {
401        _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
402    }
403#elif defined(__AVX2__) && defined(__FMA__)
404    for (; i + 7 < n; i += 8) {
405        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
406    }
407#elif defined(__SSE2__)
408    for (; i + 3 < n; i += 4) {
409        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
410    }
411#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
412    const int vlen = svcntw();
413    for (; i < n; i += vlen) {
414        const svbool_t pg = svwhilelt_b32_s32(i, n);
415        svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));
416    }
417#elif defined(__ARM_NEON) && defined(__aarch64__)
418    for (; i + 3 < n; i += 4) {
419        vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
420    }
421#elif defined(__riscv_v_intrinsic)
422    for (int vl; i < n; i += vl) {
423        vl = __riscv_vsetvl_e32m2(n - i);
424        vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
425        vfloat32m2_t vy = ggml_v_silu_m2(vx, vl);
426        __riscv_vse32_v_f32m2(&y[i], vy, vl);
427    }
428#endif
429    for (; i < n; ++i) {
430        y[i] = ggml_silu_f32(x[i]);
431    }
432}
433
434void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
435    int i = 0;
436#if defined(__AVX512F__) && defined(__AVX512DQ__)
437    for (; i + 15 < n; i += 16) {
438        _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
439    }
440#elif defined(__AVX2__) && defined(__FMA__)
441    for (; i + 7 < n; i += 8) {
442        _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
443    }
444#elif defined(__SSE2__)
445    for (; i + 3 < n; i += 4) {
446        _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
447    }
448#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
449    const int vlen = svcntw();
450    for (; i < n; i += vlen) {
451        const svbool_t pg = svwhilelt_b32_s32(i, n);
452        svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));
453    }
454#elif defined(__ARM_NEON) && defined(__aarch64__)
455    for (; i + 3 < n; i += 4) {
456        vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
457    }
458#elif defined(__riscv_v_intrinsic)
459    for (int vl; i < n; i += vl) {
460        vl = __riscv_vsetvl_e32m2(n - i);
461        vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
462        vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl);
463        vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl);
464        __riscv_vse32_v_f32m2(&y[i], vy, vl);
465    }
466#endif
467    for (; i < n; ++i) {
468        y[i] = ggml_silu_f32(x[i]) * g[i];
469    }
470}
471
472ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
473    int i = 0;
474    ggml_float sum = 0;
475// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
476// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
477#if defined(__AVX512F__) && defined(__AVX512DQ__)
478    for (; i + 15 < n; i += 16) {
479        __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
480                                   _mm512_set1_ps(mean));
481        _mm512_storeu_ps(y + i, val);
482        sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
483    }
484#elif defined(__AVX2__) && defined(__FMA__)
485    for (; i + 7 < n; i += 8) {
486        __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
487                                   _mm256_set1_ps(mean));
488        _mm256_storeu_ps(y + i, val);
489        val = _mm256_mul_ps(val,val);
490        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
491                                 _mm256_castps256_ps128(val));
492        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
493        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
494        sum += (ggml_float)_mm_cvtss_f32(val2);
495    }
496#elif defined(__SSE2__)
497    for (; i + 3 < n; i += 4) {
498        __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
499                                _mm_set1_ps(mean));
500        _mm_storeu_ps(y + i, val);
501        val = _mm_mul_ps(val, val);
502#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
503        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
504        val = _mm_add_ss(val, _mm_movehdup_ps(val));
505#else
506        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
507        val = _mm_add_ps(val, tmp);
508        tmp = _mm_movehl_ps(tmp, val);
509        val = _mm_add_ss(val, tmp);
510#endif  // __AVX__ || __AVX2__ || __AVX512F__
511        sum += (ggml_float)_mm_cvtss_f32(val);
512    }
513#elif defined(__ARM_NEON) && defined(__aarch64__)
514    for (; i + 3 < n; i += 4) {
515        float32x4_t val = vsubq_f32(vld1q_f32(x + i),
516                                    vdupq_n_f32(mean));
517        vst1q_f32(y + i, val);
518        val = vmulq_f32(val, val);
519        sum += (ggml_float)vaddvq_f32(val);
520    }
521#elif defined(__VXE__) || defined(__VXE2__)
522    for (; i + 3 < n; i += 4) {
523        float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
524        vec_xst(val, 0, y + i);
525        val = vec_mul(val, val);
526        sum += (ggml_float)vec_hsum_f32x4(val);
527    }
528#elif defined(__riscv_v_intrinsic)
529    vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
530    for (int vl; i < n; i += vl) {
531        vl = __riscv_vsetvl_e32m2(n - i);
532        vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);
533        __riscv_vse32_v_f32m2(&y[i], val, vl);
534        val = __riscv_vfmul_vv_f32m2(val, val, vl);
535        vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);
536    }
537    sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
538#endif
539    for (; i < n; ++i) {
540        float val = x[i] - mean;
541        y[i] = val;
542        val *= val;
543        sum += (ggml_float)val;
544    }
545    return sum/n;
546}
547
548ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
549    int i = 0;
550    ggml_float sum = 0;
551#if defined(__AVX512F__) && defined(__AVX512DQ__)
552    for (; i + 15 < n; i += 16) {
553        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
554                                               _mm512_set1_ps(max)));
555        _mm512_storeu_ps(y + i, val);
556        sum += (ggml_float)_mm512_reduce_add_ps(val);
557    }
558#elif defined(__AVX2__) && defined(__FMA__)
559    for (; i + 7 < n; i += 8) {
560        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
561                                               _mm256_set1_ps(max)));
562        _mm256_storeu_ps(y + i, val);
563        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
564                                 _mm256_castps256_ps128(val));
565        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
566        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
567        sum += (ggml_float)_mm_cvtss_f32(val2);
568    }
569#elif defined(__SSE2__)
570    for (; i + 3 < n; i += 4) {
571        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
572                                            _mm_set1_ps(max)));
573        _mm_storeu_ps(y + i, val);
574#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
575        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
576        val = _mm_add_ss(val, _mm_movehdup_ps(val));
577#else
578        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
579        val = _mm_add_ps(val, tmp);
580        tmp = _mm_movehl_ps(tmp, val);
581        val = _mm_add_ss(val, tmp);
582#endif
583        sum += (ggml_float)_mm_cvtss_f32(val);
584    }
585#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
586    const int vlen = svcntw();
587    for (; i < n; i += vlen) {
588        const svbool_t pg = svwhilelt_b32_s32(i, n);
589        svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),
590                                                svdup_n_f32_x(pg, max)));
591        svst1_f32(pg, y + i, val);
592        sum += (ggml_float)svaddv_f32(pg, val);
593    }
594#elif defined(__ARM_NEON) && defined(__aarch64__)
595    for (; i + 3 < n; i += 4) {
596        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
597                                                vdupq_n_f32(max)));
598        vst1q_f32(y + i, val);
599        sum += (ggml_float)vaddvq_f32(val);
600    }
601#elif defined(__riscv_v_intrinsic)
602    vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
603    for (int avl; i < n; i += avl) {
604        avl = __riscv_vsetvl_e32m2(n - i);
605        vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
606        __riscv_vse32_v_f32m2(&y[i], val, avl);
607        vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
608    }
609    return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
610#endif
611    for (; i < n; ++i) {
612        float val = expf(x[i] - max);
613        sum += (ggml_float)val;
614        y[i] = val;
615    }
616    return sum;
617}
618
619ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
620    // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
621
622    int i = 0;
623    ggml_float sum = 0;
624    for (; i < n; ++i) {
625        float val = x[i] - max;
626        y[i] = val;
627        sum += (ggml_float)expf(val);
628    }
629    return sum = (ggml_float)logf(sum);
630}