1#pragma clang diagnostic ignored "-Wunused-variable"
  2#pragma clang diagnostic ignored "-Wunused-function"
  3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
  4
  5#include <assert.h>
  6#include <HAP_farf.h>
  7#include <HAP_perf.h>
  8#include <math.h>
  9#include <string.h>
 10
 11#include "hex-dma.h"
 12#include "hvx-utils.h"
 13
 14#define GGML_COMMON_DECL_C
 15#include "ggml-common.h"
 16#include "htp-ctx.h"
 17#include "htp-msg.h"
 18#include "htp-ops.h"
 19
 20static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
 21    HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero);  // 32 elements
 22    HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero);  // 32 elements
 23    return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
 24}
 25
 26// Dot product of FP32 and FP16 vectors, accumulating to float
 27static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
 28    const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
 29    const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
 30
 31    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
 32    uint32_t nloe = n % VLEN_FP16; // leftover elements
 33
 34    const HVX_Vector zero = Q6_V_vsplat_R(0);
 35    HVX_Vector       rsum = Q6_V_vsplat_R(0);
 36
 37    uint32_t i = 0;
 38
 39    #pragma unroll(4)
 40    for (i = 0; i < nvec; i++) {
 41        // Load y (fp32) and convert into fp16
 42        HVX_Vector y_hf  = hvx_load_f32_to_f16(&vy[i*2], zero);
 43
 44        // Load x (fp16)
 45        HVX_Vector x_hf  = vx[i];
 46
 47        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
 48
 49        rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
 50    }
 51
 52    if (nloe) {
 53        // Load y (fp32) and convert into fp16
 54        HVX_Vector y_hf  = hvx_load_f32_to_f16(&vy[i*2], zero);
 55
 56        // Load x (fp16)
 57        HVX_Vector x_hf  = vx[i];
 58
 59        // Zero-out unused elements
 60        // Note that we need to clear both x and y because they may contain NANs
 61        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
 62        x_hf = Q6_V_vand_QV(bmask, x_hf);
 63        y_hf = Q6_V_vand_QV(bmask, y_hf);
 64
 65        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
 66
 67        rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
 68    }
 69
 70    rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
 71    hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
 72}
 73
 74// Dot product of FP32 and FP16 vectors, accumulating to float
 75static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
 76                                          const void * restrict y,
 77                                          const void * restrict x0,
 78                                          const void * restrict x1,
 79                                          unsigned int n,
 80                                          float        s) {
 81    const HVX_Vector * restrict vy  = (const HVX_Vector * restrict) y;   // fp32
 82    const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0;  // fp16
 83    const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1;  // fp16
 84
 85    uint32_t nvec = n / VLEN_FP16;                                       // num full fp16 hvx vectors
 86    uint32_t nloe = n % VLEN_FP16;                                       // leftover elements
 87
 88    const HVX_Vector zero  = Q6_V_vsplat_R(0);
 89    HVX_Vector       rsum0 = Q6_V_vsplat_R(0);
 90    HVX_Vector       rsum1 = Q6_V_vsplat_R(0);
 91
 92    uint32_t i = 0;
 93
 94    #pragma unroll(2)
 95    for (i = 0; i < nvec; i++) {
 96        // Load y (fp32) and convert into fp16
 97        HVX_Vector y_hf  = hvx_load_f32_to_f16(&vy[i*2], zero);
 98        // Load x (fp16)
 99        HVX_Vector x0_hf = vx0[i];
100        HVX_Vector x1_hf = vx1[i];
101
102        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
103        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
104
105        rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
106        rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
107    }
108
109    if (nloe) {
110        // Load y (fp32) and convert into fp16
111        HVX_Vector y_hf  = hvx_load_f32_to_f16(&vy[i*2], zero);
112
113        // Load x (fp16)
114        HVX_Vector x0_hf = vx0[i];
115        HVX_Vector x1_hf = vx1[i];
116
117        // Zero-out unused elements
118        // Note that we need to clear both x and y because they may contain NANs
119        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
120        x0_hf                = Q6_V_vand_QV(bmask, x0_hf);
121        x1_hf                = Q6_V_vand_QV(bmask, x1_hf);
122        y_hf                 = Q6_V_vand_QV(bmask, y_hf);
123
124        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
125        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
126
127        rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
128        rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
129    }
130
131    HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
132    hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
133}
134
135// Dot product of two F16 vectors, accumulating to float
136static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
137    const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
138    const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
139
140    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
141    uint32_t nloe = n % VLEN_FP16; // leftover elements
142
143    const HVX_Vector zero = Q6_V_vsplat_R(0);
144    HVX_Vector       rsum = Q6_V_vsplat_R(0);
145
146    uint32_t i = 0;
147
148    #pragma unroll(4)
149    for (i = 0; i < nvec; i++) {
150        HVX_Vector y_hf = vy[i];
151        HVX_Vector x_hf = vx[i];
152
153        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
154
155        rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
156    }
157
158    if (nloe) {
159        HVX_Vector y_hf = vy[i];
160
161        // Load x (fp16) and zero-out unused elements
162        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
163        HVX_Vector      x_hf = Q6_V_vand_QV(bmask, vx[i]);
164
165        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
166
167        rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
168    }
169
170    rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
171    hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
172}
173
174static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
175                                          const void * restrict y,
176                                          const void * restrict x0,
177                                          const void * restrict x1,
178                                          unsigned int n,
179                                          float        s) {
180    const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0;  // fp16
181    const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1;  // fp16
182    const HVX_Vector * restrict vy  = (const HVX_Vector * restrict) y;   // fp16
183
184    uint32_t nvec = n / VLEN_FP16;                                       // num full fp16 hvx vectors
185    uint32_t nloe = n % VLEN_FP16;                                       // leftover elements
186
187    const HVX_Vector zero  = Q6_V_vsplat_R(0);
188    HVX_Vector       rsum0 = Q6_V_vsplat_R(0);
189    HVX_Vector       rsum1 = Q6_V_vsplat_R(0);
190
191    uint32_t i = 0;
192
193    #pragma unroll(4)
194    for (i = 0; i < nvec; i++) {
195        HVX_Vector y_hf  = vy[i];
196        HVX_Vector x0_hf = vx0[i];
197        HVX_Vector x1_hf = vx1[i];
198
199        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
200        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
201
202        rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
203        rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
204    }
205
206    if (nloe) {
207        HVX_Vector y_hf = vy[i];
208
209        // Load x (fp16) and zero-out unused elements
210        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
211        HVX_Vector     x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
212        HVX_Vector     x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
213
214        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
215        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
216
217        rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
218        rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
219    }
220
221    HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
222    hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
223}
224
225// MAD: y (F32) += x (F16) * s (float)
226static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
227    const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
228    HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
229
230    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
231    uint32_t nloe = n % VLEN_FP16; // leftover elements
232
233    HVX_Vector S = hvx_vec_splat_f16(s);
234
235    uint32_t i = 0;
236    #pragma unroll(4)
237    for (i = 0; i < nvec; ++i) {
238        // Multiply x * s -> pair of F32 vectors
239        HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
240        ptr_y[i*2]   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
241        ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
242    }
243
244    if (nloe) {
245        HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
246
247        HVX_Vector xs = Q6_V_lo_W(xs_p);
248        i = 2 * i; // index for ptr_y
249
250        if (nloe >= 32) {
251            ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
252            nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
253        }
254
255        if (nloe) {
256            HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
257            hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
258        }
259    }
260}
261
262#define FLASH_ATTN_BLOCK_SIZE 128
263
264static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
265    const struct htp_tensor * q = &octx->src0;
266    const struct htp_tensor * k = &octx->src1;
267    const struct htp_tensor * v = &octx->src2;
268    const struct htp_tensor * mask  = (octx->src3.data) ? &octx->src3 : NULL;
269    const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
270    struct htp_tensor * dst = &octx->dst;
271
272    const uint32_t neq0 = q->ne[0];
273    const uint32_t neq1 = q->ne[1];
274    const uint32_t neq2 = q->ne[2];
275    const uint32_t neq3 = q->ne[3];
276
277    const uint32_t nek0 = k->ne[0];
278    const uint32_t nek1 = k->ne[1];
279    const uint32_t nek2 = k->ne[2];
280    const uint32_t nek3 = k->ne[3];
281
282    const uint32_t nev0 = v->ne[0];
283    const uint32_t nev1 = v->ne[1];
284    const uint32_t nev2 = v->ne[2];
285    const uint32_t nev3 = v->ne[3];
286
287    const uint32_t nbq1 = q->nb[1];
288    const uint32_t nbq2 = q->nb[2];
289    const uint32_t nbq3 = q->nb[3];
290
291    const uint32_t nbk1 = k->nb[1];
292    const uint32_t nbk2 = k->nb[2];
293    const uint32_t nbk3 = k->nb[3];
294
295    const uint32_t nbv1 = v->nb[1];
296    const uint32_t nbv2 = v->nb[2];
297    const uint32_t nbv3 = v->nb[3];
298
299    const uint32_t ne1 = dst->ne[1];
300    const uint32_t ne2 = dst->ne[2];
301    const uint32_t ne3 = dst->ne[3];
302
303    const uint32_t nb1 = dst->nb[1];
304    const uint32_t nb2 = dst->nb[2];
305    const uint32_t nb3 = dst->nb[3];
306
307    float scale         = 1.0f;
308    float max_bias      = 0.0f;
309    float logit_softcap = 0.0f;
310
311    memcpy(&scale,         (float *) octx->op_params + 0, sizeof(float));
312    memcpy(&max_bias,      (float *) octx->op_params + 1, sizeof(float));
313    memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
314
315    if (logit_softcap != 0) {
316        scale /= logit_softcap;
317    }
318
319    // total rows in q
320    const uint32_t nr = neq1*neq2*neq3;
321
322    const uint32_t dr = (nr + nth - 1) / nth;
323    const uint32_t ir0 = dr * ith;
324    const uint32_t ir1 = MIN(ir0 + dr, nr);
325
326    if (ir0 >= ir1) return;
327
328    dma_queue * dma = octx->ctx->dma[ith];
329
330    const uint32_t DK = nek0;
331    const uint32_t DV = nev0;
332
333    const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
334    const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
335
336    const size_t size_k_row = DK * sizeof(__fp16);
337    const size_t size_v_row = DV * sizeof(__fp16);
338    const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
339
340    const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
341    const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
342
343    const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
344    const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
345    const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
346
347    // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
348    uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
349    uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
350    uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
351    uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
352    uint8_t * spad_a = octx->dst_spad.data  + octx->dst_spad.size_per_thread  * ith;
353
354    const uint32_t n_head = neq2;
355    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
356    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
357    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
358
359    for (uint32_t ir = ir0; ir < ir1; ++ir) {
360        const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
361        const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
362        const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
363
364        const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
365        const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
366
367        const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
368        const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
369
370        // Fetch Q row
371        const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
372        dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
373
374        const uint32_t h = iq2; // head index
375        const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
376
377        float S = 0.0f;      // sum
378        float M = -INFINITY; // maximum KQ value
379
380        // Clear accumulator
381        hvx_splat_f32_a(spad_a, 0, DV);
382        float * VKQ32 = (float *) spad_a;
383
384        const __fp16 * mp_base = NULL;
385        if (mask) {
386            const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
387            const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
388            mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
389        }
390
391        const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
392
393        // Prefetch first two blocks
394        for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
395            const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
396            const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
397
398            // K
399            const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
400            uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
401            dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
402
403            // V
404            const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
405            uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
406            dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
407
408            // Mask
409            if (mask) {
410                const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
411                uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
412                // Mask is 1D contiguous for this row
413                dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
414            }
415        }
416
417        const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
418
419        for (uint32_t ib = 0; ib < n_blocks; ++ib) {
420            const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
421            const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
422
423            // Wait for DMA
424            uint8_t * k_base = dma_queue_pop(dma).dst; // K
425            uint8_t * v_base = dma_queue_pop(dma).dst; // V
426            __fp16  * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
427
428            // Inner loop processing the block from VTCM
429            uint32_t ic = 0;
430
431            const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
432
433            // Process in blocks of 32 (VLEN_FP32)
434            static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
435            HVX_Vector_x4 scores_x4;
436            HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
437            for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
438                // 1. Compute scores
439                float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
440                for (int j = 0; j < VLEN_FP32; j += 2) {
441                    const uint32_t cur_ic = ic + j;
442                    const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
443                    if (is_q_fp32) {
444                        hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
445                    } else {
446                        hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
447                    }
448                }
449
450                HVX_Vector scores = *(HVX_Vector *) scores_arr;
451
452                // 2. Softcap
453                if (logit_softcap != 0.0f) {
454                    scores = hvx_vec_tanh_f32(scores);
455                    scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
456                    scores = Q6_Vsf_equals_Vqf32(scores);
457                }
458
459                // 3. Mask
460                if (mask) {
461                    const __fp16 * mp = m_base + ic;
462                    HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
463
464                    HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
465                    HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
466
467                    HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
468
469                    HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
470                    HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
471                    scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
472                    scores = Q6_Vsf_equals_Vqf32(scores);
473                }
474
475                scores_x4.v[iv] = scores;
476                v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
477            }
478
479            {
480                // 4. Online Softmax Update
481                v_max = hvx_vec_reduce_max_f32(v_max);
482                float m_block = hvx_vec_get_f32(v_max);
483                float M_old = M;
484                float M_new = (m_block > M) ? m_block : M;
485                M = M_new;
486
487                const float ms = expf(M_old - M_new);
488                hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
489
490                HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
491                HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
492                for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
493                    HVX_Vector scores = scores_x4.v[iv];
494                    HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
495                    HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
496
497                    p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
498
499                    // 5. Accumulate V
500                    float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
501                    *(HVX_Vector*)p_arr = P;
502
503                    for (int j = 0; j < VLEN_FP32; ++j) {
504                        const uint32_t cur_ic = ic2 + j;
505                        const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
506                        hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
507                    }
508                }
509
510                p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
511                S = S * ms + hvx_vec_get_f32(p_sum_vec);
512            }
513
514            // Leftover
515            for (; ic < current_block_size; ++ic) {
516                float s_val;
517                const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
518
519                if (is_q_fp32) {
520                    hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
521                } else {
522                    hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
523                }
524
525                if (logit_softcap != 0.0f) {
526                    s_val = logit_softcap * tanhf(s_val);
527                }
528
529                if (mask) {
530                    const float m_val = m_base[ic];
531                    s_val += slope * m_val;
532                }
533
534                const float Mold = M;
535                float ms = 1.0f;
536                float vs = 1.0f;
537
538                if (s_val > M) {
539                    M = s_val;
540                    ms = expf(Mold - M);
541                    hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
542                } else {
543                    vs = expf(s_val - M);
544                }
545
546                const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
547
548                hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
549
550                S = S * ms + vs;
551            }
552
553            // Issue DMA for next+1 block (if exists)
554            if (ib + 2 < n_blocks) {
555                const uint32_t next_ib = ib + 2;
556                const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
557                const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
558
559                // K
560                const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
561                dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
562
563                // V
564                const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
565                dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
566
567                // Mask
568                if (mask) {
569                    const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
570                    dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
571                }
572            }
573        }
574
575        // sinks
576        if (sinks) {
577            const float s = ((float *)((char *) sinks->data))[h];
578
579            float ms = 1.0f;
580            float vs = 1.0f;
581
582            if (s > M) {
583                ms = expf(M - s);
584                hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
585            } else {
586                vs = expf(s - M);
587            }
588
589            S = S * ms + vs;
590        }
591
592        const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
593        hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
594
595        // Store result
596        // dst indices
597        const int i1 = iq1;
598        const int i2 = iq2;
599        const int i3 = iq3;
600
601        // dst is permuted
602        uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
603
604        if (dst->type == HTP_TYPE_F32) {
605            hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
606        } else if (dst->type == HTP_TYPE_F16) {
607            hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
608        }
609    }
610}
611
612static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
613    struct htp_ops_context * octx = data;
614    flash_attn_ext_f16_thread(octx, i, n);
615}
616
617int op_flash_attn_ext(struct htp_ops_context * octx) {
618    const struct htp_tensor * q = &octx->src0;
619    const struct htp_tensor * k = &octx->src1;
620    const struct htp_tensor * v = &octx->src2;
621    const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
622    struct htp_tensor * dst = &octx->dst;
623
624    // Check support
625    if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
626        k->type != HTP_TYPE_F16 ||
627        v->type != HTP_TYPE_F16) {
628        return HTP_STATUS_NO_SUPPORT;
629    }
630
631    octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
632    octx->src0_div1  = init_fastdiv_values(q->ne[1]);
633
634    octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
635    octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
636    octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
637    octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
638
639    if (mask) {
640        octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
641        octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
642    }
643
644    size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
645    size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
646    size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
647
648    size_t size_q_block = size_q_row_padded * 1; // single row for now
649    size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
650    size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
651    size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
652
653    size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
654
655    octx->src0_spad.size_per_thread = size_q_block * 1;
656    octx->src1_spad.size_per_thread = size_k_block * 2;
657    octx->src2_spad.size_per_thread = size_v_block * 2;
658    octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
659    octx->dst_spad.size_per_thread  = size_vkq_acc;
660
661    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
662    octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
663    octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
664    octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
665    octx->dst_spad.size  = octx->dst_spad.size_per_thread  * octx->n_threads;
666
667    size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
668
669    if (octx->ctx->vtcm_size < total_spad) {
670        return HTP_STATUS_VTCM_TOO_SMALL;
671    }
672
673    octx->src0_spad.data = octx->ctx->vtcm_base;
674    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
675    octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
676    octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
677    octx->dst_spad.data  = octx->src3_spad.data + octx->src3_spad.size;
678
679    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
680        worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
681    }
682
683    return HTP_STATUS_OK;
684}