1#pragma clang diagnostic ignored "-Wunused-variable"
2#pragma clang diagnostic ignored "-Wunused-function"
3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
5#include <HAP_farf.h>
6#include <HAP_perf.h>
7
8#include <math.h>
9#include <string.h>
10
11#include "hex-dma.h"
12#include "hvx-utils.h"
13
14#define GGML_COMMON_DECL_C
15#include "ggml-common.h"
16#include "htp-ctx.h"
17#include "htp-msg.h"
18#include "htp-ops.h"
19
20#ifndef MIN
21#define MIN(a, b) ((a) < (b) ? (a) : (b))
22#endif
23
24// Context for binary operations
25struct htp_binary_context {
26 struct htp_ops_context * octx;
27 struct fastdiv_values dim1_div;
28 struct fastdiv_values dim2_div;
29 struct fastdiv_values dim12_div;
30
31 struct fastdiv_values src1_dim1_div; // ne11
32 struct fastdiv_values src1_dim2_div; // ne12
33 struct fastdiv_values src1_dim3_div; // ne13
34
35 uint32_t nrows_per_thread;
36 bool split_at_ne01;
37 bool split_at_ne02;
38
39 // Precomputed values
40 uint32_t block_max;
41 size_t src0_row_size_aligned;
42 size_t src1_row_size_aligned;
43 size_t dst_row_size_aligned;
44 uint32_t src1_fetch_rows; // 1 or block_max
45 uint32_t src1_dma_stride; // 0 or stride
46};
47
48#define htp_binary_preamble \
49 const struct htp_tensor * src0 = &octx->src0; \
50 const struct htp_tensor * src1 = &octx->src1; \
51 struct htp_tensor * dst = &octx->dst; \
52 \
53 const uint32_t ne00 = src0->ne[0]; \
54 const uint32_t ne01 = src0->ne[1]; \
55 const uint32_t ne02 = src0->ne[2]; \
56 const uint32_t ne03 = src0->ne[3]; \
57 \
58 const uint32_t ne10 = src1->ne[0]; \
59 const uint32_t ne11 = src1->ne[1]; \
60 const uint32_t ne12 = src1->ne[2]; \
61 const uint32_t ne13 = src1->ne[3]; \
62 \
63 const uint32_t nb01 = src0->nb[1]; \
64 const uint32_t nb02 = src0->nb[2]; \
65 const uint32_t nb03 = src0->nb[3]; \
66 \
67 const uint32_t nb11 = src1->nb[1]; \
68 const uint32_t nb12 = src1->nb[2]; \
69 const uint32_t nb13 = src1->nb[3]; \
70 \
71 const uint32_t nb1 = dst->nb[1]; \
72 const uint32_t nb2 = dst->nb[2]; \
73 const uint32_t nb3 = dst->nb[3];
74
75static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
76 uint32_t ne01, uint32_t ne02) {
77 uint32_t i03, i02, i01, rem;
78 i03 = fastdiv(ir, &bctx->dim12_div);
79 rem = ir - i03 * (ne02 * ne01);
80 i02 = fastdiv(rem, &bctx->dim1_div);
81 i01 = rem - i02 * ne01;
82
83 uint32_t rows_left = end_row - ir;
84 uint32_t block_limit = rows_left;
85
86 if (bctx->split_at_ne01) {
87 block_limit = MIN(block_limit, ne01 - i01);
88 }
89 if (bctx->split_at_ne02) {
90 uint32_t rows_in_plane = (ne02 * ne01) - rem;
91 block_limit = MIN(block_limit, rows_in_plane);
92 }
93
94 return MIN(bctx->block_max, block_limit);
95}
96
97// Macro for scalar op switch
98#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
99 switch (octx->op) { \
100 case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
101 case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
102 case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
103 case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
104 default: break; \
105 }
106
107// Macro for vector op switch (All Aligned)
108#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
109 switch (octx->op) { \
110 case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
111 case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
112 case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
113 case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
114 default: break; \
115 }
116
117// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
118#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
119 switch (octx->op) { \
120 case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
121 case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
122 case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
123 case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
124 default: break; \
125 }
126
127// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
128#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
129 switch (octx->op) { \
130 case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
131 case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
132 case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
133 case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
134 default: break; \
135 }
136
137// 1. Scalar src1 (ne10 == 1)
138static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
139 struct htp_binary_context * bctx = (struct htp_binary_context *) data;
140 struct htp_ops_context * octx = bctx->octx;
141 htp_binary_preamble;
142
143 const uint32_t total_rows = ne01 * ne02 * ne03;
144 const uint32_t start_row = bctx->nrows_per_thread * ith;
145 const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
146 if (start_row >= end_row) return;
147
148 uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
149 uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
150 size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
151 size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
152
153 dma_queue * q = octx->ctx->dma[ith];
154 uint32_t ir_prefetch = start_row;
155 int spad_idx = 0;
156
157 // Preamble
158 for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
159 uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
160 uint32_t i03, i02, i01, rem;
161 i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
162 rem = ir_prefetch - i03 * (ne02 * ne01);
163 i02 = fastdiv(rem, &bctx->dim1_div);
164 i01 = rem - i02 * ne01;
165
166 uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
167 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
168
169 uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
170 uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
171
172 dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
173 dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
174 ir_prefetch += current_block_size;
175 spad_idx ^= 1;
176 }
177
178 // Main loop
179 for (uint32_t ir = start_row; ir < end_row; ) {
180 uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
181
182 uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
183 uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
184
185 uint32_t i03, i02, i01, rem;
186 i03 = fastdiv(ir, &bctx->dim12_div);
187 rem = ir - i03 * (ne02 * ne01);
188 i02 = fastdiv(rem, &bctx->dim1_div);
189 i01 = rem - i02 * ne01;
190
191 // src1 indices (broadcast/repeat)
192 uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
193 uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
194 uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);
195
196 uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
197 uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;
198
199 for (uint32_t r = 0; r < current_block_size; r++) {
200 uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
201 uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
202 float val = *(float *)src1_ptr;
203 src1_ptr += s1_stride;
204 COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
205 }
206
207 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
208 dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
209
210 if (ir_prefetch < end_row) {
211 uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
212 uint32_t p03, p02, p01, prem;
213 p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
214 prem = ir_prefetch - p03 * (ne02 * ne01);
215 p02 = fastdiv(prem, &bctx->dim1_div);
216 p01 = prem - p02 * ne01;
217 uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
218
219 dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
220 ir_prefetch += next_block_size;
221 }
222 ir += current_block_size;
223 }
224 dma_queue_flush(q);
225}
226
227// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast
228static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {
229 struct htp_binary_context * bctx = (struct htp_binary_context *) data;
230 struct htp_ops_context * octx = bctx->octx;
231 htp_binary_preamble;
232
233 const uint32_t total_rows = ne01 * ne02 * ne03;
234 const uint32_t start_row = bctx->nrows_per_thread * ith;
235 const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
236 if (start_row >= end_row) return;
237
238 uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
239 uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
240 uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
241
242 size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
243 size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;
244 size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
245
246 dma_queue * q = octx->ctx->dma[ith];
247 uint32_t ir_prefetch = start_row;
248 int spad_idx = 0;
249
250 for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
251 uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
252 uint32_t i03, i02, i01, rem;
253 i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
254 rem = ir_prefetch - i03 * (ne02 * ne01);
255 i02 = fastdiv(rem, &bctx->dim1_div);
256 i01 = rem - i02 * ne01;
257
258 uint32_t i13 = (ne13 == 1) ? 0 : i03;
259 uint32_t i12 = (ne12 == 1) ? 0 : i02;
260 uint32_t i11 = (ne11 == 1) ? 0 : i01;
261
262 uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
263 uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
264 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
265
266 uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
267 uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
268 uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
269
270 dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
271 dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
272 dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
273 ir_prefetch += current_block_size;
274 spad_idx ^= 1;
275 }
276
277 for (uint32_t ir = start_row; ir < end_row; ) {
278 uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
279 uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
280 uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
281 uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
282
283 for (uint32_t r = 0; r < current_block_size; r++) {
284 uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
285 uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
286 uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
287 COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
288 }
289
290 uint32_t i03, i02, i01, rem;
291 i03 = fastdiv(ir, &bctx->dim12_div);
292 rem = ir - i03 * (ne02 * ne01);
293 i02 = fastdiv(rem, &bctx->dim1_div);
294 i01 = rem - i02 * ne01;
295 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
296 dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
297
298 if (ir_prefetch < end_row) {
299 uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
300 uint32_t p03, p02, p01, prem;
301 p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
302 prem = ir_prefetch - p03 * (ne02 * ne01);
303 p02 = fastdiv(prem, &bctx->dim1_div);
304 p01 = prem - p02 * ne01;
305
306 uint32_t p13 = (ne13 == 1) ? 0 : p03;
307 uint32_t p12 = (ne12 == 1) ? 0 : p02;
308 uint32_t p11 = (ne11 == 1) ? 0 : p01;
309
310 uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
311 uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
312
313 dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
314 dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
315
316 ir_prefetch += next_block_size;
317 }
318 ir += current_block_size;
319 }
320 dma_queue_flush(q);
321}
322
323// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)
324static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {
325 struct htp_binary_context * bctx = (struct htp_binary_context *) data;
326 struct htp_ops_context * octx = bctx->octx;
327 htp_binary_preamble;
328
329 const uint32_t total_rows = ne01 * ne02 * ne03;
330 const uint32_t start_row = bctx->nrows_per_thread * ith;
331 const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
332 if (start_row >= end_row) return;
333
334 uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
335 uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
336 uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
337
338 size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
339 size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
340
341 dma_queue * q = octx->ctx->dma[ith];
342 uint32_t ir_prefetch = start_row;
343 int spad_idx = 0;
344
345 void * s1_ptr = (void *) src1_spad;
346
347 for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
348 uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
349 uint32_t i03, i02, i01, rem;
350 i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
351 rem = ir_prefetch - i03 * (ne02 * ne01);
352 i02 = fastdiv(rem, &bctx->dim1_div);
353 i01 = rem - i02 * ne01;
354
355 uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
356 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
357
358 uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
359 uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
360
361 dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
362 dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
363 ir_prefetch += current_block_size;
364 spad_idx ^= 1;
365 }
366
367 for (uint32_t ir = start_row; ir < end_row; ) {
368 uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
369 uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
370 uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
371
372 for (uint32_t r = 0; r < current_block_size; r++) {
373 uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
374 uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
375 uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
376 COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
377 }
378
379 uint32_t i03, i02, i01, rem;
380 i03 = fastdiv(ir, &bctx->dim12_div);
381 rem = ir - i03 * (ne02 * ne01);
382 i02 = fastdiv(rem, &bctx->dim1_div);
383 i01 = rem - i02 * ne01;
384 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
385 dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
386
387 if (ir_prefetch < end_row) {
388 uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
389 uint32_t p03, p02, p01, prem;
390 p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
391 prem = ir_prefetch - p03 * (ne02 * ne01);
392 p02 = fastdiv(prem, &bctx->dim1_div);
393 p01 = prem - p02 * ne01;
394 uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
395 dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
396 ir_prefetch += next_block_size;
397 }
398 ir += current_block_size;
399 }
400 dma_queue_flush(q);
401}
402
403// 4. Vector Complex (ne10 == ne00, complex broadcast)
404static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {
405 struct htp_binary_context * bctx = (struct htp_binary_context *) data;
406 struct htp_ops_context * octx = bctx->octx;
407 htp_binary_preamble;
408
409 const uint32_t total_rows = ne01 * ne02 * ne03;
410 const uint32_t start_row = bctx->nrows_per_thread * ith;
411 const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
412 if (start_row >= end_row) return;
413
414 uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
415 uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
416 size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
417 size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
418
419 dma_queue * q = octx->ctx->dma[ith];
420 uint32_t ir_prefetch = start_row;
421 int spad_idx = 0;
422
423 for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
424 uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
425 uint32_t i03, i02, i01, rem;
426 i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
427 rem = ir_prefetch - i03 * (ne02 * ne01);
428 i02 = fastdiv(rem, &bctx->dim1_div);
429 i01 = rem - i02 * ne01;
430
431 uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
432 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
433
434 uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
435 uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
436
437 dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
438 dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
439 ir_prefetch += current_block_size;
440 spad_idx ^= 1;
441 }
442
443 for (uint32_t ir = start_row; ir < end_row; ) {
444 uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
445 uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
446 uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
447
448 uint32_t i03, i02, i01, rem;
449 i03 = fastdiv(ir, &bctx->dim12_div);
450 rem = ir - i03 * (ne02 * ne01);
451 i02 = fastdiv(rem, &bctx->dim1_div);
452 i01 = rem - i02 * ne01;
453
454 for (uint32_t r = 0; r < current_block_size; r++) {
455 uint32_t r_i01 = i01 + r;
456 uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
457 uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
458 uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
459
460 uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
461 uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
462 uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
463
464 // Read src1 from DDR (unaligned)
465 COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
466 }
467
468 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
469 dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
470
471 if (ir_prefetch < end_row) {
472 uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
473 uint32_t p03, p02, p01, prem;
474 p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
475 prem = ir_prefetch - p03 * (ne02 * ne01);
476 p02 = fastdiv(prem, &bctx->dim1_div);
477 p01 = prem - p02 * ne01;
478 uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
479 dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
480 ir_prefetch += next_block_size;
481 }
482 ir += current_block_size;
483 }
484 dma_queue_flush(q);
485}
486
487// 5. Element Repeat (ne10 != ne00)
488static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {
489 struct htp_binary_context * bctx = (struct htp_binary_context *) data;
490 struct htp_ops_context * octx = bctx->octx;
491 htp_binary_preamble;
492
493 const uint32_t total_rows = ne01 * ne02 * ne03;
494 const uint32_t start_row = bctx->nrows_per_thread * ith;
495 const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
496 if (start_row >= end_row) return;
497
498 uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
499 uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
500 size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
501 size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
502
503 dma_queue * q = octx->ctx->dma[ith];
504 uint32_t ir_prefetch = start_row;
505 int spad_idx = 0;
506
507 for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
508 uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
509 uint32_t i03, i02, i01, rem;
510 i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
511 rem = ir_prefetch - i03 * (ne02 * ne01);
512 i02 = fastdiv(rem, &bctx->dim1_div);
513 i01 = rem - i02 * ne01;
514
515 uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
516 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
517
518 uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
519 uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
520
521 dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
522 dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
523 ir_prefetch += current_block_size;
524 spad_idx ^= 1;
525 }
526
527 for (uint32_t ir = start_row; ir < end_row; ) {
528 uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
529 uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
530 uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
531
532 uint32_t i03, i02, i01, rem;
533 i03 = fastdiv(ir, &bctx->dim12_div);
534 rem = ir - i03 * (ne02 * ne01);
535 i02 = fastdiv(rem, &bctx->dim1_div);
536 i01 = rem - i02 * ne01;
537
538 for (uint32_t r = 0; r < current_block_size; r++) {
539 uint32_t r_i01 = i01 + r;
540 uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
541 uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
542 uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
543
544 uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
545 uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
546 uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
547
548 // Repeat src1 row
549 for (uint32_t c = 0; c < ne00; c += ne10) {
550 uint32_t len = MIN(ne10, ne00 - c);
551 // Use UUU for speed and simplicity
552 COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
553 }
554 }
555
556 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
557 dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
558
559 if (ir_prefetch < end_row) {
560 uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
561 uint32_t p03, p02, p01, prem;
562 p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
563 prem = ir_prefetch - p03 * (ne02 * ne01);
564 p02 = fastdiv(prem, &bctx->dim1_div);
565 p01 = prem - p02 * ne01;
566 uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
567 dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
568 ir_prefetch += next_block_size;
569 }
570 ir += current_block_size;
571 }
572 dma_queue_flush(q);
573}
574
575// 6. ADD_ID (src1 gathered via src2 indices)
576static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
577 struct htp_binary_context * bctx = (struct htp_binary_context *) data;
578 struct htp_ops_context * octx = bctx->octx;
579
580 const struct htp_tensor * src0 = &octx->src0;
581 const struct htp_tensor * src1 = &octx->src1;
582 const struct htp_tensor * src2 = &octx->src2;
583 struct htp_tensor * dst = &octx->dst;
584
585 const uint32_t ne00 = src0->ne[0];
586 const uint32_t ne01 = src0->ne[1];
587 const uint32_t ne02 = src0->ne[2];
588 const uint32_t ne03 = src0->ne[3];
589 const uint32_t ne11 = src1->ne[1]; // for bounds check
590
591 const uint32_t nb01 = src0->nb[1];
592 const uint32_t nb02 = src0->nb[2];
593 const uint32_t nb03 = src0->nb[3];
594 const uint32_t nb11 = src1->nb[1]; // src1 row stride
595 const uint32_t nb1 = dst->nb[1];
596 const uint32_t nb2 = dst->nb[2];
597 const uint32_t nb3 = dst->nb[3];
598
599 const uint32_t total_rows = ne01 * ne02 * ne03;
600 const uint32_t start_row = bctx->nrows_per_thread * ith;
601 const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
602 if (start_row >= end_row) return;
603
604 uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
605 uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
606 size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
607 size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
608
609 dma_queue * q = octx->ctx->dma[ith];
610 uint32_t ir_prefetch = start_row;
611 int spad_idx = 0;
612
613 for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
614 uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
615 uint32_t i03, i02, i01, rem;
616 i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
617 rem = ir_prefetch - i03 * (ne02 * ne01);
618 i02 = fastdiv(rem, &bctx->dim1_div);
619 i01 = rem - i02 * ne01;
620
621 uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
622 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
623
624 uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
625 uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
626
627 dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
628 dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
629 ir_prefetch += current_block_size;
630 spad_idx ^= 1;
631 }
632
633 for (uint32_t ir = start_row; ir < end_row; ) {
634 uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
635 uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
636 uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
637
638 uint32_t i03, i02, i01, rem;
639 i03 = fastdiv(ir, &bctx->dim12_div);
640 rem = ir - i03 * (ne02 * ne01);
641 i02 = fastdiv(rem, &bctx->dim1_div);
642 i01 = rem - i02 * ne01;
643
644 for (uint32_t r = 0; r < current_block_size; r++) {
645 uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
646
647 const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);
648
649 uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;
650 uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
651 uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
652
653 hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);
654 }
655
656 uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
657 dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
658
659 if (ir_prefetch < end_row) {
660 uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
661 uint32_t p03, p02, p01, prem;
662 p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
663 prem = ir_prefetch - p03 * (ne02 * ne01);
664 p02 = fastdiv(prem, &bctx->dim1_div);
665 p01 = prem - p02 * ne01;
666 uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
667 dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
668 ir_prefetch += next_block_size;
669 }
670 ir += current_block_size;
671 }
672 dma_queue_flush(q);
673}
674
675static int execute_op_binary_f32(struct htp_ops_context * octx) {
676 const struct htp_tensor * src0 = &octx->src0;
677 const struct htp_tensor * src1 = &octx->src1;
678 struct htp_tensor * dst = &octx->dst;
679
680 const uint32_t n_threads = octx->n_threads;
681 const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
682
683 // Use packed row sizes for VTCM allocation
684 const size_t src0_row_size = src0->ne[0] * sizeof(float);
685 const size_t src1_row_size = src1->ne[0] * sizeof(float);
686 const size_t dst_row_size = dst->ne[0] * sizeof(float);
687
688 // Align to VLEN
689 const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
690 const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
691 size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
692
693 bool is_add_id = (octx->op == HTP_OP_ADD_ID);
694 bool is_scalar = !is_add_id && (src1->ne[0] == 1);
695
696 // Determine which kernel we will use to alloc memory and dispatch
697 bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
698 (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
699 (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
700 (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
701
702 bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
703 bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
704 bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
705
706 size_t spad_row_total;
707 if (is_scalar) {
708 spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
709 } else if (is_row_bcast) {
710 spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
711 } else if (use_vector_same) {
712 spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
713 } else if (is_add_id) {
714 spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
715 } else {
716 spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
717 }
718
719 size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
720 // Adjust for static src1 in row_bcast case
721 if (is_row_bcast) {
722 size_t needed_static = src1_row_size_aligned;
723 if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;
724 size_t avail = octx->ctx->vtcm_size - needed_static;
725 rows_per_buffer = avail / (n_threads * spad_row_total);
726 }
727
728 if (rows_per_buffer < 1) {
729 FARF(ERROR, "binary-f32: VTCM too small\n");
730 return HTP_STATUS_VTCM_TOO_SMALL;
731 }
732
733 octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
734 octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned;
735
736 if (is_scalar || use_complex || use_repeat || is_add_id) {
737 octx->src1_spad.size_per_thread = 0;
738 } else if (is_row_bcast) {
739 octx->src1_spad.size_per_thread = 0;
740 } else {
741 octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
742 }
743
744 octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
745 if (is_row_bcast) {
746 octx->src1_spad.size = src1_row_size_aligned;
747 } else {
748 octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
749 }
750 octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
751
752 if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
753 return HTP_STATUS_VTCM_TOO_SMALL;
754 }
755
756 octx->src0_spad.data = octx->ctx->vtcm_base;
757 octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
758 octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
759
760 if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
761 return HTP_STATUS_OK;
762 }
763
764 uint32_t n_jobs = MIN(n_threads, src0_nrows);
765
766 dma_queue * q = octx->ctx->dma[0];
767 if (is_row_bcast) {
768 dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
769 }
770
771 struct htp_binary_context bctx;
772 bctx.octx = octx;
773 bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
774 bctx.block_max = rows_per_buffer;
775 bctx.src0_row_size_aligned = src0_row_size_aligned;
776 bctx.src1_row_size_aligned = src1_row_size_aligned;
777 bctx.dst_row_size_aligned = dst_row_size_aligned;
778
779 bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
780 bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
781 bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
782
783 bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
784 bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
785 bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
786
787 bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
788 bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
789
790 bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
791 bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
792
793 bctx.split_at_ne01 = (src0->ne[2] > 1) &&
794 ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
795
796 bctx.split_at_ne02 = (src0->ne[3] > 1) &&
797 ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
798
799 // Precompute specific kernel parameters
800 if (use_vector_same) {
801 bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
802 bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
803 }
804
805 worker_callback_t worker_func;
806 if (is_add_id) worker_func = binary_job_add_id;
807 else if (is_scalar) worker_func = binary_job_scalar;
808 else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
809 else if (use_vector_same) worker_func = binary_job_vector_same_shape;
810 else if (use_complex) worker_func = binary_job_vector_complex;
811 else worker_func = binary_job_element_repeat;
812
813 if (is_row_bcast) {
814 dma_queue_pop(q);
815 }
816
817 worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
818
819 return HTP_STATUS_OK;
820}
821
822int op_binary(struct htp_ops_context * octx) {
823 if (octx->src0.type == HTP_TYPE_F32) {
824 return execute_op_binary_f32(octx);
825 }
826 return HTP_STATUS_NO_SUPPORT;
827}