1#pragma clang diagnostic ignored "-Wunused-variable"
  2#pragma clang diagnostic ignored "-Wunused-function"
  3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
  4
  5#include <HAP_farf.h>
  6#include <HAP_perf.h>
  7
  8#include <math.h>
  9#include <string.h>
 10
 11#include "hex-dma.h"
 12#include "hvx-utils.h"
 13
 14#define GGML_COMMON_DECL_C
 15#include "ggml-common.h"
 16#include "htp-ctx.h"
 17#include "htp-msg.h"
 18#include "htp-ops.h"
 19
 20#define htp_softmax_preamble3                              \
 21    const uint32_t ne00 = src0->ne[0];                     \
 22    const uint32_t ne01 = src0->ne[1];                     \
 23    const uint32_t ne02 = src0->ne[2];                     \
 24    const uint32_t ne03 = src0->ne[3];                     \
 25                                                           \
 26    const uint32_t nb00 = src0->nb[0];                     \
 27    const uint32_t nb01 = src0->nb[1];                     \
 28    const uint32_t nb02 = src0->nb[2];                     \
 29    const uint32_t nb03 = src0->nb[3];                     \
 30                                                           \
 31    const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \
 32    const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \
 33    const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \
 34    const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \
 35                                                           \
 36    const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \
 37    const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \
 38    const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \
 39    const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \
 40                                                           \
 41    const uint32_t ne0 = dst->ne[0];                       \
 42    const uint32_t ne1 = dst->ne[1];                       \
 43    const uint32_t ne2 = dst->ne[2];                       \
 44    const uint32_t ne3 = dst->ne[3];                       \
 45                                                           \
 46    const uint32_t nb0 = dst->nb[0];                       \
 47    const uint32_t nb1 = dst->nb[1];                       \
 48    const uint32_t nb2 = dst->nb[2];                       \
 49    const uint32_t nb3 = dst->nb[3];
 50
 51struct softmax_th_ctx {
 52    bool     use_f16;
 53    bool     use_src1;
 54    uint32_t n_head;
 55    uint32_t n_head_log2;
 56
 57    float scale;
 58    float max_bias;
 59    float m0;
 60    float m1;
 61
 62    struct htp_ops_context * octx;
 63};
 64
 65static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
 66    const struct htp_tensor * src0 = &octx->src0;
 67    const struct htp_tensor * src1 = &octx->src1;
 68
 69    memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
 70
 71    memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
 72    memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
 73
 74    softmax_ctx->n_head      = src0->ne[2];
 75    softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
 76
 77    softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
 78    softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
 79
 80    softmax_ctx->use_src1 = (src1->ne[0] != 0);
 81    softmax_ctx->use_f16  = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
 82
 83    softmax_ctx->octx = octx;
 84}
 85
 86static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
 87                                      uint8_t * restrict dst,
 88                                      const int num_elems,
 89                                      float     scale,
 90                                      const uint8_t * restrict mask,
 91                                      float slope) {
 92    const uint8_t * restrict src_curr  = src;
 93    uint8_t * restrict dst_curr        = dst;
 94    const uint8_t * restrict mask_curr = mask;
 95
 96    HVX_Vector scale_vec = hvx_vec_splat_f32(scale);
 97    HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
 98
 99    int step_of_1 = num_elems >> 5;
100
101    #pragma unroll(4)
102    for (int i = 0; i < step_of_1; i++) {
103        HVX_Vector v1 = *(HVX_Vector *) src_curr;
104
105        HVX_Vector v3 = *(HVX_Vector *) mask_curr;
106
107        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
108
109        HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v3, slope_vec);
110
111        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, v4);
112
113        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v5);
114
115        src_curr += VLEN;
116        dst_curr += VLEN;
117        mask_curr += VLEN;
118    }
119}
120
121static void hvx_fast_softmax_f32(const uint8_t * restrict src,
122                                 uint8_t * restrict dst,
123                                 uint8_t * restrict pad,
124                                 const int num_elems) {
125    const HVX_Vector * restrict v_src = (HVX_Vector *) src;
126    HVX_Vector * restrict v_pad       = (HVX_Vector *) pad;
127    HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
128
129    HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
130    HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]);
131    HVX_Vector zero_v  = Q6_V_vzero();
132    HVX_Vector one_v   = hvx_vec_splat_f32(1.0);
133
134    int step_of_1 = num_elems >> 5;
135
136    #pragma unroll(4)
137    for (int i = 0; i < step_of_1; i++) {
138        HVX_Vector v1 = v_src[i];
139        max_vec       = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
140    }
141
142    HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
143    max_vec      = hvx_vec_repl4(v);
144
145    #pragma unroll(4)
146    for (int i = 0; i < step_of_1; i++) {
147        HVX_Vector v1 = v_src[i];
148        HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);
149
150        HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2));
151
152        sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);
153
154        v_pad[i] = v3;
155    }
156
157    v       = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
158    sum_vec = hvx_vec_repl4(v);
159
160    HVX_VectorPred pos_sum   = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
161    HVX_Vector     v4        = hvx_vec_inverse_f32(sum_vec);
162    HVX_Vector     scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);
163
164    #pragma unroll(4)
165    for (int i = 0; i < step_of_1; i++) {
166        HVX_Vector v1 = v_pad[i];
167        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
168        v_dst[i]      = Q6_Vsf_equals_Vqf32(v2);
169    }
170}
171
172static float hvx_softmax_f32(const uint8_t * restrict src,
173                             uint8_t * restrict dst,
174                             uint8_t * restrict spad,
175                             const int   num_elems,
176                             const float max) {
177    hvx_sub_scalar_f32(spad, src, max, num_elems);
178
179    hvx_exp_f32(spad, dst, num_elems, false);
180
181    float sum = hvx_reduce_sum_f32(dst, num_elems);
182
183    return sum;
184}
185
186static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
187    struct htp_ops_context * octx = softmax_ctx->octx;
188
189    const struct htp_tensor * src0 = &octx->src0;
190    const struct htp_tensor * src1 = &octx->src1;
191    const struct htp_tensor * dst  = &octx->dst;
192
193    htp_softmax_preamble3;
194
195    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
196    uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
197    uint8_t * dst_spad_data  = octx->dst_spad.data + (ith * nb1);
198
199    float * wp0 = (float *) src0_spad_data;
200    float * wp1 = (float *) src1_spad_data;
201    float * wp2 = (float *) dst_spad_data;
202
203    for (uint32_t i03 = 0; i03 < ne03; i03++) {
204        for (uint32_t i02 = 0; i02 < ne02; i02++) {
205            for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
206                const uint32_t i11 = i01;
207                const uint32_t i12 = i02 % ne12;
208                const uint32_t i13 = i03 % ne13;
209
210                // ALiBi
211                const uint32_t h = i02;  // head
212
213                const float slope = (softmax_ctx->max_bias > 0.0f) ?
214                                        h < softmax_ctx->n_head_log2 ?
215                                        powf(softmax_ctx->m0, h + 1) :
216                                        powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
217                                        1.0f;
218
219                float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
220                float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
221
222                // broadcast the mask across rows
223                __fp16 * mp_f16 = (softmax_ctx->use_src1) ?
224                                      (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
225                                      NULL;
226                float *  mp_f32 = (softmax_ctx->use_src1) ?
227                                      (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
228                                      NULL;
229
230                if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
231                    hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
232                                              (const uint8_t *) mp_f32, slope);
233                } else {
234                    hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
235                    if (mp_f32) {
236                        if (softmax_ctx->use_f16) {
237                            for (int i = 0; i < ne00; ++i) {
238                                wp0[i] += slope * (float) mp_f16[i];
239                            }
240                        } else {
241                            for (int i = 0; i < ne00; ++i) {
242                                wp0[i] += slope * mp_f32[i];
243                            }
244                        }
245                    }
246                }
247
248                if (1 == opt_path) {
249                    hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
250                } else {
251                    float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
252                    float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
253                    sum       = sum > 0.0 ? (1.0 / sum) : 1;
254                    hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
255                }
256            }
257        }
258    }
259}
260
261static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
262    struct htp_ops_context * octx = softmax_ctx->octx;
263
264    const struct htp_tensor * src0 = &octx->src0;
265    const struct htp_tensor * src1 = &octx->src1;
266    struct htp_tensor *       dst  = &octx->dst;
267
268    htp_softmax_preamble3;
269
270    const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows
271    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
272
273    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
274    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
275
276    // no work for this thread
277    if (src0_start_row >= src0_end_row) {
278        return;
279    }
280
281    uint64_t t1, t2;
282    t1 = HAP_perf_get_qtimer_count();
283
284    int is_aligned = 1;
285    int opt_path   = 0;
286    if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) {
287        is_aligned = 0;
288        FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
289    }
290    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
291        opt_path = 1;
292    }
293
294    softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
295
296    t2 = HAP_perf_get_qtimer_count();
297
298    FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
299         softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
300         ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
301}
302
303static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
304    struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
305    softmax_job_f32_per_thread(p_softmax_ctx, n, i);
306}
307
308static int execute_op_softmax_f32(struct htp_ops_context * octx) {
309    int err = HTP_STATUS_OK;
310
311    const struct htp_tensor * src0 = &octx->src0;
312    const struct htp_tensor * src1 = &octx->src1;
313    struct htp_tensor *       dst  = &octx->dst;
314
315    worker_callback_t op_func;
316    const char *      op_type = NULL;
317
318    struct softmax_th_ctx softmax_ctx;
319
320    switch (octx->op) {
321        case HTP_OP_SOFTMAX:
322            op_func = softmax_job_dispatcher_f32;
323            op_type = "softmax-f32";
324
325            init_softmax_ctx(&softmax_ctx, octx);
326            break;
327
328        default:
329            FARF(ERROR, "Unsupported Op %u\n", octx->op);
330            return HTP_STATUS_NO_SUPPORT;
331    }
332
333    const uint32_t n_threads = octx->n_threads;
334
335    const size_t src0_row_size = src0->nb[1];
336    const size_t src1_row_size = src0_row_size;
337    const size_t dst_row_size  = dst->nb[1];
338
339    // VTCM scratchpads for all tensors
340    // N rows per thread, padded to HVX vector size
341    octx->dst_spad.size  = hex_round_up(dst_row_size, 128) * n_threads;
342    octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
343    octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
344
345    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
346
347    if (src1->ne[0]) {
348        FARF(HIGH,
349             "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
350             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
351             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
352             octx->dst_spad.size);
353    } else {
354        FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
355             src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
356             octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
357    }
358
359    // Make sure the reserved vtcm size is sufficient
360    if (octx->ctx->vtcm_size < spad_size) {
361        FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
362             spad_size);
363        return HTP_STATUS_VTCM_TOO_SMALL;
364    }
365
366    octx->src0_spad.data = octx->ctx->vtcm_base;
367    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
368    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
369
370    uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
371
372    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
373        uint32_t n_jobs             = MIN(n_threads, src0_nrows);
374        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
375        worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
376    }
377
378    return err;
379}
380
381int op_softmax(struct htp_ops_context * octx) {
382    int err = HTP_STATUS_OK;
383
384    switch (octx->src0.type) {
385        case HTP_TYPE_F32:
386            err = execute_op_softmax_f32(octx);
387            break;
388
389        default:
390            err = HTP_STATUS_NO_SUPPORT;
391            break;
392    }
393
394    return err;
395}