1#ifndef HVX_SIGMOID_H
  2#define HVX_SIGMOID_H
  3
  4#include "hvx-base.h"
  5
  6#define FAST_SIGMOID_LOG2F (0x3fb8aa3b)  // 1.442695022
  7#define FAST_SIGMOID_C1    (0x3d009076)  // 0.03138777
  8#define FAST_SIGMOID_C2    (0x3e8d74bd)  // 0.276281267
  9#define FAST_SIGMOID_C3    (0x3f000000)  // 0.5
 10
 11static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) {
 12    v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
 13    v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
 14
 15    HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v));
 16    HVX_Vector x      = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
 17    HVX_Vector xx     = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
 18
 19    HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
 20    v1            = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
 21
 22    HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
 23    v2            = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
 24    v2            = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
 25
 26    HVX_Vector v3          = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
 27    HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
 28    v3_exponent            = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
 29    v3_exponent            = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
 30    v3                     = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
 31
 32    HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
 33    HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
 34
 35    HVX_Vector res = hvx_vec_inverse_f32(v5);
 36    res            = Q6_Vqf32_vmpy_VsfVsf(v3, res);
 37
 38    return Q6_Vsf_equals_Vqf32(res);
 39}
 40
 41static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v,
 42                                                         HVX_Vector one,
 43                                                         HVX_Vector max_exp,
 44                                                         HVX_Vector min_exp) {
 45    const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
 46    const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);
 47
 48    HVX_Vector out = hvx_vec_fast_sigmoid_f32(v);
 49    out            = Q6_V_vmux_QVV(pred_max, out, one);
 50    return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
 51}
 52
 53static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
 54    // tanh(x) = 2 * sigmoid(2x) - 1
 55    HVX_Vector two = hvx_vec_splat_f32(2.0f);
 56    HVX_Vector one = hvx_vec_splat_f32(1.0f);
 57    HVX_Vector x2  = Q6_Vqf32_vmpy_VsfVsf(x, two);
 58
 59    HVX_Vector max_exp = hvx_vec_splat_f32(87.f);
 60    HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);
 61
 62    HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
 63
 64    HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
 65    res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
 66    return Q6_Vsf_equals_Vqf32(res);
 67}
 68
 69#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store)    \
 70    do {                                                        \
 71        dst_type * restrict vdst = (dst_type *) dst;            \
 72        src_type * restrict vsrc = (src_type *) src;            \
 73                                                                \
 74        const HVX_Vector one     = hvx_vec_splat_f32(1.f);      \
 75        const HVX_Vector max_exp = hvx_vec_splat_f32(87.f);     \
 76        const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);    \
 77                                                                \
 78        const uint32_t epv  = 128 / sizeof(float);              \
 79        const uint32_t nvec = n / epv;                          \
 80        const uint32_t nloe = n % epv;                          \
 81                                                                \
 82        uint32_t i = 0;                                         \
 83                                                                \
 84        _Pragma("unroll(4)")                                    \
 85        for (; i < nvec; i++) {                                 \
 86             vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
 87        }                                                       \
 88        if (nloe) {                                             \
 89             HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
 90             vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
 91        }                                                       \
 92    } while(0)
 93
 94#define hvx_tanh_loop_body(dst_type, src_type, vec_store)       \
 95    do {                                                        \
 96        dst_type * restrict vdst = (dst_type *) dst;            \
 97        src_type * restrict vsrc = (src_type *) src;            \
 98                                                                \
 99        const uint32_t epv  = 128 / sizeof(float);              \
100        const uint32_t nvec = n / epv;                          \
101        const uint32_t nloe = n % epv;                          \
102                                                                \
103        uint32_t i = 0;                                         \
104                                                                \
105        _Pragma("unroll(4)")                                    \
106        for (; i < nvec; i++) {                                 \
107             vdst[i] = hvx_vec_tanh_f32(vsrc[i]);               \
108        }                                                       \
109        if (nloe) {                                             \
110             HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]);        \
111             vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
112        }                                                       \
113    } while(0)
114
115static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
116    assert((unsigned long) dst % 128 == 0);
117    assert((unsigned long) src % 128 == 0);
118    hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
119}
120
121static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
122    assert((unsigned long) dst % 128 == 0);
123    hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
124}
125
126static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
127    assert((unsigned long) src % 128 == 0);
128    hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
129}
130
131static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
132    hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
133}
134
135static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
136    assert((unsigned long) dst % 128 == 0);
137    assert((unsigned long) src % 128 == 0);
138    hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
139}
140
141#endif /* HVX_SIGMOID_H */