1#pragma clang diagnostic ignored "-Wunused-variable"
2#pragma clang diagnostic ignored "-Wunused-function"
3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
5#include <HAP_farf.h>
6#include <HAP_perf.h>
7
8#include <string.h>
9#include <math.h>
10
11#include "hex-dma.h"
12#include "hvx-utils.h"
13
14#define GGML_COMMON_DECL_C
15#include "ggml-common.h"
16#include "htp-ctx.h"
17#include "htp-msg.h"
18#include "htp-ops.h"
19
20
21#define sum_rows_preamble \
22 struct htp_tensor *src0 = &octx->src0;\
23 struct htp_tensor *dst = &octx->dst; \
24 \
25 const uint32_t ne00 = src0->ne[0]; \
26 const uint32_t ne01 = src0->ne[1]; \
27 const uint32_t ne02 = src0->ne[2]; \
28 const uint32_t ne03 = src0->ne[3]; \
29 \
30 const uint32_t nb00 = src0->nb[0]; \
31 const uint32_t nb01 = src0->nb[1]; \
32 const uint32_t nb02 = src0->nb[2]; \
33 const uint32_t nb03 = src0->nb[3]; \
34 \
35 const uint32_t ne0 = dst->ne[0]; \
36 const uint32_t ne1 = dst->ne[1]; \
37 const uint32_t ne2 = dst->ne[2]; \
38 const uint32_t ne3 = dst->ne[3]; \
39 \
40 const uint32_t nb0 = dst->nb[0]; \
41 const uint32_t nb1 = dst->nb[1]; \
42 const uint32_t nb2 = dst->nb[2]; \
43 const uint32_t nb3 = dst->nb[3]; \
44
45static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
46 sum_rows_preamble;
47
48 const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
49 const size_t src0_row_size = nb01;
50 const size_t dst_row_size = nb1;
51
52 const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
53
54 const uint32_t src0_start_row = src0_nrows_per_thread * ith;
55 const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
56
57 // no work for this thread
58 if (src0_start_row >= src0_end_row) {
59 return HTP_STATUS_OK;
60 }
61
62 int opt_path = 0;
63 if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
64 opt_path = 1;
65 }
66
67 const uint8_t * restrict data_src = (const uint8_t *) src0->data;
68 uint8_t * restrict data_dst = (uint8_t *) dst->data;
69
70 const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
71 float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
72
73 for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
74 const float * restrict src_local = src_th + (ir * ne00);
75
76 if (ir + 1 < src0_nrows_per_thread) {
77 hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
78 }
79
80 if (1 == opt_path) {
81 dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
82 } else {
83 dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
84 }
85 }
86
87 return HTP_STATUS_OK;
88}
89
90static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
91 sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
92}
93
94int op_sum_rows(struct htp_ops_context * octx) {
95 sum_rows_preamble;
96
97 if (octx->src0.type != HTP_TYPE_F32) {
98 return HTP_STATUS_NO_SUPPORT;
99 }
100
101 if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
102 return HTP_STATUS_OK;
103 }
104
105 const int n_threads = octx->n_threads;
106 const uint32_t src0_nrows = ne01 * ne02 * ne03;
107
108 uint32_t n_jobs = MIN(n_threads, src0_nrows);
109 octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
110
111 worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
112
113 return HTP_STATUS_OK;
114}
115