1#pragma clang diagnostic ignored "-Wunused-variable"
  2#pragma clang diagnostic ignored "-Wunused-function"
  3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
  4
  5#include <HAP_farf.h>
  6#include <HAP_perf.h>
  7
  8#include <math.h>
  9#include <string.h>
 10
 11#include "hex-dma.h"
 12#include "hvx-utils.h"
 13
 14#define GGML_COMMON_DECL_C
 15#include "ggml-common.h"
 16#include "htp-ctx.h"
 17#include "htp-msg.h"
 18#include "htp-ops.h"
 19
 20#define set_rows_preamble \
 21    const uint32_t ne00 = octx->src0.ne[0]; \
 22    const uint32_t ne01 = octx->src0.ne[1]; \
 23    const uint32_t ne02 = octx->src0.ne[2]; \
 24    const uint32_t ne03 = octx->src0.ne[3]; \
 25                                            \
 26    const uint32_t ne10 = octx->src1.ne[0]; \
 27    const uint32_t ne11 = octx->src1.ne[1]; \
 28    const uint32_t ne12 = octx->src1.ne[2]; \
 29                                            \
 30    const uint32_t nb01 = octx->src0.nb[1]; \
 31    const uint32_t nb02 = octx->src0.nb[2]; \
 32    const uint32_t nb03 = octx->src0.nb[3]; \
 33                                            \
 34    const uint32_t nb10 = octx->src1.nb[0]; \
 35    const uint32_t nb11 = octx->src1.nb[1]; \
 36    const uint32_t nb12 = octx->src1.nb[2]; \
 37                                            \
 38    const uint32_t nb1 = octx->dst.nb[1];   \
 39    const uint32_t nb2 = octx->dst.nb[2];   \
 40    const uint32_t nb3 = octx->dst.nb[3];   \
 41                                            \
 42    const uint32_t ne1 = octx->dst.ne[1];   \
 43                                            \
 44    const uint32_t nr  = ne01;
 45
 46static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
 47    set_rows_preamble;
 48
 49    // parallelize by rows of src0
 50    const uint32_t dr  = octx->src0_nrows_per_thread;
 51    const uint32_t ir0 = dr * ith;
 52    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 53
 54    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
 55
 56    for (uint32_t i03 = 0; i03 < ne03; ++i03) {
 57        for (uint32_t i02 = 0; i02 < ne02; ++i02) {
 58            for (uint32_t i = ir0; i < ir1; ++i) {
 59                const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
 60                const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
 61                const uint32_t i10 = i;
 62
 63                const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
 64
 65                uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
 66                if (i1 >= ne1) {
 67                    // ignore invalid indices
 68                    continue;
 69                }
 70
 71                const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
 72                const uintptr_t dst_ptr  = octx->dst.data  + i1*nb1 + i02*nb2  + i03*nb3;
 73
 74                // copy row
 75                hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
 76            }
 77        }
 78    }
 79
 80    return HTP_STATUS_OK;
 81}
 82
 83static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
 84    set_rows_preamble;
 85
 86    // parallelize by rows of src0
 87    const uint32_t dr  = octx->src0_nrows_per_thread;
 88    const uint32_t ir0 = dr * ith;
 89    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 90
 91    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
 92
 93    for (uint32_t i03 = 0; i03 < ne03; ++i03) {
 94        for (uint32_t i02 = 0; i02 < ne02; ++i02) {
 95            for (uint32_t i = ir0; i < ir1; ++i) {
 96                const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
 97                const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
 98                const uint32_t i10 = i;
 99
100                const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
101
102                uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
103                if (i1 >= ne1) {
104                    // ignore invalid indices
105                    continue;
106                }
107
108                const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
109                uint8_t*       dst_ptr  = (uint8_t *)       octx->dst.data  + i1*nb1 + i02*nb2  + i03*nb3;
110
111                hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
112            }
113        }
114    }
115
116    return HTP_STATUS_OK;
117}
118
119static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
120    set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
121}
122
123static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
124    set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
125}
126
127int op_set_rows(struct htp_ops_context * octx) {
128    set_rows_preamble;
129
130    if (octx->src0.type != HTP_TYPE_F32) {
131        return HTP_STATUS_NO_SUPPORT;
132    }
133
134    if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
135        return HTP_STATUS_NO_SUPPORT;
136    }
137
138    if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
139        return HTP_STATUS_NO_SUPPORT;
140    }
141
142    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
143        return HTP_STATUS_OK;
144    }
145
146    octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
147    octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
148
149    const uint32_t n_jobs = MIN(nr, octx->n_threads);
150    octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
151
152    switch(octx->dst.type) {
153    case HTP_TYPE_F32:
154        worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
155        break;
156    case HTP_TYPE_F16:
157        worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
158        break;
159    default:
160        return HTP_STATUS_NO_SUPPORT;
161    }
162
163    return HTP_STATUS_OK;
164}