1#ifndef HVX_REDUCE_H
  2#define HVX_REDUCE_H
  3
  4#include <math.h>
  5#include <stdbool.h>
  6#include <stdint.h>
  7#include <assert.h>
  8
  9#include "hex-utils.h"
 10#include "hvx-base.h"
 11#include "hvx-types.h"
 12
 13static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) {
 14    unsigned int total = n * 4;  // total vec nbytes
 15    unsigned int width = 4;      // int32
 16
 17    HVX_Vector sum = in, sum_t;
 18    while (width < total) {
 19        sum_t = Q6_V_vror_VR(sum, width);     // rotate right
 20        sum   = Q6_Vw_vadd_VwVw(sum_t, sum);  // elementwise sum
 21        width = width << 1;
 22    }
 23    return sum;
 24}
 25
 26static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) {
 27    return hvx_vec_reduce_sum_n_i32(in, 32);
 28}
 29
 30static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) {
 31    unsigned int total = n * 4;  // total vec nbytes
 32    unsigned int width = 4;      // fp32 nbytes
 33
 34    HVX_Vector sum = in, sum_t;
 35    while (width < total) {
 36        sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width);  // rotate right
 37        sum   = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t);             // elementwise sum
 38        width = width << 1;
 39    }
 40    return sum;
 41}
 42
 43static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
 44    return hvx_vec_reduce_sum_n_qf32(in, 32);
 45}
 46
 47#if __HVX_ARCH__ > 75
 48
 49static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
 50    HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
 51    HVX_Vector  sum_sf  = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
 52
 53    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
 54    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
 55    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
 56    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
 57    return sum_sf;
 58}
 59
 60static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
 61    unsigned int total = n * 4;  // total vec nbytes
 62    unsigned int width = 4;      // fp32 nbytes
 63
 64    HVX_Vector sum = in, sum_t;
 65    while (width < total) {
 66        sum_t = Q6_V_vror_VR(sum, width);       // rotate right
 67        sum   = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
 68        width = width << 1;
 69    }
 70    return sum;
 71}
 72
 73#else
 74
 75static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
 76    HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
 77    HVX_Vector  sum_qf  = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
 78
 79    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
 80    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
 81    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
 82    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
 83    return Q6_Vsf_equals_Vqf32(sum_qf);
 84}
 85
 86static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
 87    unsigned int total = n * 4;  // total vec nbytes
 88    unsigned int width = 4;      // fp32 nbytes
 89
 90    HVX_Vector sum = in, sum_t;
 91    while (width < total) {
 92        sum_t = Q6_V_vror_VR(sum, width);                               // rotate right
 93        sum   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t));  // elementwise sum
 94        width = width << 1;
 95    }
 96    return sum;
 97}
 98
 99#endif
100
101static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
102    return hvx_vec_reduce_sum_n_f32(in, 32);
103}
104
105static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) {
106    unsigned total = 128;  // total vec nbytes
107    unsigned width = 2;    // fp16 nbytes
108
109    HVX_Vector _max = in, _max_t;
110    while (width < total) {
111        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
112        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max
113        width  = width << 1;
114    }
115
116    return _max;
117}
118
119static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) {
120    unsigned total = 128;  // total vec nbytes
121    unsigned width = 2;    // fp32 nbytes
122
123    HVX_Vector _max_t;
124
125    _max = Q6_Vhf_vmax_VhfVhf(in, _max);
126    while (width < total) {
127        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
128        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max
129        width  = width << 1;
130    }
131
132    return _max;
133}
134
135static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) {
136    unsigned total = 128;  // total vec nbytes
137    unsigned width = 4;    // fp32 nbytes
138
139    HVX_Vector _max = in, _max_t;
140    while (width < total) {
141        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
142        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max
143        width  = width << 1;
144    }
145
146    return _max;
147}
148
149static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) {
150    unsigned total = 128;  // total vec nbytes
151    unsigned width = 4;    // fp32 nbytes
152
153    HVX_Vector _max_t;
154
155    _max = Q6_Vsf_vmax_VsfVsf(in, _max);
156    while (width < total) {
157        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
158        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max
159        width  = width << 1;
160    }
161
162    return _max;
163}
164
165#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \
166    do {                                                                                    \
167        src_type * restrict vsrc = (src_type *) src;                                        \
168        HVX_Vector acc = init_vec;                                                          \
169                                                                                            \
170        const uint32_t elem_size = sizeof(float);                                           \
171        const uint32_t epv  = 128 / elem_size;                                              \
172        const uint32_t nvec = num_elems / epv;                                              \
173        const uint32_t nloe = num_elems % epv;                                              \
174                                                                                            \
175        uint32_t i = 0;                                                                     \
176        _Pragma("unroll(4)")                                                                \
177        for (; i < nvec; i++) {                                                             \
178            acc = vec_op(acc, vsrc[i]);                                                     \
179        }                                                                                   \
180        if (nloe) {                                                                         \
181            const float * srcf = (const float *) src + i * epv;                             \
182            HVX_Vector in = *(HVX_UVector *) srcf;                                          \
183            HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size);               \
184            acc = vec_op(acc, temp);                                                        \
185        }                                                                                   \
186        HVX_Vector v = reduce_op(acc);                                                      \
187        return scalar_reduce(v);                                                            \
188    } while(0)
189
190#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val)
191#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val)
192#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val))
193#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v)
194#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v))
195
196// Max variants
197
198static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) {
199    HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
200    assert((unsigned long) src % 128 == 0);
201    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
202}
203
204static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) {
205    HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
206    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
207}
208
209static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) {
210    if (hex_is_aligned((void *) src, 128)) {
211        return hvx_reduce_max_f32_a(src, num_elems);
212    } else {
213        return hvx_reduce_max_f32_u(src, num_elems);
214    }
215}
216
217// Sum variants
218
219static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) {
220    HVX_Vector init_vec = Q6_V_vsplat_R(0);
221    assert((unsigned long) src % 128 == 0);
222    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
223}
224
225static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) {
226    HVX_Vector init_vec = Q6_V_vsplat_R(0);
227    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
228}
229
230static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) {
231    if (hex_is_aligned((void *) src, 128)) {
232        return hvx_reduce_sum_f32_a(src, num_elems);
233    } else {
234        return hvx_reduce_sum_f32_u(src, num_elems);
235    }
236}
237
238// Sum of squares variants
239
240static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) {
241    HVX_Vector init_vec = Q6_V_vsplat_R(0);
242    assert((uintptr_t) src % 128 == 0);
243    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
244}
245
246static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) {
247    HVX_Vector init_vec = Q6_V_vsplat_R(0);
248    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
249}
250
251static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
252    if (hex_is_aligned((void *) src, 128)) {
253        return hvx_sum_of_squares_f32_a(src, num_elems);
254    } else {
255        return hvx_sum_of_squares_f32_u(src, num_elems);
256    }
257}
258
259#undef hvx_reduce_loop_body
260#undef HVX_REDUCE_MAX_OP
261#undef HVX_REDUCE_SUM_OP
262#undef HVX_REDUCE_MAX_SCALAR
263#undef HVX_REDUCE_SUM_SCALAR
264#undef HVX_SUM_SQ_OP
265
266#endif /* HVX_REDUCE_H */