1#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
2#pragma clang diagnostic ignored "-Wunused-function"
3#pragma clang diagnostic ignored "-Wunused-variable"
4#pragma clang diagnostic ignored "-Wunused-but-set-variable"
5
6#include <HAP_farf.h>
7#include <HAP_perf.h>
8
9#include <math.h>
10#include <string.h>
11
12#include "hex-dma.h"
13#include "hvx-utils.h"
14#include "hvx-dump.h"
15
16#define GGML_COMMON_DECL_C
17#include "ggml-common.h"
18#include "htp-ctx.h"
19#include "htp-msg.h"
20#include "htp-ops.h"
21
22#define MM_SPAD_SRC0_NROWS 16
23#define MM_SPAD_SRC1_NROWS 16
24#define MM_SPAD_DST_NROWS 2
25
26struct htp_matmul_context {
27 const char * type;
28 struct htp_ops_context * octx;
29
30 void (*vec_dot_1x1)(const int n, float * restrict s0,
31 const void * restrict vx0,
32 const void * restrict vy0);
33
34 void (*vec_dot_2x1)(const int n, float * restrict s0,
35 const void * restrict vx0, const void * restrict vx1,
36 const void * restrict vy0);
37
38 void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,
39 const void * restrict vx0, const void * restrict vx1,
40 const void * restrict vy0, const void * restrict vy1);
41
42 // Precomputed values
43 uint32_t src0_nrows_per_thread;
44 uint32_t src1_nrows_per_thread;
45
46 struct fastdiv_values mm_div_ne12_ne1;
47 struct fastdiv_values mm_div_ne1;
48 struct fastdiv_values mm_div_r2;
49 struct fastdiv_values mm_div_r3;
50};
51
52// vdelta control to replicate first 4x fp32 values across lanes
53static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
54 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
55 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
56 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
57 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
58 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
59 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
60 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
61};
62
63// vdelta control to replicate and interleave first 8x fp32 values across lanes
64static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
65 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
66 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
67 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
68 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
69 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
70 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
71 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
72};
73
74// vdelta control to replicate first fp32 value across all elements
75static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
76 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
77 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
78 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
79 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
80 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
81 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
82 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
83};
84
85// vdelta control to replicate first fp16 value across all elements
86static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
87 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
88 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
89 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
90 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
91 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
92 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
93 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
94};
95
96// vdelta control to replicate first fp16 value across all elements
97static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
98 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
99 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
100 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
101 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
102 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
103 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
104 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
105 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
106};
107
108// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
109static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
110 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
111 0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04,
112 0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02,
113 0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08,
114 0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48,
115 0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00,
116 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
117};
118
119static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
120 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
121 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
122 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
123 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
124 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
125};
126
127// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
128
129static inline size_t q8x4x2_row_size(uint32_t ne) {
130 // ensures perfect alignment of quants and full row
131 const uint32_t qk = QK_Q8_0x4x2;
132 const uint32_t nb = (ne + qk - 1) / qk;
133 return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
134}
135
136static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
137 const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
138
139 HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
140 HVX_Vector v2_3 = vptr[1]; // ...
141 HVX_Vector v4_5 = vptr[2]; // ...
142 HVX_Vector v6_7 = vptr[3]; // ...
143
144 const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
145 const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
146
147 HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
148 HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
149 HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
150 HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
151 HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
152 HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
153 HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
154 HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
155
156 // Convert uint4 to int4 (i.e. x - 8)
157 v0 = Q6_Vb_vsub_VbVb(v0, i8);
158 v1 = Q6_Vb_vsub_VbVb(v1, i8);
159 v2 = Q6_Vb_vsub_VbVb(v2, i8);
160 v3 = Q6_Vb_vsub_VbVb(v3, i8);
161 v4 = Q6_Vb_vsub_VbVb(v4, i8);
162 v5 = Q6_Vb_vsub_VbVb(v5, i8);
163 v6 = Q6_Vb_vsub_VbVb(v6, i8);
164 v7 = Q6_Vb_vsub_VbVb(v7, i8);
165
166 HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
167 return r;
168}
169
170static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
171 const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
172
173 HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
174 HVX_Vector v2_3 = vptr[1]; // ...
175 HVX_Vector v4_5 = vptr[2]; // ...
176 HVX_Vector v6_7 = vptr[3]; // ...
177
178 const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
179 const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
180
181 HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
182 HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
183 HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
184 HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
185 HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
186 HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
187 HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
188 HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
189
190 v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
191 v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
192 v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
193 v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
194 v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
195 v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
196 v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
197 v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
198
199 HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
200 return r;
201}
202
203static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
204 const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
205
206 HVX_Vector v0 = vptr[0]; // first 128 vals
207 HVX_Vector v1 = vptr[1]; // ...
208 HVX_Vector v2 = vptr[2]; // ...
209 HVX_Vector v3 = vptr[3]; // ...
210 HVX_Vector v4 = vptr[4]; // ...
211 HVX_Vector v5 = vptr[5]; // ...
212 HVX_Vector v6 = vptr[6]; // ...
213 HVX_Vector v7 = vptr[7]; // ...
214
215 HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
216 return r;
217}
218
219// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
220// Accumulate each block into a single int32 value.
221// Return a single HVX vector with 32x int32 accumulators.
222// This version is parameterized to support less than 1024 elements.
223// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
224
225static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
226 HVX_Vector r0 = Q6_V_vsplat_R(0);
227 HVX_Vector r1 = Q6_V_vsplat_R(0);
228 HVX_Vector r2 = Q6_V_vsplat_R(0);
229 HVX_Vector r3 = Q6_V_vsplat_R(0);
230 HVX_Vector r4 = Q6_V_vsplat_R(0);
231 HVX_Vector r5 = Q6_V_vsplat_R(0);
232 HVX_Vector r6 = Q6_V_vsplat_R(0);
233 HVX_Vector r7 = Q6_V_vsplat_R(0);
234
235 HVX_VectorPair p3;
236 HVX_VectorPair p2;
237 HVX_VectorPair p1;
238 HVX_VectorPair p0;
239
240 if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); }
241 if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); }
242 if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); }
243 if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); }
244 if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); }
245 if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); }
246 if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); }
247 if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); }
248
249 if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
250 if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
251 if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); }
252 if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); }
253
254 if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
255 if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
256 if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); }
257 if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); }
258
259 if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
260 if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
261
262 if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
263 if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
264
265 if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
266 if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
267
268 return r0;
269}
270
271static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
272 return hvx_vec_rmpy_x8_n(x, y, 1024);
273}
274
275// Handle most common cases of tensors not multiple of 1024.
276static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
277 if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
278 if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
279 if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
280 return hvx_vec_rmpy_x8_n(x, y, 1024);
281}
282
283static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
284 assert(n % 32 == 0); // min sub-block size
285 assert((unsigned long) vx0 % 128 == 0);
286 assert((unsigned long) vy0 % 128 == 0);
287
288 const uint32_t qk = QK_Q4_0x4x2 * 4;
289
290 const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
291 const uint32_t x_qblk_size = qk / 2; // int4
292 const uint32_t x_qrow_size = n / 2; // int4 (not padded)
293
294 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
295 const uint32_t y_qblk_size = qk; // int8
296 const uint32_t y_qrow_size = n; // int8 (not padded)
297
298 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
299 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
300
301 const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
302 const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
303
304 // Row sum (sf)
305 HVX_Vector r0_sum = Q6_V_vsplat_R(0);
306
307 // Multiply and accumulate into int32.
308 // Compute combined scale (fp32).
309 // Apply scale to acc and accumulate into the row sum (qf32).
310
311 const uint32_t nb = n / qk; // num full blocks
312 const uint32_t nloe = n % qk; // num leftover elemements
313
314 uint32_t i = 0;
315 for (; i < nb; i++) {
316 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
317 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
318
319 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
320
321 HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
322 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
323
324 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
325
326 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
327
328 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
329 }
330
331 // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
332 if (nloe) {
333 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
334 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
335
336 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
337
338 HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
339 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
340
341 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
342
343 // Zero out unused scales
344 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
345 r0_dd = Q6_V_vand_QV(bmask, r0_dd);
346 r0_ia = Q6_V_vand_QV(bmask, r0_ia);
347
348 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
349
350 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
351 }
352
353 r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
354
355 hvx_vec_store_u(s0, 4, r0_sum);
356}
357
358static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
359 const void * restrict vx0, const void * restrict vx1,
360 const void * restrict vy0) {
361 assert(n % 32 == 0); // min sub-block size
362 assert((unsigned long) vx0 % 128 == 0);
363 assert((unsigned long) vx1 % 128 == 0);
364 assert((unsigned long) vy0 % 128 == 0);
365
366 const uint32_t qk = QK_Q4_0x4x2 * 4;
367
368 const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
369 const uint32_t x_qblk_size = qk / 2; // int4
370 const uint32_t x_qrow_size = n / 2; // int4 (not padded)
371
372 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
373 const uint32_t y_qblk_size = qk; // int8
374 const uint32_t y_qrow_size = n; // int8 (not padded)
375
376 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
377 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
378 const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
379 const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
380
381 const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
382 const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
383
384 // Row sum (sf)
385 HVX_Vector r0_sum = Q6_V_vsplat_R(0);
386 HVX_Vector r1_sum = Q6_V_vsplat_R(0);
387
388 // Multiply and accumulate into int32.
389 // Compute combined scale (fp32).
390 // Apply scale to acc and accumulate into the row sum (qf32).
391
392 const uint32_t nb = n / qk; // num full blocks
393 const uint32_t nloe = n % qk; // num leftover elemements
394
395 uint32_t i = 0;
396 for (; i < nb; i++) {
397 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
398 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
399 HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
400
401 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
402 HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
403
404 HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
405 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
406 HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
407
408 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
409 HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
410
411 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
412 HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
413
414 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
415 r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
416 }
417
418 // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
419 if (nloe) {
420 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
421 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
422 HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
423
424 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
425 HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
426
427 HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
428 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
429 HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
430
431 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
432 HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
433
434 // Zero out unused scales
435 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
436 r0_dd = Q6_V_vand_QV(bmask, r0_dd);
437 r1_dd = Q6_V_vand_QV(bmask, r1_dd);
438 r0_ia = Q6_V_vand_QV(bmask, r0_ia);
439 r1_ia = Q6_V_vand_QV(bmask, r1_ia);
440
441 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
442 HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
443
444 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
445 r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
446 }
447
448 HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
449 hvx_vec_store_u(s0, 8, rsum);
450}
451
452static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
453 const void * restrict vx0, const void * restrict vx1,
454 const void * restrict vy0, const void * restrict vy1) {
455 assert(n % 32 == 0);
456 assert((unsigned long) vx0 % 128 == 0);
457 assert((unsigned long) vx1 % 128 == 0);
458 assert((unsigned long) vy0 % 128 == 0);
459 assert((unsigned long) vy1 % 128 == 0);
460
461 const uint32_t qk = QK_Q4_0x4x2 * 4;
462
463 const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
464 const uint32_t x_qblk_size = qk / 2; // int4
465 const uint32_t x_qrow_size = n / 2; // int4 (not padded)
466
467 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
468 const uint32_t y_qblk_size = qk; // int8
469 const uint32_t y_qrow_size = n; // int8 (not padded)
470
471 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
472 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
473 const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
474 const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
475
476 const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
477 const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
478 const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
479 const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
480
481 // Row sums (sf) - 4 accumulators for 2ร2 tile
482 HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
483 HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
484 HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
485 HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
486
487 const uint32_t nb = n / qk; // num full blocks
488 const uint32_t nloe = n % qk; // num leftover elements
489
490 uint32_t i = 0;
491 for (; i < nb; i++) {
492 // Load src1 columns (reused across both src0 rows)
493 HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
494 HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
495
496 // Load src0 rows (reused across both src1 columns)
497 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
498 HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
499
500 // Compute 4 dot products: r0รc0, r0รc1, r1รc0, r1รc1
501 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
502 HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
503 HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
504 HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
505
506 // Load scales
507 HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
508 HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
509 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
510 HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
511
512 // Compute combined scales
513 HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
514 HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
515 HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
516 HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
517
518 // Apply scales and accumulate
519 HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
520 HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
521 HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
522 HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
523
524 r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
525 r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
526 r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
527 r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
528 }
529
530 // Process leftovers
531 if (nloe) {
532 HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
533 HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
534 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
535 HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
536
537 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
538 HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
539 HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
540 HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
541
542 HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
543 HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
544 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
545 HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
546
547 HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
548 HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
549 HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
550 HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
551
552 // Zero out unused scales
553 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
554 r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
555 r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
556 r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
557 r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
558 r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
559 r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
560 r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
561 r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
562
563 HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
564 HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
565 HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
566 HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
567
568 r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
569 r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
570 r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
571 r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
572 }
573
574 // Reduce and store results
575 HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
576 HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
577
578 hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0
579 hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
580}
581
582static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
583 assert(n % 32 == 0); // min sub-block size
584 assert((unsigned long) vx0 % 128 == 0);
585 assert((unsigned long) vy0 % 128 == 0);
586
587 const uint32_t qk = QK_Q4_0x4x2 * 4;
588
589 const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
590 const uint32_t x_qblk_size = qk; // int8
591 const uint32_t x_qrow_size = n; // int8 (not padded)
592
593 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
594 const uint32_t y_qblk_size = qk; // int8
595 const uint32_t y_qrow_size = n; // int8 (not padded)
596
597 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
598 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
599
600 const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
601 const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
602
603 // Row sum (sf)
604 HVX_Vector r0_sum = Q6_V_vsplat_R(0);
605
606 // Multiply and accumulate into int32.
607 // Compute combined scale (fp32).
608 // Apply scale to acc and accumulate into the row sum (qf32).
609
610 const uint32_t nb = n / qk; // num full blocks
611 int32_t nloe = n % qk; // num leftover elemements (must be signed)
612
613 uint32_t i = 0;
614 for (; i < nb; i++) {
615 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
616 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
617
618 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
619
620 HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
621 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
622
623 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
624
625 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
626
627 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
628 }
629
630 // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
631 if (nloe) {
632 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
633 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
634
635 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
636
637 HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
638 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
639
640 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
641
642 // Zero out unused scales
643 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
644 r0_dd = Q6_V_vand_QV(bmask, r0_dd);
645 r0_ia = Q6_V_vand_QV(bmask, r0_ia);
646
647 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
648
649 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
650 }
651
652 r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
653
654 hvx_vec_store_u(s0, 4, r0_sum);
655}
656
657static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
658 const void * restrict vx0, const void * restrict vx1,
659 const void * restrict vy0) {
660 assert(n % 32 == 0); // min sub-block size
661 assert((unsigned long) vx0 % 128 == 0);
662 assert((unsigned long) vx1 % 128 == 0);
663 assert((unsigned long) vy0 % 128 == 0);
664
665 const uint32_t qk = QK_Q4_0x4x2 * 4;
666
667 const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
668 const uint32_t x_qblk_size = qk; // int8
669 const uint32_t x_qrow_size = n; // int8 (not padded)
670
671 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
672 const uint32_t y_qblk_size = qk; // int8
673 const uint32_t y_qrow_size = n; // int8 (not padded)
674
675 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
676 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
677 const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
678 const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
679
680 const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
681 const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
682
683 // Row sum (qf32)
684 HVX_Vector r0_sum = Q6_V_vsplat_R(0);
685 HVX_Vector r1_sum = Q6_V_vsplat_R(0);
686
687 // Multiply and accumulate into int32.
688 // Compute combined scale (fp32).
689 // Apply scale to acc and accumulate into the row sum (qf32).
690
691 const uint32_t nb = n / qk; // num full blocks
692 int32_t nloe = n % qk; // num leftover elemements (must be signed)
693
694 uint32_t i = 0;
695 for (; i < nb; i++) {
696 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
697 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
698 HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
699
700 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
701 HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
702
703 HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
704 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
705 HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
706
707 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
708 HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
709
710 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
711 HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
712
713 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
714 r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
715 }
716
717 // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
718 if (nloe) {
719 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
720 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
721 HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
722
723 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
724 HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
725
726 HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
727 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
728 HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
729
730 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
731 HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
732
733 // Zero out unused scales
734 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
735 r0_dd = Q6_V_vand_QV(bmask, r0_dd);
736 r1_dd = Q6_V_vand_QV(bmask, r1_dd);
737 r0_ia = Q6_V_vand_QV(bmask, r0_ia);
738 r1_ia = Q6_V_vand_QV(bmask, r1_ia);
739
740 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
741 HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
742
743 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
744 r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
745 }
746
747 HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
748 hvx_vec_store_u(s0, 8, rsum);
749}
750
751static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
752 const void * restrict vx0, const void * restrict vx1,
753 const void * restrict vy0, const void * restrict vy1) {
754 assert(n % 32 == 0);
755 assert((unsigned long) vx0 % 128 == 0);
756 assert((unsigned long) vx1 % 128 == 0);
757 assert((unsigned long) vy0 % 128 == 0);
758 assert((unsigned long) vy1 % 128 == 0);
759
760 const uint32_t qk = QK_Q8_0x4x2 * 4;
761
762 const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
763 const uint32_t x_qblk_size = qk; // int8
764 const uint32_t x_qrow_size = n; // int8 (not padded)
765
766 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
767 const uint32_t y_qblk_size = qk; // int8
768 const uint32_t y_qrow_size = n; // int8 (not padded)
769
770 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
771 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
772 const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
773 const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
774
775 const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
776 const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
777 const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
778 const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
779
780 // Row sums (sf) - 4 accumulators for 2ร2 tile
781 HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
782 HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
783 HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
784 HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
785
786 const uint32_t nb = n / qk; // num full blocks
787 const uint32_t nloe = n % qk; // num leftover elements
788
789 uint32_t i = 0;
790 for (; i < nb; i++) {
791 // Load src1 columns (reused across both src0 rows)
792 HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
793 HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
794
795 // Load src0 rows (reused across both src1 columns)
796 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
797 HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
798
799 // Compute 4 dot products: r0รc0, r0รc1, r1รc0, r1รc1
800 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
801 HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
802 HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
803 HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
804
805 // Load scales
806 HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
807 HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
808 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
809 HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
810
811 // Compute combined scales
812 HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
813 HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
814 HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
815 HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
816
817 // Apply scales and accumulate
818 HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
819 HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
820 HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
821 HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
822
823 r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
824 r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
825 r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
826 r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
827 }
828
829 // Process leftovers
830 if (nloe) {
831 HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
832 HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
833 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
834 HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
835
836 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
837 HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
838 HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
839 HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
840
841 HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
842 HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
843 HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
844 HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
845
846 HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
847 HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
848 HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
849 HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
850
851 // Zero out unused scales
852 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
853 r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
854 r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
855 r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
856 r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
857 r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
858 r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
859 r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
860 r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
861
862 HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
863 HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
864 HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
865 HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
866
867 r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
868 r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
869 r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
870 r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
871 }
872
873 // Reduce and store results
874 HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
875 HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
876
877 hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
878 hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
879}
880
881static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
882 assert(n % 32 == 0); // min sub-block size
883 assert((unsigned long) vx0 % 128 == 0);
884 assert((unsigned long) vy0 % 128 == 0);
885
886 const uint32_t qk = QK_MXFP4x4x2 * 4;
887
888 const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
889 const uint32_t x_qblk_size = qk / 2; // fp4
890 const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
891
892 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
893 const uint32_t y_qblk_size = qk; // int8
894 const uint32_t y_qrow_size = n; // int8 (not padded)
895
896 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
897 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
898
899 const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
900 const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
901
902 // Row sum (sf)
903 HVX_Vector r0_sum = Q6_V_vsplat_R(0);
904
905 // Multiply and accumulate into int32.
906 // Compute combined scale (fp32).
907 // Apply scale to acc and accumulate into the row sum (qf32).
908
909 const uint32_t nb = n / qk; // num full blocks
910 int32_t nloe = n % qk; // num leftover elemements (must be signed)
911
912 uint32_t i = 0;
913 for (; i < nb; i++) {
914 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
915 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
916
917 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
918
919 HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
920 HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
921
922 // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
923 HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
924 vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
925 vy_d = Q6_Vsf_equals_Vqf32(vy_d);
926
927 // Convert rX_d scales from e8m0 to fp32
928 // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
929 // Left shift with zero fill to create FP32
930 // FIXME: might need to handle zero as a special case (see ggml-cpu code)
931 HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
932 HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
933 r0_d = Q6_V_vdelta_VV(r0_d, expand);
934 r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
935 r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
936
937 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
938
939 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
940
941 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
942 }
943
944 // Process leftovers
945 if (nloe) {
946 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
947 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
948
949 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
950
951 HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
952 HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
953
954 // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
955 HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
956 vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
957 vy_d = Q6_Vsf_equals_Vqf32(vy_d);
958
959 // Convert rX_d scales from e8m0 to fp32
960 // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
961 // Left shift with zero fill to create FP32
962 // FIXME: might need to handle zero as a special case (see ggml-cpu code)
963 HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
964 HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
965 r0_d = Q6_V_vdelta_VV(r0_d, expand);
966 r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
967 r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
968
969 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
970
971 // Zero-out unused scales
972 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
973 r0_dd = Q6_V_vand_QV(bmask, r0_dd);
974 r0_ia = Q6_V_vand_QV(bmask, r0_ia);
975
976 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
977
978 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
979 }
980
981 r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
982
983 hvx_vec_store_u(s0, 4, r0_sum);
984}
985
986static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
987 const void * restrict vx0, const void * restrict vx1,
988 const void * restrict vy0) {
989 assert(n % 32 == 0); // min sub-block size
990 assert((unsigned long) vx0 % 128 == 0);
991 assert((unsigned long) vx1 % 128 == 0);
992 assert((unsigned long) vy0 % 128 == 0);
993
994 const uint32_t qk = QK_MXFP4x4x2 * 4;
995
996 const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
997 const uint32_t x_qblk_size = qk / 2; // fp4
998 const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
999
1000 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1001 const uint32_t y_qblk_size = qk; // int8
1002 const uint32_t y_qrow_size = n; // int8 (not padded)
1003
1004 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
1005 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
1006 const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
1007 const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
1008
1009 const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
1010 const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
1011
1012 // Row sum (sf)
1013 HVX_Vector r0_sum = Q6_V_vsplat_R(0);
1014 HVX_Vector r1_sum = Q6_V_vsplat_R(0);
1015
1016 // Multiply and accumulate into int32.
1017 // Compute combined scale (fp32).
1018 // Apply scale to acc and accumulate into the row sum (f32).
1019
1020 const uint32_t nb = n / qk; // num full blocks
1021 int32_t nloe = n % qk; // num leftover elemements (must be signed)
1022
1023 uint32_t i = 0;
1024 for (; i < nb; i++) {
1025 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
1026 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
1027 HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1028
1029 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1030 HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
1031
1032 HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
1033 HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1034 HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1035
1036 // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1037 HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
1038 vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
1039 vy_d = Q6_Vsf_equals_Vqf32(vy_d);
1040
1041 // Convert rX_d scales from e8m0 to fp32
1042 // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1043 // Left shift with zero fill to create FP32
1044 // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1045 HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
1046 HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1047 r0_d = Q6_V_vdelta_VV(r0_d, expand);
1048 r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
1049 r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
1050 r1_d = Q6_V_vdelta_VV(r1_d, expand);
1051 r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
1052 r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
1053
1054 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
1055 HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
1056
1057 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1058 HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1059
1060 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1061 r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1062 }
1063
1064 // Process leftovers
1065 if (nloe) {
1066 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
1067 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
1068 HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1069
1070 HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1071 HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
1072
1073 HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
1074 HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1075 HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1076
1077 // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1078 HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
1079 vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
1080 vy_d = Q6_Vsf_equals_Vqf32(vy_d);
1081
1082 // Convert rX_d scales from e8m0 to fp32
1083 // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1084 // Left shift with zero fill to create FP32
1085 // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1086 HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
1087 HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1088 r0_d = Q6_V_vdelta_VV(r0_d, expand);
1089 r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
1090 r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
1091 r1_d = Q6_V_vdelta_VV(r1_d, expand);
1092 r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
1093 r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
1094
1095 HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
1096 HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
1097
1098 // Zero-out unused values
1099 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1100 r0_dd = Q6_V_vand_QV(bmask, r0_dd);
1101 r1_dd = Q6_V_vand_QV(bmask, r1_dd);
1102 r0_ia = Q6_V_vand_QV(bmask, r0_ia);
1103 r1_ia = Q6_V_vand_QV(bmask, r1_ia);
1104
1105 HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1106 HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1107
1108 r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1109 r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1110 }
1111
1112 HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
1113 hvx_vec_store_u(s0, 8, rsum);
1114}
1115
1116static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
1117 const void * restrict vx0, const void * restrict vx1,
1118 const void * restrict vy0, const void * restrict vy1) {
1119 assert(n % 32 == 0);
1120 assert((unsigned long) vx0 % 128 == 0);
1121 assert((unsigned long) vx1 % 128 == 0);
1122 assert((unsigned long) vy0 % 128 == 0);
1123 assert((unsigned long) vy1 % 128 == 0);
1124
1125 const uint32_t qk = QK_MXFP4x4x2 * 4;
1126
1127 const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
1128 const uint32_t x_qblk_size = qk / 2; // fp4
1129 const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
1130
1131 const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1132 const uint32_t y_qblk_size = qk; // int8
1133 const uint32_t y_qrow_size = n; // int8 (not padded)
1134
1135 const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
1136 const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
1137 const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
1138 const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
1139
1140 const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
1141 const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
1142 const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
1143 const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
1144
1145 // Row sums (sf) - 4 accumulators for 2ร2 tile
1146 HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
1147 HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
1148 HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
1149 HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
1150
1151 const uint32_t nb = n / qk; // num full blocks
1152 const uint32_t nloe = n % qk; // num leftover elements
1153
1154 uint32_t i = 0;
1155 for (; i < nb; i++) {
1156 // Load src1 columns (reused across both src0 rows)
1157 HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
1158 HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
1159
1160 // Load src0 rows (reused across both src1 columns)
1161 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
1162 HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1163
1164 // Compute 4 dot products: r0รc0, r0รc1, r1รc0, r1รc1
1165 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
1166 HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
1167 HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
1168 HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
1169
1170 // Load scales
1171 HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
1172 HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
1173 HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1174 HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1175
1176 // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1177 HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
1178 vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
1179 vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
1180 vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
1181 vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
1182
1183 // Convert rX_d scales from e8m0 to fp32
1184 // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1185 // Left shift with zero fill to create FP32
1186 // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1187 HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
1188 HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1189 r0_d = Q6_V_vdelta_VV(r0_d, expand);
1190 r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
1191 r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
1192 r1_d = Q6_V_vdelta_VV(r1_d, expand);
1193 r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
1194 r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
1195
1196 // Compute combined scales
1197 HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
1198 HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
1199 HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
1200 HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
1201
1202 // Apply scales and accumulate
1203 HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1204 HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1205 HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1206 HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1207
1208 r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1209 r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1210 r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1211 r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1212 }
1213
1214 // Process leftovers
1215 if (nloe) {
1216 HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
1217 HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
1218 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
1219 HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1220
1221 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
1222 HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
1223 HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
1224 HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
1225
1226 HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
1227 HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
1228 HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1229 HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1230
1231 // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1232 HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
1233 vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
1234 vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
1235 vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
1236 vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
1237
1238 // Convert rX_d scales from e8m0 to fp32
1239 // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1240 // Left shift with zero fill to create FP32
1241 // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1242 HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
1243 HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1244 r0_d = Q6_V_vdelta_VV(r0_d, expand);
1245 r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
1246 r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
1247 r1_d = Q6_V_vdelta_VV(r1_d, expand);
1248 r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
1249 r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
1250
1251 HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
1252 HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
1253 HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
1254 HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
1255
1256 // Zero out unused scales
1257 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1258 r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
1259 r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
1260 r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
1261 r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
1262 r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
1263 r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
1264 r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
1265 r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
1266
1267 HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1268 HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1269 HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1270 HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1271
1272 r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1273 r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1274 r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1275 r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1276 }
1277
1278 // Reduce and store results
1279 HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
1280 HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
1281
1282 hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
1283 hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
1284}
1285
1286static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1287 const HVX_Vector * restrict x = (const HVX_Vector *) vx;
1288 const HVX_Vector * restrict y = (const HVX_Vector *) vy;
1289
1290 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
1291 uint32_t nloe = n % VLEN_FP16; // leftover elements
1292
1293 HVX_Vector rsum = Q6_V_vsplat_R(0);
1294
1295 uint32_t i = 0;
1296
1297 #pragma unroll(4)
1298 for (i = 0; i < nvec; i++) {
1299 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
1300 rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1301 }
1302
1303 if (nloe) {
1304 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1305 HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
1306 HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
1307
1308 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1309 rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1310 }
1311
1312 rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
1313 hvx_vec_store_u(&s[0], 4, rsum);
1314}
1315
1316static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
1317 const void * restrict vx0, const void * restrict vx1,
1318 const void * restrict vy0) {
1319 const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
1320 const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
1321 const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
1322
1323 uint32_t nvec = n / VLEN_FP16;
1324 uint32_t nloe = n % VLEN_FP16;
1325
1326 HVX_Vector rsum0 = Q6_V_vsplat_R(0);
1327 HVX_Vector rsum1 = Q6_V_vsplat_R(0);
1328
1329 uint32_t i = 0;
1330
1331 #pragma unroll(2)
1332 for (i = 0; i < nvec; i++) {
1333 HVX_Vector y_hf = y[i];
1334 HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
1335 HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
1336
1337 rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
1338 rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
1339 }
1340
1341 if (nloe) {
1342 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1343 HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
1344 HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
1345 HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
1346
1347 HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
1348 HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
1349
1350 rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
1351 rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
1352 }
1353
1354 HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
1355 hvx_vec_store_u(s0, 8, rsum);
1356}
1357
1358static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,
1359 const void * restrict vx0, const void * restrict vx1,
1360 const void * restrict vy0, const void * restrict vy1) {
1361 const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
1362 const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
1363 const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
1364 const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
1365
1366 uint32_t nvec = n / VLEN_FP16;
1367 uint32_t nloe = n % VLEN_FP16;
1368
1369 // Row sums (sf) - 4 accumulators for 2ร2 tile
1370 HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
1371 HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
1372 HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
1373 HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
1374
1375 uint32_t i = 0;
1376
1377 #pragma unroll(2)
1378 for (i = 0; i < nvec; i++) {
1379 HVX_Vector r0_hf = x0[i];
1380 HVX_Vector r1_hf = x1[i];
1381 HVX_Vector c0_hf = y0[i];
1382 HVX_Vector c1_hf = y1[i];
1383
1384 // Compute 4 dot products: r0รc0, r0รc1, r1รc0, r1รc1
1385 HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
1386 HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
1387 HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
1388 HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
1389
1390 HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
1391 HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
1392 HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
1393 HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
1394
1395 r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
1396 r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
1397 r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
1398 r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
1399 }
1400
1401 if (nloe) {
1402 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1403
1404 HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
1405 HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
1406 HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
1407 HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
1408
1409 HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
1410 HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
1411 HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
1412 HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
1413
1414 HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
1415 HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
1416 HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
1417 HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
1418
1419 r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
1420 r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
1421 r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
1422 r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
1423
1424 }
1425
1426 // Reduce and store results
1427 HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
1428 HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
1429
1430 hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
1431 hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
1432}
1433
1434static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1435 const HVX_UVector * restrict x = (const HVX_UVector *) vx;
1436 const HVX_UVector * restrict y = (const HVX_UVector *) vy;
1437
1438 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
1439 uint32_t nloe = n % VLEN_FP16; // leftover elements
1440
1441 HVX_Vector rsum = Q6_V_vsplat_R(0);
1442
1443 uint32_t i = 0;
1444
1445 #pragma unroll(4)
1446 for (i = 0; i < nvec; i++) {
1447 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
1448 rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1449 }
1450
1451 if (nloe) {
1452 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1453 HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
1454 HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
1455
1456 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1457 rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1458 }
1459
1460 rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
1461 hvx_vec_store_u(&s[0], 4, rsum);
1462}
1463
1464static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
1465 const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
1466 const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
1467
1468 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
1469 uint32_t nloe = n % VLEN_FP16; // leftover elements
1470
1471 const HVX_Vector zero = Q6_V_vsplat_R(0);
1472
1473 HVX_Vector rsum = Q6_V_vsplat_R(0);
1474
1475 uint32_t i = 0;
1476
1477 #pragma unroll(2)
1478 for (i = 0; i < nvec; i++) {
1479 // Load y (fp32) and convert into fp16
1480 HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
1481 HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
1482 HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
1483
1484 // Load x (fp16)
1485 HVX_Vector x_hf = vx[i];
1486
1487 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1488
1489 rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1490 }
1491
1492 if (nloe) {
1493 // Load y (fp32) and convert into fp16
1494 HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
1495 HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
1496 HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
1497
1498 // Load x (fp16)
1499 HVX_Vector x_hf = vx[i];
1500
1501 // Zero-out unused elements
1502 // Note that we need to clear both x and y because they may contain NANs
1503 HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1504 x_hf = Q6_V_vand_QV(bmask, x_hf);
1505 y_hf = Q6_V_vand_QV(bmask, y_hf);
1506
1507 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1508
1509 rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1510 }
1511
1512 // Convert into fp32 and reduce
1513 rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
1514 hvx_vec_store_u(&s[0], 4, rsum);
1515}
1516
1517#define htp_matmul_tensors_preamble \
1518 struct htp_tensor * restrict src0 = &octx->src0; \
1519 struct htp_tensor * restrict src1 = &octx->src1; \
1520 struct htp_tensor * restrict src2 = &octx->src2; \
1521 struct htp_tensor * restrict dst = &octx->dst; \
1522 struct htp_spad * restrict src0_spad = &octx->src0_spad; \
1523 struct htp_spad * restrict src1_spad = &octx->src1_spad; \
1524 struct htp_spad * restrict dst_spad = &octx->dst_spad; \
1525 \
1526 const uint32_t ne00 = src0->ne[0]; \
1527 const uint32_t ne01 = src0->ne[1]; \
1528 const uint32_t ne02 = src0->ne[2]; \
1529 const uint32_t ne03 = src0->ne[3]; \
1530 \
1531 const uint32_t ne10 = src1->ne[0]; \
1532 const uint32_t ne11 = src1->ne[1]; \
1533 const uint32_t ne12 = src1->ne[2]; \
1534 const uint32_t ne13 = src1->ne[3]; \
1535 \
1536 const uint32_t ne20 = src2->ne[0]; \
1537 const uint32_t ne21 = src2->ne[1]; \
1538 const uint32_t ne22 = src2->ne[2]; \
1539 const uint32_t ne23 = src2->ne[3]; \
1540 \
1541 const uint32_t ne0 = dst->ne[0]; \
1542 const uint32_t ne1 = dst->ne[1]; \
1543 const uint32_t ne2 = dst->ne[2]; \
1544 const uint32_t ne3 = dst->ne[3]; \
1545 \
1546 const uint32_t nb00 = src0->nb[0]; \
1547 const uint32_t nb01 = src0->nb[1]; \
1548 const uint32_t nb02 = src0->nb[2]; \
1549 const uint32_t nb03 = src0->nb[3]; \
1550 \
1551 const uint32_t nb10 = src1->nb[0]; \
1552 const uint32_t nb11 = src1->nb[1]; \
1553 const uint32_t nb12 = src1->nb[2]; \
1554 const uint32_t nb13 = src1->nb[3]; \
1555 \
1556 const uint32_t nb0 = dst->nb[0]; \
1557 const uint32_t nb1 = dst->nb[1]; \
1558 const uint32_t nb2 = dst->nb[2]; \
1559 const uint32_t nb3 = dst->nb[3];
1560
1561#define htp_matmul_preamble \
1562 struct htp_matmul_context * mmctx = data; \
1563 struct htp_ops_context * octx = mmctx->octx; \
1564 htp_matmul_tensors_preamble; \
1565 dma_queue *dma_queue = octx->ctx->dma[ith]; \
1566 uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;
1567
1568// *** matmul with support for 4d tensors and full broadcasting
1569
1570static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
1571 htp_matmul_preamble;
1572
1573 uint64_t t1, t2;
1574 t1 = HAP_perf_get_qtimer_count();
1575
1576 assert(ne12 % ne02 == 0);
1577 assert(ne13 % ne03 == 0);
1578
1579 // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
1580 const uint32_t nr0 = ne0;
1581
1582 // This is the size of the rest of the dimensions of the result
1583 const uint32_t nr1 = ne1 * ne2 * ne3;
1584
1585 // distribute the thread work across the inner or outer loop based on which one is larger
1586 uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
1587 uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
1588
1589 // The number of elements in each chunk
1590 const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1591 const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
1592
1593 uint32_t current_chunk = ith;
1594
1595 const uint32_t ith0 = current_chunk % nchunk0;
1596 const uint32_t ith1 = current_chunk / nchunk0;
1597
1598 const uint32_t ir0_start = dr0 * ith0;
1599 const uint32_t ir0_end = MIN(ir0_start + dr0, nr0);
1600
1601 const uint32_t ir1_start = dr1 * ith1;
1602 const uint32_t ir1_end = MIN(ir1_start + dr1, nr1);
1603
1604 // no work for this thread
1605 if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
1606 return;
1607 }
1608
1609 // block-tiling attempt
1610 const uint32_t blck_0 = 64;
1611 const uint32_t blck_1 = 64;
1612
1613 for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
1614 for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
1615 for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
1616 const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);
1617 const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);
1618 const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
1619
1620 // broadcast src0 into src1
1621 const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);
1622 const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);
1623
1624 const uint32_t i1 = i11;
1625 const uint32_t i2 = i12;
1626 const uint32_t i3 = i13;
1627
1628 const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
1629 const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
1630 float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
1631
1632 const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
1633 for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
1634 const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
1635 mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
1636 }
1637 }
1638 }
1639 }
1640
1641 t2 = HAP_perf_get_qtimer_count();
1642
1643 FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
1644 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
1645 src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1646 (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1647}
1648
1649// src1 tensor is already in VTCM spad
1650static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
1651 htp_matmul_preamble;
1652
1653 const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
1654 const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
1655
1656 const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1657 const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1658 const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1659
1660 // no work for this thread
1661 if (src0_start_row >= src0_end_row) {
1662 return;
1663 }
1664
1665 const size_t dst_row_size = nb1;
1666 const size_t src0_row_size = nb01;
1667 const size_t src1_row_size = nb11;
1668
1669 const size_t src0_stride = src0_spad->stride;
1670 const size_t src1_stride = src1_spad->stride;
1671
1672 // Per-thread VTCM scratchpads for all tensors
1673 // Note that the entire src1 tensor is already in VTCM
1674 // For other tensors we allocate N rows per thread, padded to HVX vector size
1675 uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1676 uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1677 uint8_t * restrict src1_data = src1_spad->data;
1678
1679 volatile uint64_t t1, t2;
1680 t1 = HAP_perf_get_qtimer_count();
1681
1682 const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
1683
1684 // Prefill spad with src0 rows
1685 #pragma unroll(4)
1686 for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1687 const int is0 = (ir0 - src0_start_row);
1688 if (is0 >= MM_SPAD_SRC0_NROWS) {
1689 break;
1690 }
1691 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1692 src0_stride, src0_row_size, 2);
1693 }
1694
1695 // Process src0 rows
1696 for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1697 const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1698
1699 // Process src1 columns in pairs (2ร2 tiling)
1700 uint32_t ir1 = 0;
1701 for (; ir1 + 1 < src1_nrows; ir1 += 2) {
1702 const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
1703 const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
1704 float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
1705 float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
1706 mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);
1707 }
1708
1709 // Handle remaining src1 rows (fallback to 2ร1)
1710 for (; ir1 < src1_nrows; ++ir1) {
1711 const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
1712 float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
1713 mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
1714 }
1715
1716 // Prefetch next (n + spad_nrows) row
1717 const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1718 const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1719 if (pr0 < src0_end_row_x2) {
1720 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
1721 src0_stride, src0_row_size, 2);
1722 }
1723 }
1724
1725 // Process the last row (if any)
1726 if (src0_end_row != src0_end_row_x2) {
1727 uint32_t ir0 = src0_end_row_x2;
1728 const int is0 = (ir0 - src0_start_row);
1729 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1730 src0_stride, src0_row_size, 1);
1731 const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1732
1733 #pragma unroll(2)
1734 for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
1735 const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
1736 float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
1737 mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
1738 }
1739 }
1740
1741 t2 = HAP_perf_get_qtimer_count();
1742
1743 FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
1744 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1745 src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1746 (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1747}
1748
1749// q8x4x2 src1 tensor is already in VTCM spad
1750static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
1751 htp_matmul_preamble;
1752
1753 const uint32_t src0_nrows = ne01;
1754
1755 const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1756 const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1757 const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1758
1759 // no work for this thread
1760 if (src0_start_row >= src0_end_row) {
1761 return;
1762 }
1763
1764 const size_t dst_row_size = nb1;
1765 const size_t src0_row_size = nb01;
1766 const size_t src1_row_size = nb11;
1767
1768 const size_t src0_stride = src0_spad->stride;
1769 const size_t src1_stride = src1_spad->stride;
1770
1771 // Per-thread VTCM scratchpads for all tensors
1772 // Note that the entire src1 tensor is already in VTCM
1773 // For other tensors we allocate N rows per thread, padded to HVX vector size
1774 uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1775 uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1776 uint8_t * src1_data = src1_spad->data;
1777
1778 uint64_t t1, t2;
1779 t1 = HAP_perf_get_qtimer_count();
1780
1781 float * tmp = (float *) spad_dst;
1782
1783 const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
1784 const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
1785 float * restrict dst_col = (float *) dst->data;
1786
1787 // Prefill spad with 2x src0 rows
1788 #pragma unroll(2)
1789 for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1790 const uint32_t is0 = (ir0 - src0_start_row);
1791 if (is0 >= MM_SPAD_SRC0_NROWS) {
1792 break;
1793 }
1794 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1795 src0_stride, src0_row_size, 2);
1796 }
1797
1798 // Process src0 rows
1799 for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1800 const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1801 mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
1802
1803 // Prefetch next (n + spad_nrows) row
1804 const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1805 const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1806 if (pr0 < src0_end_row_x2) {
1807 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
1808 src0_stride, src0_row_size, 2);
1809 }
1810 }
1811
1812 // Process the last row (if any)
1813 if (src0_end_row != src0_end_row_x2) {
1814 const uint32_t ir0 = src0_end_row_x2;
1815 const uint32_t is0 = (ir0 - src0_start_row);
1816 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1817 src0_stride, src0_row_size, 1);
1818 const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1819 mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
1820 }
1821
1822 hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
1823
1824 t2 = HAP_perf_get_qtimer_count();
1825
1826 FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
1827 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1828 src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1829 (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1830}
1831
1832#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)]
1833
1834struct mmid_row_mapping {
1835 uint32_t i1;
1836 uint32_t i2;
1837};
1838
1839// src1 tensor is already in VTCM spad
1840static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
1841 htp_matmul_preamble;
1842
1843 struct htp_tensor * restrict ids = &octx->src2;
1844 struct htp_spad * restrict src2_spad = &octx->src2_spad;
1845
1846 uint64_t t1, t2;
1847 t1 = HAP_perf_get_qtimer_count();
1848
1849 const uint32_t src0_nrows = ne01; // src0 rows per expert
1850 const uint32_t src1_nrows = ne11;
1851
1852 const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1853 const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1854 const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1855
1856 // no work for this thread
1857 if (src0_start_row >= src0_end_row) {
1858 return;
1859 }
1860
1861 const uint32_t n_ids = ids->ne[0]; // n_expert_used
1862 const uint32_t n_as = ne02; // n_expert
1863
1864 const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
1865 const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
1866
1867 const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
1868 const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size;
1869
1870 const size_t dst_row_size = nb1;
1871 const size_t src0_row_size = nb01;
1872 const size_t src1_row_size = q8x4x2_row_size(ne10);
1873
1874 const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
1875
1876 // Per-thread VTCM scratchpads for all tensors
1877 // Note that the entire src1 tensor is already in VTCM
1878 // For other tensors we allocate N rows per thread, padded to HVX vector size
1879 uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1880 uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1881 uint8_t * restrict src1_data = src1_spad->data;
1882
1883 for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
1884 const int32_t cne1 = matrix_row_counts[cur_a];
1885
1886 if (cne1 == 0) {
1887 continue;
1888 }
1889
1890 const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
1891
1892 // Prefill spad with src0 rows
1893 #pragma unroll(4)
1894 for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1895 const int is0 = (ir0 - src0_start_row);
1896 if (is0 >= MM_SPAD_SRC0_NROWS) {
1897 break;
1898 }
1899 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1900 src0_row_size_padded, src0_row_size, 2);
1901 }
1902
1903 // Process src0 rows
1904 for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1905 const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1906
1907 for (uint32_t cid = 0; cid < cne1; ++cid) {
1908 struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
1909 const int rm1 = row_mapping.i1; // expert idx
1910 const int rm2 = row_mapping.i2; // token idx
1911
1912 const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
1913 const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1914 float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1915
1916 mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
1917 }
1918
1919 // Prefetch next (n + spad_nrows) row
1920 const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1921 const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1922 if (pr0 < src0_end_row_x2) {
1923 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
1924 src0_row_size_padded, src0_row_size, 2);
1925 }
1926 }
1927
1928 // Process the last row (if any)
1929 if (src0_end_row != src0_end_row_x2) {
1930 uint32_t ir0 = src0_end_row_x2;
1931 const uint32_t is0 = (ir0 - src0_start_row);
1932 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1933 src0_row_size_padded, src0_row_size, 1);
1934 const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1935
1936 for (uint32_t cid = 0; cid < cne1; ++cid) {
1937 struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
1938 const int rm1 = row_mapping.i1; // expert idx
1939 const int rm2 = row_mapping.i2; // token idx
1940
1941 const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
1942 const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1943 float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1944
1945 mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
1946 }
1947 }
1948 }
1949
1950 t2 = HAP_perf_get_qtimer_count();
1951
1952 FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
1953 ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
1954 src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
1955 dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1956}
1957
1958// src1 tensor is already in VTCM spad
1959static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
1960 htp_matmul_preamble;
1961
1962 struct htp_tensor * restrict ids = &octx->src2;
1963 struct htp_spad * restrict src2_spad = &octx->src2_spad;
1964
1965 uint64_t t1, t2;
1966 t1 = HAP_perf_get_qtimer_count();
1967
1968 const uint32_t src0_nrows = ne01; // src0 rows per expert
1969
1970 const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1971 const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1972 const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1973
1974 // no work for this thread
1975 if (src0_start_row >= src0_end_row) {
1976 return;
1977 }
1978
1979 assert(ne13 % ne03 == 0);
1980
1981 const size_t dst_row_size = nb1;
1982 const size_t src0_row_size = nb01;
1983 const size_t src1_row_size = q8x4x2_row_size(ne10);
1984
1985 const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
1986
1987 const uint32_t n_aids = src2->ne[0]; // num activated experts
1988 const uint32_t n_ids = ne02; // num experts
1989
1990 // Per-thread VTCM scratchpads for all tensors
1991 // Note that the entire src1 tensor is already in VTCM
1992 // For other tensors we allocate N rows per thread, padded to HVX vector size
1993 uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1994 uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1995 uint8_t * restrict src1_data = src1_spad->data;
1996
1997 for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert
1998 const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]);
1999 assert(eid < n_ids);
2000
2001 const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02;
2002 const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
2003 float * restrict dst_row = (float *) (dst->data + ie1 * nb1);
2004
2005 // Prefill spad with src0 rows
2006 #pragma unroll(4)
2007 for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
2008 const int is0 = (ir0 - src0_start_row);
2009 if (is0 >= MM_SPAD_SRC0_NROWS) {
2010 break;
2011 }
2012 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
2013 src0_row_size_padded, src0_row_size, 2);
2014 }
2015
2016 // Process src0 rows
2017 for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
2018 const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
2019 mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
2020
2021 // Prefetch next (n + spad_nrows) row
2022 const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
2023 const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
2024 if (pr0 < src0_end_row_x2) {
2025 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
2026 src0_row_size_padded, src0_row_size, 2);
2027 }
2028 }
2029
2030 // Process the last row (if any)
2031 if (src0_end_row != src0_end_row_x2) {
2032 uint32_t ir0 = src0_end_row_x2;
2033 const uint32_t is0 = (ir0 - src0_start_row);
2034 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
2035 src0_row_size_padded, src0_row_size, 1);
2036 const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
2037 mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
2038 }
2039 }
2040
2041 t2 = HAP_perf_get_qtimer_count();
2042
2043 FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
2044 ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
2045 src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
2046 dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2047}
2048
2049// *** dynamic quant
2050
2051static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2052 assert((unsigned long) x % 128 == 0);
2053 assert((unsigned long) y_q % 128 == 0);
2054
2055 HVX_Vector * vx = (HVX_Vector *) x;
2056 HVX_Vector zero = Q6_V_vsplat_R(0);
2057
2058 // Use reduce max fp32 to find max(abs(e)) first
2059 HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
2060 HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
2061 HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
2062 HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
2063 // Load and convert into QF32
2064 HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
2065 HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
2066 HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
2067 HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
2068
2069 // Convert to QF32
2070 HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
2071 HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
2072 HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
2073 HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
2074
2075 // Combine and convert to fp16
2076 HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
2077 HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
2078
2079 // Convert into fp16
2080 HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
2081 HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
2082
2083 // Replicate first fp16 scale across all lanes
2084 HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
2085 vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
2086 vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
2087
2088 HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
2089 HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
2090 HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
2091 HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
2092
2093 hvx_vec_store_u(y_d + 0, 2, vd01_hf);
2094 HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);
2095 hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);
2096
2097 hvx_vec_store_u(y_d + 4, 2, vd23_hf);
2098 rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);
2099 hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
2100
2101 // Divide input by the scale
2102 HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
2103 HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
2104 vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
2105 vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
2106
2107 // Convert to int8
2108 HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
2109 HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
2110 HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
2111
2112 *(HVX_Vector *) y_q = vx_i8;
2113}
2114
2115static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2116 assert((unsigned long) x % 128 == 0);
2117 assert((unsigned long) y_q % 128 == 0);
2118
2119 HVX_Vector * vx = (HVX_Vector *) x;
2120
2121 // Load and convert into QF32
2122 HVX_Vector zero = Q6_V_vsplat_R(0);
2123 HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
2124 HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
2125 HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
2126 HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
2127
2128 // Convert into fp16
2129 HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
2130 HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
2131
2132 // Compute max and scale
2133 HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
2134 HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
2135
2136 // Replicate first fp16 scale across all lanes
2137 HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
2138 vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
2139 vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
2140
2141 HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
2142 HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
2143 HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
2144 HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
2145
2146 hvx_vec_store_u(y_d + 0, 4, vd01_hf);
2147 hvx_vec_store_u(y_d + 4, 4, vd23_hf);
2148
2149 // Divide input by the scale
2150 HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
2151 HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
2152 vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
2153 vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
2154
2155 // Convert to int8
2156 HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
2157 HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
2158 HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
2159
2160 *(HVX_Vector *) y_q = vx_i8;
2161}
2162
2163static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2164 assert((unsigned long) x % 128 == 0);
2165 assert((unsigned long) y_q % 128 == 0);
2166
2167 HVX_Vector * vx = (HVX_Vector *) x;
2168
2169 // Load and convert into QF32
2170 HVX_Vector zero = Q6_V_vsplat_R(0);
2171 HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
2172 HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
2173 HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
2174 HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
2175
2176 // Convert into fp16
2177 HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
2178 HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
2179
2180 // Compute max and scale
2181 HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
2182 vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
2183
2184 // Replicate first fp16 scale across all lanes
2185 HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
2186 vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
2187
2188 HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
2189 HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
2190
2191 *(HVX_UVector *) y_d = vd_hf;
2192
2193 // Divide input by the scale
2194 HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);
2195 vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
2196 vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
2197
2198 // Convert to int8
2199 HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
2200 HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
2201 HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
2202
2203 *(HVX_Vector *) y_q = vx_i8;
2204}
2205
2206// Overrides input x
2207static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
2208 assert(k % 32 == 0);
2209 const uint32_t qk = QK_Q8_0x4x2;
2210 const uint32_t nb = (k + qk - 1) / qk;
2211
2212 const uint32_t qrow_size = k; // int8
2213
2214 const uint32_t dblk_size = 8 * 2; // 8x __fp16
2215 const uint32_t qblk_size = QK_Q8_0x4x2; // int8
2216
2217 uint8_t * restrict y_q = (y + 0); // quants first
2218 uint8_t * restrict y_d = (y + qrow_size); // then scales
2219
2220 // Temp scales override input since we're working off of the aligned temp buffer in VTCM
2221 uint8_t * restrict t_d = (uint8_t *) x;
2222
2223 for (uint32_t i = 0; i < nb; i++) {
2224#if FP32_QUANTIZE_GROUP_SIZE == 32
2225 quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2226 quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2227#elif FP32_QUANTIZE_GROUP_SIZE == 64
2228 quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2229 quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2230#elif FP32_QUANTIZE_GROUP_SIZE == 128
2231 quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2232 quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2233#else
2234#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
2235#endif
2236 }
2237
2238 // now copy the scales into final location
2239 hvx_copy_f16_ua(y_d, t_d, nb * 8);
2240}
2241
2242static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
2243 struct htp_matmul_context * mmctx = data;
2244 struct htp_ops_context * octx = mmctx->octx;
2245
2246 const struct htp_tensor * src = &octx->src1;
2247 uint8_t * restrict dst = octx->src1_spad.data;
2248 struct htp_spad * spad = &octx->src0_spad;
2249 uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2250
2251 uint64_t t1 = HAP_perf_get_qtimer_count();
2252
2253 const uint32_t ne0 = src->ne[0];
2254 const uint32_t ne1 = src->ne[1];
2255 const uint32_t ne2 = src->ne[2];
2256 const uint32_t ne3 = src->ne[3];
2257
2258 const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
2259
2260 const uint32_t ir_first = nrows_per_thread * ith; // first row
2261 const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
2262
2263 const size_t src_row_size = src->nb[1];
2264 const size_t dst_row_size = q8x4x2_row_size(ne0);
2265
2266 uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);
2267 uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
2268 uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
2269
2270 const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
2271 memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
2272
2273 for (uint32_t i = ir_first; i < ir_last; ++i) {
2274 hex_l2fetch(src_data, src_row_size, src_row_size, 2);
2275 hvx_copy_f32_aa(tmp_data, src_data, ne0);
2276
2277 // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
2278 quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0);
2279 dst_data += dst_row_size;
2280 src_data += src_row_size;
2281 }
2282
2283 uint64_t t2 = HAP_perf_get_qtimer_count();
2284
2285 FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
2286 ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2287}
2288
2289static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
2290 struct htp_matmul_context * mmctx = data;
2291 struct htp_ops_context * octx = mmctx->octx;
2292
2293 const struct htp_tensor * src = &octx->src1;
2294 uint8_t * restrict dst = octx->src1_spad.data;
2295 uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2296 uint32_t dst_stride = octx->src1_spad.stride;
2297
2298 uint64_t t1 = HAP_perf_get_qtimer_count();
2299
2300 const uint32_t ne0 = src->ne[0];
2301 const uint32_t ne1 = src->ne[1];
2302 const uint32_t ne2 = src->ne[2];
2303 const uint32_t ne3 = src->ne[3];
2304
2305 const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
2306
2307 const uint32_t ir_first = nrows_per_thread * ith; // first row
2308 const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
2309
2310 const size_t src_row_size = ne0 * sizeof(float);
2311 const size_t src_stride = src->nb[1];
2312
2313 uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
2314 uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
2315
2316 for (uint32_t i = ir_first; i < ir_last; ++i) {
2317 hex_l2fetch(src_data, src_row_size, src_stride, 2);
2318 hvx_copy_f16_f32_au(dst_data, src_data, ne0);
2319
2320 dst_data += dst_stride;
2321 src_data += src_stride;
2322 }
2323
2324 uint64_t t2 = HAP_perf_get_qtimer_count();
2325
2326 FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
2327 ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2328}
2329
2330// TODO just a plain copy that should be done via the DMA during the Op setup
2331static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
2332 struct htp_matmul_context * mmctx = data;
2333 struct htp_ops_context * octx = mmctx->octx;
2334
2335 const struct htp_tensor * src = &octx->src1;
2336 uint8_t * restrict dst = octx->src1_spad.data;
2337 uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2338 uint32_t dst_stride = octx->src1_spad.stride;
2339
2340 uint64_t t1 = HAP_perf_get_qtimer_count();
2341
2342 const uint32_t ne0 = src->ne[0];
2343 const uint32_t ne1 = src->ne[1];
2344 const uint32_t ne2 = src->ne[2];
2345 const uint32_t ne3 = src->ne[3];
2346
2347 const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
2348
2349 const uint32_t ir_first = nrows_per_thread * ith; // first row
2350 const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
2351
2352 const size_t src_row_size = ne0 * sizeof(float);
2353 const size_t src_stride = src->nb[1];
2354
2355 uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
2356 uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
2357
2358 for (uint32_t i = ir_first; i < ir_last; ++i) {
2359 hex_l2fetch(src_data, src_row_size, src_stride, 2);
2360 hvx_copy_f16_au(dst_data, src_data, ne0);
2361
2362 dst_data += dst_stride;
2363 src_data += src_stride;
2364 }
2365
2366 uint64_t t2 = HAP_perf_get_qtimer_count();
2367
2368 FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
2369 ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2370}
2371
2372
2373static inline bool htp_is_permuted(const struct htp_tensor * t) {
2374 return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
2375}
2376
2377static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {
2378 switch (type) {
2379 case HTP_TYPE_Q4_0:
2380 mmctx->type = "q4x4x2-f32";
2381 mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
2382 mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
2383 mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
2384 return 0;
2385 case HTP_TYPE_Q8_0:
2386 mmctx->type = "q8x4x2-f32";
2387 mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
2388 mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
2389 mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
2390 return 0;
2391 case HTP_TYPE_MXFP4:
2392 mmctx->type = "mxfp4x4x2-f32";
2393 mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
2394 mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
2395 mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
2396 return 0;
2397 default:
2398 return -1;
2399 }
2400}
2401
2402static void htp_mminit_spad(struct htp_ops_context * octx,
2403 size_t dst_row_size,
2404 size_t src0_row_size_padded,
2405 size_t src1_row_size,
2406 uint32_t src1_nrows,
2407 size_t src2_spad_size_per_thread) {
2408 octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2409 octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2410 octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
2411
2412 if (src2_spad_size_per_thread > 0) {
2413 octx->src2_spad.size_per_thread = src2_spad_size_per_thread;
2414 octx->src2_spad.size = octx->src2_spad.size_per_thread;
2415 }
2416
2417 // src0 spad is also used in dynamic quantizer to store padded src1 rows
2418 size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2419 if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2420 octx->src0_spad.size_per_thread = src1_row_size_padded;
2421 }
2422
2423 octx->src1_spad.size = octx->src1_spad.size_per_thread;
2424 octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2425 octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2426}
2427
2428int op_matmul(struct htp_ops_context * octx) {
2429 htp_matmul_tensors_preamble;
2430
2431 struct htp_matmul_context mmctx_struct = {0};
2432 struct htp_matmul_context * mmctx = &mmctx_struct;
2433 mmctx->octx = octx;
2434
2435 const uint32_t src0_nrows = ne01 * ne02 * ne03;
2436 const uint32_t src1_nrows = ne11 * ne12 * ne13;
2437
2438 // Compute src0_nrows_per_thread
2439 mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2440 mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
2441
2442 const size_t src0_row_size = nb01;
2443 const size_t dst_row_size = nb1;
2444 size_t src1_row_size = nb11;
2445
2446 const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
2447 size_t src1_row_size_padded;
2448
2449 worker_callback_t quant_job_func;
2450 worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
2451
2452 bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
2453
2454 if (src0->type == HTP_TYPE_F16) {
2455 // Try optimized f16-f16 path first (src1 in VTCM)
2456 const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128);
2457 const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
2458 const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
2459 const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
2460
2461 const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
2462
2463 // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
2464 // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
2465 const bool is_batched = (ne02 > 1) || (ne03 > 1);
2466 const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
2467
2468 if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
2469 // Optimized path
2470 quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;
2471 mmctx->type = "f16-f16";
2472 mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;
2473 mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;
2474 mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;
2475
2476 src1_row_size = f16_src1_row_size; // row size post quantization
2477
2478 octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2479 octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2480 octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
2481
2482 octx->src1_spad.size = octx->src1_spad.size_per_thread;
2483 octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2484 octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2485 } else {
2486 // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
2487 quant_job_func = NULL;
2488 if (src1->type == HTP_TYPE_F32) {
2489 mmctx->type = "f16-f32";
2490 mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;
2491 matmul_job_func = matmul_4d;
2492 } else {
2493 mmctx->type = "f16-f16";
2494 mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;
2495 matmul_job_func = matmul_4d;
2496 }
2497
2498 src1_row_size = nb11; // original row size in DDR
2499
2500 octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2501 octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
2502 octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
2503
2504 octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2505 octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
2506 octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2507
2508 // Init fastdiv for matmul_4d (supports broadcasting)
2509 mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
2510 mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
2511 mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
2512 mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
2513
2514 need_quant = false;
2515 }
2516 } else {
2517 if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
2518 return HTP_STATUS_NO_SUPPORT;
2519 }
2520
2521 quant_job_func = quantize_f32_q8x4x2;
2522 src1_row_size = q8x4x2_row_size(ne10);
2523 htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
2524 }
2525
2526 // VTCM scratchpads for all tensors
2527 size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2528
2529 FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
2530 octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
2531
2532 FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0],
2533 src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
2534 dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
2535
2536 // Make sure the reserved vtcm size is sufficient
2537 if (octx->ctx->vtcm_size < spad_size) {
2538 FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type,
2539 octx->ctx->vtcm_size, spad_size);
2540 return HTP_STATUS_VTCM_TOO_SMALL;
2541 }
2542
2543 octx->src0_spad.data = octx->ctx->vtcm_base;
2544 octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2545 octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2546
2547 octx->src0_spad.stride = src0_row_size_padded;
2548 octx->src1_spad.stride = src1_row_size;
2549
2550 if (need_quant) {
2551 const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2552 mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2553 worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
2554 }
2555
2556 if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2557 const uint32_t n_matmul_jobs = octx->n_threads;
2558 worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
2559 }
2560
2561 return HTP_STATUS_OK;
2562}
2563
2564int op_matmul_id(struct htp_ops_context * octx) {
2565 htp_matmul_tensors_preamble;
2566
2567 struct htp_matmul_context mmctx_struct = {0};
2568 struct htp_matmul_context * mmctx = &mmctx_struct;
2569 mmctx->octx = octx;
2570
2571 struct htp_tensor * restrict ids = &octx->src2;
2572
2573 const size_t src0_row_size = nb01;
2574 const size_t dst_row_size = nb1;
2575
2576 const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
2577
2578 const uint32_t src0_nrows = ne01; // per expert
2579 const uint32_t src1_nrows = ne11 * ne12 * ne13;
2580
2581 worker_callback_t quant_job_func;
2582 worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;
2583
2584 // Compute src0_nrows_per_thread
2585 mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2586 mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
2587
2588 size_t src1_row_size;
2589 size_t src1_row_size_padded;
2590
2591 // row groups
2592 const int n_ids = ids->ne[0]; // n_expert_used
2593 const int n_as = ne02; // n_expert
2594
2595 size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
2596 size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
2597
2598 if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
2599 return HTP_STATUS_NO_SUPPORT;
2600 }
2601
2602 quant_job_func = quantize_f32_q8x4x2;
2603 src1_row_size = q8x4x2_row_size(ne10);
2604
2605 const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2606 htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
2607
2608 size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2609
2610 FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
2611 octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
2612
2613 FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type,
2614 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
2615 ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
2616 src1->data, dst->data);
2617
2618 // Make sure the reserved vtcm size is sufficient
2619 if (octx->ctx->vtcm_size < spad_size) {
2620 FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
2621 return HTP_STATUS_VTCM_TOO_SMALL;
2622 }
2623
2624 octx->src0_spad.data = octx->ctx->vtcm_base;
2625 octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2626 octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2627 octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
2628
2629 octx->src0_spad.stride = src0_row_size_padded;
2630 octx->src1_spad.stride = src1_row_size;
2631
2632 if (src1_nrows > 1) {
2633 // initialize matrix_row_counts and map
2634 uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
2635 struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size;
2636
2637 memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
2638
2639 // group rows by src0 matrix
2640 for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx
2641 for (uint32_t id = 0; id < n_ids; ++id) { // expert idx
2642 const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
2643
2644 assert(i02 >= 0 && i02 < n_as);
2645
2646 MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
2647 matrix_row_counts[i02] += 1;
2648 }
2649 }
2650 }
2651
2652 // Setup worker pool callbacks
2653 if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
2654 const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2655 mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2656 worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
2657 }
2658
2659 if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2660 const uint32_t n_matmul_jobs = octx->n_threads;
2661 worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
2662 }
2663
2664 return HTP_STATUS_OK;
2665}