1#ifndef HVX_SIGMOID_H
2#define HVX_SIGMOID_H
3
4#include "hvx-base.h"
5
6#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
7#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
8#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
9#define FAST_SIGMOID_C3 (0x3f000000) // 0.5
10
11static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) {
12 v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
13 v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
14
15 HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v));
16 HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
17 HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
18
19 HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
20 v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
21
22 HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
23 v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
24 v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
25
26 HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
27 HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
28 v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
29 v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
30 v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
31
32 HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
33 HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
34
35 HVX_Vector res = hvx_vec_inverse_f32(v5);
36 res = Q6_Vqf32_vmpy_VsfVsf(v3, res);
37
38 return Q6_Vsf_equals_Vqf32(res);
39}
40
41static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v,
42 HVX_Vector one,
43 HVX_Vector max_exp,
44 HVX_Vector min_exp) {
45 const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
46 const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);
47
48 HVX_Vector out = hvx_vec_fast_sigmoid_f32(v);
49 out = Q6_V_vmux_QVV(pred_max, out, one);
50 return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
51}
52
53static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
54 // tanh(x) = 2 * sigmoid(2x) - 1
55 HVX_Vector two = hvx_vec_splat_f32(2.0f);
56 HVX_Vector one = hvx_vec_splat_f32(1.0f);
57 HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
58
59 HVX_Vector max_exp = hvx_vec_splat_f32(87.f);
60 HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);
61
62 HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
63
64 HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
65 res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
66 return Q6_Vsf_equals_Vqf32(res);
67}
68
69#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store) \
70 do { \
71 dst_type * restrict vdst = (dst_type *) dst; \
72 src_type * restrict vsrc = (src_type *) src; \
73 \
74 const HVX_Vector one = hvx_vec_splat_f32(1.f); \
75 const HVX_Vector max_exp = hvx_vec_splat_f32(87.f); \
76 const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); \
77 \
78 const uint32_t epv = 128 / sizeof(float); \
79 const uint32_t nvec = n / epv; \
80 const uint32_t nloe = n % epv; \
81 \
82 uint32_t i = 0; \
83 \
84 _Pragma("unroll(4)") \
85 for (; i < nvec; i++) { \
86 vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
87 } \
88 if (nloe) { \
89 HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
90 vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
91 } \
92 } while(0)
93
94#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
95 do { \
96 dst_type * restrict vdst = (dst_type *) dst; \
97 src_type * restrict vsrc = (src_type *) src; \
98 \
99 const uint32_t epv = 128 / sizeof(float); \
100 const uint32_t nvec = n / epv; \
101 const uint32_t nloe = n % epv; \
102 \
103 uint32_t i = 0; \
104 \
105 _Pragma("unroll(4)") \
106 for (; i < nvec; i++) { \
107 vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
108 } \
109 if (nloe) { \
110 HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
111 vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
112 } \
113 } while(0)
114
115static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
116 assert((unsigned long) dst % 128 == 0);
117 assert((unsigned long) src % 128 == 0);
118 hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
119}
120
121static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
122 assert((unsigned long) dst % 128 == 0);
123 hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
124}
125
126static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
127 assert((unsigned long) src % 128 == 0);
128 hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
129}
130
131static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
132 hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
133}
134
135static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
136 assert((unsigned long) dst % 128 == 0);
137 assert((unsigned long) src % 128 == 0);
138 hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
139}
140
141#endif /* HVX_SIGMOID_H */