1#include "unary-ops.h"
  2
  3static inline float op_abs(float x) {
  4    return fabsf(x);
  5}
  6
  7static inline float op_sgn(float x) {
  8    return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
  9}
 10
 11static inline float op_neg(float x) {
 12    return -x;
 13}
 14
 15static inline float op_step(float x) {
 16    return (x > 0.f) ? 1.f : 0.f;
 17}
 18
 19static inline float op_tanh(float x) {
 20    return tanhf(x);
 21}
 22
 23static inline float op_elu(float x) {
 24    return (x > 0.f) ? x : expm1f(x);
 25}
 26
 27static inline float op_relu(float x) {
 28    return (x > 0.f) ? x : 0.f;
 29}
 30
 31static inline float op_sigmoid(float x) {
 32    return 1.f / (1.f + expf(-x));
 33}
 34
 35static inline float op_hardsigmoid(float x) {
 36    return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
 37}
 38
 39static inline float op_exp(float x) {
 40    return expf(x);
 41}
 42
 43static inline float op_hardswish(float x) {
 44    return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
 45}
 46
 47static inline float op_sqr(float x) {
 48    return x * x;
 49}
 50
 51static inline float op_sqrt(float x) {
 52    return sqrtf(x);
 53}
 54
 55static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
 56    if (x > 0.0f) {
 57        return alpha_p * x * x + beta * x;
 58    } else {
 59        const float min_x_eps = fminf(x, eps);
 60        return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
 61    }
 62}
 63
 64static inline float op_sin(float x) {
 65    return sinf(x);
 66}
 67
 68static inline float op_cos(float x) {
 69    return cosf(x);
 70}
 71
 72static inline float op_log(float x) {
 73    return logf(x);
 74}
 75
 76static inline float op_expm1(float x) {
 77    return expf(x) - 1.0f;
 78}
 79
 80static inline float op_softplus(float x) {
 81    return (x > 20.0f) ? x : logf(1.0f + expf(x));
 82}
 83
 84static inline float op_floor(float x) {
 85    return floorf(x);
 86}
 87
 88static inline float op_ceil(float x) {
 89    return ceilf(x);
 90}
 91
 92static inline float op_round(float x) {
 93    return roundf(x);
 94}
 95
 96static inline float op_trunc(float x) {
 97    return truncf(x);
 98}
 99
100template <float (*op)(float), typename src0_t, typename dst_t>
101static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
102    constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
103    constexpr auto f32_to_dst  = type_conversion_table<dst_t >::from_f32;
104
105    for (int i = 0; i < n; i++) {
106        y[i] = f32_to_dst(op(src0_to_f32(x[i])));
107    }
108}
109
110template <float (*op)(float), typename src0_t, typename dst_t>
111static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
112    const ggml_tensor * src0 = dst->src[0];
113
114    GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
115
116    GGML_TENSOR_UNARY_OP_LOCALS
117
118    GGML_ASSERT( nb0 == sizeof(dst_t));
119    GGML_ASSERT(nb00 == sizeof(src0_t));
120
121    const auto [ir0, ir1] = get_thread_range(params, src0);
122
123    for (int64_t ir = ir0; ir < ir1; ++ir) {
124        const int64_t i03 = ir/(ne02*ne01);
125        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
126        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
127
128        dst_t        * dst_ptr  = (dst_t  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
129        const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
130
131        vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
132    }
133}
134
135// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
136template <float (*op)(float)>
137static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
138    const ggml_tensor * src0 = dst->src[0];
139
140    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32
141        apply_unary_op<op, float, float>(params, dst);
142    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16
143        apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
144    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
145        apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
146    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
147        apply_unary_op<op, ggml_bf16_t, float>(params, dst);
148    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {
149        apply_unary_op<op, ggml_fp16_t, float>(params, dst);
150    } else {
151        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
152            ggml_type_name(dst->type), ggml_type_name(src0->type));
153        GGML_ABORT("fatal error");
154    }
155}
156
157template <float (*op)(float, ggml_tensor *)>
158static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
159    const ggml_tensor * src0 = dst->src[0];
160
161    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32
162        apply_unary_op<op, float, float>(params, dst);
163    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16
164        apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
165    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
166        apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
167    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
168        apply_unary_op<op, ggml_bf16_t, float>(params, dst);
169    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {
170        apply_unary_op<op, ggml_fp16_t, float>(params, dst);
171    } else {
172        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
173            ggml_type_name(dst->type), ggml_type_name(src0->type));
174        GGML_ABORT("fatal error");
175    }
176}
177
178// Extend vec_unary_op to support functors
179template <typename Op, typename src0_t, typename dst_t>
180static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
181    constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
182    constexpr auto f32_to_dst  = type_conversion_table<dst_t >::from_f32;
183
184    for (int i = 0; i < n; i++) {
185        y[i] = f32_to_dst(op(src0_to_f32(x[i])));
186    }
187}
188
189// Extend apply_unary_op to support functors
190template <typename Op, typename src0_t, typename dst_t>
191static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
192    const ggml_tensor * src0 = dst->src[0];
193
194    GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
195
196    GGML_TENSOR_UNARY_OP_LOCALS
197
198    GGML_ASSERT( nb0 == sizeof(dst_t));
199    GGML_ASSERT(nb00 == sizeof(src0_t));
200
201    const auto [ir0, ir1] = get_thread_range(params, src0);
202
203    for (int64_t ir = ir0; ir < ir1; ++ir) {
204        const int64_t i03 = ir/(ne02*ne01);
205        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
206        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
207
208        dst_t        * dst_ptr  = (dst_t  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
209        const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
210
211        vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
212    }
213}
214
215// Generic dispatcher for functors
216template <typename Op>
217static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
218    const ggml_tensor * src0 = dst->src[0];
219
220    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32
221        apply_unary_op_functor<Op, float, float>(params, dst, op);
222    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16
223        apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);
224    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
225        apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);
226    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
227        apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);
228    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {
229        apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);
230    } else {
231        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
232            ggml_type_name(dst->type), ggml_type_name(src0->type));
233        GGML_ABORT("fatal error");
234    }
235}
236
237void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
238    unary_op<op_abs>(params, dst);
239}
240
241void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
242    unary_op<op_sgn>(params, dst);
243}
244
245void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
246    unary_op<op_neg>(params, dst);
247}
248
249void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
250    unary_op<op_step>(params, dst);
251}
252
253void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
254    unary_op<op_tanh>(params, dst);
255}
256
257void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
258    unary_op<op_elu>(params, dst);
259}
260
261void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
262    unary_op<op_relu>(params, dst);
263}
264
265void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
266    unary_op<op_sigmoid>(params, dst);
267}
268
269void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
270    unary_op<op_hardsigmoid>(params, dst);
271}
272
273void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
274    unary_op<op_exp>(params, dst);
275}
276
277void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
278    unary_op<op_hardswish>(params, dst);
279}
280
281void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
282    unary_op<op_sqr>(params, dst);
283}
284
285void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
286    unary_op<op_sqrt>(params, dst);
287}
288
289void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
290    unary_op<op_sin>(params, dst);
291}
292
293void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
294    unary_op<op_cos>(params, dst);
295}
296
297void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
298    unary_op<op_log>(params, dst);
299}
300
301void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
302    unary_op<op_expm1>(params, dst);
303}
304
305void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
306    unary_op<op_softplus>(params, dst);
307}
308
309void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
310    unary_op<op_floor>(params, dst);
311}
312
313void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
314    unary_op<op_ceil>(params, dst);
315}
316
317void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
318    unary_op<op_round>(params, dst);
319}
320
321void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
322    unary_op<op_trunc>(params, dst);
323}
324
325void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
326    const float alpha_n = ggml_get_op_params_f32(dst, 1);
327    const float alpha_p = ggml_get_op_params_f32(dst, 2);
328    const float beta = ggml_get_op_params_f32(dst, 3);
329    const float eps = ggml_get_op_params_f32(dst, 4);
330
331    const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
332        return op_xielu(f, alpha_n, alpha_p, beta, eps);
333    };
334
335    unary_op_functor(params, dst, xielu_op_params);
336}
337