1#ifndef HVX_SQRT_H
  2#define HVX_SQRT_H
  3
  4#include <stdbool.h>
  5#include <stdint.h>
  6
  7#include "hex-utils.h"
  8
  9#include "hvx-base.h"
 10
 11#define RSQRT_CONST        0x5f3759df  // Constant for fast inverse square root calculation
 12#define RSQRT_ONE_HALF     0x3f000000  // 0.5
 13#define RSQRT_THREE_HALVES 0x3fc00000  // 1.5
 14
 15#if __HVX_ARCH__ < 79
 16#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
 17#else
 18#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
 19#endif
 20
 21static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
 22    //Algorithm :
 23    //  x2 = input*0.5
 24    //  y  = * (long *) &input
 25    //  y  = 0x5f3759df - (y>>1)
 26    //  y  = y*(threehalfs - x2*y*y)
 27
 28    HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
 29    HVX_Vector onehalf    = Q6_V_vsplat_R(RSQRT_ONE_HALF);
 30    HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
 31
 32    HVX_Vector x2, y, ypower2, temp;
 33
 34    x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
 35    x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
 36
 37    y = Q6_Vw_vasr_VwR(in_vec, 1);
 38    y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
 39
 40    // 1st iteration
 41    ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
 42    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
 43    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
 44    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
 45    temp    = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
 46
 47    // 2nd iteration
 48    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
 49    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
 50    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
 51    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
 52    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
 53    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
 54
 55    // 3rd iteration
 56    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
 57    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
 58    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
 59    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
 60    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
 61    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
 62
 63    return Q6_Vsf_equals_Vqf32(temp);
 64}
 65
 66// Compute sqrt(x) as x*inv_sqrt(x)
 67#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store)                \
 68    do {                                                                     \
 69        dst_type * restrict vdst = (dst_type *) dst;                         \
 70        src_type * restrict vsrc = (src_type *) src;                         \
 71                                                                             \
 72        const uint32_t nvec = n / VLEN_FP32;                                 \
 73        const uint32_t nloe = n % VLEN_FP32;                                 \
 74                                                                             \
 75        uint32_t i = 0;                                                      \
 76                                                                             \
 77        _Pragma("unroll(4)")                                                 \
 78        for (; i < nvec; i++) {                                              \
 79            HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]);                \
 80            HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]);             \
 81            vdst[i] = sqrt_res;                                              \
 82        }                                                                    \
 83        if (nloe) {                                                          \
 84            HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]);                \
 85            HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]);             \
 86            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res);      \
 87        }                                                                    \
 88    } while(0)
 89
 90static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
 91    assert((unsigned long) dst % 128 == 0);
 92    assert((unsigned long) src % 128 == 0);
 93    hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
 94}
 95
 96static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
 97    assert((unsigned long) dst % 128 == 0);
 98    hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
 99}
100
101static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
102    assert((unsigned long) src % 128 == 0);
103    hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
104}
105
106static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
107    hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
108}
109
110static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
111    if ((unsigned long) dst % 128 == 0) {
112        if ((unsigned long) src % 128 == 0) {
113            hvx_sqrt_f32_aa(dst, src, num_elems);
114        } else {
115            hvx_sqrt_f32_au(dst, src, num_elems);
116        }
117    } else {
118        if ((unsigned long) src % 128 == 0) {
119            hvx_sqrt_f32_ua(dst, src, num_elems);
120        } else {
121            hvx_sqrt_f32_uu(dst, src, num_elems);
122        }
123    }
124}
125
126#endif /* HVX_SQRT_H */