1#pragma clang diagnostic ignored "-Wunused-variable"
  2#pragma clang diagnostic ignored "-Wunused-function"
  3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
  4
  5#include <HAP_farf.h>
  6#include <HAP_perf.h>
  7
  8#include <math.h>
  9#include <string.h>
 10
 11#define GGML_COMMON_DECL_C
 12#include "ggml-common.h"
 13#include "htp-ctx.h"
 14#include "htp-msg.h"
 15#include "htp-ops.h"
 16#include "hvx-utils.h"
 17
 18struct htp_copy_context {
 19    struct htp_ops_context * octx;
 20
 21    uint32_t          src0_type_size;
 22    uint32_t          src0_block_size;
 23
 24    uint32_t          dst_type_size;
 25    uint32_t          dst_block_size;
 26
 27    uint32_t          src0_blocks_per_row;
 28    uint32_t          dst_blocks_per_row;
 29
 30    uint32_t          src0_nrows_per_thread;
 31
 32    void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
 33};
 34
 35#define cpy_preamble                       \
 36    struct htp_tensor *src0 = &octx->src0; \
 37    struct htp_tensor *dst  = &octx->dst;  \
 38                                           \
 39    const uint32_t ne00 = src0->ne[0];     \
 40    const uint32_t ne01 = src0->ne[1];     \
 41    const uint32_t ne02 = src0->ne[2];     \
 42    const uint32_t ne03 = src0->ne[3];     \
 43                                           \
 44    const uint32_t nb00 = src0->nb[0];     \
 45    const uint32_t nb01 = src0->nb[1];     \
 46    const uint32_t nb02 = src0->nb[2];     \
 47    const uint32_t nb03 = src0->nb[3];     \
 48                                           \
 49    const uint32_t  ne0 = dst->ne[0];      \
 50    const uint32_t  ne1 = dst->ne[1];      \
 51    const uint32_t  ne2 = dst->ne[2];      \
 52    const uint32_t  ne3 = dst->ne[3];      \
 53                                           \
 54    const uint32_t  nb0 = dst->nb[0];      \
 55    const uint32_t  nb1 = dst->nb[1];      \
 56    const uint32_t  nb2 = dst->nb[2];      \
 57    const uint32_t  nb3 = dst->nb[3];      \
 58                                           \
 59    const uint32_t   nr = ne01;
 60
 61static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
 62    cpy_preamble;
 63
 64    // parallelize by src0 rows
 65    const uint32_t dr  = ct->src0_nrows_per_thread;
 66    const uint32_t ir0 = dr * ith;
 67    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
 68
 69    // copy by rows
 70    for (uint32_t i03 = 0; i03 < ne03; i03++) {
 71        for (uint32_t i02 = 0; i02 < ne02; i02++) {
 72            #pragma unroll(2)
 73            for (uint32_t i01 = ir0; i01 < ir1; i01++) {
 74                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;
 75                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
 76                hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2);
 77                hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size);
 78            }
 79        }
 80    }
 81}
 82
 83static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) {
 84    cpy_preamble;
 85
 86    // parallelize by src0 rows
 87    const uint32_t dr  = ct->src0_nrows_per_thread;
 88    const uint32_t ir0 = dr * ith;
 89    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
 90
 91    // dst counters
 92    int64_t k10 = 0;
 93    int64_t i11 = 0;
 94    int64_t i12 = 0;
 95    int64_t i13 = 0;
 96
 97    // number of blocks in a row
 98    const int64_t nk00 = ct->src0_blocks_per_row;
 99    const int64_t nk0  = ct->dst_blocks_per_row;
100
101    for (int64_t i03 = 0; i03 < ne03; i03++) {
102        for (int64_t i02 = 0; i02 < ne02; i02++) {
103            k10 += nk00 * ir0;
104            while (k10 >= nk0) {
105                k10 -= nk0;
106                if (++i11 == ne1) {
107                    i11 = 0;
108                    if (++i12 == ne2) {
109                        i12 = 0;
110                        if (++i13 == ne3) {
111                            i13 = 0;
112                        }
113                    }
114                }
115            }
116            for (int64_t i01 = ir0; i01 < ir1; i01++) {
117                for (int64_t k00 = 0; k00 < nk00; k00++) {
118                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
119                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
120                    memcpy(dst_ptr, src0_ptr, ct->dst_type_size);
121
122                    if (++k10 == nk0) {
123                        k10 = 0;
124                        if (++i11 == ne1) {
125                            i11 = 0;
126                            if (++i12 == ne2) {
127                                i12 = 0;
128                                if (++i13 == ne3) {
129                                    i13 = 0;
130                                }
131                            }
132                        }
133                    }
134                }
135            }
136            k10 += nk00 * (ne01 - ir1);
137            while (k10 >= nk0) {
138                k10 -= nk0;
139                if (++i11 == ne1) {
140                    i11 = 0;
141                    if (++i12 == ne2) {
142                        i12 = 0;
143                        if (++i13 == ne3) {
144                            i13 = 0;
145                        }
146                    }
147                }
148            }
149        }
150    }
151}
152
153static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
154    cpy_preamble;
155
156    // parallelize by src0 rows
157    const uint32_t dr  = ct->src0_nrows_per_thread;
158    const uint32_t ir0 = dr * ith;
159    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
160
161    // copy by rows
162    for (uint32_t i03 = 0; i03 < ne03; i03++) {
163        for (uint32_t i02 = 0; i02 < ne02; i02++) {
164            #pragma unroll(2)
165            for (uint32_t i01 = ir0; i01 < ir1; i01++) {
166                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;
167                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
168                hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2);
169                hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
170            }
171        }
172    }
173}
174
175static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
176    cpy_preamble;
177
178    // parallelize by src0 rows
179    const uint32_t dr  = ct->src0_nrows_per_thread;
180    const uint32_t ir0 = dr * ith;
181    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
182
183    // copy by rows
184    for (uint32_t i03 = 0; i03 < ne03; i03++) {
185        for (uint32_t i02 = 0; i02 < ne02; i02++) {
186            #pragma unroll(2)
187            for (uint32_t i01 = ir0; i01 < ir1; i01++) {
188                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;
189                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
190                hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2);
191                hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00);
192            }
193        }
194    }
195}
196
197static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
198    struct htp_copy_context *ct = (struct htp_copy_context *) data;
199    ct->copy(ct, ct->octx, n, i);
200}
201
202int op_cpy(struct htp_ops_context * octx) {
203    cpy_preamble;
204
205    struct htp_copy_context ct;
206    ct.octx = octx;
207
208    switch (src0->type) {
209    case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
210    case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
211    default:
212        return HTP_STATUS_NO_SUPPORT;
213    }
214
215    switch (dst->type) {
216    case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
217    case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
218    default:
219        return HTP_STATUS_NO_SUPPORT;
220    }
221
222    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
223        return HTP_STATUS_OK;
224    }
225
226    const bool sametype   = (src0->type == dst->type);
227    const bool transposed = (nb00 > nb01) || (nb0 > nb1);
228    const bool sameshape  = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
229
230    const uint32_t n_jobs = MIN(nr, octx->n_threads);
231    ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
232
233    if (sametype && sameshape) {
234        ct.copy = cpy_thread_sametype_sameshape;
235    } else if (sameshape) {
236        /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32)
237            ct.copy = cpy_thread_f16_f32_sameshape;
238        else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16)
239            ct.copy = cpy_thread_f32_f16_sameshape;
240        else
241            return HTP_STATUS_NO_SUPPORT;
242    } else if (sametype) {
243        ct.copy = cpy_thread_sametype_reshape;
244    } else {
245        return HTP_STATUS_NO_SUPPORT;
246    }
247
248    worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
249
250    return HTP_STATUS_OK;
251}