1#pragma clang diagnostic ignored "-Wunused-variable"
2#pragma clang diagnostic ignored "-Wunused-function"
3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
5#include <HAP_farf.h>
6#include <HAP_perf.h>
7
8#include <math.h>
9#include <string.h>
10
11#include "hex-dma.h"
12#include "hvx-utils.h"
13
14#define GGML_COMMON_DECL_C
15#include "ggml-common.h"
16#include "htp-ctx.h"
17#include "htp-msg.h"
18#include "htp-ops.h"
19
20#define set_rows_preamble \
21 const uint32_t ne00 = octx->src0.ne[0]; \
22 const uint32_t ne01 = octx->src0.ne[1]; \
23 const uint32_t ne02 = octx->src0.ne[2]; \
24 const uint32_t ne03 = octx->src0.ne[3]; \
25 \
26 const uint32_t ne10 = octx->src1.ne[0]; \
27 const uint32_t ne11 = octx->src1.ne[1]; \
28 const uint32_t ne12 = octx->src1.ne[2]; \
29 \
30 const uint32_t nb01 = octx->src0.nb[1]; \
31 const uint32_t nb02 = octx->src0.nb[2]; \
32 const uint32_t nb03 = octx->src0.nb[3]; \
33 \
34 const uint32_t nb10 = octx->src1.nb[0]; \
35 const uint32_t nb11 = octx->src1.nb[1]; \
36 const uint32_t nb12 = octx->src1.nb[2]; \
37 \
38 const uint32_t nb1 = octx->dst.nb[1]; \
39 const uint32_t nb2 = octx->dst.nb[2]; \
40 const uint32_t nb3 = octx->dst.nb[3]; \
41 \
42 const uint32_t ne1 = octx->dst.ne[1]; \
43 \
44 const uint32_t nr = ne01;
45
46static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
47 set_rows_preamble;
48
49 // parallelize by rows of src0
50 const uint32_t dr = octx->src0_nrows_per_thread;
51 const uint32_t ir0 = dr * ith;
52 const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
53
54 const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
55
56 for (uint32_t i03 = 0; i03 < ne03; ++i03) {
57 for (uint32_t i02 = 0; i02 < ne02; ++i02) {
58 for (uint32_t i = ir0; i < ir1; ++i) {
59 const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
60 const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
61 const uint32_t i10 = i;
62
63 const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
64
65 uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
66 if (i1 >= ne1) {
67 // ignore invalid indices
68 continue;
69 }
70
71 const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
72 const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
73
74 // copy row
75 hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
76 }
77 }
78 }
79
80 return HTP_STATUS_OK;
81}
82
83static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
84 set_rows_preamble;
85
86 // parallelize by rows of src0
87 const uint32_t dr = octx->src0_nrows_per_thread;
88 const uint32_t ir0 = dr * ith;
89 const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
90
91 const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
92
93 for (uint32_t i03 = 0; i03 < ne03; ++i03) {
94 for (uint32_t i02 = 0; i02 < ne02; ++i02) {
95 for (uint32_t i = ir0; i < ir1; ++i) {
96 const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
97 const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
98 const uint32_t i10 = i;
99
100 const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
101
102 uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
103 if (i1 >= ne1) {
104 // ignore invalid indices
105 continue;
106 }
107
108 const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
109 uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
110
111 hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
112 }
113 }
114 }
115
116 return HTP_STATUS_OK;
117}
118
119static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
120 set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
121}
122
123static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
124 set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
125}
126
127int op_set_rows(struct htp_ops_context * octx) {
128 set_rows_preamble;
129
130 if (octx->src0.type != HTP_TYPE_F32) {
131 return HTP_STATUS_NO_SUPPORT;
132 }
133
134 if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
135 return HTP_STATUS_NO_SUPPORT;
136 }
137
138 if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
139 return HTP_STATUS_NO_SUPPORT;
140 }
141
142 if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
143 return HTP_STATUS_OK;
144 }
145
146 octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
147 octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
148
149 const uint32_t n_jobs = MIN(nr, octx->n_threads);
150 octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
151
152 switch(octx->dst.type) {
153 case HTP_TYPE_F32:
154 worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
155 break;
156 case HTP_TYPE_F16:
157 worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
158 break;
159 default:
160 return HTP_STATUS_NO_SUPPORT;
161 }
162
163 return HTP_STATUS_OK;
164}