1#pragma clang diagnostic ignored "-Wunused-variable"
2#pragma clang diagnostic ignored "-Wunused-function"
3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
5#include <assert.h>
6#include <HAP_farf.h>
7#include <HAP_perf.h>
8#include <math.h>
9#include <string.h>
10
11#include "hex-dma.h"
12#include "hvx-utils.h"
13
14#define GGML_COMMON_DECL_C
15#include "ggml-common.h"
16#include "htp-ctx.h"
17#include "htp-msg.h"
18#include "htp-ops.h"
19
20static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
21 HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
22 HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
23 return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
24}
25
26// Dot product of FP32 and FP16 vectors, accumulating to float
27static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
28 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
29 const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
30
31 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
32 uint32_t nloe = n % VLEN_FP16; // leftover elements
33
34 const HVX_Vector zero = Q6_V_vsplat_R(0);
35 HVX_Vector rsum = Q6_V_vsplat_R(0);
36
37 uint32_t i = 0;
38
39 #pragma unroll(4)
40 for (i = 0; i < nvec; i++) {
41 // Load y (fp32) and convert into fp16
42 HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
43
44 // Load x (fp16)
45 HVX_Vector x_hf = vx[i];
46
47 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
48
49 rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
50 }
51
52 if (nloe) {
53 // Load y (fp32) and convert into fp16
54 HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
55
56 // Load x (fp16)
57 HVX_Vector x_hf = vx[i];
58
59 // Zero-out unused elements
60 // Note that we need to clear both x and y because they may contain NANs
61 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
62 x_hf = Q6_V_vand_QV(bmask, x_hf);
63 y_hf = Q6_V_vand_QV(bmask, y_hf);
64
65 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
66
67 rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
68 }
69
70 rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
71 hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
72}
73
74// Dot product of FP32 and FP16 vectors, accumulating to float
75static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
76 const void * restrict y,
77 const void * restrict x0,
78 const void * restrict x1,
79 unsigned int n,
80 float s) {
81 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
82 const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
83 const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
84
85 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
86 uint32_t nloe = n % VLEN_FP16; // leftover elements
87
88 const HVX_Vector zero = Q6_V_vsplat_R(0);
89 HVX_Vector rsum0 = Q6_V_vsplat_R(0);
90 HVX_Vector rsum1 = Q6_V_vsplat_R(0);
91
92 uint32_t i = 0;
93
94 #pragma unroll(2)
95 for (i = 0; i < nvec; i++) {
96 // Load y (fp32) and convert into fp16
97 HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
98 // Load x (fp16)
99 HVX_Vector x0_hf = vx0[i];
100 HVX_Vector x1_hf = vx1[i];
101
102 HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
103 HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
104
105 rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
106 rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
107 }
108
109 if (nloe) {
110 // Load y (fp32) and convert into fp16
111 HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
112
113 // Load x (fp16)
114 HVX_Vector x0_hf = vx0[i];
115 HVX_Vector x1_hf = vx1[i];
116
117 // Zero-out unused elements
118 // Note that we need to clear both x and y because they may contain NANs
119 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
120 x0_hf = Q6_V_vand_QV(bmask, x0_hf);
121 x1_hf = Q6_V_vand_QV(bmask, x1_hf);
122 y_hf = Q6_V_vand_QV(bmask, y_hf);
123
124 HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
125 HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
126
127 rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
128 rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
129 }
130
131 HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
132 hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
133}
134
135// Dot product of two F16 vectors, accumulating to float
136static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
137 const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
138 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
139
140 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
141 uint32_t nloe = n % VLEN_FP16; // leftover elements
142
143 const HVX_Vector zero = Q6_V_vsplat_R(0);
144 HVX_Vector rsum = Q6_V_vsplat_R(0);
145
146 uint32_t i = 0;
147
148 #pragma unroll(4)
149 for (i = 0; i < nvec; i++) {
150 HVX_Vector y_hf = vy[i];
151 HVX_Vector x_hf = vx[i];
152
153 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
154
155 rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
156 }
157
158 if (nloe) {
159 HVX_Vector y_hf = vy[i];
160
161 // Load x (fp16) and zero-out unused elements
162 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
163 HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
164
165 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
166
167 rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
168 }
169
170 rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
171 hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
172}
173
174static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
175 const void * restrict y,
176 const void * restrict x0,
177 const void * restrict x1,
178 unsigned int n,
179 float s) {
180 const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
181 const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
182 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
183
184 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
185 uint32_t nloe = n % VLEN_FP16; // leftover elements
186
187 const HVX_Vector zero = Q6_V_vsplat_R(0);
188 HVX_Vector rsum0 = Q6_V_vsplat_R(0);
189 HVX_Vector rsum1 = Q6_V_vsplat_R(0);
190
191 uint32_t i = 0;
192
193 #pragma unroll(4)
194 for (i = 0; i < nvec; i++) {
195 HVX_Vector y_hf = vy[i];
196 HVX_Vector x0_hf = vx0[i];
197 HVX_Vector x1_hf = vx1[i];
198
199 HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
200 HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
201
202 rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
203 rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
204 }
205
206 if (nloe) {
207 HVX_Vector y_hf = vy[i];
208
209 // Load x (fp16) and zero-out unused elements
210 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
211 HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
212 HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
213
214 HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
215 HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
216
217 rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
218 rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
219 }
220
221 HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
222 hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
223}
224
225// MAD: y (F32) += x (F16) * s (float)
226static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
227 const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
228 HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
229
230 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
231 uint32_t nloe = n % VLEN_FP16; // leftover elements
232
233 HVX_Vector S = hvx_vec_splat_f16(s);
234
235 uint32_t i = 0;
236 #pragma unroll(4)
237 for (i = 0; i < nvec; ++i) {
238 // Multiply x * s -> pair of F32 vectors
239 HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
240 ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
241 ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
242 }
243
244 if (nloe) {
245 HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
246
247 HVX_Vector xs = Q6_V_lo_W(xs_p);
248 i = 2 * i; // index for ptr_y
249
250 if (nloe >= 32) {
251 ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
252 nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
253 }
254
255 if (nloe) {
256 HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
257 hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
258 }
259 }
260}
261
262#define FLASH_ATTN_BLOCK_SIZE 128
263
264static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
265 const struct htp_tensor * q = &octx->src0;
266 const struct htp_tensor * k = &octx->src1;
267 const struct htp_tensor * v = &octx->src2;
268 const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
269 const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
270 struct htp_tensor * dst = &octx->dst;
271
272 const uint32_t neq0 = q->ne[0];
273 const uint32_t neq1 = q->ne[1];
274 const uint32_t neq2 = q->ne[2];
275 const uint32_t neq3 = q->ne[3];
276
277 const uint32_t nek0 = k->ne[0];
278 const uint32_t nek1 = k->ne[1];
279 const uint32_t nek2 = k->ne[2];
280 const uint32_t nek3 = k->ne[3];
281
282 const uint32_t nev0 = v->ne[0];
283 const uint32_t nev1 = v->ne[1];
284 const uint32_t nev2 = v->ne[2];
285 const uint32_t nev3 = v->ne[3];
286
287 const uint32_t nbq1 = q->nb[1];
288 const uint32_t nbq2 = q->nb[2];
289 const uint32_t nbq3 = q->nb[3];
290
291 const uint32_t nbk1 = k->nb[1];
292 const uint32_t nbk2 = k->nb[2];
293 const uint32_t nbk3 = k->nb[3];
294
295 const uint32_t nbv1 = v->nb[1];
296 const uint32_t nbv2 = v->nb[2];
297 const uint32_t nbv3 = v->nb[3];
298
299 const uint32_t ne1 = dst->ne[1];
300 const uint32_t ne2 = dst->ne[2];
301 const uint32_t ne3 = dst->ne[3];
302
303 const uint32_t nb1 = dst->nb[1];
304 const uint32_t nb2 = dst->nb[2];
305 const uint32_t nb3 = dst->nb[3];
306
307 float scale = 1.0f;
308 float max_bias = 0.0f;
309 float logit_softcap = 0.0f;
310
311 memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
312 memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
313 memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
314
315 if (logit_softcap != 0) {
316 scale /= logit_softcap;
317 }
318
319 // total rows in q
320 const uint32_t nr = neq1*neq2*neq3;
321
322 const uint32_t dr = (nr + nth - 1) / nth;
323 const uint32_t ir0 = dr * ith;
324 const uint32_t ir1 = MIN(ir0 + dr, nr);
325
326 if (ir0 >= ir1) return;
327
328 dma_queue * dma = octx->ctx->dma[ith];
329
330 const uint32_t DK = nek0;
331 const uint32_t DV = nev0;
332
333 const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
334 const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
335
336 const size_t size_k_row = DK * sizeof(__fp16);
337 const size_t size_v_row = DV * sizeof(__fp16);
338 const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
339
340 const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
341 const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
342
343 const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
344 const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
345 const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
346
347 // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
348 uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
349 uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
350 uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
351 uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
352 uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
353
354 const uint32_t n_head = neq2;
355 const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
356 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
357 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
358
359 for (uint32_t ir = ir0; ir < ir1; ++ir) {
360 const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
361 const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
362 const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
363
364 const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
365 const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
366
367 const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
368 const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
369
370 // Fetch Q row
371 const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
372 dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
373
374 const uint32_t h = iq2; // head index
375 const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
376
377 float S = 0.0f; // sum
378 float M = -INFINITY; // maximum KQ value
379
380 // Clear accumulator
381 hvx_splat_f32_a(spad_a, 0, DV);
382 float * VKQ32 = (float *) spad_a;
383
384 const __fp16 * mp_base = NULL;
385 if (mask) {
386 const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
387 const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
388 mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
389 }
390
391 const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
392
393 // Prefetch first two blocks
394 for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
395 const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
396 const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
397
398 // K
399 const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
400 uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
401 dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
402
403 // V
404 const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
405 uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
406 dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
407
408 // Mask
409 if (mask) {
410 const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
411 uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
412 // Mask is 1D contiguous for this row
413 dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
414 }
415 }
416
417 const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
418
419 for (uint32_t ib = 0; ib < n_blocks; ++ib) {
420 const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
421 const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
422
423 // Wait for DMA
424 uint8_t * k_base = dma_queue_pop(dma).dst; // K
425 uint8_t * v_base = dma_queue_pop(dma).dst; // V
426 __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
427
428 // Inner loop processing the block from VTCM
429 uint32_t ic = 0;
430
431 const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
432
433 // Process in blocks of 32 (VLEN_FP32)
434 static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
435 HVX_Vector_x4 scores_x4;
436 HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
437 for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
438 // 1. Compute scores
439 float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
440 for (int j = 0; j < VLEN_FP32; j += 2) {
441 const uint32_t cur_ic = ic + j;
442 const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
443 if (is_q_fp32) {
444 hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
445 } else {
446 hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
447 }
448 }
449
450 HVX_Vector scores = *(HVX_Vector *) scores_arr;
451
452 // 2. Softcap
453 if (logit_softcap != 0.0f) {
454 scores = hvx_vec_tanh_f32(scores);
455 scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
456 scores = Q6_Vsf_equals_Vqf32(scores);
457 }
458
459 // 3. Mask
460 if (mask) {
461 const __fp16 * mp = m_base + ic;
462 HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
463
464 HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
465 HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
466
467 HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
468
469 HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
470 HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
471 scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
472 scores = Q6_Vsf_equals_Vqf32(scores);
473 }
474
475 scores_x4.v[iv] = scores;
476 v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
477 }
478
479 {
480 // 4. Online Softmax Update
481 v_max = hvx_vec_reduce_max_f32(v_max);
482 float m_block = hvx_vec_get_f32(v_max);
483 float M_old = M;
484 float M_new = (m_block > M) ? m_block : M;
485 M = M_new;
486
487 const float ms = expf(M_old - M_new);
488 hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
489
490 HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
491 HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
492 for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
493 HVX_Vector scores = scores_x4.v[iv];
494 HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
495 HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
496
497 p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
498
499 // 5. Accumulate V
500 float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
501 *(HVX_Vector*)p_arr = P;
502
503 for (int j = 0; j < VLEN_FP32; ++j) {
504 const uint32_t cur_ic = ic2 + j;
505 const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
506 hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
507 }
508 }
509
510 p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
511 S = S * ms + hvx_vec_get_f32(p_sum_vec);
512 }
513
514 // Leftover
515 for (; ic < current_block_size; ++ic) {
516 float s_val;
517 const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
518
519 if (is_q_fp32) {
520 hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
521 } else {
522 hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
523 }
524
525 if (logit_softcap != 0.0f) {
526 s_val = logit_softcap * tanhf(s_val);
527 }
528
529 if (mask) {
530 const float m_val = m_base[ic];
531 s_val += slope * m_val;
532 }
533
534 const float Mold = M;
535 float ms = 1.0f;
536 float vs = 1.0f;
537
538 if (s_val > M) {
539 M = s_val;
540 ms = expf(Mold - M);
541 hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
542 } else {
543 vs = expf(s_val - M);
544 }
545
546 const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
547
548 hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
549
550 S = S * ms + vs;
551 }
552
553 // Issue DMA for next+1 block (if exists)
554 if (ib + 2 < n_blocks) {
555 const uint32_t next_ib = ib + 2;
556 const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
557 const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
558
559 // K
560 const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
561 dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
562
563 // V
564 const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
565 dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
566
567 // Mask
568 if (mask) {
569 const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
570 dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
571 }
572 }
573 }
574
575 // sinks
576 if (sinks) {
577 const float s = ((float *)((char *) sinks->data))[h];
578
579 float ms = 1.0f;
580 float vs = 1.0f;
581
582 if (s > M) {
583 ms = expf(M - s);
584 hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
585 } else {
586 vs = expf(s - M);
587 }
588
589 S = S * ms + vs;
590 }
591
592 const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
593 hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
594
595 // Store result
596 // dst indices
597 const int i1 = iq1;
598 const int i2 = iq2;
599 const int i3 = iq3;
600
601 // dst is permuted
602 uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
603
604 if (dst->type == HTP_TYPE_F32) {
605 hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
606 } else if (dst->type == HTP_TYPE_F16) {
607 hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
608 }
609 }
610}
611
612static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
613 struct htp_ops_context * octx = data;
614 flash_attn_ext_f16_thread(octx, i, n);
615}
616
617int op_flash_attn_ext(struct htp_ops_context * octx) {
618 const struct htp_tensor * q = &octx->src0;
619 const struct htp_tensor * k = &octx->src1;
620 const struct htp_tensor * v = &octx->src2;
621 const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
622 struct htp_tensor * dst = &octx->dst;
623
624 // Check support
625 if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
626 k->type != HTP_TYPE_F16 ||
627 v->type != HTP_TYPE_F16) {
628 return HTP_STATUS_NO_SUPPORT;
629 }
630
631 octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
632 octx->src0_div1 = init_fastdiv_values(q->ne[1]);
633
634 octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
635 octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
636 octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
637 octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
638
639 if (mask) {
640 octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
641 octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
642 }
643
644 size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
645 size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
646 size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
647
648 size_t size_q_block = size_q_row_padded * 1; // single row for now
649 size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
650 size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
651 size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
652
653 size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
654
655 octx->src0_spad.size_per_thread = size_q_block * 1;
656 octx->src1_spad.size_per_thread = size_k_block * 2;
657 octx->src2_spad.size_per_thread = size_v_block * 2;
658 octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
659 octx->dst_spad.size_per_thread = size_vkq_acc;
660
661 octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
662 octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
663 octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
664 octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
665 octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
666
667 size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
668
669 if (octx->ctx->vtcm_size < total_spad) {
670 return HTP_STATUS_VTCM_TOO_SMALL;
671 }
672
673 octx->src0_spad.data = octx->ctx->vtcm_base;
674 octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
675 octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
676 octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
677 octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
678
679 if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
680 worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
681 }
682
683 return HTP_STATUS_OK;
684}