diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h new file mode 100644 index 0000000..12a1b7f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h | |||
| @@ -0,0 +1,173 @@ | |||
| 1 | #ifndef HVX_BASE_H | ||
| 2 | #define HVX_BASE_H | ||
| 3 | |||
| 4 | #include <stdbool.h> | ||
| 5 | #include <stdint.h> | ||
| 6 | |||
| 7 | #include "hex-utils.h" | ||
| 8 | #include "hvx-types.h" | ||
| 9 | |||
| 10 | static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { | ||
| 11 | // Rotate as needed. | ||
| 12 | v = Q6_V_vlalign_VVR(v, v, (size_t) dst); | ||
| 13 | |||
| 14 | uint32_t left_off = (size_t) dst & 127; | ||
| 15 | uint32_t right_off = left_off + n; | ||
| 16 | |||
| 17 | HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst); | ||
| 18 | HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off); | ||
| 19 | |||
| 20 | if (right_off > 128) { | ||
| 21 | Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v); | ||
| 22 | // all 1's | ||
| 23 | qr = Q6_Q_vcmp_eq_VbVb(v, v); | ||
| 24 | } | ||
| 25 | |||
| 26 | ql_not = Q6_Q_or_QQn(ql_not, qr); | ||
| 27 | Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v); | ||
| 28 | } | ||
| 29 | |||
| 30 | static inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) { | ||
| 31 | assert((unsigned long) dst % 128 == 0); | ||
| 32 | HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n)); | ||
| 33 | Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v); | ||
| 34 | } | ||
| 35 | |||
| 36 | static inline HVX_Vector hvx_vec_splat_f32(float v) { | ||
| 37 | union { float f; uint32_t i; } u = { .f = v }; | ||
| 38 | return Q6_V_vsplat_R(u.i); | ||
| 39 | } | ||
| 40 | |||
| 41 | static inline HVX_Vector hvx_vec_splat_f16(float v) { | ||
| 42 | union { __fp16 f; uint16_t i; } u = { .f = v }; | ||
| 43 | return Q6_Vh_vsplat_R(u.i); | ||
| 44 | } | ||
| 45 | |||
| 46 | static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) { | ||
| 47 | // vdelta control to replicate first 4 bytes across all elements | ||
| 48 | static const uint8_t __attribute__((aligned(128))) repl[128] = { | ||
| 49 | 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, | ||
| 50 | 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, | ||
| 51 | 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, | ||
| 52 | 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, | ||
| 53 | 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, | ||
| 54 | 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, | ||
| 55 | 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, | ||
| 56 | 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, | ||
| 57 | }; | ||
| 58 | |||
| 59 | HVX_Vector ctrl = *(HVX_Vector *) repl; | ||
| 60 | return Q6_V_vdelta_VV(v, ctrl); | ||
| 61 | } | ||
| 62 | |||
| 63 | static inline float hvx_vec_get_f32(HVX_Vector v) { | ||
| 64 | float __attribute__((aligned(128))) x; | ||
| 65 | hvx_vec_store_a(&x, 4, v); | ||
| 66 | return x; | ||
| 67 | } | ||
| 68 | |||
| 69 | static inline int32_t hvx_vec_get_i32(HVX_Vector v) { | ||
| 70 | int32_t __attribute__((aligned(128))) x; | ||
| 71 | hvx_vec_store_a(&x, 4, v); | ||
| 72 | return x; | ||
| 73 | } | ||
| 74 | |||
| 75 | static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { | ||
| 76 | // abs by clearing the fp16 sign bit | ||
| 77 | HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); | ||
| 78 | return Q6_V_vand_VV(v, mask); | ||
| 79 | } | ||
| 80 | |||
| 81 | static inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) { | ||
| 82 | // neg by setting the fp16 sign bit | ||
| 83 | HVX_Vector mask = Q6_Vh_vsplat_R(0x8000); | ||
| 84 | return Q6_V_vxor_VV(v, mask); | ||
| 85 | } | ||
| 86 | |||
| 87 | static inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) { | ||
| 88 | // abs by clearing the fp32 sign bit | ||
| 89 | HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff); | ||
| 90 | return Q6_V_vand_VV(v, mask); | ||
| 91 | } | ||
| 92 | |||
| 93 | static inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) { | ||
| 94 | #if __HVX_ARCH__ > 75 | ||
| 95 | return Q6_Vsf_vfneg_Vsf(v); | ||
| 96 | #else | ||
| 97 | // neg by setting the fp32 sign bit | ||
| 98 | HVX_Vector mask = Q6_V_vsplat_R(0x80000000); | ||
| 99 | return Q6_V_vxor_VV(v, mask); | ||
| 100 | #endif // __HVX_ARCH__ > 75 | ||
| 101 | } | ||
| 102 | |||
| 103 | static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { | ||
| 104 | const HVX_Vector vnan_exp = Q6_Vh_vsplat_R(0x7C00); | ||
| 105 | const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF); | ||
| 106 | |||
| 107 | // get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s | ||
| 108 | HVX_VectorPred p_exp = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp); | ||
| 109 | HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp)); | ||
| 110 | return Q6_Q_and_QQ(p_exp, p_frac); | ||
| 111 | } | ||
| 112 | |||
| 113 | static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { | ||
| 114 | const HVX_Vector zero = Q6_V_vsplat_R(0); | ||
| 115 | HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); | ||
| 116 | HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); | ||
| 117 | HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0))); | ||
| 118 | |||
| 119 | #if __HVX_ARCH__ < 79 | ||
| 120 | // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) | ||
| 121 | const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY); | ||
| 122 | HVX_VectorPred nan = hvx_vec_is_nan_f16(v); | ||
| 123 | v = Q6_V_vmux_QVV(nan, neg_inf, v); | ||
| 124 | #endif | ||
| 125 | |||
| 126 | return v; | ||
| 127 | } | ||
| 128 | |||
| 129 | /* Q6_Vsf_equals_Vw is only available on v73+.*/ | ||
| 130 | #if __HVX_ARCH__ < 73 | ||
| 131 | static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) | ||
| 132 | { | ||
| 133 | HVX_Vector const vzero = Q6_V_vzero(); | ||
| 134 | HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); | ||
| 135 | HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); | ||
| 136 | HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); | ||
| 137 | HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); | ||
| 138 | HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); | ||
| 139 | HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); | ||
| 140 | return ret; | ||
| 141 | } | ||
| 142 | |||
| 143 | static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) | ||
| 144 | { | ||
| 145 | return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in)); | ||
| 146 | } | ||
| 147 | #endif | ||
| 148 | |||
| 149 | static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { | ||
| 150 | // This looks complicated. | ||
| 151 | // Ideally should just be Q6_Vh_equals_Vhf(vin) | ||
| 152 | // but that instruction does not do proper rounding. | ||
| 153 | |||
| 154 | // convert to qf32, multiplying by 1.0 in the process. | ||
| 155 | HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00)); | ||
| 156 | |||
| 157 | // 'in-range' values are +/32752. | ||
| 158 | // add 192K to it, convert to sf | ||
| 159 | HVX_Vector v192K = Q6_V_vsplat_R(0x48400000); | ||
| 160 | HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K)); | ||
| 161 | HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K)); | ||
| 162 | |||
| 163 | // for in-range cases, result is {163858... 229360} so the exponent is always 144. | ||
| 164 | // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer. | ||
| 165 | // Start by <<10 to get the final 'sign' bit in bit 15... | ||
| 166 | vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10); | ||
| 167 | vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10); | ||
| 168 | |||
| 169 | // now round down to 16 | ||
| 170 | return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); | ||
| 171 | } | ||
| 172 | |||
| 173 | #endif /* HVX_BASE_H */ | ||
