1#ifndef HVX_EXP_H
2#define HVX_EXP_H
3
4#include <stdbool.h>
5#include <stdint.h>
6
7#include "hvx-base.h"
8#include "hvx-floor.h"
9
10#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!)
11#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!)
12#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!)
13#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!)
14#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!)
15#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
16#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
17#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
18#define EXP_ONE (0x3f800000) // 1.0
19#define EXP_RANGE_R (0x41a00000) // 20.0
20#define EXP_RANGE_L (0xc1a00000) // -20.0
21
22static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
23 HVX_Vector z_qf32_v;
24 HVX_Vector x_v;
25 HVX_Vector x_qf32_v;
26 HVX_Vector y_v;
27 HVX_Vector k_v;
28 HVX_Vector f_v;
29 HVX_Vector epsilon_v;
30 HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
31 HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
32 HVX_Vector E_const;
33 HVX_Vector zero_v = Q6_V_vzero();
34
35 // exp(x) is approximated as follows:
36 // f = floor(x/ln(2)) = floor(x*log2(e))
37 // epsilon = x - f*ln(2)
38 // exp(x) = exp(epsilon+f*ln(2))
39 // = exp(epsilon)*exp(f*ln(2))
40 // = exp(epsilon)*2^f
41 //
42 // Since epsilon is close to zero, it can be approximated with its Taylor series:
43 // exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
44 // Preserving the first eight elements, we get:
45 // exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
46 // = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
47
48 HVX_Vector temp_v = in_vec;
49
50 // Clamp inputs to (-20.0, 20.0)
51 HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
52 HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
53
54 in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
55 in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
56
57 epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
58 epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
59
60 // f_v is the floating point result and k_v is the integer result
61 f_v = hvx_vec_floor_f32(epsilon_v);
62 k_v = hvx_vec_truncate_f32(f_v);
63
64 x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
65
66 // x = x - f_v * logn2;
67 epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
68 x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
69 // normalize before every QFloat's vmpy
70 x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
71
72 // z = x * x;
73 z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
74 z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
75
76 x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
77
78 // y = E4 + E5 * x;
79 E_const = Q6_V_vsplat_R(EXP_COEFF_5);
80 y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
81 E_const = Q6_V_vsplat_R(EXP_COEFF_4);
82 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
83 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
84
85 // y = E3 + y * x;
86 E_const = Q6_V_vsplat_R(EXP_COEFF_3);
87 y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
88 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
89 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
90
91 // y = E2 + y * x;
92 E_const = Q6_V_vsplat_R(EXP_COEFF_2);
93 y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
94 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
95 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
96
97 // y = E1 + y * x;
98 E_const = Q6_V_vsplat_R(EXP_COEFF_1);
99 y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
100 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
101 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
102
103 // y = E0 + y * x;
104 E_const = Q6_V_vsplat_R(EXP_COEFF_0);
105 y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
106 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
107 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
108
109 // y = x + y * z;
110 y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
111 y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
112 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
113
114 // y = y + 1.0;
115 y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
116
117 // insert exponents
118 // y = ldexpf(y, k);
119 // y_v += k_v; // qf32
120 // modify exponent
121
122 y_v = Q6_Vsf_equals_Vqf32(y_v);
123
124 // add k_v to the exponent of y_v
125 HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
126
127 y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
128 y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
129
130 // exponent cannot be negative; if overflow is detected, result is set to zero
131 HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
132
133 y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
134
135 y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
136
137 return y_v;
138}
139
140static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
141 const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
142
143 HVX_Vector out = hvx_vec_exp_f32(in_vec);
144
145 return Q6_V_vmux_QVV(pred0, inf, out);
146}
147
148static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
149 int left_over = num_elems & (VLEN_FP32 - 1);
150 int num_elems_whole = num_elems - left_over;
151
152 int unaligned_addr = 0;
153 int unaligned_loop = 0;
154 if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) {
155 unaligned_addr = 1;
156 }
157 // assert((0 == unaligned_addr) || (0 == num_elems_whole));
158 if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
159 unaligned_loop = 1;
160 }
161
162 HVX_Vector vec_out = Q6_V_vzero();
163
164 static const float kInf = INFINITY;
165 static const float kMaxExp = 88.02f; // log(INF)
166
167 const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
168 const HVX_Vector inf = hvx_vec_splat_f32(kInf);
169
170 if (0 == unaligned_loop) {
171 HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
172 HVX_Vector * p_vec_out = (HVX_Vector *) dst;
173
174 #pragma unroll(4)
175 for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
176 if (true == negate) {
177 HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++);
178 *p_vec_out++ = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
179 } else {
180 *p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf);
181 }
182 }
183 } else {
184 #pragma unroll(4)
185 for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
186 HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
187
188 if (true == negate) {
189 HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
190 *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
191 } else {
192 *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf);
193 }
194 }
195 }
196
197 if (left_over > 0) {
198 const float * srcf = (float *) src + num_elems_whole;
199 float * dstf = (float *) dst + num_elems_whole;
200
201 HVX_Vector in = *(HVX_UVector *) srcf;
202
203 if (true == negate) {
204 HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
205
206 vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
207 } else {
208 vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf);
209 }
210
211 hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
212 }
213}
214
215#endif /* HVX_EXP_H */