1#pragma clang diagnostic ignored "-Wunused-variable"
2#pragma clang diagnostic ignored "-Wunused-function"
3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
5#include <HAP_farf.h>
6#include <HAP_perf.h>
7
8#include <math.h>
9#include <string.h>
10
11#define GGML_COMMON_DECL_C
12#include "ggml-common.h"
13#include "htp-ctx.h"
14#include "htp-msg.h"
15#include "htp-ops.h"
16#include "hvx-utils.h"
17
18struct htp_copy_context {
19 struct htp_ops_context * octx;
20
21 uint32_t src0_type_size;
22 uint32_t src0_block_size;
23
24 uint32_t dst_type_size;
25 uint32_t dst_block_size;
26
27 uint32_t src0_blocks_per_row;
28 uint32_t dst_blocks_per_row;
29
30 uint32_t src0_nrows_per_thread;
31
32 void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
33};
34
35#define cpy_preamble \
36 struct htp_tensor *src0 = &octx->src0; \
37 struct htp_tensor *dst = &octx->dst; \
38 \
39 const uint32_t ne00 = src0->ne[0]; \
40 const uint32_t ne01 = src0->ne[1]; \
41 const uint32_t ne02 = src0->ne[2]; \
42 const uint32_t ne03 = src0->ne[3]; \
43 \
44 const uint32_t nb00 = src0->nb[0]; \
45 const uint32_t nb01 = src0->nb[1]; \
46 const uint32_t nb02 = src0->nb[2]; \
47 const uint32_t nb03 = src0->nb[3]; \
48 \
49 const uint32_t ne0 = dst->ne[0]; \
50 const uint32_t ne1 = dst->ne[1]; \
51 const uint32_t ne2 = dst->ne[2]; \
52 const uint32_t ne3 = dst->ne[3]; \
53 \
54 const uint32_t nb0 = dst->nb[0]; \
55 const uint32_t nb1 = dst->nb[1]; \
56 const uint32_t nb2 = dst->nb[2]; \
57 const uint32_t nb3 = dst->nb[3]; \
58 \
59 const uint32_t nr = ne01;
60
61static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
62 cpy_preamble;
63
64 // parallelize by src0 rows
65 const uint32_t dr = ct->src0_nrows_per_thread;
66 const uint32_t ir0 = dr * ith;
67 const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
68
69 // copy by rows
70 for (uint32_t i03 = 0; i03 < ne03; i03++) {
71 for (uint32_t i02 = 0; i02 < ne02; i02++) {
72 #pragma unroll(2)
73 for (uint32_t i01 = ir0; i01 < ir1; i01++) {
74 uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
75 uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
76 hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2);
77 hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size);
78 }
79 }
80 }
81}
82
83static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) {
84 cpy_preamble;
85
86 // parallelize by src0 rows
87 const uint32_t dr = ct->src0_nrows_per_thread;
88 const uint32_t ir0 = dr * ith;
89 const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
90
91 // dst counters
92 int64_t k10 = 0;
93 int64_t i11 = 0;
94 int64_t i12 = 0;
95 int64_t i13 = 0;
96
97 // number of blocks in a row
98 const int64_t nk00 = ct->src0_blocks_per_row;
99 const int64_t nk0 = ct->dst_blocks_per_row;
100
101 for (int64_t i03 = 0; i03 < ne03; i03++) {
102 for (int64_t i02 = 0; i02 < ne02; i02++) {
103 k10 += nk00 * ir0;
104 while (k10 >= nk0) {
105 k10 -= nk0;
106 if (++i11 == ne1) {
107 i11 = 0;
108 if (++i12 == ne2) {
109 i12 = 0;
110 if (++i13 == ne3) {
111 i13 = 0;
112 }
113 }
114 }
115 }
116 for (int64_t i01 = ir0; i01 < ir1; i01++) {
117 for (int64_t k00 = 0; k00 < nk00; k00++) {
118 const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
119 char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
120 memcpy(dst_ptr, src0_ptr, ct->dst_type_size);
121
122 if (++k10 == nk0) {
123 k10 = 0;
124 if (++i11 == ne1) {
125 i11 = 0;
126 if (++i12 == ne2) {
127 i12 = 0;
128 if (++i13 == ne3) {
129 i13 = 0;
130 }
131 }
132 }
133 }
134 }
135 }
136 k10 += nk00 * (ne01 - ir1);
137 while (k10 >= nk0) {
138 k10 -= nk0;
139 if (++i11 == ne1) {
140 i11 = 0;
141 if (++i12 == ne2) {
142 i12 = 0;
143 if (++i13 == ne3) {
144 i13 = 0;
145 }
146 }
147 }
148 }
149 }
150 }
151}
152
153static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
154 cpy_preamble;
155
156 // parallelize by src0 rows
157 const uint32_t dr = ct->src0_nrows_per_thread;
158 const uint32_t ir0 = dr * ith;
159 const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
160
161 // copy by rows
162 for (uint32_t i03 = 0; i03 < ne03; i03++) {
163 for (uint32_t i02 = 0; i02 < ne02; i02++) {
164 #pragma unroll(2)
165 for (uint32_t i01 = ir0; i01 < ir1; i01++) {
166 uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
167 uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
168 hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2);
169 hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
170 }
171 }
172 }
173}
174
175static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
176 cpy_preamble;
177
178 // parallelize by src0 rows
179 const uint32_t dr = ct->src0_nrows_per_thread;
180 const uint32_t ir0 = dr * ith;
181 const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
182
183 // copy by rows
184 for (uint32_t i03 = 0; i03 < ne03; i03++) {
185 for (uint32_t i02 = 0; i02 < ne02; i02++) {
186 #pragma unroll(2)
187 for (uint32_t i01 = ir0; i01 < ir1; i01++) {
188 uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
189 uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
190 hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2);
191 hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00);
192 }
193 }
194 }
195}
196
197static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
198 struct htp_copy_context *ct = (struct htp_copy_context *) data;
199 ct->copy(ct, ct->octx, n, i);
200}
201
202int op_cpy(struct htp_ops_context * octx) {
203 cpy_preamble;
204
205 struct htp_copy_context ct;
206 ct.octx = octx;
207
208 switch (src0->type) {
209 case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
210 case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
211 default:
212 return HTP_STATUS_NO_SUPPORT;
213 }
214
215 switch (dst->type) {
216 case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
217 case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
218 default:
219 return HTP_STATUS_NO_SUPPORT;
220 }
221
222 if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
223 return HTP_STATUS_OK;
224 }
225
226 const bool sametype = (src0->type == dst->type);
227 const bool transposed = (nb00 > nb01) || (nb0 > nb1);
228 const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
229
230 const uint32_t n_jobs = MIN(nr, octx->n_threads);
231 ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
232
233 if (sametype && sameshape) {
234 ct.copy = cpy_thread_sametype_sameshape;
235 } else if (sameshape) {
236 /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32)
237 ct.copy = cpy_thread_f16_f32_sameshape;
238 else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16)
239 ct.copy = cpy_thread_f32_f16_sameshape;
240 else
241 return HTP_STATUS_NO_SUPPORT;
242 } else if (sametype) {
243 ct.copy = cpy_thread_sametype_reshape;
244 } else {
245 return HTP_STATUS_NO_SUPPORT;
246 }
247
248 worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
249
250 return HTP_STATUS_OK;
251}