diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h new file mode 100644 index 0000000..e31a100 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h | |||
| @@ -0,0 +1,126 @@ | |||
| 1 | #ifndef HVX_SQRT_H | ||
| 2 | #define HVX_SQRT_H | ||
| 3 | |||
| 4 | #include <stdbool.h> | ||
| 5 | #include <stdint.h> | ||
| 6 | |||
| 7 | #include "hex-utils.h" | ||
| 8 | |||
| 9 | #include "hvx-base.h" | ||
| 10 | |||
| 11 | #define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation | ||
| 12 | #define RSQRT_ONE_HALF 0x3f000000 // 0.5 | ||
| 13 | #define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 | ||
| 14 | |||
| 15 | #if __HVX_ARCH__ < 79 | ||
| 16 | #define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) | ||
| 17 | #else | ||
| 18 | #define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) | ||
| 19 | #endif | ||
| 20 | |||
| 21 | static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { | ||
| 22 | //Algorithm : | ||
| 23 | // x2 = input*0.5 | ||
| 24 | // y = * (long *) &input | ||
| 25 | // y = 0x5f3759df - (y>>1) | ||
| 26 | // y = y*(threehalfs - x2*y*y) | ||
| 27 | |||
| 28 | HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); | ||
| 29 | HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF); | ||
| 30 | HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES); | ||
| 31 | |||
| 32 | HVX_Vector x2, y, ypower2, temp; | ||
| 33 | |||
| 34 | x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf); | ||
| 35 | x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero()); | ||
| 36 | |||
| 37 | y = Q6_Vw_vasr_VwR(in_vec, 1); | ||
| 38 | y = Q6_Vw_vsub_VwVw(rsqrtconst, y); | ||
| 39 | |||
| 40 | // 1st iteration | ||
| 41 | ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y); | ||
| 42 | ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); | ||
| 43 | temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); | ||
| 44 | temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); | ||
| 45 | temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp)); | ||
| 46 | |||
| 47 | // 2nd iteration | ||
| 48 | y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); | ||
| 49 | ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); | ||
| 50 | ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); | ||
| 51 | temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); | ||
| 52 | temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); | ||
| 53 | temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); | ||
| 54 | |||
| 55 | // 3rd iteration | ||
| 56 | y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); | ||
| 57 | ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); | ||
| 58 | ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); | ||
| 59 | temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); | ||
| 60 | temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); | ||
| 61 | temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); | ||
| 62 | |||
| 63 | return Q6_Vsf_equals_Vqf32(temp); | ||
| 64 | } | ||
| 65 | |||
| 66 | // Compute sqrt(x) as x*inv_sqrt(x) | ||
| 67 | #define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \ | ||
| 68 | do { \ | ||
| 69 | dst_type * restrict vdst = (dst_type *) dst; \ | ||
| 70 | src_type * restrict vsrc = (src_type *) src; \ | ||
| 71 | \ | ||
| 72 | const uint32_t nvec = n / VLEN_FP32; \ | ||
| 73 | const uint32_t nloe = n % VLEN_FP32; \ | ||
| 74 | \ | ||
| 75 | uint32_t i = 0; \ | ||
| 76 | \ | ||
| 77 | _Pragma("unroll(4)") \ | ||
| 78 | for (; i < nvec; i++) { \ | ||
| 79 | HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ | ||
| 80 | HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ | ||
| 81 | vdst[i] = sqrt_res; \ | ||
| 82 | } \ | ||
| 83 | if (nloe) { \ | ||
| 84 | HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ | ||
| 85 | HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ | ||
| 86 | vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \ | ||
| 87 | } \ | ||
| 88 | } while(0) | ||
| 89 | |||
| 90 | static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { | ||
| 91 | assert((unsigned long) dst % 128 == 0); | ||
| 92 | assert((unsigned long) src % 128 == 0); | ||
| 93 | hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); | ||
| 94 | } | ||
| 95 | |||
| 96 | static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { | ||
| 97 | assert((unsigned long) dst % 128 == 0); | ||
| 98 | hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); | ||
| 99 | } | ||
| 100 | |||
| 101 | static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { | ||
| 102 | assert((unsigned long) src % 128 == 0); | ||
| 103 | hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); | ||
| 104 | } | ||
| 105 | |||
| 106 | static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { | ||
| 107 | hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); | ||
| 108 | } | ||
| 109 | |||
| 110 | static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) { | ||
| 111 | if ((unsigned long) dst % 128 == 0) { | ||
| 112 | if ((unsigned long) src % 128 == 0) { | ||
| 113 | hvx_sqrt_f32_aa(dst, src, num_elems); | ||
| 114 | } else { | ||
| 115 | hvx_sqrt_f32_au(dst, src, num_elems); | ||
| 116 | } | ||
| 117 | } else { | ||
| 118 | if ((unsigned long) src % 128 == 0) { | ||
| 119 | hvx_sqrt_f32_ua(dst, src, num_elems); | ||
| 120 | } else { | ||
| 121 | hvx_sqrt_f32_uu(dst, src, num_elems); | ||
| 122 | } | ||
| 123 | } | ||
| 124 | } | ||
| 125 | |||
| 126 | #endif /* HVX_SQRT_H */ | ||
