1#ifndef HVX_ARITH_H
  2#define HVX_ARITH_H
  3
  4#include <assert.h>
  5#include <stddef.h>
  6#include <stdint.h>
  7#include <math.h>
  8
  9#include "hvx-base.h"
 10#include "hex-utils.h"
 11
 12//
 13// Binary operations (add, mul, sub)
 14//
 15
 16#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
 17    do {                                                                       \
 18        dst_type * restrict vdst  = (dst_type *) dst;                          \
 19        src0_type * restrict vsrc0 = (src0_type *) src0;                       \
 20        src1_type * restrict vsrc1 = (src1_type *) src1;                       \
 21                                                                               \
 22        const uint32_t elem_size = sizeof(float);                              \
 23        const uint32_t epv  = 128 / elem_size;                                 \
 24        const uint32_t nvec = n / epv;                                         \
 25        const uint32_t nloe = n % epv;                                         \
 26                                                                               \
 27        uint32_t i = 0;                                                        \
 28                                                                               \
 29        _Pragma("unroll(4)")                                                   \
 30        for (; i < nvec; i++) {                                                \
 31            vdst[i] = vec_op(vsrc0[i], vsrc1[i]);                              \
 32        }                                                                      \
 33        if (nloe) {                                                            \
 34            HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]);                         \
 35            vec_store((void *) &vdst[i], nloe * elem_size, v);                 \
 36        }                                                                      \
 37    } while(0)
 38
 39#if __HVX_ARCH__ < 79
 40#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
 41#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
 42#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
 43#else
 44#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
 45#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
 46#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
 47#endif
 48
 49// Generic macro to define alignment permutations for an op
 50#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
 51static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
 52    assert((uintptr_t) dst % 128 == 0); \
 53    assert((uintptr_t) src0 % 128 == 0); \
 54    assert((uintptr_t) src1 % 128 == 0); \
 55    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
 56} \
 57static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
 58    assert((uintptr_t) dst % 128 == 0); \
 59    assert((uintptr_t) src0 % 128 == 0); \
 60    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
 61} \
 62static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
 63    assert((uintptr_t) dst % 128 == 0); \
 64    assert((uintptr_t) src1 % 128 == 0); \
 65    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
 66} \
 67static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
 68    assert((uintptr_t) dst % 128 == 0); \
 69    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
 70} \
 71static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
 72    assert((uintptr_t) src0 % 128 == 0); \
 73    assert((uintptr_t) src1 % 128 == 0); \
 74    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
 75} \
 76static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
 77    assert((uintptr_t) src0 % 128 == 0); \
 78    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
 79} \
 80static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
 81    assert((uintptr_t) src1 % 128 == 0); \
 82    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
 83} \
 84static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
 85    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
 86} \
 87
 88DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
 89DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
 90DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
 91
 92// Dispatcher logic
 93#define HVX_BINARY_DISPATCHER(OP_NAME) \
 94static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
 95    if (hex_is_aligned((void *) dst, 128)) { \
 96        if (hex_is_aligned((void *) src0, 128)) { \
 97            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
 98            else                                    OP_NAME##_aau(dst, src0, src1, num_elems); \
 99        } else { \
100            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
101            else                                    OP_NAME##_auu(dst, src0, src1, num_elems); \
102        } \
103    } else { \
104        if (hex_is_aligned((void *) src0, 128)) { \
105            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
106            else                                    OP_NAME##_uau(dst, src0, src1, num_elems); \
107        } else { \
108            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
109            else                                    OP_NAME##_uuu(dst, src0, src1, num_elems); \
110        } \
111    } \
112}
113
114HVX_BINARY_DISPATCHER(hvx_add_f32)
115HVX_BINARY_DISPATCHER(hvx_sub_f32)
116HVX_BINARY_DISPATCHER(hvx_mul_f32)
117
118// Mul-Mul Optimized
119static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
120    assert((unsigned long) dst % 128 == 0);
121    assert((unsigned long) src0 % 128 == 0);
122    assert((unsigned long) src1 % 128 == 0);
123    assert((unsigned long) src2 % 128 == 0);
124
125    HVX_Vector * restrict vdst  = (HVX_Vector *) dst;
126    HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0;
127    HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1;
128    HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2;
129
130    const uint32_t elem_size = sizeof(float);
131    const uint32_t epv  = 128 / elem_size;
132    const uint32_t nvec = num_elems / epv;
133    const uint32_t nloe = num_elems % epv;
134
135    uint32_t i = 0;
136
137    _Pragma("unroll(4)")
138    for (; i < nvec; i++) {
139        HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
140        vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
141    }
142
143    if (nloe) {
144        HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
145        HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
146        hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
147    }
148}
149
150// Scalar Operations
151
152#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro)   \
153    do {                                                                       \
154        dst_type * restrict vdst = (dst_type *) dst;                           \
155        src_type * restrict vsrc = (src_type *) src;                           \
156                                                                               \
157        const uint32_t elem_size = sizeof(float);                              \
158        const uint32_t epv  = 128 / elem_size;                                 \
159        const uint32_t nvec = n / epv;                                         \
160        const uint32_t nloe = n % epv;                                         \
161                                                                               \
162        uint32_t i = 0;                                                        \
163                                                                               \
164        _Pragma("unroll(4)")                                                   \
165        for (; i < nvec; i++) {                                                \
166            HVX_Vector v = vsrc[i];                                            \
167            vdst[i] = scalar_op_macro(v);                                      \
168        }                                                                      \
169        if (nloe) {                                                            \
170            HVX_Vector v = vsrc[i];                                            \
171            v = scalar_op_macro(v);                                            \
172            vec_store((void *) &vdst[i], nloe * elem_size, v);                 \
173        }                                                                      \
174    } while(0)
175
176#define HVX_OP_ADD_SCALAR(v) \
177    ({ \
178        const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
179        HVX_Vector out = HVX_OP_ADD(v, val_vec); \
180        Q6_V_vmux_QVV(pred_inf, inf, out); \
181    })
182
183#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
184#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
185
186// Add Scalar Variants
187
188static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
189    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
190    const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
191    assert((unsigned long) dst % 128 == 0);
192    assert((unsigned long) src % 128 == 0);
193    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
194}
195
196static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
197    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
198    const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
199    assert((unsigned long) dst % 128 == 0);
200    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
201}
202
203static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
204    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
205    const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
206    assert((unsigned long) src % 128 == 0);
207    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
208}
209
210static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
211    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
212    static const float kInf = INFINITY;
213    const HVX_Vector inf = hvx_vec_splat_f32(kInf);
214    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
215}
216
217// Sub Scalar Variants
218
219static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
220    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
221    assert((unsigned long) dst % 128 == 0);
222    assert((unsigned long) src % 128 == 0);
223    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
224}
225
226static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
227    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
228    assert((unsigned long) dst % 128 == 0);
229    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
230}
231
232static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
233    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
234    assert((unsigned long) src % 128 == 0);
235    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
236}
237
238static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
239    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
240    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
241}
242
243// Mul Scalar Variants
244
245static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
246    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
247    assert((unsigned long) dst % 128 == 0);
248    assert((unsigned long) src % 128 == 0);
249    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
250}
251
252static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
253    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
254    assert((unsigned long) dst % 128 == 0);
255    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
256}
257
258static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
259    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
260    assert((unsigned long) src % 128 == 0);
261    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
262}
263
264static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
265    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
266    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
267}
268
269static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
270    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
271        hvx_add_scalar_f32_aa(dst, src, val, num_elems);
272    } else if (hex_is_aligned((void *) dst, 128)) {
273        hvx_add_scalar_f32_au(dst, src, val, num_elems);
274    } else if (hex_is_aligned((void *) src, 128)) {
275        hvx_add_scalar_f32_ua(dst, src, val, num_elems);
276    } else {
277        hvx_add_scalar_f32_uu(dst, src, val, num_elems);
278    }
279}
280
281static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
282    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
283        hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
284    } else if (hex_is_aligned((void *) dst, 128)) {
285        hvx_mul_scalar_f32_au(dst, src, val, num_elems);
286    } else if (hex_is_aligned((void *) src, 128)) {
287        hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
288    } else {
289        hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
290    }
291}
292
293static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
294    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
295        hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
296    } else if (hex_is_aligned((void *) dst, 128)) {
297        hvx_sub_scalar_f32_au(dst, src, val, num_elems);
298    } else if (hex_is_aligned((void *) src, 128)) {
299        hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
300    } else {
301        hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
302    }
303}
304
305// MIN Scalar variants
306
307#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v)
308
309static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
310    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
311    assert((unsigned long) dst % 128 == 0);
312    assert((unsigned long) src % 128 == 0);
313    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
314}
315
316static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
317    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
318    assert((unsigned long) dst % 128 == 0);
319    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
320}
321
322static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
323    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
324    assert((unsigned long) src % 128 == 0);
325    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
326}
327
328static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
329    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
330    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
331}
332
333static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
334    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
335        hvx_min_scalar_f32_aa(dst, src, val, num_elems);
336    } else if (hex_is_aligned((void *) dst, 128)) {
337        hvx_min_scalar_f32_au(dst, src, val, num_elems);
338    } else if (hex_is_aligned((void *) src, 128)) {
339        hvx_min_scalar_f32_ua(dst, src, val, num_elems);
340    } else {
341        hvx_min_scalar_f32_uu(dst, src, val, num_elems);
342    }
343}
344
345// CLAMP Scalar variants
346
347#define HVX_OP_CLAMP_SCALAR(v) \
348    ({ \
349        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \
350        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \
351        HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \
352        Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \
353    })
354
355static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
356    const HVX_Vector min_vec = hvx_vec_splat_f32(min);
357    const HVX_Vector max_vec = hvx_vec_splat_f32(max);
358    assert((unsigned long) dst % 128 == 0);
359    assert((unsigned long) src % 128 == 0);
360    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
361}
362
363static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
364    const HVX_Vector min_vec = hvx_vec_splat_f32(min);
365    const HVX_Vector max_vec = hvx_vec_splat_f32(max);
366    assert((unsigned long) dst % 128 == 0);
367    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
368}
369
370static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
371    const HVX_Vector min_vec = hvx_vec_splat_f32(min);
372    const HVX_Vector max_vec = hvx_vec_splat_f32(max);
373    assert((unsigned long) src % 128 == 0);
374    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
375}
376
377static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
378    const HVX_Vector min_vec = hvx_vec_splat_f32(min);
379    const HVX_Vector max_vec = hvx_vec_splat_f32(max);
380    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
381}
382
383static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
384    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
385        hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems);
386    } else if (hex_is_aligned((void *) dst, 128)) {
387        hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems);
388    } else if (hex_is_aligned((void *) src, 128)) {
389        hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems);
390    } else {
391        hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems);
392    }
393}
394
395//
396// Square
397//
398
399#define hvx_sqr_loop_body(dst_type, src_type, vec_store)           \
400    do {                                                                   \
401        dst_type * restrict vdst  = (dst_type *) dst;                      \
402        src_type * restrict vsrc = (src_type *) src;                       \
403                                                                           \
404        const uint32_t elem_size = sizeof(float);                          \
405        const uint32_t epv  = 128 / elem_size;                             \
406        const uint32_t nvec = n / epv;                                     \
407        const uint32_t nloe = n % epv;                                     \
408                                                                           \
409        uint32_t i = 0;                                                    \
410                                                                           \
411        _Pragma("unroll(4)")                                               \
412        for (; i < nvec; i++) {                                            \
413            vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]);                        \
414        }                                                                  \
415        if (nloe) {                                                        \
416            HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]);                   \
417            vec_store((void *) &vdst[i], nloe * elem_size, v);             \
418        }                                                                  \
419    } while(0)
420
421static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
422    assert((unsigned long) dst % 128 == 0);
423    assert((unsigned long) src % 128 == 0);
424    hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
425}
426
427static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
428    assert((unsigned long) dst % 128 == 0);
429    hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
430}
431
432static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
433    assert((unsigned long) src % 128 == 0);
434    hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
435}
436
437static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
438    hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
439}
440
441static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
442    if (hex_is_aligned((void *) dst, 128)) {
443        if (hex_is_aligned((void *) src, 128)) {
444            hvx_sqr_f32_aa(dst, src, num_elems);
445        } else {
446            hvx_sqr_f32_au(dst, src, num_elems);
447        }
448    } else {
449        if (hex_is_aligned((void *) src, 128)) {
450            hvx_sqr_f32_ua(dst, src, num_elems);
451        } else {
452            hvx_sqr_f32_uu(dst, src, num_elems);
453        }
454    }
455}
456
457#undef HVX_OP_ADD
458#undef HVX_OP_SUB
459#undef HVX_OP_MUL
460#undef hvx_arith_loop_body
461#undef HVX_OP_ADD_SCALAR
462#undef HVX_OP_SUB_SCALAR
463#undef HVX_OP_MUL_SCALAR
464#undef hvx_scalar_loop_body
465#undef HVX_OP_MIN_SCALAR
466#undef HVX_OP_CLAMP_SCALAR
467#undef DEFINE_HVX_BINARY_OP_VARIANTS
468#undef HVX_BINARY_DISPATCHER
469
470#endif // HVX_ARITH_H