1#ifndef HVX_ARITH_H
2#define HVX_ARITH_H
3
4#include <assert.h>
5#include <stddef.h>
6#include <stdint.h>
7#include <math.h>
8
9#include "hvx-base.h"
10#include "hex-utils.h"
11
12//
13// Binary operations (add, mul, sub)
14//
15
16#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
17 do { \
18 dst_type * restrict vdst = (dst_type *) dst; \
19 src0_type * restrict vsrc0 = (src0_type *) src0; \
20 src1_type * restrict vsrc1 = (src1_type *) src1; \
21 \
22 const uint32_t elem_size = sizeof(float); \
23 const uint32_t epv = 128 / elem_size; \
24 const uint32_t nvec = n / epv; \
25 const uint32_t nloe = n % epv; \
26 \
27 uint32_t i = 0; \
28 \
29 _Pragma("unroll(4)") \
30 for (; i < nvec; i++) { \
31 vdst[i] = vec_op(vsrc0[i], vsrc1[i]); \
32 } \
33 if (nloe) { \
34 HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \
35 vec_store((void *) &vdst[i], nloe * elem_size, v); \
36 } \
37 } while(0)
38
39#if __HVX_ARCH__ < 79
40#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
41#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
42#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
43#else
44#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
45#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
46#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
47#endif
48
49// Generic macro to define alignment permutations for an op
50#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
51static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
52 assert((uintptr_t) dst % 128 == 0); \
53 assert((uintptr_t) src0 % 128 == 0); \
54 assert((uintptr_t) src1 % 128 == 0); \
55 hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
56} \
57static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
58 assert((uintptr_t) dst % 128 == 0); \
59 assert((uintptr_t) src0 % 128 == 0); \
60 hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
61} \
62static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
63 assert((uintptr_t) dst % 128 == 0); \
64 assert((uintptr_t) src1 % 128 == 0); \
65 hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
66} \
67static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
68 assert((uintptr_t) dst % 128 == 0); \
69 hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
70} \
71static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
72 assert((uintptr_t) src0 % 128 == 0); \
73 assert((uintptr_t) src1 % 128 == 0); \
74 hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
75} \
76static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
77 assert((uintptr_t) src0 % 128 == 0); \
78 hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
79} \
80static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
81 assert((uintptr_t) src1 % 128 == 0); \
82 hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
83} \
84static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
85 hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
86} \
87
88DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
89DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
90DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
91
92// Dispatcher logic
93#define HVX_BINARY_DISPATCHER(OP_NAME) \
94static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
95 if (hex_is_aligned((void *) dst, 128)) { \
96 if (hex_is_aligned((void *) src0, 128)) { \
97 if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
98 else OP_NAME##_aau(dst, src0, src1, num_elems); \
99 } else { \
100 if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
101 else OP_NAME##_auu(dst, src0, src1, num_elems); \
102 } \
103 } else { \
104 if (hex_is_aligned((void *) src0, 128)) { \
105 if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
106 else OP_NAME##_uau(dst, src0, src1, num_elems); \
107 } else { \
108 if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
109 else OP_NAME##_uuu(dst, src0, src1, num_elems); \
110 } \
111 } \
112}
113
114HVX_BINARY_DISPATCHER(hvx_add_f32)
115HVX_BINARY_DISPATCHER(hvx_sub_f32)
116HVX_BINARY_DISPATCHER(hvx_mul_f32)
117
118// Mul-Mul Optimized
119static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
120 assert((unsigned long) dst % 128 == 0);
121 assert((unsigned long) src0 % 128 == 0);
122 assert((unsigned long) src1 % 128 == 0);
123 assert((unsigned long) src2 % 128 == 0);
124
125 HVX_Vector * restrict vdst = (HVX_Vector *) dst;
126 HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0;
127 HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1;
128 HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2;
129
130 const uint32_t elem_size = sizeof(float);
131 const uint32_t epv = 128 / elem_size;
132 const uint32_t nvec = num_elems / epv;
133 const uint32_t nloe = num_elems % epv;
134
135 uint32_t i = 0;
136
137 _Pragma("unroll(4)")
138 for (; i < nvec; i++) {
139 HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
140 vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
141 }
142
143 if (nloe) {
144 HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
145 HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
146 hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
147 }
148}
149
150// Scalar Operations
151
152#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \
153 do { \
154 dst_type * restrict vdst = (dst_type *) dst; \
155 src_type * restrict vsrc = (src_type *) src; \
156 \
157 const uint32_t elem_size = sizeof(float); \
158 const uint32_t epv = 128 / elem_size; \
159 const uint32_t nvec = n / epv; \
160 const uint32_t nloe = n % epv; \
161 \
162 uint32_t i = 0; \
163 \
164 _Pragma("unroll(4)") \
165 for (; i < nvec; i++) { \
166 HVX_Vector v = vsrc[i]; \
167 vdst[i] = scalar_op_macro(v); \
168 } \
169 if (nloe) { \
170 HVX_Vector v = vsrc[i]; \
171 v = scalar_op_macro(v); \
172 vec_store((void *) &vdst[i], nloe * elem_size, v); \
173 } \
174 } while(0)
175
176#define HVX_OP_ADD_SCALAR(v) \
177 ({ \
178 const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
179 HVX_Vector out = HVX_OP_ADD(v, val_vec); \
180 Q6_V_vmux_QVV(pred_inf, inf, out); \
181 })
182
183#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
184#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
185
186// Add Scalar Variants
187
188static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
189 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
190 const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
191 assert((unsigned long) dst % 128 == 0);
192 assert((unsigned long) src % 128 == 0);
193 hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
194}
195
196static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
197 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
198 const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
199 assert((unsigned long) dst % 128 == 0);
200 hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
201}
202
203static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
204 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
205 const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
206 assert((unsigned long) src % 128 == 0);
207 hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
208}
209
210static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
211 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
212 static const float kInf = INFINITY;
213 const HVX_Vector inf = hvx_vec_splat_f32(kInf);
214 hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
215}
216
217// Sub Scalar Variants
218
219static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
220 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
221 assert((unsigned long) dst % 128 == 0);
222 assert((unsigned long) src % 128 == 0);
223 hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
224}
225
226static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
227 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
228 assert((unsigned long) dst % 128 == 0);
229 hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
230}
231
232static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
233 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
234 assert((unsigned long) src % 128 == 0);
235 hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
236}
237
238static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
239 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
240 hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
241}
242
243// Mul Scalar Variants
244
245static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
246 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
247 assert((unsigned long) dst % 128 == 0);
248 assert((unsigned long) src % 128 == 0);
249 hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
250}
251
252static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
253 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
254 assert((unsigned long) dst % 128 == 0);
255 hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
256}
257
258static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
259 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
260 assert((unsigned long) src % 128 == 0);
261 hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
262}
263
264static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
265 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
266 hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
267}
268
269static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
270 if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
271 hvx_add_scalar_f32_aa(dst, src, val, num_elems);
272 } else if (hex_is_aligned((void *) dst, 128)) {
273 hvx_add_scalar_f32_au(dst, src, val, num_elems);
274 } else if (hex_is_aligned((void *) src, 128)) {
275 hvx_add_scalar_f32_ua(dst, src, val, num_elems);
276 } else {
277 hvx_add_scalar_f32_uu(dst, src, val, num_elems);
278 }
279}
280
281static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
282 if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
283 hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
284 } else if (hex_is_aligned((void *) dst, 128)) {
285 hvx_mul_scalar_f32_au(dst, src, val, num_elems);
286 } else if (hex_is_aligned((void *) src, 128)) {
287 hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
288 } else {
289 hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
290 }
291}
292
293static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
294 if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
295 hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
296 } else if (hex_is_aligned((void *) dst, 128)) {
297 hvx_sub_scalar_f32_au(dst, src, val, num_elems);
298 } else if (hex_is_aligned((void *) src, 128)) {
299 hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
300 } else {
301 hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
302 }
303}
304
305// MIN Scalar variants
306
307#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v)
308
309static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
310 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
311 assert((unsigned long) dst % 128 == 0);
312 assert((unsigned long) src % 128 == 0);
313 hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
314}
315
316static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
317 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
318 assert((unsigned long) dst % 128 == 0);
319 hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
320}
321
322static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
323 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
324 assert((unsigned long) src % 128 == 0);
325 hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
326}
327
328static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
329 const HVX_Vector val_vec = hvx_vec_splat_f32(val);
330 hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
331}
332
333static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
334 if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
335 hvx_min_scalar_f32_aa(dst, src, val, num_elems);
336 } else if (hex_is_aligned((void *) dst, 128)) {
337 hvx_min_scalar_f32_au(dst, src, val, num_elems);
338 } else if (hex_is_aligned((void *) src, 128)) {
339 hvx_min_scalar_f32_ua(dst, src, val, num_elems);
340 } else {
341 hvx_min_scalar_f32_uu(dst, src, val, num_elems);
342 }
343}
344
345// CLAMP Scalar variants
346
347#define HVX_OP_CLAMP_SCALAR(v) \
348 ({ \
349 HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \
350 HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \
351 HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \
352 Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \
353 })
354
355static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
356 const HVX_Vector min_vec = hvx_vec_splat_f32(min);
357 const HVX_Vector max_vec = hvx_vec_splat_f32(max);
358 assert((unsigned long) dst % 128 == 0);
359 assert((unsigned long) src % 128 == 0);
360 hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
361}
362
363static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
364 const HVX_Vector min_vec = hvx_vec_splat_f32(min);
365 const HVX_Vector max_vec = hvx_vec_splat_f32(max);
366 assert((unsigned long) dst % 128 == 0);
367 hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
368}
369
370static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
371 const HVX_Vector min_vec = hvx_vec_splat_f32(min);
372 const HVX_Vector max_vec = hvx_vec_splat_f32(max);
373 assert((unsigned long) src % 128 == 0);
374 hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
375}
376
377static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
378 const HVX_Vector min_vec = hvx_vec_splat_f32(min);
379 const HVX_Vector max_vec = hvx_vec_splat_f32(max);
380 hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
381}
382
383static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
384 if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
385 hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems);
386 } else if (hex_is_aligned((void *) dst, 128)) {
387 hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems);
388 } else if (hex_is_aligned((void *) src, 128)) {
389 hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems);
390 } else {
391 hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems);
392 }
393}
394
395//
396// Square
397//
398
399#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
400 do { \
401 dst_type * restrict vdst = (dst_type *) dst; \
402 src_type * restrict vsrc = (src_type *) src; \
403 \
404 const uint32_t elem_size = sizeof(float); \
405 const uint32_t epv = 128 / elem_size; \
406 const uint32_t nvec = n / epv; \
407 const uint32_t nloe = n % epv; \
408 \
409 uint32_t i = 0; \
410 \
411 _Pragma("unroll(4)") \
412 for (; i < nvec; i++) { \
413 vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
414 } \
415 if (nloe) { \
416 HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
417 vec_store((void *) &vdst[i], nloe * elem_size, v); \
418 } \
419 } while(0)
420
421static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
422 assert((unsigned long) dst % 128 == 0);
423 assert((unsigned long) src % 128 == 0);
424 hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
425}
426
427static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
428 assert((unsigned long) dst % 128 == 0);
429 hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
430}
431
432static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
433 assert((unsigned long) src % 128 == 0);
434 hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
435}
436
437static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
438 hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
439}
440
441static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
442 if (hex_is_aligned((void *) dst, 128)) {
443 if (hex_is_aligned((void *) src, 128)) {
444 hvx_sqr_f32_aa(dst, src, num_elems);
445 } else {
446 hvx_sqr_f32_au(dst, src, num_elems);
447 }
448 } else {
449 if (hex_is_aligned((void *) src, 128)) {
450 hvx_sqr_f32_ua(dst, src, num_elems);
451 } else {
452 hvx_sqr_f32_uu(dst, src, num_elems);
453 }
454 }
455}
456
457#undef HVX_OP_ADD
458#undef HVX_OP_SUB
459#undef HVX_OP_MUL
460#undef hvx_arith_loop_body
461#undef HVX_OP_ADD_SCALAR
462#undef HVX_OP_SUB_SCALAR
463#undef HVX_OP_MUL_SCALAR
464#undef hvx_scalar_loop_body
465#undef HVX_OP_MIN_SCALAR
466#undef HVX_OP_CLAMP_SCALAR
467#undef DEFINE_HVX_BINARY_OP_VARIANTS
468#undef HVX_BINARY_DISPATCHER
469
470#endif // HVX_ARITH_H