1// Vectorized functions for fundamental operations
   2
   3#pragma once
   4
   5#include "ggml-impl.h"
   6#include "simd-mappings.h"
   7#include "ggml.h"
   8#include "ggml-cpu.h"
   9
  10#if defined(GGML_USE_ACCELERATE)
  11#include <Accelerate/Accelerate.h>
  12#endif
  13
  14// floating point type used to accumulate sums
  15typedef double ggml_float;
  16
  17#define GGML_GELU_FP16
  18#define GGML_GELU_QUICK_FP16
  19
  20#define GGML_SOFT_MAX_UNROLL 4
  21#define GGML_VEC_DOT_UNROLL  2
  22#define GGML_VEC_MAD_UNROLL  32
  23
  24#ifdef __cplusplus
  25extern "C" {
  26#endif
  27
  28//
  29// global data
  30//
  31
  32// precomputed gelu table for f16 (128 KB)
  33extern ggml_fp16_t ggml_table_gelu_f16[1 << 16];
  34
  35// precomputed quick gelu table for f16 (128 KB)
  36extern ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
  37
  38//
  39// fundamental operations
  40//
  41
  42void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
  43void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
  44void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
  45
  46void ggml_vec_silu_f32(const int n, float * y, const float * x);
  47ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )
  48ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
  49ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
  50
  51inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  52inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  53
  54inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t   v) { for (int i = 0; i < n; ++i) x[i] = v;    }
  55inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
  56
  57inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  58inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  59
  60inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
  61    int i = 0;
  62#if defined(__AVX2__)
  63    for (; i + 7 < n; i += 8) {
  64        __m256 vx = _mm256_loadu_ps(x + i);
  65        __m256 vy = _mm256_loadu_ps(y + i);
  66        __m256 vz = _mm256_add_ps(vx, vy);
  67        _mm256_storeu_ps(z + i, vz);
  68    }
  69#endif
  70    for (; i < n; ++i) {
  71        z[i] = x[i] + y[i];
  72    }
  73}
  74
  75inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  76    for (int i = 0; i < n; ++i) {
  77        z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
  78    }
  79}
  80inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
  81inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
  82inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
  83inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
  84inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  85    for (int i = 0; i < n; ++i) {
  86        z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) - GGML_CPU_FP16_TO_FP32(y[i]));
  87    }
  88}
  89inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
  90inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
  91inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
  92inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  93    for (int i = 0; i < n; ++i) {
  94        y[i] = GGML_CPU_FP32_TO_FP16(-GGML_CPU_FP16_TO_FP32(x[i]));
  95    }
  96}
  97
  98inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
  99inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
 100    for (int i = 0; i < n; ++i) {
 101        z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) * GGML_CPU_FP16_TO_FP32(y[i]));
 102    }
 103}
 104inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 105inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
 106    for (int i = 0; i < n; ++i) {
 107        z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) / GGML_CPU_FP16_TO_FP32(y[i]));
 108    }
 109}
 110
 111// compute GGML_VEC_DOT_UNROLL dot products at once
 112// xs - x row stride in bytes
 113inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GGML_RESTRICT s, void * GGML_RESTRICT xv, ggml_fp16_t * GGML_RESTRICT y) {
 114    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
 115
 116    ggml_fp16_t * GGML_RESTRICT x[GGML_VEC_DOT_UNROLL];
 117
 118    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
 119        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
 120    }
 121
 122#if defined(GGML_SIMD)
 123    #if defined(__ARM_FEATURE_SVE)
 124
 125        const int sve_register_length = svcntb() * 8;
 126        const int ggml_f16_epr = sve_register_length / 16; // running when 16
 127        const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
 128
 129        const int np = (n & ~(ggml_f16_step - 1));
 130
 131        svfloat16_t sum_00 = svdup_n_f16(0.0f);
 132        svfloat16_t sum_01 = svdup_n_f16(0.0f);
 133        svfloat16_t sum_02 = svdup_n_f16(0.0f);
 134        svfloat16_t sum_03 = svdup_n_f16(0.0f);
 135
 136        svfloat16_t sum_10 = svdup_n_f16(0.0f);
 137        svfloat16_t sum_11 = svdup_n_f16(0.0f);
 138        svfloat16_t sum_12 = svdup_n_f16(0.0f);
 139        svfloat16_t sum_13 = svdup_n_f16(0.0f);
 140
 141        svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
 142        svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
 143
 144        for (int i = 0; i < np; i += ggml_f16_step) {
 145            ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
 146
 147            ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
 148            sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1);     // sum_00 = sum_00+ax1*ay1
 149            ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
 150            sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
 151
 152            ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
 153
 154            ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
 155            sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
 156            ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
 157            sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
 158
 159            ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
 160
 161            ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
 162            sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
 163            ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
 164            sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
 165
 166            ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
 167
 168            ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);
 169            sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
 170            ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);
 171            sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);
 172
 173            ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
 174
 175            ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);
 176
 177            sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
 178            ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);
 179            sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);
 180
 181            ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
 182
 183            ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);
 184
 185            sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
 186            ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);
 187            sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);
 188
 189            ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
 190
 191            ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);
 192
 193            sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
 194            ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);
 195            sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);
 196
 197            ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
 198
 199            ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);
 200
 201            sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
 202            ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);
 203            sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
 204        }
 205
 206        const int np2 = (n & ~(ggml_f16_epr - 1));
 207        for (int k = np; k < np2; k += ggml_f16_epr) {
 208            svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
 209
 210            svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);
 211            sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);
 212            rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);
 213            sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);
 214        }
 215
 216        if (np2 < n) {
 217            svbool_t pg = svwhilelt_b16(np2, n);
 218            svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
 219            svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
 220            svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
 221
 222            sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
 223            sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
 224        }
 225        GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
 226        GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
 227
 228    #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
 229        size_t vl = __riscv_vsetvlmax_e32m4();
 230
 231        // initialize accumulators to all zeroes
 232        vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
 233        vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
 234        vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
 235        vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
 236
 237        // calculate step size
 238        const size_t epr = __riscv_vsetvlmax_e16m2();
 239        const size_t step = epr * 2;
 240        const int np = (n & ~(step - 1));
 241
 242        // unroll by 2 along the row dimension
 243        for (int i = 0; i < np; i += step) {
 244            vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
 245            vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
 246            vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
 247            vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
 248            vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
 249
 250            vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
 251            vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
 252            vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
 253            vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
 254            vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
 255        }
 256
 257        vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
 258        vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
 259
 260        // leftovers
 261        for (int i = np; i < n; i += vl) {
 262            vl = __riscv_vsetvl_e16m2(n - i);
 263            vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
 264            vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
 265            vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
 266
 267            vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
 268            vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
 269        }
 270
 271        // reduce
 272        vl = __riscv_vsetvlmax_e32m2();
 273        vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
 274                                    __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
 275        vl = __riscv_vsetvlmax_e32m1();
 276        vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
 277        __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
 278        vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
 279                                    acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
 280
 281        vl = __riscv_vsetvlmax_e32m2();
 282        vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
 283                                    __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
 284        vl = __riscv_vsetvlmax_e32m1();
 285        vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
 286                                    __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
 287        vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
 288                                    acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
 289        sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
 290        sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
 291
 292    #else
 293        const int np = (n & ~(GGML_F16_STEP - 1));
 294
 295        GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
 296
 297        GGML_F16_VEC ax[GGML_F16_ARR];
 298        GGML_F16_VEC ay[GGML_F16_ARR];
 299
 300        for (int i = 0; i < np; i += GGML_F16_STEP) {
 301            for (int j = 0; j < GGML_F16_ARR; j++) {
 302                ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
 303
 304                for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
 305                    ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
 306
 307                    sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
 308                }
 309            }
 310        }
 311
 312        // reduce sum0..sum3 to sum0
 313        for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
 314            GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
 315        }
 316
 317        // leftovers
 318        for (int i = np; i < n; ++i) {
 319            for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
 320                sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
 321            }
 322        }
 323    #endif
 324#else
 325    for (int i = 0; i < n; ++i) {
 326        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
 327            sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
 328        }
 329    }
 330#endif
 331
 332    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
 333        s[i] = (float)sumf[i];
 334    }
 335}
 336
 337inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) {
 338#if defined(GGML_SIMD)
 339    #if defined(__ARM_FEATURE_SVE)
 340
 341        const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
 342        const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
 343        const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
 344        GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
 345
 346        const int np = (n & ~(ggml_f32_step - 1));
 347        svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
 348        svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
 349        for (int i = 0; i < np; i += ggml_f32_step) {
 350
 351            ax1 = GGML_F32_VEC_LOAD(x + i);
 352            ay1 = GGML_F32_VEC_LOAD(y + i);
 353            ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
 354
 355            GGML_F32_VEC_STORE(y + i, ay1);
 356
 357            ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
 358            ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
 359            ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
 360
 361            GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
 362
 363            ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
 364            ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
 365            ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
 366
 367            GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
 368
 369            ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
 370            ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
 371            ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
 372
 373            GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
 374
 375            ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
 376            ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
 377            ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
 378
 379            GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
 380
 381            ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
 382            ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
 383            ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
 384
 385            GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
 386
 387            ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
 388            ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
 389            ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
 390
 391            GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
 392
 393            ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
 394            ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
 395            ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
 396
 397            GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
 398        }
 399        // leftovers
 400        // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
 401        const int np2 = (n & ~(ggml_f32_epr - 1));
 402        for (int i = np; i < np2; i += ggml_f32_epr) {
 403            ax1 = GGML_F32_VEC_LOAD(x + i);
 404            ay1 = GGML_F32_VEC_LOAD(y + i);
 405            ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
 406
 407            GGML_F32_VEC_STORE(y + i, ay1);
 408        }
 409        // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
 410        if (np2 < n) {
 411            svbool_t pg =svwhilelt_b32(np2, n);
 412            ax1 = svld1_f32(pg, x + np2);
 413            ay1 = svld1_f32(pg, y + np2);
 414            ay1 = svmad_f32_m(pg, ax1, vx, ay1);
 415
 416            svst1_f32(pg, y + np2, ay1);
 417        }
 418    #elif defined(__riscv_v_intrinsic)
 419        for (int i = 0, avl; i < n; i += avl) {
 420            avl = __riscv_vsetvl_e32m8(n - i);
 421            vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
 422            vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
 423            vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);
 424            __riscv_vse32_v_f32m8(&y[i], ny, avl);
 425        }
 426    #else
 427        const int np = (n & ~(GGML_F32_STEP - 1));
 428
 429        GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
 430
 431        GGML_F32_VEC ax[GGML_F32_ARR];
 432        GGML_F32_VEC ay[GGML_F32_ARR];
 433
 434        for (int i = 0; i < np; i += GGML_F32_STEP) {
 435            for (int j = 0; j < GGML_F32_ARR; j++) {
 436                ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
 437                ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
 438                ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
 439
 440                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
 441            }
 442        }
 443
 444        // leftovers
 445        for (int i = np; i < n; ++i) {
 446            y[i] += x[i]*v;
 447        }
 448    #endif
 449#else
 450    // scalar
 451    for (int i = 0; i < n; ++i) {
 452        y[i] += x[i]*v;
 453    }
 454#endif
 455}
 456
 457inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
 458#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
 459    const int sve_register_length = svcntb() * 8;
 460    const int ggml_f16_epr = sve_register_length / 16;
 461    const int ggml_f16_step = 8 * ggml_f16_epr;
 462
 463    GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
 464
 465    int np = (n & ~(ggml_f16_step - 1));
 466
 467    svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
 468    svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
 469    for (int i = 0; i < np; i += ggml_f16_step) {
 470        ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
 471        ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
 472        ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
 473
 474        GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
 475
 476        ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
 477        ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
 478        ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
 479
 480        GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
 481
 482        ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
 483        ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
 484        ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
 485
 486        GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
 487
 488        ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
 489        ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
 490        ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
 491
 492        GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
 493
 494        ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
 495        ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
 496        ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
 497
 498        GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
 499
 500        ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
 501        ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
 502        ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
 503
 504        GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
 505
 506        ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
 507        ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
 508        ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
 509
 510        GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
 511
 512        ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
 513        ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
 514        ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
 515
 516        GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
 517    }
 518    const int np2 = (n & ~(ggml_f16_epr - 1));
 519    for (int k = np; k < np2; k += ggml_f16_epr) {
 520        svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
 521        svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
 522        ry = GGML_F16x_VEC_FMA(ry, rx, vx);
 523
 524        GGML_F16x_VEC_STORE(y + k, ry, 0);
 525    }
 526
 527    if (np2 < n) {
 528        svbool_t pg = svwhilelt_b16(np2, n);
 529        svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
 530        svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
 531        hy = svmad_f16_x(pg, hx, vx, hy);
 532        svst1_f16(pg, (__fp16 *)(y + np2), hy);
 533    }
 534    np = n;
 535#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
 536    const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
 537    const _Float16 scale = *(const _Float16*)(&s);
 538
 539    // calculate step size
 540    const int epr = __riscv_vsetvlmax_e16m4();
 541    const int step = epr * 2;
 542    int np = (n & ~(step - 1));
 543
 544    // unroll by 2
 545    for (int i = 0; i < np; i += step) {
 546        vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
 547        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
 548        ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
 549        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
 550        __asm__ __volatile__ ("" ::: "memory");
 551
 552        vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
 553        vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
 554        ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
 555        __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
 556        __asm__ __volatile__ ("" ::: "memory");
 557    }
 558
 559    // leftovers
 560    int vl;
 561    for (int i = np; i < n; i += vl) {
 562        vl = __riscv_vsetvl_e16m4(n - i);
 563        vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
 564        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
 565        ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
 566        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
 567    }
 568    np = n;
 569#elif defined(GGML_SIMD)
 570    const int np = (n & ~(GGML_F16_STEP - 1));
 571
 572    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
 573
 574    GGML_F16_VEC ax[GGML_F16_ARR];
 575    GGML_F16_VEC ay[GGML_F16_ARR];
 576
 577    for (int i = 0; i < np; i += GGML_F16_STEP) {
 578        for (int j = 0; j < GGML_F16_ARR; j++) {
 579            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
 580            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
 581            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
 582
 583            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
 584        }
 585    }
 586#else
 587    const int np = 0;
 588#endif
 589
 590    // leftovers
 591    for (int i = np; i < n; ++i) {
 592        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
 593    }
 594}
 595
 596// xs and vs are byte strides of x and v
 597inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * GGML_RESTRICT y, const float * GGML_RESTRICT xv, const float * GGML_RESTRICT vv) {
 598
 599    const float * GGML_RESTRICT x[GGML_VEC_MAD_UNROLL];
 600    const float * GGML_RESTRICT v[GGML_VEC_MAD_UNROLL];
 601
 602    for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
 603        x[i] = (const float *) ((const char *) xv + i*xs);
 604        v[i] = (const float *) ((const char *) vv + i*vs);
 605    }
 606
 607#if defined(GGML_SIMD)
 608    #if defined(__ARM_FEATURE_SVE)
 609        // scalar Route to scalar implementation       //TODO: Write SVE code
 610        for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
 611            for (int i = 0; i < n; ++i) {
 612                y[i] += x[k][i]*v[k][0];
 613            }
 614        }
 615    #elif defined(__riscv_v_intrinsic)
 616        for (int i = 0, avl; i < n; i += avl) {
 617            avl = __riscv_vsetvl_e32m8(n - i);
 618            vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
 619            for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) {
 620                vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);
 621                ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);
 622            }
 623            __riscv_vse32_v_f32m8(&y[i], ay, avl);
 624        }
 625    #else
 626        const int np = (n & ~(GGML_F32_STEP - 1));
 627
 628        GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
 629
 630        for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
 631            vx[k] = GGML_F32_VEC_SET1(v[k][0]);
 632        }
 633
 634        GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
 635        GGML_F32_VEC ay[GGML_F32_ARR];
 636
 637        for (int i = 0; i < np; i += GGML_F32_STEP) {
 638            for (int j = 0; j < GGML_F32_ARR; j++) {
 639                ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
 640
 641                for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
 642                    ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
 643                    ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
 644                }
 645
 646                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
 647            }
 648        }
 649
 650        // leftovers
 651        for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
 652            for (int i = np; i < n; ++i) {
 653                y[i] += x[k][i]*v[k][0];
 654            }
 655        }
 656    #endif
 657#else
 658    // scalar
 659    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
 660        for (int i = 0; i < n; ++i) {
 661            y[i] += x[k][i]*v[k][0];
 662        }
 663    }
 664#endif
 665}
 666
 667inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
 668#if defined(GGML_USE_ACCELERATE)
 669    vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
 670#elif defined(GGML_SIMD)
 671    #if defined(__ARM_FEATURE_SVE)
 672        // scalar ; TODO: Write SVE code
 673        for (int i = 0; i < n; ++i) {
 674            y[i] = x[i]*s + b;
 675        }
 676    #elif defined(__riscv_v_intrinsic)
 677        for (int i = 0, avl; i < n; i += avl) {
 678            avl = __riscv_vsetvl_e32m8(n - i);
 679            vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
 680            vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
 681            vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
 682            __riscv_vse32_v_f32m8(&y[i], ny, avl);
 683        }
 684    #else
 685        const int np = (n & ~(GGML_F32_STEP - 1));
 686
 687        GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
 688        GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
 689
 690        GGML_F32_VEC ay[GGML_F32_ARR];
 691
 692        for (int i = 0; i < np; i += GGML_F32_STEP) {
 693            for (int j = 0; j < GGML_F32_ARR; j++) {
 694                ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
 695                ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs);
 696
 697                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
 698            }
 699        }
 700
 701        // leftovers
 702        for (int i = np; i < n; ++i) {
 703            y[i] = x[i]*s + b;
 704        }
 705    #endif
 706#else
 707    // scalar
 708    for (int i = 0; i < n; ++i) {
 709        y[i] = x[i]*s + b;
 710    }
 711#endif
 712}
 713
 714//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
 715inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 716#if defined(GGML_USE_ACCELERATE)
 717    vDSP_vsmul(y, 1, &v, y, 1, n);
 718#elif defined(GGML_SIMD)
 719    #if defined(__ARM_FEATURE_SVE)
 720        const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
 721        const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
 722        const int ggml_f32_step = 2 * ggml_f32_epr;
 723
 724        GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
 725        const int np = (n & ~(ggml_f32_step - 1));
 726        svfloat32_t ay1;
 727        svfloat32_t ay2;
 728        for (int i = 0; i < np; i += ggml_f32_step) {
 729            ay1 = GGML_F32_VEC_LOAD(y + i);
 730            ay1 = GGML_F32_VEC_MUL(ay1, vx);
 731            GGML_F32_VEC_STORE(y + i, ay1);
 732
 733            ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
 734            ay2 = GGML_F32_VEC_MUL(ay2, vx);
 735            GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
 736        }
 737        // leftovers
 738        // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
 739        for (int i = np; i < n; i += ggml_f32_epr) {
 740            svbool_t pg = svwhilelt_b32(i, n);
 741            ay1 = svld1_f32(pg, y + i);
 742            ay1 = svmul_f32_m(pg, ay1, vx);
 743            svst1_f32(pg, y + i, ay1);
 744        }
 745    #elif defined(__riscv_v_intrinsic)
 746        for (int i = 0, avl; i < n; i += avl) {
 747            avl = __riscv_vsetvl_e32m8(n - i);
 748            vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
 749            vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);
 750            __riscv_vse32_v_f32m8(&y[i], ny, avl);
 751        }
 752    #else
 753        const int np = (n & ~(GGML_F32_STEP - 1));
 754
 755        GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
 756
 757        GGML_F32_VEC ay[GGML_F32_ARR];
 758
 759        for (int i = 0; i < np; i += GGML_F32_STEP) {
 760            for (int j = 0; j < GGML_F32_ARR; j++) {
 761                ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
 762                ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
 763
 764                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
 765            }
 766        }
 767
 768        // leftovers
 769        for (int i = np; i < n; ++i) {
 770            y[i] *= v;
 771        }
 772    #endif
 773#else
 774    // scalar
 775    for (int i = 0; i < n; ++i) {
 776        y[i] *= v;
 777    }
 778#endif
 779}
 780
 781inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
 782#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
 783    const int sve_register_length = svcntb() * 8;
 784    const int ggml_f16_epr = sve_register_length / 16;
 785    const int ggml_f16_step = 2 * ggml_f16_epr;
 786
 787    GGML_F16x_VEC vx =  GGML_F16x_VEC_SET1(v);
 788    const int np = (n & ~(ggml_f16_step - 1));
 789    svfloat16_t ay1, ay2;
 790
 791    for (int i = 0; i < np; i += ggml_f16_step) {
 792        ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
 793        ay1 = GGML_F16x_VEC_MUL(ay1, vx);
 794        GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
 795
 796        ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
 797        ay2 = GGML_F16x_VEC_MUL(ay2, vx);
 798        GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
 799    }
 800    // leftovers
 801    // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
 802    if (np < n) {
 803        svbool_t pg = svwhilelt_b16(np, n);
 804        svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
 805        svfloat16_t out = svmul_f16_m(pg, hy, vx);
 806        svst1_f16(pg, (__fp16 *)(y + np), out);
 807    }
 808#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
 809    const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
 810    const _Float16 scale = *(const _Float16*)(&s);
 811
 812    // calculate step size
 813    const int epr = __riscv_vsetvlmax_e16m4();
 814    const int step = epr * 2;
 815    const int np = (n & ~(step - 1));
 816
 817    // unroll by 2
 818    for (int i = 0; i < np; i += step) {
 819        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
 820        ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
 821        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
 822        __asm__ __volatile__ ("" ::: "memory");
 823
 824        vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
 825        ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
 826        __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
 827        __asm__ __volatile__ ("" ::: "memory");
 828    }
 829
 830    // leftovers
 831    int vl;
 832    for (int i = np; i < n; i += vl) {
 833        vl = __riscv_vsetvl_e16m4(n - i);
 834        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
 835        ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
 836        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
 837    }
 838#elif defined(GGML_SIMD)
 839    const int np = (n & ~(GGML_F16_STEP - 1));
 840
 841    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
 842
 843    GGML_F16_VEC ay[GGML_F16_ARR];
 844
 845    for (int i = 0; i < np; i += GGML_F16_STEP) {
 846        for (int j = 0; j < GGML_F16_ARR; j++) {
 847            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
 848            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
 849
 850            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
 851        }
 852    }
 853
 854    // leftovers
 855    for (int i = np; i < n; ++i) {
 856        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
 857    }
 858#else
 859    // scalar
 860    for (int i = 0; i < n; ++i) {
 861        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
 862    }
 863#endif
 864}
 865
 866inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
 867inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 868inline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 869    for (int i = 0; i < n; ++i) {
 870        float v = GGML_CPU_FP16_TO_FP32(x[i]);
 871        y[i] = GGML_CPU_FP32_TO_FP16(v*v);
 872    }
 873}
 874inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
 875inline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 876    for (int i = 0; i < n; ++i) {
 877        y[i] = GGML_CPU_FP32_TO_FP16(sqrtf(GGML_CPU_FP16_TO_FP32(x[i])));
 878    }
 879}
 880inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
 881inline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 882    for (int i = 0; i < n; ++i) {
 883        y[i] = GGML_CPU_FP32_TO_FP16(logf(GGML_CPU_FP16_TO_FP32(x[i])));
 884    }
 885}
 886inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
 887inline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 888    for (int i = 0; i < n; ++i) {
 889        y[i] = GGML_CPU_FP32_TO_FP16(sinf(GGML_CPU_FP16_TO_FP32(x[i])));
 890    }
 891}
 892inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
 893inline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 894    for (int i = 0; i < n; ++i) {
 895        y[i] = GGML_CPU_FP32_TO_FP16(cosf(GGML_CPU_FP16_TO_FP32(x[i])));
 896    }
 897}
 898inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
 899inline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 900    for (int i = 0; i < n; ++i) {
 901        y[i] = GGML_CPU_FP32_TO_FP16(fabsf(GGML_CPU_FP16_TO_FP32(x[i])));
 902    }
 903}
 904inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
 905inline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 906    for (int i = 0; i < n; ++i) {
 907        float v = GGML_CPU_FP16_TO_FP32(x[i]);
 908        y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f));
 909    }
 910}
 911inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
 912inline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 913    for (int i = 0; i < n; ++i) {
 914        y[i] = GGML_CPU_FP32_TO_FP16((GGML_CPU_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f);
 915    }
 916}
 917inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
 918inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 919    for (int i = 0; i < n; ++i) {
 920        y[i] = GGML_CPU_FP32_TO_FP16(tanhf(GGML_CPU_FP16_TO_FP32(x[i])));
 921    }
 922}
 923inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
 924inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 925    for (int i = 0; i < n; ++i) {
 926        const float v = GGML_CPU_FP16_TO_FP32(x[i]);
 927        y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
 928    }
 929}
 930inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 931inline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 932    for (int i = 0; i < n; ++i) {
 933        float v = GGML_CPU_FP16_TO_FP32(x[i]);
 934        y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : 0.f);
 935    }
 936}
 937inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
 938inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) {
 939    for (int i = 0; i < n; ++i) {
 940        float v = GGML_CPU_FP16_TO_FP32(x[i]);
 941        y[i] = GGML_CPU_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f));
 942    }
 943}
 944inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
 945inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 946    for (int i = 0; i < n; ++i) {
 947        y[i] = GGML_CPU_FP32_TO_FP16(1.f / (1.f + expf(-GGML_CPU_FP16_TO_FP32(x[i]))));
 948    }
 949}
 950// TODO: optimize performance
 951inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
 952inline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 953    for (int i = 0; i < n; ++i) {
 954        float v = GGML_CPU_FP16_TO_FP32(x[i]);
 955        y[i] = GGML_CPU_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f)));
 956    }
 957}
 958inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
 959inline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 960    for (int i = 0; i < n; ++i) {
 961        y[i] = GGML_CPU_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_CPU_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f)));
 962    }
 963}
 964inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
 965inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 966    for (int i = 0; i < n; ++i) {
 967        y[i] = GGML_CPU_FP32_TO_FP16(expf(GGML_CPU_FP16_TO_FP32(x[i])));
 968    }
 969}
 970
 971static const float GELU_COEF_A     = 0.044715f;
 972static const float GELU_QUICK_COEF = -1.702f;
 973static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
 974static const float SQRT_2_INV      = 0.70710678118654752440084436210484f;
 975
 976inline static float ggml_gelu_f32(float x) {
 977    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 978}
 979
 980inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 981    const uint16_t * i16 = (const uint16_t *) x;
 982    for (int i = 0; i < n; ++i) {
 983        y[i] = ggml_table_gelu_f16[i16[i]];
 984    }
 985}
 986
 987inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 988    for (int i = 0; i < n; ++i) {
 989        float xi = GGML_CPU_FP16_TO_FP32(x[i]);
 990        float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
 991        y[i] = GGML_CPU_FP32_TO_FP16(res);
 992    }
 993}
 994
 995#ifdef GGML_GELU_FP16
 996inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 997    uint16_t t;
 998    for (int i = 0; i < n; ++i) {
 999        if (x[i] <= -10.0f) {
1000            y[i] = 0.0f;
1001        } else if (x[i] >= 10.0f) {
1002            y[i] = x[i];
1003        } else {
1004            ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
1005            memcpy(&t, &fp16, sizeof(uint16_t));
1006            y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]);
1007        }
1008    }
1009}
1010#else
1011inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
1012    for (int i = 0; i < n; ++i) {
1013        y[i] = ggml_gelu_f32(x[i]);
1014    }
1015}
1016#endif
1017
1018inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
1019    for (int i = 0; i < n; ++i) {
1020        float xi = x[i];
1021        y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
1022    }
1023}
1024
1025inline static float ggml_gelu_quick_f32(float x) {
1026    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
1027}
1028
1029//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
1030//    const uint16_t * i16 = (const uint16_t *) x;
1031//    for (int i = 0; i < n; ++i) {
1032//        y[i] = ggml_table_gelu_quick_f16[i16[i]];
1033//    }
1034//}
1035
1036#ifdef GGML_GELU_QUICK_FP16
1037inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
1038    uint16_t t;
1039    for (int i = 0; i < n; ++i) {
1040        ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
1041        memcpy(&t, &fp16, sizeof(uint16_t));
1042        y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
1043    }
1044}
1045#else
1046inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
1047    for (int i = 0; i < n; ++i) {
1048        y[i] = ggml_gelu_quick_f32(x[i]);
1049    }
1050}
1051#endif
1052
1053inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
1054    for (int i = 0; i < n; ++i) {
1055        float v = GGML_CPU_FP16_TO_FP32(x[i]);
1056        y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v))));
1057    }
1058}
1059
1060// Sigmoid Linear Unit (SiLU) function
1061inline static float ggml_silu_f32(float x) {
1062    return x/(1.0f + expf(-x));
1063}
1064inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
1065    float v = GGML_CPU_FP16_TO_FP32(x);
1066    return GGML_CPU_FP32_TO_FP16(v/(1.0f + expf(-v)));
1067}
1068
1069#if __FINITE_MATH_ONLY__
1070#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
1071#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
1072#endif
1073
1074/* Below function was borrowed from the GitHub repository:
1075https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
1076#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
1077    inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
1078        // Constants
1079        const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
1080        const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
1081        const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
1082        const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
1083        const svfloat32_t one = svdup_n_f32(1.0f);
1084        const svfloat32_t inactive1 = svdup_n_f32(0.0f);
1085        const svint32_t inactive2 = svdup_n_s32(0);
1086
1087        // Algorithm starts here
1088        svfloat32_t t0 = svmul_f32_m(pg, src, log2_e);  // y = x * log2(e)
1089        svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0);         // rount to int (float)
1090        svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1);         // n
1091
1092        t1 = svsub_f32_m(pg, t0, t1);   // a = y - floor(y)
1093        t1 = svadd_f32_m(pg, t1, one);  // b = a + 1
1094
1095        svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17);  // v = b >> 17 (u32)
1096        svfloat32_t t4 = svexpa_f32(t3);                                   // c = fexpa(v)
1097        t4 = svscale_f32_m(pg, t4, t2);                                    // fexpa(v) * 2^(n)
1098
1099        // and_(t2.d, t1.d, not_mask17.d)
1100        svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
1101        t5 = svsub_f32_m(pg, t1, t5);                // z
1102        t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq);  // ln2 + half_ln2_sq * z
1103        t0 = svmla_f32_m(pg, one, t5, t0);           // 1 + (ln2 * z) + (half_ln2_sq * z * z)
1104        t0 = svmul_f32_m(pg, t0, t4);                // Final result
1105
1106        return t0;
1107    }
1108#endif
1109
1110#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
1111
1112inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {
1113    const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f);
1114    const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f);
1115    const svfloat32_t n = svsub_f32_x(pg, z, r);
1116    const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f);
1117    const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23);
1118    const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1))));
1119    const svbool_t c = svacgt_n_f32(pg, n, 126);
1120    const svfloat32_t u = svmul_f32_x(pg, b, b);
1121    const svfloat32_t j = svmla_f32_x(pg,
1122        svmul_n_f32_x(pg, b, 0x1.ffffecp-1f),
1123        svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b),
1124                        svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u);
1125    const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000);
1126    const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000));
1127    const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d));
1128    return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1),
1129                     svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
1130}
1131
1132// computes silu x/(1+exp(-x)) in single precision vector
1133inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
1134    const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
1135    const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
1136    const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
1137    const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);
1138    const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
1139    return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
1140}
1141
1142#elif defined(__ARM_NEON) && defined(__aarch64__)
1143
1144// adapted from arm limited optimized routine
1145// the maximum error is 1.45358 plus 0.5 ulps
1146// numbers above 88.38 will flush to infinity
1147// numbers beneath -103.97 will flush to zero
1148inline static float32x4_t ggml_v_expf(float32x4_t x) {
1149    const float32x4_t r = vdupq_n_f32(0x1.8p23f);
1150    const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
1151    const float32x4_t n = vsubq_f32(z, r);
1152    const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
1153                                    vdupq_n_f32(0x1.7f7d1cp-20f));
1154    const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
1155    const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
1156    const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
1157    const float32x4_t u = vmulq_f32(b, b);
1158    const float32x4_t j = vfmaq_f32(
1159        vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
1160        vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
1161                  vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
1162    if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
1163        return vfmaq_f32(k, j, k);
1164    const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
1165    const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
1166    const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
1167    return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
1168                     vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
1169}
1170
1171// computes silu x/(1+exp(-x)) in single precision vector
1172inline static float32x4_t ggml_v_silu(float32x4_t x) {
1173    const float32x4_t one = vdupq_n_f32(1.0f);
1174    const float32x4_t zero = vdupq_n_f32(0.0f);
1175    const float32x4_t neg_x = vsubq_f32(zero, x);
1176    const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
1177    const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
1178    return vdivq_f32(x, one_plus_exp_neg_x);
1179}
1180
1181#elif defined(__AVX512F__) && defined(__AVX512DQ__)
1182
1183// adapted from arm limited optimized routine
1184// the maximum error is 1.45358 plus 0.5 ulps
1185// numbers above 88.38 will flush to infinity
1186// numbers beneath -103.97 will flush to zero
1187inline static __m512 ggml_v_expf(__m512 x) {
1188  const __m512 r = _mm512_set1_ps(0x1.8p23f);
1189  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
1190  const __m512 n = _mm512_sub_ps(z, r);
1191  const __m512 b =
1192      _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
1193                       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
1194  const __mmask16 d =
1195      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
1196  const __m512 u = _mm512_mul_ps(b, b);
1197  const __m512 j = _mm512_fmadd_ps(
1198      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
1199                                      _mm512_set1_ps(0x1.573e2ep-5f)),
1200                      u,
1201                      _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
1202                                      _mm512_set1_ps(0x1.fffdb6p-2f))),
1203      u,
1204      _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
1205  const __m512 res = _mm512_scalef_ps(j, n);
1206  if (_mm512_kortestz(d, d))
1207    return res;
1208  const __m512 zero = _mm512_setzero_ps();
1209  const __m512 alt = _mm512_mask_blend_ps(
1210      _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
1211  return _mm512_mask_blend_ps(d, res, alt);
1212}
1213
1214// computes silu x/(1+exp(-x)) in single precision vector
1215inline static __m512 ggml_v_silu(__m512 x) {
1216    const __m512 one = _mm512_set1_ps(1);
1217    const __m512 zero = _mm512_setzero_ps();
1218    const __m512 neg_x = _mm512_sub_ps(zero, x);
1219    const __m512 exp_neg_x = ggml_v_expf(neg_x);
1220    const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
1221    return _mm512_div_ps(x, one_plus_exp_neg_x);
1222}
1223
1224#elif defined(__AVX2__) && defined(__FMA__)
1225
1226// adapted from arm limited optimized routine
1227// the maximum error is 1.45358 plus 0.5 ulps
1228// numbers above 88.38 will flush to infinity
1229// numbers beneath -103.97 will flush to zero
1230inline static __m256 ggml_v_expf(__m256 x) {
1231  const __m256 r = _mm256_set1_ps(0x1.8p23f);
1232  const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
1233  const __m256 n = _mm256_sub_ps(z, r);
1234  const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
1235                                    _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
1236  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
1237  const __m256 k = _mm256_castsi256_ps(
1238      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
1239  const __m256i c = _mm256_castps_si256(
1240      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
1241                    _mm256_set1_ps(126), _CMP_GT_OQ));
1242  const __m256 u = _mm256_mul_ps(b, b);
1243  const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
1244                                                                   _mm256_set1_ps(0x1.573e2ep-5f)), u,
1245                                                   _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
1246                                                                   _mm256_set1_ps(0x1.fffdb6p-2f))),
1247                                   u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
1248  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
1249    return _mm256_fmadd_ps(j, k, k);
1250  const __m256i g = _mm256_and_si256(
1251      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
1252      _mm256_set1_epi32(0x82000000u));
1253  const __m256 s1 =
1254      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
1255  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
1256  const __m256i d = _mm256_castps_si256(
1257      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
1258                    _mm256_set1_ps(192), _CMP_GT_OQ));
1259  return _mm256_or_ps(
1260      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
1261      _mm256_andnot_ps(
1262          _mm256_castsi256_ps(d),
1263          _mm256_or_ps(
1264              _mm256_and_ps(_mm256_castsi256_ps(c),
1265                            _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
1266              _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
1267}
1268
1269// computes silu x/(1+exp(-x)) in single precision vector
1270inline static __m256 ggml_v_silu(__m256 x) {
1271    const __m256 one = _mm256_set1_ps(1);
1272    const __m256 zero = _mm256_setzero_ps();
1273    const __m256 neg_x = _mm256_sub_ps(zero, x);
1274    const __m256 exp_neg_x = ggml_v_expf(neg_x);
1275    const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
1276    return _mm256_div_ps(x, one_plus_exp_neg_x);
1277}
1278
1279#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
1280
1281#if defined(__FMA__)
1282#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
1283#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
1284#else
1285#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
1286#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
1287#endif
1288
1289// adapted from arm limited optimized routine
1290// the maximum error is 1.45358 plus 0.5 ulps
1291// numbers above 88.38 will flush to infinity
1292// numbers beneath -103.97 will flush to zero
1293inline static __m128 ggml_v_expf(__m128 x) {
1294    const __m128 r = _mm_set1_ps(0x1.8p23f);
1295    const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
1296    const __m128 n = _mm_sub_ps(z, r);
1297    const __m128 b =
1298        NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
1299    const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
1300    const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
1301    const __m128i c =
1302        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
1303    const __m128 u = _mm_mul_ps(b, b);
1304    const __m128 j =
1305        MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
1306                        MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
1307                u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
1308    if (!_mm_movemask_epi8(c))
1309        return MADD128(j, k, k);
1310    const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
1311                                    _mm_set1_epi32(0x82000000u));
1312    const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
1313    const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
1314    const __m128i d =
1315        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
1316    return _mm_or_ps(
1317        _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
1318        _mm_andnot_ps(_mm_castsi128_ps(d),
1319                      _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
1320                                _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
1321}
1322
1323// computes silu x/(1+exp(-x)) in single precision vector
1324inline static __m128 ggml_v_silu(__m128 x) {
1325    const __m128 one = _mm_set1_ps(1);
1326    const __m128 zero = _mm_setzero_ps();
1327    const __m128 neg_x = _mm_sub_ps(zero, x);
1328    const __m128 exp_neg_x = ggml_v_expf(neg_x);
1329    const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
1330    return _mm_div_ps(x, one_plus_exp_neg_x);
1331}
1332
1333#elif defined(__riscv_v_intrinsic)
1334
1335// adapted from arm limited optimized routine
1336// the maximum error is 1.45358 plus 0.5 ulps
1337// numbers above 88.38 will flush to infinity
1338// numbers beneath -103.97 will flush to zero
1339inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
1340    const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
1341#ifdef __riscv_xtheadvector
1342    // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
1343    vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
1344    z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
1345#else
1346    const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
1347#endif
1348    const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
1349    const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
1350                                                    0x1.7f7d1cp-20f, n, vl);
1351    const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
1352    const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
1353    const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
1354    const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
1355    const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
1356        __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
1357        __riscv_vfmacc_vv_f32m2(
1358            __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
1359            __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
1360            u, vl), u, vl);
1361    if (!__riscv_vcpop_m_b16(c, vl))
1362        return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
1363    const vbool16_t  dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
1364    const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
1365    const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
1366    const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
1367    const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
1368        __riscv_vfmacc_vv_f32m2(k, k, j, vl),
1369        __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
1370        c, vl);
1371    return __riscv_vmerge_vvm_f32m2(
1372        r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
1373        __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
1374        vl);
1375}
1376
1377// computes silu x/(1+exp(-x)) in single precision vector
1378inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
1379    const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
1380    const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);
1381    const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
1382    return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
1383}
1384
1385#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
1386
1387inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
1388    for (int i = 0; i < n; ++i) {
1389        y[i] = ggml_silu_f16(x[i]);
1390    }
1391}
1392
1393inline static float ggml_silu_backward_f32(float x, float dy) {
1394    const float s = 1.0f/(1.0f + expf(-x));
1395    return dy*s*(1.0f + x*(1.0f - s));
1396}
1397
1398inline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) {
1399    const float v = GGML_CPU_FP16_TO_FP32(x);
1400    const float s = 1.0f/(1.0f + expf(-v));
1401    return GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s)));
1402}
1403
1404inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
1405    for (int i = 0; i < n; ++i) {
1406        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
1407    }
1408}
1409
1410inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {
1411    for (int i = 0; i < n; ++i) {
1412        dx[i] = ggml_silu_backward_f16(x[i], dy[i]);
1413    }
1414}
1415
1416inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
1417    for (int i = 0; i < n; ++i) {
1418        y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
1419    }
1420}
1421
1422inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1423    for (int i = 0; i < n; ++i) {
1424        float v = GGML_CPU_FP16_TO_FP32(x[i]);
1425        y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);
1426    }
1427}
1428
1429#ifdef GGML_GELU_FP16
1430inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
1431    uint16_t t;
1432    for (int i = 0; i < n; ++i) {
1433        if (x[i] <= -10.0f) {
1434            y[i] = 0.0f;
1435        } else if (x[i] >= 10.0f) {
1436            y[i] = x[i] * g[i];
1437        } else {
1438            ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
1439            memcpy(&t, &fp16, sizeof(uint16_t));
1440            y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
1441        }
1442    }
1443}
1444#else
1445inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
1446    for (int i = 0; i < n; ++i) {
1447        y[i] = ggml_gelu_f32(x[i]) * g[i];
1448    }
1449}
1450#endif
1451
1452inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1453    const uint16_t * i16 = (const uint16_t *) x;
1454    for (int i = 0; i < n; ++i) {
1455        float v = GGML_CPU_FP16_TO_FP32(g[i]);
1456        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
1457    }
1458}
1459
1460void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
1461
1462inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1463    for (int i = 0; i < n; ++i) {
1464        float xi = GGML_CPU_FP16_TO_FP32(x[i]);
1465        float gi = GGML_CPU_FP16_TO_FP32(g[i]);
1466        y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
1467    }
1468}
1469
1470inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
1471    for (int i = 0; i < n; ++i) {
1472        float xi = x[i];
1473        y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
1474    }
1475}
1476
1477inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1478    for (int i = 0; i < n; ++i) {
1479        float xi = GGML_CPU_FP16_TO_FP32(x[i]);
1480        float gi = GGML_CPU_FP16_TO_FP32(g[i]);
1481        y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
1482    }
1483}
1484
1485#ifdef GGML_GELU_QUICK_FP16
1486inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1487    uint16_t t;
1488    for (int i = 0; i < n; ++i) {
1489        ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
1490        memcpy(&t, &fp16, sizeof(uint16_t));
1491        y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
1492    }
1493}
1494#else
1495inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1496    for (int i = 0; i < n; ++i) {
1497        y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
1498    }
1499}
1500#endif
1501
1502inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1503    const uint16_t * i16 = (const uint16_t *) x;
1504    for (int i = 0; i < n; ++i) {
1505        float v = GGML_CPU_FP16_TO_FP32(g[i]);
1506        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
1507    }
1508}
1509
1510inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
1511#ifndef GGML_USE_ACCELERATE
1512    ggml_float sum = 0.0;
1513    for (int i = 0; i < n; ++i) {
1514        sum += (ggml_float)x[i];
1515    }
1516    *s = (float)sum;
1517#else
1518    vDSP_sve(x, 1, s, n);
1519#endif
1520}
1521
1522inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
1523    for (int i = 0; i < n; ++i) {
1524        if (i == 0) {
1525            y[i] = x[i];
1526        } else {
1527            y[i] = y[i - 1] + x[i];
1528        }
1529    }
1530}
1531
1532inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
1533    ggml_float sum = 0.0;
1534    for (int i = 0; i < n; ++i) {
1535        sum += (ggml_float)x[i];
1536    }
1537    *s = sum;
1538}
1539
1540inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
1541    float sum = 0.0f;
1542    for (int i = 0; i < n; ++i) {
1543        sum += GGML_CPU_FP16_TO_FP32(x[i]);
1544    }
1545    *s = sum;
1546}
1547
1548inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
1549    float sum = 0.0f;
1550    for (int i = 0; i < n; ++i) {
1551        sum += GGML_BF16_TO_FP32(x[i]);
1552    }
1553    *s = sum;
1554}
1555
1556inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1557#ifndef GGML_USE_ACCELERATE
1558    float max = -INFINITY;
1559    for (int i = 0; i < n; ++i) {
1560        max = MAX(max, x[i]);
1561    }
1562    *s = max;
1563#else
1564    vDSP_maxv(x, 1, s, n);
1565#endif
1566}
1567
1568inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
1569    ggml_vec_norm_f32(n, s, x);
1570    *s = 1.f/(*s);
1571}
1572
1573inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
1574    float max = -INFINITY;
1575    int idx = 0;
1576    for (int i = 0; i < n; ++i) {
1577        max = MAX(max, x[i]);
1578        if (max == x[i]) { idx = i; }
1579    }
1580    *s = idx;
1581}
1582
1583#ifdef __cplusplus
1584}
1585#endif