1#include "unary-ops.h"
2
3static inline float op_abs(float x) {
4 return fabsf(x);
5}
6
7static inline float op_sgn(float x) {
8 return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
9}
10
11static inline float op_neg(float x) {
12 return -x;
13}
14
15static inline float op_step(float x) {
16 return (x > 0.f) ? 1.f : 0.f;
17}
18
19static inline float op_tanh(float x) {
20 return tanhf(x);
21}
22
23static inline float op_elu(float x) {
24 return (x > 0.f) ? x : expm1f(x);
25}
26
27static inline float op_relu(float x) {
28 return (x > 0.f) ? x : 0.f;
29}
30
31static inline float op_sigmoid(float x) {
32 return 1.f / (1.f + expf(-x));
33}
34
35static inline float op_hardsigmoid(float x) {
36 return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
37}
38
39static inline float op_exp(float x) {
40 return expf(x);
41}
42
43static inline float op_hardswish(float x) {
44 return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
45}
46
47static inline float op_sqr(float x) {
48 return x * x;
49}
50
51static inline float op_sqrt(float x) {
52 return sqrtf(x);
53}
54
55static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
56 if (x > 0.0f) {
57 return alpha_p * x * x + beta * x;
58 } else {
59 const float min_x_eps = fminf(x, eps);
60 return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
61 }
62}
63
64static inline float op_sin(float x) {
65 return sinf(x);
66}
67
68static inline float op_cos(float x) {
69 return cosf(x);
70}
71
72static inline float op_log(float x) {
73 return logf(x);
74}
75
76static inline float op_expm1(float x) {
77 return expf(x) - 1.0f;
78}
79
80static inline float op_softplus(float x) {
81 return (x > 20.0f) ? x : logf(1.0f + expf(x));
82}
83
84static inline float op_floor(float x) {
85 return floorf(x);
86}
87
88static inline float op_ceil(float x) {
89 return ceilf(x);
90}
91
92static inline float op_round(float x) {
93 return roundf(x);
94}
95
96static inline float op_trunc(float x) {
97 return truncf(x);
98}
99
100template <float (*op)(float), typename src0_t, typename dst_t>
101static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
102 constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
103 constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
104
105 for (int i = 0; i < n; i++) {
106 y[i] = f32_to_dst(op(src0_to_f32(x[i])));
107 }
108}
109
110template <float (*op)(float), typename src0_t, typename dst_t>
111static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
112 const ggml_tensor * src0 = dst->src[0];
113
114 GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
115
116 GGML_TENSOR_UNARY_OP_LOCALS
117
118 GGML_ASSERT( nb0 == sizeof(dst_t));
119 GGML_ASSERT(nb00 == sizeof(src0_t));
120
121 const auto [ir0, ir1] = get_thread_range(params, src0);
122
123 for (int64_t ir = ir0; ir < ir1; ++ir) {
124 const int64_t i03 = ir/(ne02*ne01);
125 const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
126 const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
127
128 dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
129 const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
130
131 vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
132 }
133}
134
135// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
136template <float (*op)(float)>
137static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
138 const ggml_tensor * src0 = dst->src[0];
139
140 /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
141 apply_unary_op<op, float, float>(params, dst);
142 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
143 apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
144 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
145 apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
146 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
147 apply_unary_op<op, ggml_bf16_t, float>(params, dst);
148 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
149 apply_unary_op<op, ggml_fp16_t, float>(params, dst);
150 } else {
151 fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
152 ggml_type_name(dst->type), ggml_type_name(src0->type));
153 GGML_ABORT("fatal error");
154 }
155}
156
157template <float (*op)(float, ggml_tensor *)>
158static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
159 const ggml_tensor * src0 = dst->src[0];
160
161 /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
162 apply_unary_op<op, float, float>(params, dst);
163 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
164 apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
165 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
166 apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
167 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
168 apply_unary_op<op, ggml_bf16_t, float>(params, dst);
169 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
170 apply_unary_op<op, ggml_fp16_t, float>(params, dst);
171 } else {
172 fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
173 ggml_type_name(dst->type), ggml_type_name(src0->type));
174 GGML_ABORT("fatal error");
175 }
176}
177
178// Extend vec_unary_op to support functors
179template <typename Op, typename src0_t, typename dst_t>
180static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
181 constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
182 constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
183
184 for (int i = 0; i < n; i++) {
185 y[i] = f32_to_dst(op(src0_to_f32(x[i])));
186 }
187}
188
189// Extend apply_unary_op to support functors
190template <typename Op, typename src0_t, typename dst_t>
191static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
192 const ggml_tensor * src0 = dst->src[0];
193
194 GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
195
196 GGML_TENSOR_UNARY_OP_LOCALS
197
198 GGML_ASSERT( nb0 == sizeof(dst_t));
199 GGML_ASSERT(nb00 == sizeof(src0_t));
200
201 const auto [ir0, ir1] = get_thread_range(params, src0);
202
203 for (int64_t ir = ir0; ir < ir1; ++ir) {
204 const int64_t i03 = ir/(ne02*ne01);
205 const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
206 const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
207
208 dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
209 const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
210
211 vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
212 }
213}
214
215// Generic dispatcher for functors
216template <typename Op>
217static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
218 const ggml_tensor * src0 = dst->src[0];
219
220 /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
221 apply_unary_op_functor<Op, float, float>(params, dst, op);
222 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
223 apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);
224 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
225 apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);
226 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
227 apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);
228 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
229 apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);
230 } else {
231 fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
232 ggml_type_name(dst->type), ggml_type_name(src0->type));
233 GGML_ABORT("fatal error");
234 }
235}
236
237void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
238 unary_op<op_abs>(params, dst);
239}
240
241void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
242 unary_op<op_sgn>(params, dst);
243}
244
245void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
246 unary_op<op_neg>(params, dst);
247}
248
249void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
250 unary_op<op_step>(params, dst);
251}
252
253void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
254 unary_op<op_tanh>(params, dst);
255}
256
257void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
258 unary_op<op_elu>(params, dst);
259}
260
261void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
262 unary_op<op_relu>(params, dst);
263}
264
265void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
266 unary_op<op_sigmoid>(params, dst);
267}
268
269void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
270 unary_op<op_hardsigmoid>(params, dst);
271}
272
273void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
274 unary_op<op_exp>(params, dst);
275}
276
277void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
278 unary_op<op_hardswish>(params, dst);
279}
280
281void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
282 unary_op<op_sqr>(params, dst);
283}
284
285void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
286 unary_op<op_sqrt>(params, dst);
287}
288
289void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
290 unary_op<op_sin>(params, dst);
291}
292
293void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
294 unary_op<op_cos>(params, dst);
295}
296
297void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
298 unary_op<op_log>(params, dst);
299}
300
301void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
302 unary_op<op_expm1>(params, dst);
303}
304
305void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
306 unary_op<op_softplus>(params, dst);
307}
308
309void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
310 unary_op<op_floor>(params, dst);
311}
312
313void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
314 unary_op<op_ceil>(params, dst);
315}
316
317void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
318 unary_op<op_round>(params, dst);
319}
320
321void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
322 unary_op<op_trunc>(params, dst);
323}
324
325void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
326 const float alpha_n = ggml_get_op_params_f32(dst, 1);
327 const float alpha_p = ggml_get_op_params_f32(dst, 2);
328 const float beta = ggml_get_op_params_f32(dst, 3);
329 const float eps = ggml_get_op_params_f32(dst, 4);
330
331 const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
332 return op_xielu(f, alpha_n, alpha_p, beta, eps);
333 };
334
335 unary_op_functor(params, dst, xielu_op_params);
336}
337