1#pragma clang diagnostic ignored "-Wunused-variable"
  2#pragma clang diagnostic ignored "-Wunused-function"
  3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
  4
  5#include <HAP_farf.h>
  6#include <HAP_perf.h>
  7
  8#include <math.h>
  9#include <string.h>
 10
 11#include "hex-dma.h"
 12#include "hvx-utils.h"
 13
 14#define GGML_COMMON_DECL_C
 15#include "ggml-common.h"
 16#include "htp-ctx.h"
 17#include "htp-msg.h"
 18#include "htp-ops.h"
 19
 20// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
 21#define HTP_ROPE_TYPE_NORMAL 0
 22#define HTP_ROPE_TYPE_NEOX   2
 23
 24#define htp_rope_preamble              \
 25    const uint32_t ne00 = src0->ne[0]; \
 26    const uint32_t ne01 = src0->ne[1]; \
 27    const uint32_t ne02 = src0->ne[2]; \
 28    const uint32_t ne03 = src0->ne[3]; \
 29                                       \
 30    const uint32_t ne0 = dst->ne[0];   \
 31    const uint32_t ne1 = dst->ne[1];   \
 32    const uint32_t ne2 = dst->ne[2];   \
 33    const uint32_t ne3 = dst->ne[3];   \
 34                                       \
 35    const uint32_t nb00 = src0->nb[0]; \
 36    const uint32_t nb01 = src0->nb[1]; \
 37    const uint32_t nb02 = src0->nb[2]; \
 38    const uint32_t nb03 = src0->nb[3]; \
 39                                       \
 40    const uint32_t nb0 = dst->nb[0];   \
 41    const uint32_t nb1 = dst->nb[1];   \
 42    const uint32_t nb2 = dst->nb[2];   \
 43    const uint32_t nb3 = dst->nb[3];
 44
 45struct rope_th_ctx {
 46    int32_t n_dims;
 47    int32_t mode;
 48    int32_t n_ctx_orig;
 49    int32_t sections[4];
 50
 51    float freq_base;
 52    float freq_scale;
 53    float ext_factor;
 54    float attn_factor;
 55    float beta_fast;
 56    float beta_slow;
 57    float theta_scale;
 58    float corr_dims[2];
 59
 60    struct htp_ops_context * octx;
 61};
 62
 63static float rope_yarn_ramp(const float low, const float high, const int i0) {
 64    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
 65
 66    return (1 - MIN(1, MAX(0, y)));
 67}
 68
 69static void rope_cache_init(const float    theta_base,
 70                            const float    freq_scale,
 71                            const float *  freq_factors,
 72                            float *        corr_dims,
 73                            const uint32_t ne0,
 74                            const float    ext_factor,
 75                            const float    mscale,
 76                            float *        cache,
 77                            const float    theta_scale) {
 78    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
 79    float theta = theta_base;
 80
 81    for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
 82        const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
 83
 84        float theta_extrap = theta / ff;
 85
 86        // Get n-d rotational scaling corrected for extrapolation
 87        float theta_interp = freq_scale * theta_extrap;
 88        float theta_final  = theta_interp;
 89        float mscale_final = mscale;
 90
 91        if (ext_factor != 0.0f) {
 92            float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
 93            theta_final    = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 94
 95            // Get n-d magnitude scaling corrected for interpolation
 96            mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
 97        }
 98
 99        cache[i0 + 0] = cosf(theta_final) * mscale_final;
100        cache[i0 + 1] = sinf(theta_final) * mscale_final;
101
102        theta *= theta_scale;
103    }
104}
105
106#define M_PI 3.1415926535897932384626433
107
108static void rope_corr_dims(int     n_dims,
109                           int     n_ctx_orig,
110                           float   freq_base,
111                           float   beta_fast,
112                           float   beta_slow,
113                           float * dims) {
114    float start = floorf(n_dims * logf(n_ctx_orig / (beta_fast * 2 * (float) M_PI)) / (2 * logf(freq_base)));
115    float end   = ceilf(n_dims * logf(n_ctx_orig / (beta_slow * 2 * (float) M_PI)) / (2 * logf(freq_base)));
116    dims[0]     = MAX(0, start);
117    dims[1]     = MIN(n_dims - 1, end);
118}
119
120static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
121    memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
122
123    const int32_t * op_params = &octx->op_params[0];
124
125    rope_ctx->n_dims     = ((const int32_t *) op_params)[1];
126    rope_ctx->mode       = ((const int32_t *) op_params)[2];
127    rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
128
129    memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
130    memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
131    memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
132    memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
133    memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
134    memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
135    memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
136
137    rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
138
139    rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
140                   rope_ctx->beta_slow, rope_ctx->corr_dims);
141
142    rope_ctx->octx = octx;
143    FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
144         rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
145}
146
147static void hvx_calc_rope_neox_f32(const float * restrict src0,
148                                   float * restrict dst,
149                                   const int num_elems,
150                                   const float * restrict theta_cache) {
151    // for (int i = 0; i < num_elems; i += 2) {
152    //const float cos_theta = theta_cache[i + 0];
153    //const float sin_theta = theta_cache[i + 1];
154
155    //const float x0 = src[0];
156    //const float x1 = src[num_elems/2];
157
158    //dst[0] = x0*cos_theta - x1*sin_theta;
159    //dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
160
161    //src += 1;
162    //dst += 1;
163    // }
164
165    const uint8_t * restrict src0_curr  = (const uint8_t *) src0;
166    const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
167    uint8_t * restrict dst_curr         = (uint8_t *) dst;
168
169    int step_of_1 = num_elems >> 6;  // 6 because we process two vectors at once
170    int half_size = (sizeof(float) * (num_elems / 2));
171
172    for (int i = 0; i < step_of_1; i++) {
173        HVX_Vector v0 = *(HVX_Vector *) src0_curr;
174        HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
175
176        HVX_Vector v2 = *(HVX_Vector *) theta_curr;
177        HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
178
179        HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
180
181        HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
182        HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
183        HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));
184        HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));
185
186        HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
187        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
188
189        *(HVX_Vector *) dst_curr               = Q6_Vsf_equals_Vqf32(v4);
190        *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
191
192        src0_curr += VLEN;
193        theta_curr += 2 * VLEN;
194        dst_curr += VLEN;
195    }
196}
197
198static void hvx_calc_rope_f32(const float * restrict src0,
199                              float * restrict dst,
200                              const int num_elems,
201                              const float * restrict theta_cache) {
202    // for (int i = 0; i < num_elems; i += 2) {
203    //const float cos_theta = theta_cache[i + 0];
204    //const float sin_theta = theta_cache[i + 1];
205
206    //const float x0 = src[0];
207    //const float x1 = src[1];
208
209    //dst[0] = x0*cos_theta - x1*sin_theta;
210    //dst[1] = x0*sin_theta + x1*cos_theta;
211
212    //src += 2;
213    //dst += 2;
214    // }
215
216    const uint8_t * restrict src0_curr  = (const uint8_t *) src0;
217    const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
218    uint8_t * restrict dst_curr         = (uint8_t *) dst;
219
220    int step_of_1 = num_elems >> 6;  // 6 because we process two vectors at once
221
222    for (int i = 0; i < step_of_1; i++) {
223        HVX_Vector v0 = *(HVX_Vector *) src0_curr;
224        HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
225
226        HVX_Vector v2 = *(HVX_Vector *) theta_curr;
227        HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
228
229        HVX_VectorPair vx0_x1   = Q6_W_vdeal_VVR(v1, v0, -4);  // vx0_x1[0] = x0, vx0_x1[1] = x1
230        HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
231
232        HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));
233        HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));
234        HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin));
235        HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin));
236
237        HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
238        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
239
240        HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
241
242        *(HVX_Vector *) dst_curr          = Q6_V_lo_W(vstore);
243        *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
244
245        src0_curr += 2 * VLEN;
246        theta_curr += 2 * VLEN;
247        dst_curr += 2 * VLEN;
248    }
249}
250
251static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
252                         const uint32_t       ir0,
253                         const uint32_t       ir1,
254                         int                  nth,
255                         int                  ith,
256                         const int            opt_path) {
257    struct htp_ops_context * octx = rope_ctx->octx;
258
259    const struct htp_tensor * src0 = &octx->src0;
260    const struct htp_tensor * src1 = &octx->src1;
261    const struct htp_tensor * src2 = &octx->src2;
262    struct htp_tensor *       dst  = &octx->dst;
263
264    const int32_t mode    = rope_ctx->mode;
265    const bool    is_neox = mode & HTP_ROPE_TYPE_NEOX;
266
267    htp_rope_preamble;
268
269    const int32_t * pos = (const int32_t *) src1->data;
270
271    float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
272
273    const float * freq_factors = NULL;
274    if (src2 != NULL) {
275        freq_factors = (const float *) src2->data;
276    }
277
278    const uint32_t i1_end       = MIN(ir1, ne1);
279    const int32_t  half_dims    = rope_ctx->n_dims / 2;
280    const size_t   remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
281    for (uint32_t i3 = 0; i3 < ne3; i3++) {      // batch
282        for (uint32_t i2 = 0; i2 < ne2; i2++) {  // seq-len
283            const int32_t p = pos[i2];
284
285            rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
286                            rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
287
288            for (uint32_t i1 = ir0; i1 < i1_end; i1++) {  // attn-heads
289                const float * src      = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
290                float *       dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
291
292                const float * src_loc      = src;
293                float *       dst_data_loc = dst_data;
294
295                if (1 == opt_path) {
296                    if (is_neox) {
297                        hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
298                    } else {
299                        hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
300                    }
301
302                    src_loc += rope_ctx->n_dims;
303                    dst_data_loc += rope_ctx->n_dims;
304                } else {
305                    for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
306                        const float cos_theta = wp0[i0 + 0];
307                        const float sin_theta = wp0[i0 + 1];
308
309                        if (is_neox) {
310                            const float x0 = src_loc[0];
311                            const float x1 = src_loc[half_dims];
312
313                            dst_data_loc[0]         = x0 * cos_theta - x1 * sin_theta;
314                            dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
315
316                            src_loc += 1;
317                            dst_data_loc += 1;
318                        } else {
319                            const float x0 = src_loc[0];
320                            const float x1 = src_loc[1];
321
322                            dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
323                            dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
324
325                            src_loc += 2;
326                            dst_data_loc += 2;
327                        }
328                    }
329
330                    src_loc += (is_neox ? half_dims : 0);
331                    dst_data_loc += (is_neox ? half_dims : 0);
332                }
333
334                // TODO: use simd to speed up the remaining elements copy
335                memcpy(dst_data_loc, src_loc, remain_bytes);
336            }
337        }
338    }
339}
340
341static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
342    struct htp_ops_context * octx = rope_ctx->octx;
343
344    const struct htp_tensor * src0 = &octx->src0;
345    const struct htp_tensor * src1 = &octx->src1;
346    struct htp_tensor *       dst  = &octx->dst;
347
348    htp_rope_preamble;
349
350    const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows
351    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
352
353    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
354    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
355
356    // no work for this thread
357    if (src0_start_row >= src0_end_row) {
358        return;
359    }
360
361    uint64_t t1, t2;
362    t1 = HAP_perf_get_qtimer_count();
363
364    int is_aligned = 1;
365    int opt_path   = 0;
366    if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
367        (0 == hex_is_aligned((void *) dst->data, VLEN))) {
368        FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
369        is_aligned = 0;
370    }
371    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
372        opt_path = 1;
373    }
374
375    rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
376
377    t2 = HAP_perf_get_qtimer_count();
378
379    FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
380         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
381}
382
383static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
384    struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
385
386    rope_job_f32_per_thread(rope_ctx, n, i);
387}
388
389static int execute_op_rope_f32(struct htp_ops_context * octx) {
390    int err = HTP_STATUS_OK;
391
392    const struct htp_tensor * src0 = &octx->src0;
393    const struct htp_tensor * src1 = &octx->src1;
394    const struct htp_tensor * src2 = &octx->src2;
395    struct htp_tensor *       dst  = &octx->dst;
396
397    worker_callback_t op_func;
398    const char *      op_type = NULL;
399
400    struct rope_th_ctx rope_ctx;
401
402    switch (octx->op) {
403        case HTP_OP_ROPE:
404            op_func = rope_job_dispatcher_f32;
405            op_type = "rope-f32";
406
407            init_rope_ctx(&rope_ctx, octx);
408            break;
409
410        default:
411            FARF(ERROR, "Unsupported Op %u\n", octx->op);
412            return HTP_STATUS_NO_SUPPORT;
413    }
414
415    const uint32_t n_threads = octx->n_threads;
416
417    const size_t src0_row_size = src0->nb[1];
418    const size_t src1_row_size = src0_row_size;
419    const size_t dst_row_size  = dst->nb[1];
420
421    // VTCM scratchpads for all tensors
422    // N rows per thread, padded to HVX vector size
423    octx->dst_spad.size  = hex_round_up(dst_row_size, 128) * n_threads;
424    octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
425    octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
426
427    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
428
429    if (src2->ne[0]) {
430        FARF(HIGH,
431             "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
432             "dst-spad-size %u\n",
433             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
434             src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
435             dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
436    } else {
437        FARF(HIGH,
438             "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
439             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
440             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
441             octx->dst_spad.size);
442    }
443
444    // Make sure the reserved vtcm size is sufficient
445    if (octx->ctx->vtcm_size < spad_size) {
446        FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
447             spad_size);
448        return HTP_STATUS_VTCM_TOO_SMALL;
449    }
450
451    octx->src0_spad.data = octx->ctx->vtcm_base;
452    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
453    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
454
455    uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
456
457    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
458        uint32_t n_jobs             = MIN(n_threads, src0_nrows);
459        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
460        worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
461    }
462
463    return err;
464}
465
466int op_rope(struct htp_ops_context * octx) {
467    int err = HTP_STATUS_OK;
468
469    switch (octx->src0.type) {
470        case HTP_TYPE_F32:
471            err = execute_op_rope_f32(octx);
472            break;
473
474        default:
475            err = HTP_STATUS_NO_SUPPORT;
476            break;
477    }
478
479    return err;
480}