1#pragma clang diagnostic ignored "-Wunused-variable"
  2#pragma clang diagnostic ignored "-Wunused-function"
  3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
  4
  5#include <HAP_farf.h>
  6#include <HAP_perf.h>
  7
  8#include <math.h>
  9#include <string.h>
 10
 11#define GGML_COMMON_DECL_C
 12#include "ggml-common.h"
 13#include "htp-ctx.h"
 14#include "htp-msg.h"
 15#include "htp-ops.h"
 16#include "hvx-utils.h"
 17
 18#define get_rows_preamble \
 19    const uint32_t ne00 = octx->src0.ne[0]; \
 20    const uint32_t ne01 = octx->src0.ne[1]; \
 21    const uint32_t ne02 = octx->src0.ne[2]; \
 22    const uint32_t ne03 = octx->src0.ne[3]; \
 23                                            \
 24    const uint32_t ne10 = octx->src1.ne[0]; \
 25    const uint32_t ne11 = octx->src1.ne[1]; \
 26    const uint32_t ne12 = octx->src1.ne[2]; \
 27                                            \
 28    const uint32_t nb01 = octx->src0.nb[1]; \
 29    const uint32_t nb02 = octx->src0.nb[2]; \
 30    const uint32_t nb03 = octx->src0.nb[3]; \
 31                                            \
 32    const uint32_t nb10 = octx->src1.nb[0]; \
 33    const uint32_t nb11 = octx->src1.nb[1]; \
 34    const uint32_t nb12 = octx->src1.nb[2]; \
 35                                            \
 36    const uint32_t nb1 = octx->dst.nb[1];   \
 37    const uint32_t nb2 = octx->dst.nb[2];   \
 38    const uint32_t nb3 = octx->dst.nb[3];   \
 39                                            \
 40    const uint32_t nr = ne10 * ne11 * ne12;
 41
 42static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
 43    get_rows_preamble;
 44
 45    // parallelize by src1 elements (which correspond to dst rows)
 46    const uint32_t dr  = octx->src1_nrows_per_thread;
 47    const uint32_t ir0 = dr * ith;
 48    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 49
 50    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
 51
 52    for (uint32_t i = ir0; i < ir1; ++i) {
 53        const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
 54        const uint32_t rem = i - i12 * ne11 * ne10;
 55        const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
 56        const uint32_t i10 = rem - i11 * ne10;
 57
 58        const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
 59
 60        uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
 61
 62        if (i01 >= ne01) {
 63            // invalid index, skip for now to avoid crash
 64            continue;
 65        }
 66
 67        const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
 68        const uintptr_t dst_ptr  = octx->dst.data  + i10*nb1  + i11*nb2  + i12*nb3;
 69        hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
 70    }
 71
 72    return HTP_STATUS_OK;
 73}
 74
 75static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
 76    get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
 77}
 78
 79int op_get_rows(struct htp_ops_context * octx) {
 80    get_rows_preamble;
 81
 82    if (octx->src0.type != HTP_TYPE_F32) {
 83        return HTP_STATUS_NO_SUPPORT;
 84    }
 85
 86    if (octx->dst.type != HTP_TYPE_F32) {
 87        return HTP_STATUS_NO_SUPPORT;
 88    }
 89
 90    if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
 91        return HTP_STATUS_NO_SUPPORT;
 92    }
 93
 94    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
 95        return HTP_STATUS_OK;
 96    }
 97
 98    octx->get_rows_div_ne10      = init_fastdiv_values(octx->src1.ne[0]);
 99    octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
100
101    const uint32_t n_jobs = MIN(nr, octx->n_threads);
102    octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
103
104    worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
105    return HTP_STATUS_OK;
106}