1#ifndef HVX_REDUCE_H
2#define HVX_REDUCE_H
3
4#include <math.h>
5#include <stdbool.h>
6#include <stdint.h>
7#include <assert.h>
8
9#include "hex-utils.h"
10#include "hvx-base.h"
11#include "hvx-types.h"
12
13static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) {
14 unsigned int total = n * 4; // total vec nbytes
15 unsigned int width = 4; // int32
16
17 HVX_Vector sum = in, sum_t;
18 while (width < total) {
19 sum_t = Q6_V_vror_VR(sum, width); // rotate right
20 sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum
21 width = width << 1;
22 }
23 return sum;
24}
25
26static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) {
27 return hvx_vec_reduce_sum_n_i32(in, 32);
28}
29
30static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) {
31 unsigned int total = n * 4; // total vec nbytes
32 unsigned int width = 4; // fp32 nbytes
33
34 HVX_Vector sum = in, sum_t;
35 while (width < total) {
36 sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right
37 sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum
38 width = width << 1;
39 }
40 return sum;
41}
42
43static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
44 return hvx_vec_reduce_sum_n_qf32(in, 32);
45}
46
47#if __HVX_ARCH__ > 75
48
49static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
50 HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
51 HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
52
53 sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
54 sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
55 sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
56 sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
57 return sum_sf;
58}
59
60static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
61 unsigned int total = n * 4; // total vec nbytes
62 unsigned int width = 4; // fp32 nbytes
63
64 HVX_Vector sum = in, sum_t;
65 while (width < total) {
66 sum_t = Q6_V_vror_VR(sum, width); // rotate right
67 sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
68 width = width << 1;
69 }
70 return sum;
71}
72
73#else
74
75static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
76 HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
77 HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
78
79 sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
80 sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
81 sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
82 sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
83 return Q6_Vsf_equals_Vqf32(sum_qf);
84}
85
86static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
87 unsigned int total = n * 4; // total vec nbytes
88 unsigned int width = 4; // fp32 nbytes
89
90 HVX_Vector sum = in, sum_t;
91 while (width < total) {
92 sum_t = Q6_V_vror_VR(sum, width); // rotate right
93 sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
94 width = width << 1;
95 }
96 return sum;
97}
98
99#endif
100
101static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
102 return hvx_vec_reduce_sum_n_f32(in, 32);
103}
104
105static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) {
106 unsigned total = 128; // total vec nbytes
107 unsigned width = 2; // fp16 nbytes
108
109 HVX_Vector _max = in, _max_t;
110 while (width < total) {
111 _max_t = Q6_V_vror_VR(_max, width); // rotate right
112 _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
113 width = width << 1;
114 }
115
116 return _max;
117}
118
119static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) {
120 unsigned total = 128; // total vec nbytes
121 unsigned width = 2; // fp32 nbytes
122
123 HVX_Vector _max_t;
124
125 _max = Q6_Vhf_vmax_VhfVhf(in, _max);
126 while (width < total) {
127 _max_t = Q6_V_vror_VR(_max, width); // rotate right
128 _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
129 width = width << 1;
130 }
131
132 return _max;
133}
134
135static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) {
136 unsigned total = 128; // total vec nbytes
137 unsigned width = 4; // fp32 nbytes
138
139 HVX_Vector _max = in, _max_t;
140 while (width < total) {
141 _max_t = Q6_V_vror_VR(_max, width); // rotate right
142 _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
143 width = width << 1;
144 }
145
146 return _max;
147}
148
149static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) {
150 unsigned total = 128; // total vec nbytes
151 unsigned width = 4; // fp32 nbytes
152
153 HVX_Vector _max_t;
154
155 _max = Q6_Vsf_vmax_VsfVsf(in, _max);
156 while (width < total) {
157 _max_t = Q6_V_vror_VR(_max, width); // rotate right
158 _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
159 width = width << 1;
160 }
161
162 return _max;
163}
164
165#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \
166 do { \
167 src_type * restrict vsrc = (src_type *) src; \
168 HVX_Vector acc = init_vec; \
169 \
170 const uint32_t elem_size = sizeof(float); \
171 const uint32_t epv = 128 / elem_size; \
172 const uint32_t nvec = num_elems / epv; \
173 const uint32_t nloe = num_elems % epv; \
174 \
175 uint32_t i = 0; \
176 _Pragma("unroll(4)") \
177 for (; i < nvec; i++) { \
178 acc = vec_op(acc, vsrc[i]); \
179 } \
180 if (nloe) { \
181 const float * srcf = (const float *) src + i * epv; \
182 HVX_Vector in = *(HVX_UVector *) srcf; \
183 HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size); \
184 acc = vec_op(acc, temp); \
185 } \
186 HVX_Vector v = reduce_op(acc); \
187 return scalar_reduce(v); \
188 } while(0)
189
190#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val)
191#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val)
192#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val))
193#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v)
194#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v))
195
196// Max variants
197
198static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) {
199 HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
200 assert((unsigned long) src % 128 == 0);
201 hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
202}
203
204static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) {
205 HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
206 hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
207}
208
209static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) {
210 if (hex_is_aligned((void *) src, 128)) {
211 return hvx_reduce_max_f32_a(src, num_elems);
212 } else {
213 return hvx_reduce_max_f32_u(src, num_elems);
214 }
215}
216
217// Sum variants
218
219static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) {
220 HVX_Vector init_vec = Q6_V_vsplat_R(0);
221 assert((unsigned long) src % 128 == 0);
222 hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
223}
224
225static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) {
226 HVX_Vector init_vec = Q6_V_vsplat_R(0);
227 hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
228}
229
230static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) {
231 if (hex_is_aligned((void *) src, 128)) {
232 return hvx_reduce_sum_f32_a(src, num_elems);
233 } else {
234 return hvx_reduce_sum_f32_u(src, num_elems);
235 }
236}
237
238// Sum of squares variants
239
240static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) {
241 HVX_Vector init_vec = Q6_V_vsplat_R(0);
242 assert((uintptr_t) src % 128 == 0);
243 hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
244}
245
246static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) {
247 HVX_Vector init_vec = Q6_V_vsplat_R(0);
248 hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
249}
250
251static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
252 if (hex_is_aligned((void *) src, 128)) {
253 return hvx_sum_of_squares_f32_a(src, num_elems);
254 } else {
255 return hvx_sum_of_squares_f32_u(src, num_elems);
256 }
257}
258
259#undef hvx_reduce_loop_body
260#undef HVX_REDUCE_MAX_OP
261#undef HVX_REDUCE_SUM_OP
262#undef HVX_REDUCE_MAX_SCALAR
263#undef HVX_REDUCE_SUM_SCALAR
264#undef HVX_SUM_SQ_OP
265
266#endif /* HVX_REDUCE_H */