1#pragma clang diagnostic ignored "-Wunused-variable"
  2#pragma clang diagnostic ignored "-Wunused-function"
  3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
  4
  5#include <HAP_farf.h>
  6#include <HAP_perf.h>
  7
  8#include <math.h>
  9#include <string.h>
 10
 11#include "hex-dma.h"
 12#include "hvx-utils.h"
 13
 14#define GGML_COMMON_DECL_C
 15#include "ggml-common.h"
 16#include "htp-ctx.h"
 17#include "htp-msg.h"
 18#include "htp-ops.h"
 19
 20#define htp_unary_preamble            \
 21    const uint32_t ne00 = src->ne[0]; \
 22    const uint32_t ne01 = src->ne[1]; \
 23    const uint32_t ne02 = src->ne[2]; \
 24    const uint32_t ne03 = src->ne[3]; \
 25                                      \
 26    const uint32_t ne0 = dst->ne[0];  \
 27    const uint32_t ne1 = dst->ne[1];  \
 28    const uint32_t ne2 = dst->ne[2];  \
 29    const uint32_t ne3 = dst->ne[3];  \
 30                                      \
 31    const uint32_t nb00 = src->nb[0]; \
 32    const uint32_t nb01 = src->nb[1]; \
 33    const uint32_t nb02 = src->nb[2]; \
 34    const uint32_t nb03 = src->nb[3]; \
 35                                      \
 36    const uint32_t nb0 = dst->nb[0];  \
 37    const uint32_t nb1 = dst->nb[1];  \
 38    const uint32_t nb2 = dst->nb[2];  \
 39    const uint32_t nb3 = dst->nb[3];
 40
 41static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
 42                                  uint8_t * restrict dst,
 43                                  uint8_t * restrict pad,
 44                                  const int num_elems,
 45                                  float     epsilon) {
 46    const HVX_Vector * restrict v_src = (HVX_Vector *) src;
 47    HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
 48
 49    HVX_Vector sum_v     = Q6_V_vsplat_R(0x00000000);
 50    HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
 51
 52    int step_of_1 = num_elems >> 5;
 53    #pragma unroll(4)
 54    for (int i = 0; i < step_of_1; i++) {
 55        HVX_Vector v1 = v_src[i];
 56        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
 57        sum_v         = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
 58    }
 59
 60    HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
 61    sum_v                  = hvx_vec_repl4(reduced_sum);
 62
 63    HVX_Vector t_v            = hvx_vec_splat_f32((float) num_elems);
 64    HVX_Vector denom_v        = hvx_vec_inverse_f32(t_v);
 65    HVX_Vector mean_v         = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
 66    HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
 67
 68    HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
 69
 70    #pragma unroll(4)
 71    for (int i = 0; i < step_of_1; i++) {
 72        HVX_Vector v1 = v_src[i];
 73        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
 74        v_dst[i]      = Q6_Vsf_equals_Vqf32(v2);
 75    }
 76}
 77
 78static void scale_htp_f32(const float * restrict src,
 79                          float * restrict dst,
 80                          uint8_t * restrict spad,
 81                          const uint32_t num_rows,
 82                          const uint32_t row_elems,
 83                          const size_t   row_size,
 84                          int32_t *      op_params,
 85                          int            opt_path) {
 86    float scale = 0.f;
 87    float bias  = 0.f;
 88    memcpy(&scale, &op_params[0], sizeof(float));
 89    memcpy(&bias,  &op_params[1], sizeof(float));
 90
 91    for (uint32_t ir = 0; ir < num_rows; ir++) {
 92        const float * restrict src_local = src + (ir * row_elems);
 93        float * restrict dst_local       = dst + (ir * row_elems);
 94
 95        if (ir + 1 < num_rows) {
 96            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
 97        }
 98
 99        hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
100    }
101}
102
103static void rms_norm_htp_f32(const float * restrict src,
104                             float * restrict dst,
105                             uint8_t * restrict spad,
106                             const uint32_t num_rows,
107                             const uint32_t row_elems,
108                             const size_t   row_size,
109                             int32_t *      op_params,
110                             int            opt_path) {
111    float epsilon = 0.f;
112    memcpy(&epsilon, op_params, sizeof(float));
113
114    for (uint32_t ir = 0; ir < num_rows; ir++) {
115        const float * restrict src_local = src + (ir * row_elems);
116        float * restrict dst_local       = dst + (ir * row_elems);
117
118        if (ir + 1 < num_rows) {
119            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
120        }
121
122        if (1 == opt_path) {
123            hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
124        } else {
125            float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
126
127            const float mean  = sum / row_elems;
128            const float scale = 1.0f / sqrtf(mean + epsilon);
129
130            hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
131        }
132    }
133}
134
135static void sqr_htp_f32(const float * restrict src,
136                          float * restrict dst,
137                          uint8_t * restrict spad,
138                          const uint32_t num_rows,
139                          const uint32_t row_elems,
140                          const size_t   row_size,
141                          int32_t *      op_params,
142                          int            opt_path) {
143
144    for (uint32_t ir = 0; ir < num_rows; ir++) {
145        const float * restrict src_local = src + (ir * row_elems);
146        float * restrict dst_local       = dst + (ir * row_elems);
147
148        if (ir + 1 < num_rows) {
149            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
150        }
151
152        if (1 == opt_path) {
153            hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
154        } else {
155            hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
156        }
157    }
158}
159
160static void sqrt_htp_f32(const float * restrict src,
161                          float * restrict dst,
162                          uint8_t * restrict spad,
163                          const uint32_t num_rows,
164                          const uint32_t row_elems,
165                          const size_t   row_size,
166                          int32_t *      op_params,
167                          int            opt_path) {
168
169    for (uint32_t ir = 0; ir < num_rows; ir++) {
170        const float * restrict src_local = src + (ir * row_elems);
171        float * restrict dst_local       = dst + (ir * row_elems);
172
173        if (ir + 1 < num_rows) {
174            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
175        }
176
177        if (1 == opt_path) {
178            hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
179        } else {
180            hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
181        }
182    }
183}
184
185static void unary_job_f32_per_thread(const struct htp_tensor * src,
186                                     struct htp_tensor *       dst,
187                                     uint8_t *                 spad,
188                                     int                       htp_op,
189                                     int32_t *                 op_params,
190                                     uint32_t                  nth,
191                                     uint32_t                  ith,
192                                     uint32_t                  src0_nrows_per_thread) {
193    htp_unary_preamble;
194
195    const size_t src0_row_size = nb01;
196    const size_t dst_row_size  = nb1;
197
198    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
199
200    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
201    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
202
203    // no work for this thread
204    if (src0_start_row >= src0_end_row) {
205        return;
206    }
207
208    uint64_t t1, t2;
209    t1 = HAP_perf_get_qtimer_count();
210
211    int is_aligned = 1;
212    int opt_path   = 0;
213    if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
214        is_aligned = 0;
215    }
216    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
217        opt_path = 1;
218    }
219
220    const uint8_t * restrict data_src = (const uint8_t *) src->data;
221    uint8_t * restrict data_dst       = (uint8_t *) dst->data;
222
223    const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
224    float * restrict dst_th       = (float *) (data_dst + (src0_start_row * dst_row_size));
225    uint8_t * restrict spad_th    = (uint8_t *) spad + (ith * nb01);
226
227    switch (htp_op) {
228        case HTP_OP_RMS_NORM:
229            rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
230            break;
231        case HTP_OP_SCALE:
232            scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
233            break;
234        case HTP_OP_SQR:
235            sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
236            break;
237        case HTP_OP_SQRT:
238            sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
239            break;
240
241        default:
242            break;
243    }
244
245    t2 = HAP_perf_get_qtimer_count();
246
247    FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
248         src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
249         dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
250}
251
252static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
253    struct htp_ops_context * octx = (struct htp_ops_context *) data;
254
255    unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
256                             octx->src0_nrows_per_thread);
257}
258
259static int execute_op_unary_f32(struct htp_ops_context * octx) {
260    int err = HTP_STATUS_OK;
261
262    const struct htp_tensor * src0 = &octx->src0;
263    struct htp_tensor *       dst  = &octx->dst;
264
265    worker_callback_t unary_op_func;
266    const char *      op_type = NULL;
267
268    switch (octx->op) {
269        case HTP_OP_RMS_NORM:
270            unary_op_func = unary_job_dispatcher_f32;
271            op_type       = "rmsnorm-f32";
272            break;
273        case HTP_OP_SCALE:
274            unary_op_func = unary_job_dispatcher_f32;
275            op_type       = "scale-f32";
276            break;
277        case HTP_OP_SQR:
278            unary_op_func = unary_job_dispatcher_f32;
279            op_type       = "sqr-f32";
280            break;
281        case HTP_OP_SQRT:
282            unary_op_func = unary_job_dispatcher_f32;
283            op_type       = "sqrt-f32";
284            break;
285
286        default:
287            FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
288            return HTP_STATUS_NO_SUPPORT;
289    }
290
291    const int      n_threads  = octx->n_threads;
292    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
293
294    const size_t src0_row_size = src0->nb[1];
295    const size_t dst_row_size  = dst->nb[1];
296
297    // VTCM scratchpads for all tensors
298    octx->dst_spad.size  = hex_round_up(dst_row_size, 128) * n_threads;
299    octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
300
301    size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
302
303    FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
304         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
305         octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
306
307    // Make sure the reserved vtcm size is sufficient
308    if (octx->ctx->vtcm_size < spad_size) {
309        FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
310             spad_size);
311        return HTP_STATUS_VTCM_TOO_SMALL;
312    }
313
314    octx->src0_spad.data = octx->ctx->vtcm_base;
315    octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;
316
317    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
318        uint32_t n_jobs = MIN(n_threads, src0_nrows);
319
320        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
321
322        worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
323    }
324
325    return err;
326}
327
328int op_unary(struct htp_ops_context * octx) {
329    int err = HTP_STATUS_OK;
330
331    switch (octx->src0.type) {
332        case HTP_TYPE_F32:
333            err = execute_op_unary_f32(octx);
334            break;
335
336        default:
337            err = HTP_STATUS_NO_SUPPORT;
338            break;
339    }
340
341    return err;
342}