1#pragma clang diagnostic ignored "-Wunused-variable"
  2#pragma clang diagnostic ignored "-Wunused-function"
  3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
  4
  5#include <HAP_farf.h>
  6#include <HAP_perf.h>
  7
  8#include <math.h>
  9#include <string.h>
 10
 11#include "hex-dma.h"
 12#include "hvx-utils.h"
 13
 14#define GGML_COMMON_DECL_C
 15#include "ggml-common.h"
 16#include "htp-ctx.h"
 17#include "htp-msg.h"
 18#include "htp-ops.h"
 19
 20#ifndef MIN
 21#define MIN(a, b) ((a) < (b) ? (a) : (b))
 22#endif
 23
 24// Context for binary operations
 25struct htp_binary_context {
 26    struct htp_ops_context * octx;
 27    struct fastdiv_values dim1_div;
 28    struct fastdiv_values dim2_div;
 29    struct fastdiv_values dim12_div;
 30
 31    struct fastdiv_values src1_dim1_div; // ne11
 32    struct fastdiv_values src1_dim2_div; // ne12
 33    struct fastdiv_values src1_dim3_div; // ne13
 34
 35    uint32_t nrows_per_thread;
 36    bool split_at_ne01;
 37    bool split_at_ne02;
 38
 39    // Precomputed values
 40    uint32_t block_max;
 41    size_t   src0_row_size_aligned;
 42    size_t   src1_row_size_aligned;
 43    size_t   dst_row_size_aligned;
 44    uint32_t src1_fetch_rows; // 1 or block_max
 45    uint32_t src1_dma_stride; // 0 or stride
 46};
 47
 48#define htp_binary_preamble            \
 49    const struct htp_tensor * src0 = &octx->src0; \
 50    const struct htp_tensor * src1 = &octx->src1; \
 51    struct htp_tensor *       dst  = &octx->dst;  \
 52                                       \
 53    const uint32_t ne00 = src0->ne[0]; \
 54    const uint32_t ne01 = src0->ne[1]; \
 55    const uint32_t ne02 = src0->ne[2]; \
 56    const uint32_t ne03 = src0->ne[3]; \
 57                                       \
 58    const uint32_t ne10 = src1->ne[0]; \
 59    const uint32_t ne11 = src1->ne[1]; \
 60    const uint32_t ne12 = src1->ne[2]; \
 61    const uint32_t ne13 = src1->ne[3]; \
 62                                       \
 63    const uint32_t nb01 = src0->nb[1]; \
 64    const uint32_t nb02 = src0->nb[2]; \
 65    const uint32_t nb03 = src0->nb[3]; \
 66                                       \
 67    const uint32_t nb11 = src1->nb[1]; \
 68    const uint32_t nb12 = src1->nb[2]; \
 69    const uint32_t nb13 = src1->nb[3]; \
 70                                       \
 71    const uint32_t nb1 = dst->nb[1];   \
 72    const uint32_t nb2 = dst->nb[2];   \
 73    const uint32_t nb3 = dst->nb[3];
 74
 75static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
 76                                uint32_t ne01, uint32_t ne02) {
 77    uint32_t i03, i02, i01, rem;
 78    i03 = fastdiv(ir, &bctx->dim12_div);
 79    rem = ir - i03 * (ne02 * ne01);
 80    i02 = fastdiv(rem, &bctx->dim1_div);
 81    i01 = rem - i02 * ne01;
 82
 83    uint32_t rows_left = end_row - ir;
 84    uint32_t block_limit = rows_left;
 85
 86    if (bctx->split_at_ne01) {
 87        block_limit = MIN(block_limit, ne01 - i01);
 88    }
 89    if (bctx->split_at_ne02) {
 90         uint32_t rows_in_plane = (ne02 * ne01) - rem;
 91         block_limit = MIN(block_limit, rows_in_plane);
 92    }
 93
 94    return MIN(bctx->block_max, block_limit);
 95}
 96
 97// Macro for scalar op switch
 98#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
 99    switch (octx->op) { \
100        case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
101        case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
102        case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
103        case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
104        default: break; \
105    }
106
107// Macro for vector op switch (All Aligned)
108#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
109    switch (octx->op) { \
110        case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
111        case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
112        case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
113        case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
114        default: break; \
115    }
116
117// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
118#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
119    switch (octx->op) { \
120        case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
121        case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
122        case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
123        case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
124        default: break; \
125    }
126
127// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
128#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
129    switch (octx->op) { \
130        case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
131        case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
132        case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
133        case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
134        default: break; \
135    }
136
137// 1. Scalar src1 (ne10 == 1)
138static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
139    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
140    struct htp_ops_context * octx = bctx->octx;
141    htp_binary_preamble;
142
143    const uint32_t total_rows = ne01 * ne02 * ne03;
144    const uint32_t start_row = bctx->nrows_per_thread * ith;
145    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
146    if (start_row >= end_row) return;
147
148    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
149    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
150    size_t src0_spad_half    = octx->src0_spad.size_per_thread / 2;
151    size_t dst_spad_half     = octx->dst_spad.size_per_thread  / 2;
152
153    dma_queue * q = octx->ctx->dma[ith];
154    uint32_t ir_prefetch = start_row;
155    int spad_idx = 0;
156
157    // Preamble
158    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
159        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
160        uint32_t i03, i02, i01, rem;
161        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
162        rem = ir_prefetch - i03 * (ne02 * ne01);
163        i02 = fastdiv(rem, &bctx->dim1_div);
164        i01 = rem - i02 * ne01;
165
166        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
167        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
168
169        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
170        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
171
172        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
173        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
174        ir_prefetch += current_block_size;
175        spad_idx ^= 1;
176    }
177
178    // Main loop
179    for (uint32_t ir = start_row; ir < end_row; ) {
180        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
181
182        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
183        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
184
185        uint32_t i03, i02, i01, rem;
186        i03 = fastdiv(ir, &bctx->dim12_div);
187        rem = ir - i03 * (ne02 * ne01);
188        i02 = fastdiv(rem, &bctx->dim1_div);
189        i01 = rem - i02 * ne01;
190
191        // src1 indices (broadcast/repeat)
192        uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
193        uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
194        uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);
195
196        uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
197        uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;
198
199        for (uint32_t r = 0; r < current_block_size; r++) {
200            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
201            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
202            float val = *(float *)src1_ptr;
203            src1_ptr += s1_stride;
204            COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
205        }
206
207        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
208        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
209
210        if (ir_prefetch < end_row) {
211             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
212             uint32_t p03, p02, p01, prem;
213             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
214             prem = ir_prefetch - p03 * (ne02 * ne01);
215             p02 = fastdiv(prem, &bctx->dim1_div);
216             p01 = prem - p02 * ne01;
217             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
218
219             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
220             ir_prefetch += next_block_size;
221        }
222        ir += current_block_size;
223    }
224    dma_queue_flush(q);
225}
226
227// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast
228static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {
229    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
230    struct htp_ops_context * octx = bctx->octx;
231    htp_binary_preamble;
232
233    const uint32_t total_rows = ne01 * ne02 * ne03;
234    const uint32_t start_row = bctx->nrows_per_thread * ith;
235    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
236    if (start_row >= end_row) return;
237
238    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
239    uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
240    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
241
242    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
243    size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;
244    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
245
246    dma_queue * q = octx->ctx->dma[ith];
247    uint32_t ir_prefetch = start_row;
248    int spad_idx = 0;
249
250    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
251        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
252        uint32_t i03, i02, i01, rem;
253        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
254        rem = ir_prefetch - i03 * (ne02 * ne01);
255        i02 = fastdiv(rem, &bctx->dim1_div);
256        i01 = rem - i02 * ne01;
257
258        uint32_t i13 = (ne13 == 1) ? 0 : i03;
259        uint32_t i12 = (ne12 == 1) ? 0 : i02;
260        uint32_t i11 = (ne11 == 1) ? 0 : i01;
261
262        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
263        uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
264        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
265
266        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
267        uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
268        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
269
270        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
271        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
272        dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
273        ir_prefetch += current_block_size;
274        spad_idx ^= 1;
275    }
276
277    for (uint32_t ir = start_row; ir < end_row; ) {
278        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
279        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
280        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
281        uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
282
283        for (uint32_t r = 0; r < current_block_size; r++) {
284            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
285            uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
286            uint8_t * r_dst  = d_spad  + r * bctx->dst_row_size_aligned;
287            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
288        }
289
290        uint32_t i03, i02, i01, rem;
291        i03 = fastdiv(ir, &bctx->dim12_div);
292        rem = ir - i03 * (ne02 * ne01);
293        i02 = fastdiv(rem, &bctx->dim1_div);
294        i01 = rem - i02 * ne01;
295        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
296        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
297
298        if (ir_prefetch < end_row) {
299             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
300             uint32_t p03, p02, p01, prem;
301             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
302             prem = ir_prefetch - p03 * (ne02 * ne01);
303             p02 = fastdiv(prem, &bctx->dim1_div);
304             p01 = prem - p02 * ne01;
305
306             uint32_t p13 = (ne13 == 1) ? 0 : p03;
307             uint32_t p12 = (ne12 == 1) ? 0 : p02;
308             uint32_t p11 = (ne11 == 1) ? 0 : p01;
309
310             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
311             uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
312
313             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
314             dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
315
316             ir_prefetch += next_block_size;
317        }
318        ir += current_block_size;
319    }
320    dma_queue_flush(q);
321}
322
323// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)
324static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {
325    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
326    struct htp_ops_context * octx = bctx->octx;
327    htp_binary_preamble;
328
329    const uint32_t total_rows = ne01 * ne02 * ne03;
330    const uint32_t start_row = bctx->nrows_per_thread * ith;
331    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
332    if (start_row >= end_row) return;
333
334    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
335    uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
336    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
337
338    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
339    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
340
341    dma_queue * q = octx->ctx->dma[ith];
342    uint32_t ir_prefetch = start_row;
343    int spad_idx = 0;
344
345    void * s1_ptr = (void *) src1_spad;
346
347    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
348        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
349        uint32_t i03, i02, i01, rem;
350        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
351        rem = ir_prefetch - i03 * (ne02 * ne01);
352        i02 = fastdiv(rem, &bctx->dim1_div);
353        i01 = rem - i02 * ne01;
354
355        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
356        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
357
358        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
359        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
360
361        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
362        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
363        ir_prefetch += current_block_size;
364        spad_idx ^= 1;
365    }
366
367    for (uint32_t ir = start_row; ir < end_row; ) {
368        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
369        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
370        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
371
372        for (uint32_t r = 0; r < current_block_size; r++) {
373            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
374            uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
375            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
376            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
377        }
378
379        uint32_t i03, i02, i01, rem;
380        i03 = fastdiv(ir, &bctx->dim12_div);
381        rem = ir - i03 * (ne02 * ne01);
382        i02 = fastdiv(rem, &bctx->dim1_div);
383        i01 = rem - i02 * ne01;
384        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
385        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
386
387        if (ir_prefetch < end_row) {
388             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
389             uint32_t p03, p02, p01, prem;
390             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
391             prem = ir_prefetch - p03 * (ne02 * ne01);
392             p02 = fastdiv(prem, &bctx->dim1_div);
393             p01 = prem - p02 * ne01;
394             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
395             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
396             ir_prefetch += next_block_size;
397        }
398        ir += current_block_size;
399    }
400    dma_queue_flush(q);
401}
402
403// 4. Vector Complex (ne10 == ne00, complex broadcast)
404static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {
405    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
406    struct htp_ops_context * octx = bctx->octx;
407    htp_binary_preamble;
408
409    const uint32_t total_rows = ne01 * ne02 * ne03;
410    const uint32_t start_row = bctx->nrows_per_thread * ith;
411    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
412    if (start_row >= end_row) return;
413
414    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
415    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
416    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
417    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
418
419    dma_queue * q = octx->ctx->dma[ith];
420    uint32_t ir_prefetch = start_row;
421    int spad_idx = 0;
422
423    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
424        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
425        uint32_t i03, i02, i01, rem;
426        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
427        rem = ir_prefetch - i03 * (ne02 * ne01);
428        i02 = fastdiv(rem, &bctx->dim1_div);
429        i01 = rem - i02 * ne01;
430
431        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
432        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
433
434        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
435        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
436
437        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
438        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
439        ir_prefetch += current_block_size;
440        spad_idx ^= 1;
441    }
442
443    for (uint32_t ir = start_row; ir < end_row; ) {
444        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
445        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
446        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
447
448        uint32_t i03, i02, i01, rem;
449        i03 = fastdiv(ir, &bctx->dim12_div);
450        rem = ir - i03 * (ne02 * ne01);
451        i02 = fastdiv(rem, &bctx->dim1_div);
452        i01 = rem - i02 * ne01;
453
454        for (uint32_t r = 0; r < current_block_size; r++) {
455            uint32_t r_i01 = i01 + r;
456            uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
457            uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
458            uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
459
460            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
461            uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
462            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
463
464            // Read src1 from DDR (unaligned)
465            COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
466        }
467
468        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
469        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
470
471        if (ir_prefetch < end_row) {
472             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
473             uint32_t p03, p02, p01, prem;
474             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
475             prem = ir_prefetch - p03 * (ne02 * ne01);
476             p02 = fastdiv(prem, &bctx->dim1_div);
477             p01 = prem - p02 * ne01;
478             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
479             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
480             ir_prefetch += next_block_size;
481        }
482        ir += current_block_size;
483    }
484    dma_queue_flush(q);
485}
486
487// 5. Element Repeat (ne10 != ne00)
488static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {
489    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
490    struct htp_ops_context * octx = bctx->octx;
491    htp_binary_preamble;
492
493    const uint32_t total_rows = ne01 * ne02 * ne03;
494    const uint32_t start_row = bctx->nrows_per_thread * ith;
495    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
496    if (start_row >= end_row) return;
497
498    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
499    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
500    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
501    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
502
503    dma_queue * q = octx->ctx->dma[ith];
504    uint32_t ir_prefetch = start_row;
505    int spad_idx = 0;
506
507    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
508        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
509        uint32_t i03, i02, i01, rem;
510        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
511        rem = ir_prefetch - i03 * (ne02 * ne01);
512        i02 = fastdiv(rem, &bctx->dim1_div);
513        i01 = rem - i02 * ne01;
514
515        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
516        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
517
518        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
519        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
520
521        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
522        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
523        ir_prefetch += current_block_size;
524        spad_idx ^= 1;
525    }
526
527    for (uint32_t ir = start_row; ir < end_row; ) {
528        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
529        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
530        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
531
532        uint32_t i03, i02, i01, rem;
533        i03 = fastdiv(ir, &bctx->dim12_div);
534        rem = ir - i03 * (ne02 * ne01);
535        i02 = fastdiv(rem, &bctx->dim1_div);
536        i01 = rem - i02 * ne01;
537
538        for (uint32_t r = 0; r < current_block_size; r++) {
539            uint32_t r_i01 = i01 + r;
540            uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
541            uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
542            uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
543
544            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
545            uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
546            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
547
548            // Repeat src1 row
549            for (uint32_t c = 0; c < ne00; c += ne10) {
550                uint32_t len = MIN(ne10, ne00 - c);
551                // Use UUU for speed and simplicity
552                COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
553            }
554        }
555
556        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
557        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
558
559        if (ir_prefetch < end_row) {
560             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
561             uint32_t p03, p02, p01, prem;
562             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
563             prem = ir_prefetch - p03 * (ne02 * ne01);
564             p02 = fastdiv(prem, &bctx->dim1_div);
565             p01 = prem - p02 * ne01;
566             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
567             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
568             ir_prefetch += next_block_size;
569        }
570        ir += current_block_size;
571    }
572    dma_queue_flush(q);
573}
574
575// 6. ADD_ID (src1 gathered via src2 indices)
576static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
577    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
578    struct htp_ops_context * octx = bctx->octx;
579
580    const struct htp_tensor * src0 = &octx->src0;
581    const struct htp_tensor * src1 = &octx->src1;
582    const struct htp_tensor * src2 = &octx->src2;
583    struct htp_tensor *       dst  = &octx->dst;
584
585    const uint32_t ne00 = src0->ne[0];
586    const uint32_t ne01 = src0->ne[1];
587    const uint32_t ne02 = src0->ne[2];
588    const uint32_t ne03 = src0->ne[3];
589    const uint32_t ne11 = src1->ne[1]; // for bounds check
590
591    const uint32_t nb01 = src0->nb[1];
592    const uint32_t nb02 = src0->nb[2];
593    const uint32_t nb03 = src0->nb[3];
594    const uint32_t nb11 = src1->nb[1]; // src1 row stride
595    const uint32_t nb1 = dst->nb[1];
596    const uint32_t nb2 = dst->nb[2];
597    const uint32_t nb3 = dst->nb[3];
598
599    const uint32_t total_rows = ne01 * ne02 * ne03;
600    const uint32_t start_row = bctx->nrows_per_thread * ith;
601    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
602    if (start_row >= end_row) return;
603
604    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
605    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
606    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
607    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
608
609    dma_queue * q = octx->ctx->dma[ith];
610    uint32_t ir_prefetch = start_row;
611    int spad_idx = 0;
612
613    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
614        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
615        uint32_t i03, i02, i01, rem;
616        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
617        rem = ir_prefetch - i03 * (ne02 * ne01);
618        i02 = fastdiv(rem, &bctx->dim1_div);
619        i01 = rem - i02 * ne01;
620
621        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
622        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
623
624        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
625        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
626
627        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
628        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
629        ir_prefetch += current_block_size;
630        spad_idx ^= 1;
631    }
632
633    for (uint32_t ir = start_row; ir < end_row; ) {
634        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
635        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
636        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
637
638        uint32_t i03, i02, i01, rem;
639        i03 = fastdiv(ir, &bctx->dim12_div);
640        rem = ir - i03 * (ne02 * ne01);
641        i02 = fastdiv(rem, &bctx->dim1_div);
642        i01 = rem - i02 * ne01;
643
644        for (uint32_t r = 0; r < current_block_size; r++) {
645            uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
646
647            const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);
648
649            uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;
650            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
651            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
652
653            hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);
654        }
655
656        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
657        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
658
659        if (ir_prefetch < end_row) {
660             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
661             uint32_t p03, p02, p01, prem;
662             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
663             prem = ir_prefetch - p03 * (ne02 * ne01);
664             p02 = fastdiv(prem, &bctx->dim1_div);
665             p01 = prem - p02 * ne01;
666             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
667             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
668             ir_prefetch += next_block_size;
669        }
670        ir += current_block_size;
671    }
672    dma_queue_flush(q);
673}
674
675static int execute_op_binary_f32(struct htp_ops_context * octx) {
676    const struct htp_tensor * src0 = &octx->src0;
677    const struct htp_tensor * src1 = &octx->src1;
678    struct htp_tensor *       dst  = &octx->dst;
679
680    const uint32_t n_threads  = octx->n_threads;
681    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
682
683    // Use packed row sizes for VTCM allocation
684    const size_t src0_row_size = src0->ne[0] * sizeof(float);
685    const size_t src1_row_size = src1->ne[0] * sizeof(float);
686    const size_t dst_row_size  = dst->ne[0] * sizeof(float);
687
688    // Align to VLEN
689    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
690    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
691    size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
692
693    bool is_add_id = (octx->op == HTP_OP_ADD_ID);
694    bool is_scalar = !is_add_id && (src1->ne[0] == 1);
695
696    // Determine which kernel we will use to alloc memory and dispatch
697    bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
698               (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
699               (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
700               (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
701
702    bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
703    bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
704    bool use_repeat  = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
705
706    size_t spad_row_total;
707    if (is_scalar) {
708        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
709    } else if (is_row_bcast) {
710        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
711    } else if (use_vector_same) {
712        spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
713    } else if (is_add_id) {
714        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
715    } else {
716        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
717    }
718
719    size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
720    // Adjust for static src1 in row_bcast case
721    if (is_row_bcast) {
722        size_t needed_static = src1_row_size_aligned;
723        if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;
724        size_t avail = octx->ctx->vtcm_size - needed_static;
725        rows_per_buffer = avail / (n_threads * spad_row_total);
726    }
727
728    if (rows_per_buffer < 1) {
729         FARF(ERROR, "binary-f32: VTCM too small\n");
730         return HTP_STATUS_VTCM_TOO_SMALL;
731    }
732
733    octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
734    octx->dst_spad.size_per_thread  = rows_per_buffer * 2 * dst_row_size_aligned;
735
736    if (is_scalar || use_complex || use_repeat || is_add_id) {
737        octx->src1_spad.size_per_thread = 0;
738    } else if (is_row_bcast) {
739        octx->src1_spad.size_per_thread = 0;
740    } else {
741        octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
742    }
743
744    octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
745    if (is_row_bcast) {
746        octx->src1_spad.size = src1_row_size_aligned;
747    } else {
748        octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
749    }
750    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;
751
752    if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
753        return HTP_STATUS_VTCM_TOO_SMALL;
754    }
755
756    octx->src0_spad.data = octx->ctx->vtcm_base;
757    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
758    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
759
760    if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
761        return HTP_STATUS_OK;
762    }
763
764    uint32_t n_jobs = MIN(n_threads, src0_nrows);
765
766    dma_queue * q = octx->ctx->dma[0];
767    if (is_row_bcast) {
768        dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
769    }
770
771    struct htp_binary_context bctx;
772    bctx.octx = octx;
773    bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
774    bctx.block_max = rows_per_buffer;
775    bctx.src0_row_size_aligned = src0_row_size_aligned;
776    bctx.src1_row_size_aligned = src1_row_size_aligned;
777    bctx.dst_row_size_aligned  = dst_row_size_aligned;
778
779    bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
780    bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
781    bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
782
783    bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
784    bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
785    bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
786
787    bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
788    bool dst_contig_dim1  = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
789
790    bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
791    bool dst_contig_dim2  = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
792
793    bctx.split_at_ne01 = (src0->ne[2] > 1) &&
794                         ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
795
796    bctx.split_at_ne02 = (src0->ne[3] > 1) &&
797                         ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
798
799    // Precompute specific kernel parameters
800    if (use_vector_same) {
801        bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
802        bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
803    }
804
805    worker_callback_t worker_func;
806    if (is_add_id) worker_func = binary_job_add_id;
807    else if (is_scalar) worker_func = binary_job_scalar;
808    else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
809    else if (use_vector_same) worker_func = binary_job_vector_same_shape;
810    else if (use_complex) worker_func = binary_job_vector_complex;
811    else worker_func = binary_job_element_repeat;
812
813    if (is_row_bcast) {
814        dma_queue_pop(q);
815    }
816
817    worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
818
819    return HTP_STATUS_OK;
820}
821
822int op_binary(struct htp_ops_context * octx) {
823    if (octx->src0.type == HTP_TYPE_F32) {
824        return execute_op_binary_f32(octx);
825    }
826    return HTP_STATUS_NO_SUPPORT;
827}