1#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
   2#pragma clang diagnostic ignored "-Wunused-function"
   3#pragma clang diagnostic ignored "-Wunused-variable"
   4#pragma clang diagnostic ignored "-Wunused-but-set-variable"
   5
   6#include <HAP_farf.h>
   7#include <HAP_perf.h>
   8
   9#include <math.h>
  10#include <string.h>
  11
  12#include "hex-dma.h"
  13#include "hvx-utils.h"
  14#include "hvx-dump.h"
  15
  16#define GGML_COMMON_DECL_C
  17#include "ggml-common.h"
  18#include "htp-ctx.h"
  19#include "htp-msg.h"
  20#include "htp-ops.h"
  21
  22#define MM_SPAD_SRC0_NROWS 16
  23#define MM_SPAD_SRC1_NROWS 16
  24#define MM_SPAD_DST_NROWS  2
  25
  26struct htp_matmul_context {
  27    const char * type;
  28    struct htp_ops_context * octx;
  29
  30    void (*vec_dot_1x1)(const int n, float * restrict s0,
  31         const void * restrict vx0,
  32         const void * restrict vy0);
  33
  34    void (*vec_dot_2x1)(const int n, float * restrict s0,
  35         const void * restrict vx0, const void * restrict vx1,
  36         const void * restrict vy0);
  37
  38    void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,
  39         const void * restrict vx0, const void * restrict vx1,
  40         const void * restrict vy0, const void * restrict vy1);
  41
  42    // Precomputed values
  43    uint32_t src0_nrows_per_thread;
  44    uint32_t src1_nrows_per_thread;
  45
  46    struct fastdiv_values mm_div_ne12_ne1;
  47    struct fastdiv_values mm_div_ne1;
  48    struct fastdiv_values mm_div_r2;
  49    struct fastdiv_values mm_div_r3;
  50};
  51
  52// vdelta control to replicate first 4x fp32 values across lanes
  53static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
  54    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
  55    0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
  56    0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
  57    0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
  58    0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
  59    0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
  60    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
  61};
  62
  63// vdelta control to replicate and interleave first 8x fp32 values across lanes
  64static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
  65    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
  66    0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
  67    0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
  68    0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
  69    0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
  70    0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
  71    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
  72};
  73
  74// vdelta control to replicate first fp32 value across all elements
  75static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
  76    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
  77    0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
  78    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
  79    0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
  80    0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
  81    0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
  82    0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
  83};
  84
  85// vdelta control to replicate first fp16 value across all elements
  86static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
  87    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
  88    0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
  89    0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
  90    0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
  91    0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
  92    0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
  93    0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
  94};
  95
  96// vdelta control to replicate first fp16 value across all elements
  97static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
  98    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
  99    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
 100    0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
 101    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
 102    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
 103    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
 104    0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
 105    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
 106};
 107
 108// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
 109static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
 110    0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
 111    0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04,
 112    0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02,
 113    0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08,
 114    0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48,
 115    0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00,
 116    0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
 117};
 118
 119static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
 120    0,    0, 1,    0, 2,    0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
 121    0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,
 122    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,
 123    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,
 124    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0,
 125};
 126
 127// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
 128
 129static inline size_t q8x4x2_row_size(uint32_t ne) {
 130    // ensures perfect alignment of quants and full row
 131    const uint32_t qk = QK_Q8_0x4x2;
 132    const uint32_t nb = (ne + qk - 1) / qk;
 133    return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
 134}
 135
 136static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
 137    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
 138
 139    HVX_Vector v0_1 = vptr[0];  // first 256 elements (128 bytes)
 140    HVX_Vector v2_3 = vptr[1];  // ...
 141    HVX_Vector v4_5 = vptr[2];  // ...
 142    HVX_Vector v6_7 = vptr[3];  // ...
 143
 144    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
 145    const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
 146
 147    HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F
 148    HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4
 149    HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4);  // & 0x0F
 150    HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4);    // >> 4
 151    HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4);  // & 0x0F
 152    HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4);    // >> 4
 153    HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4);  // & 0x0F
 154    HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4
 155
 156    // Convert uint4 to int4 (i.e. x - 8)
 157    v0 = Q6_Vb_vsub_VbVb(v0, i8);
 158    v1 = Q6_Vb_vsub_VbVb(v1, i8);
 159    v2 = Q6_Vb_vsub_VbVb(v2, i8);
 160    v3 = Q6_Vb_vsub_VbVb(v3, i8);
 161    v4 = Q6_Vb_vsub_VbVb(v4, i8);
 162    v5 = Q6_Vb_vsub_VbVb(v5, i8);
 163    v6 = Q6_Vb_vsub_VbVb(v6, i8);
 164    v7 = Q6_Vb_vsub_VbVb(v7, i8);
 165
 166    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
 167    return r;
 168}
 169
 170static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
 171    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
 172
 173    HVX_Vector v0_1 = vptr[0];  // first 256 elements (128 bytes)
 174    HVX_Vector v2_3 = vptr[1];  // ...
 175    HVX_Vector v4_5 = vptr[2];  // ...
 176    HVX_Vector v6_7 = vptr[3];  // ...
 177
 178    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
 179    const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
 180
 181    HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F
 182    HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4
 183    HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4);  // & 0x0F
 184    HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4);    // >> 4
 185    HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4);  // & 0x0F
 186    HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4);    // >> 4
 187    HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4);  // & 0x0F
 188    HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4
 189
 190    v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
 191    v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
 192    v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
 193    v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
 194    v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
 195    v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
 196    v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
 197    v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
 198
 199    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
 200    return r;
 201}
 202
 203static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
 204    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
 205
 206    HVX_Vector v0 = vptr[0];  // first  128 vals
 207    HVX_Vector v1 = vptr[1];  // ...
 208    HVX_Vector v2 = vptr[2];  // ...
 209    HVX_Vector v3 = vptr[3];  // ...
 210    HVX_Vector v4 = vptr[4];  // ...
 211    HVX_Vector v5 = vptr[5];  // ...
 212    HVX_Vector v6 = vptr[6];  // ...
 213    HVX_Vector v7 = vptr[7];  // ...
 214
 215    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
 216    return r;
 217}
 218
 219// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
 220// Accumulate each block into a single int32 value.
 221// Return a single HVX vector with 32x int32 accumulators.
 222// This version is parameterized to support less than 1024 elements.
 223// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
 224
 225static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
 226    HVX_Vector r0 = Q6_V_vsplat_R(0);
 227    HVX_Vector r1 = Q6_V_vsplat_R(0);
 228    HVX_Vector r2 = Q6_V_vsplat_R(0);
 229    HVX_Vector r3 = Q6_V_vsplat_R(0);
 230    HVX_Vector r4 = Q6_V_vsplat_R(0);
 231    HVX_Vector r5 = Q6_V_vsplat_R(0);
 232    HVX_Vector r6 = Q6_V_vsplat_R(0);
 233    HVX_Vector r7 = Q6_V_vsplat_R(0);
 234
 235    HVX_VectorPair p3;
 236    HVX_VectorPair p2;
 237    HVX_VectorPair p1;
 238    HVX_VectorPair p0;
 239
 240    if (n >=  128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); }
 241    if (n >=  256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); }
 242    if (n >=  384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); }
 243    if (n >=  512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); }
 244    if (n >=  640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); }
 245    if (n >=  768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); }
 246    if (n >=  896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); }
 247    if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); }
 248
 249    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
 250    if (n >=  384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
 251    if (n >=  640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); }
 252    if (n >=  896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); }
 253
 254    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
 255    if (n >=  384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
 256    if (n >=  640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); }
 257    if (n >=  896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); }
 258
 259    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
 260    if (n >=  640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
 261
 262    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
 263    if (n >=  640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
 264
 265    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
 266    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
 267
 268    return r0;
 269}
 270
 271static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
 272    return hvx_vec_rmpy_x8_n(x, y, 1024);
 273}
 274
 275// Handle most common cases of tensors not multiple of 1024.
 276static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
 277    if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
 278    if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
 279    if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
 280    return hvx_vec_rmpy_x8_n(x, y, 1024);
 281}
 282
 283static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
 284    assert(n % 32 == 0);  // min sub-block size
 285    assert((unsigned long) vx0 % 128 == 0);
 286    assert((unsigned long) vy0 % 128 == 0);
 287
 288    const uint32_t qk = QK_Q4_0x4x2 * 4;
 289
 290    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 291    const uint32_t x_qblk_size = qk / 2;                                      // int4
 292    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
 293
 294    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 295    const uint32_t y_qblk_size = qk;                                          // int8
 296    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 297
 298    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);            // quants first
 299    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size);  // then scales
 300
 301    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
 302    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 303
 304    // Row sum (sf)
 305    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
 306
 307    // Multiply and accumulate into int32.
 308    // Compute combined scale (fp32).
 309    // Apply scale to acc and accumulate into the row sum (qf32).
 310
 311    const uint32_t nb   = n / qk;  // num full blocks
 312    const uint32_t nloe = n % qk;  // num leftover elemements
 313
 314    uint32_t i = 0;
 315    for (; i < nb; i++) {
 316        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 317        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
 318
 319        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 320
 321        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
 322        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 323
 324        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 325
 326        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 327
 328        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 329    }
 330
 331    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
 332    if (nloe) {
 333        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 334        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
 335
 336        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
 337
 338        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
 339        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 340
 341        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 342
 343        // Zero out unused scales
 344        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
 345        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
 346        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 347
 348        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 349
 350        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 351    }
 352
 353    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 354
 355    hvx_vec_store_u(s0, 4, r0_sum);
 356}
 357
 358static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
 359                                      const void * restrict vx0, const void * restrict vx1,
 360                                      const void * restrict vy0) {
 361    assert(n % 32 == 0);  // min sub-block size
 362    assert((unsigned long) vx0 % 128 == 0);
 363    assert((unsigned long) vx1 % 128 == 0);
 364    assert((unsigned long) vy0 % 128 == 0);
 365
 366    const uint32_t qk = QK_Q4_0x4x2 * 4;
 367
 368    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 369    const uint32_t x_qblk_size = qk / 2;                                      // int4
 370    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
 371
 372    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 373    const uint32_t y_qblk_size = qk;                                          // int8
 374    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 375
 376    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
 377    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
 378    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
 379    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 380
 381    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
 382    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 383
 384    // Row sum (sf)
 385    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
 386    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
 387
 388    // Multiply and accumulate into int32.
 389    // Compute combined scale (fp32).
 390    // Apply scale to acc and accumulate into the row sum (qf32).
 391
 392    const uint32_t nb   = n / qk;  // num full blocks
 393    const uint32_t nloe = n % qk;  // num leftover elemements
 394
 395    uint32_t i = 0;
 396    for (; i < nb; i++) {
 397        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 398        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
 399        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
 400
 401        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 402        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
 403
 404        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
 405        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 406        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 407
 408        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 409        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
 410
 411        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 412        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 413
 414        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 415        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
 416    }
 417
 418    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
 419    if (nloe) {
 420        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 421        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
 422        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
 423
 424        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
 425        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
 426
 427        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
 428        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 429        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 430
 431        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 432        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
 433
 434        // Zero out unused scales
 435        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
 436        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
 437        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
 438        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 439        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
 440
 441        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 442        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 443
 444        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 445        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
 446    }
 447
 448    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
 449    hvx_vec_store_u(s0, 8, rsum);
 450}
 451
 452static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
 453                                        const void * restrict vx0, const void * restrict vx1,
 454                                        const void * restrict vy0, const void * restrict vy1) {
 455    assert(n % 32 == 0);
 456    assert((unsigned long) vx0 % 128 == 0);
 457    assert((unsigned long) vx1 % 128 == 0);
 458    assert((unsigned long) vy0 % 128 == 0);
 459    assert((unsigned long) vy1 % 128 == 0);
 460
 461    const uint32_t qk = QK_Q4_0x4x2 * 4;
 462
 463    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 464    const uint32_t x_qblk_size = qk / 2;                                      // int4
 465    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
 466
 467    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 468    const uint32_t y_qblk_size = qk;                                          // int8
 469    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 470
 471    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
 472    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
 473    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
 474    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 475
 476    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
 477    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
 478    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
 479    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
 480
 481    // Row sums (sf) - 4 accumulators for 2ร—2 tile
 482    HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
 483    HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
 484    HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
 485    HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
 486
 487    const uint32_t nb   = n / qk;  // num full blocks
 488    const uint32_t nloe = n % qk;  // num leftover elements
 489
 490    uint32_t i = 0;
 491    for (; i < nb; i++) {
 492        // Load src1 columns (reused across both src0 rows)
 493        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
 494        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
 495
 496        // Load src0 rows (reused across both src1 columns)
 497        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
 498        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
 499
 500        // Compute 4 dot products: r0ร—c0, r0ร—c1, r1ร—c0, r1ร—c1
 501        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
 502        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
 503        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
 504        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
 505
 506        // Load scales
 507        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
 508        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
 509        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 510        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 511
 512        // Compute combined scales
 513        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
 514        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
 515        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
 516        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
 517
 518        // Apply scales and accumulate
 519        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
 520        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
 521        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
 522        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
 523
 524        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
 525        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
 526        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
 527        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
 528    }
 529
 530    // Process leftovers
 531    if (nloe) {
 532        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
 533        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
 534        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
 535        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
 536
 537        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
 538        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
 539        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
 540        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
 541
 542        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
 543        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
 544        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 545        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 546
 547        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
 548        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
 549        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
 550        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
 551
 552        // Zero out unused scales
 553        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
 554        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
 555        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
 556        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
 557        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
 558        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
 559        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
 560        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
 561        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
 562
 563        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
 564        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
 565        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
 566        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
 567
 568        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
 569        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
 570        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
 571        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
 572    }
 573
 574    // Reduce and store results
 575    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
 576    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
 577
 578    hvx_vec_store_u(s0, 8, r0_r1_c0_sum);  // row0,col0 row1,col0
 579    hvx_vec_store_u(s1, 8, r0_r1_c1_sum);  // row0,col1 row1,col1
 580}
 581
 582static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
 583    assert(n % 32 == 0);  // min sub-block size
 584    assert((unsigned long) vx0 % 128 == 0);
 585    assert((unsigned long) vy0 % 128 == 0);
 586
 587    const uint32_t qk = QK_Q4_0x4x2 * 4;
 588
 589    const uint32_t x_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
 590    const uint32_t x_qblk_size = qk;                                         // int8
 591    const uint32_t x_qrow_size = n;                                          // int8 (not padded)
 592
 593    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
 594    const uint32_t y_qblk_size = qk;                                         // int8
 595    const uint32_t y_qrow_size = n;                                          // int8 (not padded)
 596
 597    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);           // quants first
 598    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
 599
 600    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);              // quants first
 601    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);    // then scales
 602
 603    // Row sum (sf)
 604    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
 605
 606    // Multiply and accumulate into int32.
 607    // Compute combined scale (fp32).
 608    // Apply scale to acc and accumulate into the row sum (qf32).
 609
 610    const uint32_t nb   = n / qk;  // num full blocks
 611    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
 612
 613    uint32_t i = 0;
 614    for (; i < nb; i++) {
 615        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 616        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
 617
 618        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 619
 620        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
 621        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 622
 623        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 624
 625        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 626
 627        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 628    }
 629
 630    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
 631    if (nloe) {
 632        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 633        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
 634
 635        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
 636
 637        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
 638        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 639
 640        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 641
 642        // Zero out unused scales
 643        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
 644        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
 645        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 646
 647        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 648
 649        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 650    }
 651
 652    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 653
 654    hvx_vec_store_u(s0, 4, r0_sum);
 655}
 656
 657static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
 658                                      const void * restrict vx0, const void * restrict vx1,
 659                                      const void * restrict vy0) {
 660    assert(n % 32 == 0);  // min sub-block size
 661    assert((unsigned long) vx0 % 128 == 0);
 662    assert((unsigned long) vx1 % 128 == 0);
 663    assert((unsigned long) vy0 % 128 == 0);
 664
 665    const uint32_t qk = QK_Q4_0x4x2 * 4;
 666
 667    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 668    const uint32_t x_qblk_size = qk;                                          // int8
 669    const uint32_t x_qrow_size = n;                                           // int8 (not padded)
 670
 671    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 672    const uint32_t y_qblk_size = qk;                                          // int8
 673    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 674
 675    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
 676    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
 677    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
 678    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 679
 680    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
 681    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 682
 683    // Row sum (qf32)
 684    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
 685    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
 686
 687    // Multiply and accumulate into int32.
 688    // Compute combined scale (fp32).
 689    // Apply scale to acc and accumulate into the row sum (qf32).
 690
 691    const uint32_t nb   = n / qk;  // num full blocks
 692    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
 693
 694    uint32_t i = 0;
 695    for (; i < nb; i++) {
 696        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 697        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
 698        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
 699
 700        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 701        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
 702
 703        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
 704        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 705        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 706
 707        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 708        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
 709
 710        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 711        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 712
 713        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 714        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
 715    }
 716
 717    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
 718    if (nloe) {
 719        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 720        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
 721        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
 722
 723        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
 724        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
 725
 726        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
 727        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 728        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 729
 730        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 731        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
 732
 733        // Zero out unused scales
 734        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
 735        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
 736        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
 737        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 738        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
 739
 740        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 741        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 742
 743        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 744        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
 745    }
 746
 747    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
 748    hvx_vec_store_u(s0, 8, rsum);
 749}
 750
 751static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
 752                                        const void * restrict vx0, const void * restrict vx1,
 753                                        const void * restrict vy0, const void * restrict vy1) {
 754    assert(n % 32 == 0);
 755    assert((unsigned long) vx0 % 128 == 0);
 756    assert((unsigned long) vx1 % 128 == 0);
 757    assert((unsigned long) vy0 % 128 == 0);
 758    assert((unsigned long) vy1 % 128 == 0);
 759
 760    const uint32_t qk = QK_Q8_0x4x2 * 4;
 761
 762    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 763    const uint32_t x_qblk_size = qk;                                          // int8
 764    const uint32_t x_qrow_size = n;                                           // int8 (not padded)
 765
 766    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
 767    const uint32_t y_qblk_size = qk;                                          // int8
 768    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 769
 770    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
 771    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
 772    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
 773    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 774
 775    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
 776    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
 777    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
 778    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
 779
 780    // Row sums (sf) - 4 accumulators for 2ร—2 tile
 781    HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
 782    HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
 783    HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
 784    HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
 785
 786    const uint32_t nb   = n / qk;  // num full blocks
 787    const uint32_t nloe = n % qk;  // num leftover elements
 788
 789    uint32_t i = 0;
 790    for (; i < nb; i++) {
 791        // Load src1 columns (reused across both src0 rows)
 792        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
 793        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
 794
 795        // Load src0 rows (reused across both src1 columns)
 796        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
 797        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
 798
 799        // Compute 4 dot products: r0ร—c0, r0ร—c1, r1ร—c0, r1ร—c1
 800        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
 801        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
 802        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
 803        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
 804
 805        // Load scales
 806        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
 807        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
 808        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 809        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 810
 811        // Compute combined scales
 812        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
 813        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
 814        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
 815        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
 816
 817        // Apply scales and accumulate
 818        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
 819        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
 820        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
 821        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
 822
 823        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
 824        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
 825        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
 826        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
 827    }
 828
 829    // Process leftovers
 830    if (nloe) {
 831        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
 832        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
 833        HVX_Vector_x8 r0_q  = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
 834        HVX_Vector_x8 r1_q  = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
 835
 836        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
 837        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
 838        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
 839        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
 840
 841        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
 842        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
 843        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 844        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 845
 846        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
 847        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
 848        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
 849        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
 850
 851        // Zero out unused scales
 852        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
 853        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
 854        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
 855        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
 856        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
 857        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
 858        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
 859        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
 860        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
 861
 862        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
 863        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
 864        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
 865        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
 866
 867        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
 868        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
 869        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
 870        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
 871    }
 872
 873    // Reduce and store results
 874    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
 875    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
 876
 877    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
 878    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
 879}
 880
 881static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
 882    assert(n % 32 == 0);  // min sub-block size
 883    assert((unsigned long) vx0 % 128 == 0);
 884    assert((unsigned long) vy0 % 128 == 0);
 885
 886    const uint32_t qk = QK_MXFP4x4x2 * 4;
 887
 888    const uint32_t x_dblk_size = 8 * 4 * 1;                                  // 32x e8m0
 889    const uint32_t x_qblk_size = qk / 2;                                     // fp4
 890    const uint32_t x_qrow_size = n / 2;                                      // fp4 (not padded)
 891
 892    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
 893    const uint32_t y_qblk_size = qk;                                         // int8
 894    const uint32_t y_qrow_size = n;                                          // int8 (not padded)
 895
 896    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);           // quants first
 897    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
 898
 899    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);              // quants first
 900    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);    // then scales
 901
 902    // Row sum (sf)
 903    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
 904
 905    // Multiply and accumulate into int32.
 906    // Compute combined scale (fp32).
 907    // Apply scale to acc and accumulate into the row sum (qf32).
 908
 909    const uint32_t nb   = n / qk;  // num full blocks
 910    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
 911
 912    uint32_t i = 0;
 913    for (; i < nb; i++) {
 914        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 915        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
 916
 917        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 918
 919        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
 920        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
 921
 922        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
 923        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
 924        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
 925        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);
 926
 927        // Convert rX_d scales from e8m0 to fp32
 928        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
 929        // Left shift with zero fill to create FP32
 930        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
 931        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
 932        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
 933        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
 934        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
 935        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
 936
 937        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
 938
 939        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 940
 941        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 942    }
 943
 944    // Process leftovers
 945    if (nloe) {
 946        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
 947        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
 948
 949        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 950
 951        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
 952        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
 953
 954        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
 955        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
 956        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
 957        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);
 958
 959        // Convert rX_d scales from e8m0 to fp32
 960        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
 961        // Left shift with zero fill to create FP32
 962        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
 963        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
 964        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
 965        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
 966        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
 967        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
 968
 969        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
 970
 971        // Zero-out unused scales
 972        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
 973        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
 974        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 975
 976        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 977
 978        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
 979    }
 980
 981    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 982
 983    hvx_vec_store_u(s0, 4, r0_sum);
 984}
 985
 986static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
 987                                      const void * restrict vx0, const void * restrict vx1,
 988                                      const void * restrict vy0) {
 989    assert(n % 32 == 0);  // min sub-block size
 990    assert((unsigned long) vx0 % 128 == 0);
 991    assert((unsigned long) vx1 % 128 == 0);
 992    assert((unsigned long) vy0 % 128 == 0);
 993
 994    const uint32_t qk = QK_MXFP4x4x2 * 4;
 995
 996    const uint32_t x_dblk_size = 8 * 4 * 1;                                   // 32x e8m0
 997    const uint32_t x_qblk_size = qk / 2;                                      // fp4
 998    const uint32_t x_qrow_size = n / 2;                                       // fp4 (not padded)
 999
1000    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
1001    const uint32_t y_qblk_size = qk;                                          // int8
1002    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
1003
1004    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
1005    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
1006    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
1007    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
1008
1009    const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0;               // quants first
1010    const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size;     // then scales
1011
1012    // Row sum (sf)
1013    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
1014    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
1015
1016    // Multiply and accumulate into int32.
1017    // Compute combined scale (fp32).
1018    // Apply scale to acc and accumulate into the row sum (f32).
1019
1020    const uint32_t nb   = n / qk;  // num full blocks
1021    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
1022
1023    uint32_t i = 0;
1024    for (; i < nb; i++) {
1025        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
1026        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
1027        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1028
1029        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1030        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
1031
1032        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
1033        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1034        HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1035
1036        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1037        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
1038        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
1039        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);
1040
1041        // Convert rX_d scales from e8m0 to fp32
1042        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1043        // Left shift with zero fill to create FP32
1044        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1045        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
1046        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1047        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
1048        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
1049        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
1050        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
1051        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
1052        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
1053
1054        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
1055        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
1056
1057        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1058        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1059
1060        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1061        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1062    }
1063
1064    // Process leftovers
1065    if (nloe) {
1066        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
1067        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
1068        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1069
1070        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1071        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
1072
1073        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
1074        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1075        HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1076
1077        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1078        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
1079        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
1080        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);
1081
1082        // Convert rX_d scales from e8m0 to fp32
1083        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1084        // Left shift with zero fill to create FP32
1085        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1086        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
1087        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1088        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
1089        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
1090        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
1091        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
1092        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
1093        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
1094
1095        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
1096        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
1097
1098        // Zero-out unused values
1099        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1100        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
1101        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
1102        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
1103        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
1104
1105        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1106        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1107
1108        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1109        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1110    }
1111
1112    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
1113    hvx_vec_store_u(s0, 8, rsum);
1114}
1115
1116static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
1117                                        const void * restrict vx0, const void * restrict vx1,
1118                                        const void * restrict vy0, const void * restrict vy1) {
1119    assert(n % 32 == 0);
1120    assert((unsigned long) vx0 % 128 == 0);
1121    assert((unsigned long) vx1 % 128 == 0);
1122    assert((unsigned long) vy0 % 128 == 0);
1123    assert((unsigned long) vy1 % 128 == 0);
1124
1125    const uint32_t qk = QK_MXFP4x4x2 * 4;
1126
1127    const uint32_t x_dblk_size = 8 * 4 * 1;                                   // 32x e8m0
1128    const uint32_t x_qblk_size = qk / 2;                                      // fp4
1129    const uint32_t x_qrow_size = n / 2;                                       // fp4 (not padded)
1130
1131    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
1132    const uint32_t y_qblk_size = qk;                                          // int8
1133    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
1134
1135    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
1136    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
1137    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
1138    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
1139
1140    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
1141    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
1142    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
1143    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
1144
1145    // Row sums (sf) - 4 accumulators for 2ร—2 tile
1146    HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
1147    HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
1148    HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
1149    HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
1150
1151    const uint32_t nb   = n / qk;  // num full blocks
1152    const uint32_t nloe = n % qk;  // num leftover elements
1153
1154    uint32_t i = 0;
1155    for (; i < nb; i++) {
1156        // Load src1 columns (reused across both src0 rows)
1157        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
1158        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
1159
1160        // Load src0 rows (reused across both src1 columns)
1161        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
1162        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1163
1164        // Compute 4 dot products: r0ร—c0, r0ร—c1, r1ร—c0, r1ร—c1
1165        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
1166        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
1167        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
1168        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
1169
1170        // Load scales
1171        HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d   + i * y_dblk_size);
1172        HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d   + i * y_dblk_size);
1173        HVX_Vector r0_d  = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1174        HVX_Vector r1_d  = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1175
1176        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1177        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
1178        vy0_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
1179        vy0_d           = Q6_Vsf_equals_Vqf32(vy0_d);
1180        vy1_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
1181        vy1_d           = Q6_Vsf_equals_Vqf32(vy1_d);
1182
1183        // Convert rX_d scales from e8m0 to fp32
1184        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1185        // Left shift with zero fill to create FP32
1186        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1187        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
1188        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1189        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
1190        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
1191        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
1192        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
1193        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
1194        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
1195
1196        // Compute combined scales
1197        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
1198        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
1199        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
1200        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
1201
1202        // Apply scales and accumulate
1203        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1204        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1205        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1206        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1207
1208        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1209        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1210        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1211        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1212    }
1213
1214    // Process leftovers
1215    if (nloe) {
1216        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
1217        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
1218        HVX_Vector_x8 r0_q  = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
1219        HVX_Vector_x8 r1_q  = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1220
1221        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
1222        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
1223        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
1224        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
1225
1226        HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d   + i * y_dblk_size);
1227        HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d   + i * y_dblk_size);
1228        HVX_Vector r0_d  = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1229        HVX_Vector r1_d  = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1230
1231        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1232        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
1233        vy0_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
1234        vy0_d           = Q6_Vsf_equals_Vqf32(vy0_d);
1235        vy1_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
1236        vy1_d           = Q6_Vsf_equals_Vqf32(vy1_d);
1237
1238        // Convert rX_d scales from e8m0 to fp32
1239        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1240        // Left shift with zero fill to create FP32
1241        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1242        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
1243        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1244        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
1245        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
1246        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
1247        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
1248        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
1249        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
1250
1251        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
1252        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
1253        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
1254        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
1255
1256        // Zero out unused scales
1257        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1258        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
1259        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
1260        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
1261        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
1262        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
1263        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
1264        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
1265        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
1266
1267        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1268        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1269        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1270        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1271
1272        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1273        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1274        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1275        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1276    }
1277
1278    // Reduce and store results
1279    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
1280    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
1281
1282    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
1283    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
1284}
1285
1286static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1287    const HVX_Vector * restrict x = (const HVX_Vector *) vx;
1288    const HVX_Vector * restrict y = (const HVX_Vector *) vy;
1289
1290    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
1291    uint32_t nloe = n % VLEN_FP16; // leftover elements
1292
1293    HVX_Vector rsum = Q6_V_vsplat_R(0);
1294
1295    uint32_t i = 0;
1296
1297    #pragma unroll(4)
1298    for (i = 0; i < nvec; i++) {
1299        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
1300        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
1301    }
1302
1303    if (nloe) {
1304        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1305        HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
1306        HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
1307
1308        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1309        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
1310    }
1311
1312    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
1313    hvx_vec_store_u(&s[0], 4, rsum);
1314}
1315
1316static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
1317                                const void * restrict vx0, const void * restrict vx1,
1318                                const void * restrict vy0) {
1319    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
1320    const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
1321    const HVX_Vector * restrict y  = (const HVX_Vector *) vy0;
1322
1323    uint32_t nvec = n / VLEN_FP16;
1324    uint32_t nloe = n % VLEN_FP16;
1325
1326    HVX_Vector rsum0 = Q6_V_vsplat_R(0);
1327    HVX_Vector rsum1 = Q6_V_vsplat_R(0);
1328
1329    uint32_t i = 0;
1330
1331    #pragma unroll(2)
1332    for (i = 0; i < nvec; i++) {
1333        HVX_Vector y_hf = y[i];
1334        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
1335        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
1336
1337        rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
1338        rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
1339    }
1340
1341    if (nloe) {
1342        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1343        HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
1344        HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
1345        HVX_Vector y_hf  = Q6_V_vand_QV(bmask, y[i]);
1346
1347        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
1348        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
1349
1350        rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
1351        rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
1352    }
1353
1354    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
1355    hvx_vec_store_u(s0, 8, rsum);
1356}
1357
1358static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,
1359                                const void * restrict vx0, const void * restrict vx1,
1360                                const void * restrict vy0, const void * restrict vy1) {
1361    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
1362    const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
1363    const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
1364    const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
1365
1366    uint32_t nvec = n / VLEN_FP16;
1367    uint32_t nloe = n % VLEN_FP16;
1368
1369    // Row sums (sf) - 4 accumulators for 2ร—2 tile
1370    HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
1371    HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
1372    HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
1373    HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
1374
1375    uint32_t i = 0;
1376
1377    #pragma unroll(2)
1378    for (i = 0; i < nvec; i++) {
1379        HVX_Vector r0_hf = x0[i];
1380        HVX_Vector r1_hf = x1[i];
1381        HVX_Vector c0_hf = y0[i];
1382        HVX_Vector c1_hf = y1[i];
1383
1384        // Compute 4 dot products: r0ร—c0, r0ร—c1, r1ร—c0, r1ร—c1
1385        HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
1386        HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
1387        HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
1388        HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
1389
1390        HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
1391        HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
1392        HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
1393        HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
1394
1395        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
1396        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
1397        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
1398        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
1399    }
1400
1401    if (nloe) {
1402        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1403
1404        HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
1405        HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
1406        HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
1407        HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
1408
1409        HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
1410        HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
1411        HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
1412        HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
1413
1414        HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
1415        HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
1416        HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
1417        HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
1418
1419        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
1420        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
1421        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
1422        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
1423
1424    }
1425
1426    // Reduce and store results
1427    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
1428    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
1429
1430    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
1431    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
1432}
1433
1434static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1435    const HVX_UVector * restrict x = (const HVX_UVector *) vx;
1436    const HVX_UVector * restrict y = (const HVX_UVector *) vy;
1437
1438    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
1439    uint32_t nloe = n % VLEN_FP16; // leftover elements
1440
1441    HVX_Vector rsum = Q6_V_vsplat_R(0);
1442
1443    uint32_t i = 0;
1444
1445    #pragma unroll(4)
1446    for (i = 0; i < nvec; i++) {
1447        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
1448        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
1449    }
1450
1451    if (nloe) {
1452        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1453        HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
1454        HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
1455
1456        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1457        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
1458    }
1459
1460    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
1461    hvx_vec_store_u(&s[0], 4, rsum);
1462}
1463
1464static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
1465    const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
1466    const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
1467
1468    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
1469    uint32_t nloe = n % VLEN_FP16; // leftover elements
1470
1471    const HVX_Vector zero = Q6_V_vsplat_R(0);
1472
1473    HVX_Vector       rsum = Q6_V_vsplat_R(0);
1474
1475    uint32_t i = 0;
1476
1477    #pragma unroll(2)
1478    for (i = 0; i < nvec; i++) {
1479        // Load y (fp32) and convert into fp16
1480        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
1481        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
1482        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
1483
1484        // Load x (fp16)
1485        HVX_Vector x_hf  = vx[i];
1486
1487        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1488
1489        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
1490    }
1491
1492    if (nloe) {
1493        // Load y (fp32) and convert into fp16
1494        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
1495        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
1496        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
1497
1498        // Load x (fp16)
1499        HVX_Vector x_hf  = vx[i];
1500
1501        // Zero-out unused elements
1502        // Note that we need to clear both x and y because they may contain NANs
1503        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1504        x_hf = Q6_V_vand_QV(bmask, x_hf);
1505        y_hf = Q6_V_vand_QV(bmask, y_hf);
1506
1507        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1508
1509        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
1510    }
1511
1512    // Convert into fp32 and reduce
1513    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
1514    hvx_vec_store_u(&s[0], 4, rsum);
1515}
1516
1517#define htp_matmul_tensors_preamble    \
1518    struct htp_tensor * restrict src0    = &octx->src0;      \
1519    struct htp_tensor * restrict src1    = &octx->src1;      \
1520    struct htp_tensor * restrict src2    = &octx->src2;      \
1521    struct htp_tensor * restrict dst     = &octx->dst;       \
1522    struct htp_spad * restrict src0_spad = &octx->src0_spad; \
1523    struct htp_spad * restrict src1_spad = &octx->src1_spad; \
1524    struct htp_spad * restrict dst_spad  = &octx->dst_spad;  \
1525                                                             \
1526    const uint32_t ne00 = src0->ne[0]; \
1527    const uint32_t ne01 = src0->ne[1]; \
1528    const uint32_t ne02 = src0->ne[2]; \
1529    const uint32_t ne03 = src0->ne[3]; \
1530                                       \
1531    const uint32_t ne10 = src1->ne[0]; \
1532    const uint32_t ne11 = src1->ne[1]; \
1533    const uint32_t ne12 = src1->ne[2]; \
1534    const uint32_t ne13 = src1->ne[3]; \
1535                                       \
1536    const uint32_t ne20 = src2->ne[0]; \
1537    const uint32_t ne21 = src2->ne[1]; \
1538    const uint32_t ne22 = src2->ne[2]; \
1539    const uint32_t ne23 = src2->ne[3]; \
1540                                       \
1541    const uint32_t ne0 = dst->ne[0];   \
1542    const uint32_t ne1 = dst->ne[1];   \
1543    const uint32_t ne2 = dst->ne[2];   \
1544    const uint32_t ne3 = dst->ne[3];   \
1545                                       \
1546    const uint32_t nb00 = src0->nb[0]; \
1547    const uint32_t nb01 = src0->nb[1]; \
1548    const uint32_t nb02 = src0->nb[2]; \
1549    const uint32_t nb03 = src0->nb[3]; \
1550                                       \
1551    const uint32_t nb10 = src1->nb[0]; \
1552    const uint32_t nb11 = src1->nb[1]; \
1553    const uint32_t nb12 = src1->nb[2]; \
1554    const uint32_t nb13 = src1->nb[3]; \
1555                                       \
1556    const uint32_t nb0 = dst->nb[0];   \
1557    const uint32_t nb1 = dst->nb[1];   \
1558    const uint32_t nb2 = dst->nb[2];   \
1559    const uint32_t nb3 = dst->nb[3];
1560
1561#define htp_matmul_preamble                                     \
1562    struct htp_matmul_context * mmctx = data;                   \
1563    struct htp_ops_context * octx  = mmctx->octx;               \
1564    htp_matmul_tensors_preamble;                                \
1565    dma_queue *dma_queue           = octx->ctx->dma[ith];       \
1566    uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;
1567
1568// *** matmul with support for 4d tensors and full broadcasting
1569
1570static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
1571    htp_matmul_preamble;
1572
1573    uint64_t t1, t2;
1574    t1 = HAP_perf_get_qtimer_count();
1575
1576    assert(ne12 % ne02 == 0);
1577    assert(ne13 % ne03 == 0);
1578
1579    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
1580    const uint32_t nr0 = ne0;
1581
1582    // This is the size of the rest of the dimensions of the result
1583    const uint32_t nr1 = ne1 * ne2 * ne3;
1584
1585    // distribute the thread work across the inner or outer loop based on which one is larger
1586    uint32_t nchunk0 = nr0 > nr1 ? nth : 1;  // parallelize by src0 rows
1587    uint32_t nchunk1 = nr0 > nr1 ? 1 : nth;  // parallelize by src1 rows
1588
1589    // The number of elements in each chunk
1590    const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1591    const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
1592
1593    uint32_t current_chunk = ith;
1594
1595    const uint32_t ith0 = current_chunk % nchunk0;
1596    const uint32_t ith1 = current_chunk / nchunk0;
1597
1598    const uint32_t ir0_start = dr0 * ith0;
1599    const uint32_t ir0_end   = MIN(ir0_start + dr0, nr0);
1600
1601    const uint32_t ir1_start = dr1 * ith1;
1602    const uint32_t ir1_end   = MIN(ir1_start + dr1, nr1);
1603
1604    // no work for this thread
1605    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
1606        return;
1607    }
1608
1609    // block-tiling attempt
1610    const uint32_t blck_0 = 64;
1611    const uint32_t blck_1 = 64;
1612
1613    for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
1614        for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
1615            for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
1616                const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);
1617                const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);
1618                const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
1619
1620                // broadcast src0 into src1
1621                const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);
1622                const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);
1623
1624                const uint32_t i1 = i11;
1625                const uint32_t i2 = i12;
1626                const uint32_t i3 = i13;
1627
1628                const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
1629                const uint8_t * restrict src1_col  = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
1630                float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
1631
1632                const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
1633                for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
1634                    const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
1635                    mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
1636                }
1637            }
1638        }
1639    }
1640
1641    t2 = HAP_perf_get_qtimer_count();
1642
1643    FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
1644         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
1645         src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1646         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1647}
1648
1649// src1 tensor is already in VTCM spad
1650static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
1651    htp_matmul_preamble;
1652
1653    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
1654    const uint32_t src1_nrows = ne11 * ne12 * ne13;  // src1 rows
1655
1656    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;
1657    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1658    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1659
1660    // no work for this thread
1661    if (src0_start_row >= src0_end_row) {
1662        return;
1663    }
1664
1665    const size_t dst_row_size  = nb1;
1666    const size_t src0_row_size = nb01;
1667    const size_t src1_row_size = nb11;
1668
1669    const size_t src0_stride = src0_spad->stride;
1670    const size_t src1_stride = src1_spad->stride;
1671
1672    // Per-thread VTCM scratchpads for all tensors
1673    // Note that the entire src1 tensor is already in VTCM
1674    // For other tensors we allocate N rows per thread, padded to HVX vector size
1675    uint8_t * restrict spad_dst  = dst_spad->data  + dst_spad->size_per_thread  * ith;
1676    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1677    uint8_t * restrict src1_data = src1_spad->data;
1678
1679    volatile uint64_t t1, t2;
1680    t1 = HAP_perf_get_qtimer_count();
1681
1682    const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
1683
1684    // Prefill spad with src0 rows
1685    #pragma unroll(4)
1686    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1687        const int is0 = (ir0 - src0_start_row);
1688        if (is0 >= MM_SPAD_SRC0_NROWS) {
1689            break;
1690        }
1691        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1692                       src0_stride, src0_row_size, 2);
1693    }
1694
1695    // Process src0 rows
1696    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1697        const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1698
1699        // Process src1 columns in pairs (2ร—2 tiling)
1700        uint32_t ir1 = 0;
1701        for (; ir1 + 1 < src1_nrows; ir1 += 2) {
1702            const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
1703            const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
1704            float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
1705            float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
1706            mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);
1707        }
1708
1709        // Handle remaining src1 rows (fallback to 2ร—1)
1710        for (; ir1 < src1_nrows; ++ir1) {
1711            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
1712            float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
1713            mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
1714        }
1715
1716        // Prefetch next (n + spad_nrows) row
1717        const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1718        const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1719        if (pr0 < src0_end_row_x2) {
1720            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
1721                           src0_stride, src0_row_size, 2);
1722        }
1723    }
1724
1725    // Process the last row (if any)
1726    if (src0_end_row != src0_end_row_x2) {
1727        uint32_t  ir0 = src0_end_row_x2;
1728        const int is0 = (ir0 - src0_start_row);
1729        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1730                       src0_stride, src0_row_size, 1);
1731        const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1732
1733        #pragma unroll(2)
1734        for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
1735            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
1736            float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
1737            mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
1738        }
1739    }
1740
1741    t2 = HAP_perf_get_qtimer_count();
1742
1743    FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
1744         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1745         src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1746         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1747}
1748
1749// q8x4x2 src1 tensor is already in VTCM spad
1750static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
1751    htp_matmul_preamble;
1752
1753    const uint32_t src0_nrows = ne01;
1754
1755    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;
1756    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1757    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1758
1759    // no work for this thread
1760    if (src0_start_row >= src0_end_row) {
1761        return;
1762    }
1763
1764    const size_t dst_row_size  = nb1;
1765    const size_t src0_row_size = nb01;
1766    const size_t src1_row_size = nb11;
1767
1768    const size_t src0_stride = src0_spad->stride;
1769    const size_t src1_stride = src1_spad->stride;
1770
1771    // Per-thread VTCM scratchpads for all tensors
1772    // Note that the entire src1 tensor is already in VTCM
1773    // For other tensors we allocate N rows per thread, padded to HVX vector size
1774    uint8_t * spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
1775    uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1776    uint8_t * src1_data = src1_spad->data;
1777
1778    uint64_t t1, t2;
1779    t1 = HAP_perf_get_qtimer_count();
1780
1781    float * tmp = (float *) spad_dst;
1782
1783    const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
1784    const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
1785    float * restrict dst_col          = (float *) dst->data;
1786
1787    // Prefill spad with 2x src0 rows
1788    #pragma unroll(2)
1789    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1790        const uint32_t is0 = (ir0 - src0_start_row);
1791        if (is0 >= MM_SPAD_SRC0_NROWS) {
1792            break;
1793        }
1794        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1795                       src0_stride, src0_row_size, 2);
1796    }
1797
1798    // Process src0 rows
1799    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1800        const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1801        mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
1802
1803        // Prefetch next (n + spad_nrows) row
1804        const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1805        const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1806        if (pr0 < src0_end_row_x2) {
1807            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
1808                           src0_stride, src0_row_size, 2);
1809        }
1810    }
1811
1812    // Process the last row (if any)
1813    if (src0_end_row != src0_end_row_x2) {
1814        const uint32_t ir0 = src0_end_row_x2;
1815        const uint32_t is0 = (ir0 - src0_start_row);
1816        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1817                       src0_stride, src0_row_size, 1);
1818        const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1819        mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
1820    }
1821
1822    hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
1823
1824    t2 = HAP_perf_get_qtimer_count();
1825
1826    FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
1827         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1828         src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1829         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1830}
1831
1832#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)]
1833
1834struct mmid_row_mapping {
1835    uint32_t i1;
1836    uint32_t i2;
1837};
1838
1839// src1 tensor is already in VTCM spad
1840static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
1841    htp_matmul_preamble;
1842
1843    struct htp_tensor * restrict     ids = &octx->src2;
1844    struct htp_spad * restrict src2_spad = &octx->src2_spad;
1845
1846    uint64_t t1, t2;
1847    t1 = HAP_perf_get_qtimer_count();
1848
1849    const uint32_t src0_nrows = ne01;  // src0 rows per expert
1850    const uint32_t src1_nrows = ne11;
1851
1852    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;
1853    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1854    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1855
1856    // no work for this thread
1857    if (src0_start_row >= src0_end_row) {
1858        return;
1859    }
1860
1861    const uint32_t n_ids = ids->ne[0];  // n_expert_used
1862    const uint32_t n_as  = ne02;        // n_expert
1863
1864    const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
1865    const size_t matrix_row_map_size    = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
1866
1867    const uint32_t *                matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
1868    const struct mmid_row_mapping * matrix_rows       = (const void *) src2_spad->data + matrix_row_counts_size;
1869
1870    const size_t dst_row_size  = nb1;
1871    const size_t src0_row_size = nb01;
1872    const size_t src1_row_size = q8x4x2_row_size(ne10);
1873
1874    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
1875
1876    // Per-thread VTCM scratchpads for all tensors
1877    // Note that the entire src1 tensor is already in VTCM
1878    // For other tensors we allocate N rows per thread, padded to HVX vector size
1879    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
1880    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1881    uint8_t * restrict src1_data = src1_spad->data;
1882
1883    for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
1884        const int32_t cne1 = matrix_row_counts[cur_a];
1885
1886        if (cne1 == 0) {
1887            continue;
1888        }
1889
1890        const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
1891
1892        // Prefill spad with src0 rows
1893        #pragma unroll(4)
1894        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1895            const int is0 = (ir0 - src0_start_row);
1896            if (is0 >= MM_SPAD_SRC0_NROWS) {
1897                break;
1898            }
1899            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1900                           src0_row_size_padded, src0_row_size, 2);
1901        }
1902
1903        // Process src0 rows
1904        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1905            const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1906
1907            for (uint32_t cid = 0; cid < cne1; ++cid) {
1908                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
1909                const int               rm1         = row_mapping.i1;  // expert idx
1910                const int               rm2         = row_mapping.i2;  // token idx
1911
1912                const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx
1913                const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1914                float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1915
1916                mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
1917            }
1918
1919            // Prefetch next (n + spad_nrows) row
1920            const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1921            const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1922            if (pr0 < src0_end_row_x2) {
1923                dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
1924                               src0_row_size_padded, src0_row_size, 2);
1925            }
1926        }
1927
1928        // Process the last row (if any)
1929        if (src0_end_row != src0_end_row_x2) {
1930            uint32_t       ir0 = src0_end_row_x2;
1931            const uint32_t is0 = (ir0 - src0_start_row);
1932            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1933                           src0_row_size_padded, src0_row_size, 1);
1934            const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1935
1936            for (uint32_t cid = 0; cid < cne1; ++cid) {
1937                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
1938                const int               rm1         = row_mapping.i1;  // expert idx
1939                const int               rm2         = row_mapping.i2;  // token idx
1940
1941                const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx
1942                const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1943                float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1944
1945                mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
1946            }
1947        }
1948    }
1949
1950    t2 = HAP_perf_get_qtimer_count();
1951
1952    FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
1953         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
1954         src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
1955         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1956}
1957
1958// src1 tensor is already in VTCM spad
1959static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
1960    htp_matmul_preamble;
1961
1962    struct htp_tensor * restrict     ids = &octx->src2;
1963    struct htp_spad * restrict src2_spad = &octx->src2_spad;
1964
1965    uint64_t t1, t2;
1966    t1 = HAP_perf_get_qtimer_count();
1967
1968    const uint32_t src0_nrows = ne01;  // src0 rows per expert
1969
1970    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;
1971    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1972    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1973
1974    // no work for this thread
1975    if (src0_start_row >= src0_end_row) {
1976        return;
1977    }
1978
1979    assert(ne13 % ne03 == 0);
1980
1981    const size_t dst_row_size  = nb1;
1982    const size_t src0_row_size = nb01;
1983    const size_t src1_row_size = q8x4x2_row_size(ne10);
1984
1985    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
1986
1987    const uint32_t n_aids = src2->ne[0];  // num activated experts
1988    const uint32_t n_ids  = ne02;         // num experts
1989
1990    // Per-thread VTCM scratchpads for all tensors
1991    // Note that the entire src1 tensor is already in VTCM
1992    // For other tensors we allocate N rows per thread, padded to HVX vector size
1993    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
1994    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1995    uint8_t * restrict src1_data = src1_spad->data;
1996
1997    for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) {  // for each expert
1998        const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]);
1999        assert(eid < n_ids);
2000
2001        const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02;
2002        const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
2003        float * restrict dst_row          = (float *) (dst->data + ie1 * nb1);
2004
2005        // Prefill spad with src0 rows
2006        #pragma unroll(4)
2007        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
2008            const int is0 = (ir0 - src0_start_row);
2009            if (is0 >= MM_SPAD_SRC0_NROWS) {
2010                break;
2011            }
2012            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
2013                           src0_row_size_padded, src0_row_size, 2);
2014        }
2015
2016        // Process src0 rows
2017        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
2018            const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
2019            mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
2020
2021            // Prefetch next (n + spad_nrows) row
2022            const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
2023            const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
2024            if (pr0 < src0_end_row_x2) {
2025                dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
2026                               src0_row_size_padded, src0_row_size, 2);
2027            }
2028        }
2029
2030        // Process the last row (if any)
2031        if (src0_end_row != src0_end_row_x2) {
2032            uint32_t       ir0 = src0_end_row_x2;
2033            const uint32_t is0 = (ir0 - src0_start_row);
2034            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
2035                           src0_row_size_padded, src0_row_size, 1);
2036            const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
2037            mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
2038        }
2039    }
2040
2041    t2 = HAP_perf_get_qtimer_count();
2042
2043    FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
2044         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
2045         src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
2046         dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2047}
2048
2049// *** dynamic quant
2050
2051static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2052    assert((unsigned long) x % 128 == 0);
2053    assert((unsigned long) y_q % 128 == 0);
2054
2055    HVX_Vector * vx = (HVX_Vector *) x;
2056    HVX_Vector zero   = Q6_V_vsplat_R(0);
2057
2058    // Use reduce max fp32 to find max(abs(e)) first
2059    HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
2060    HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
2061    HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
2062    HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
2063    // Load and convert into QF32
2064    HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements
2065    HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements
2066    HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements
2067    HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements
2068
2069    // Convert to QF32
2070    HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
2071    HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
2072    HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
2073    HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
2074
2075    // Combine and convert to fp16
2076    HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
2077    HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
2078
2079    // Convert into fp16
2080    HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
2081    HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
2082
2083    // Replicate first fp16 scale across all lanes
2084    HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
2085    vmax01_hf         = Q6_V_vdelta_VV(vmax01_hf, ctrl);
2086    vmax23_hf         = Q6_V_vdelta_VV(vmax23_hf, ctrl);
2087
2088    HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
2089    HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
2090    HVX_Vector vd01_hf   = Q6_Vhf_equals_Vqf16(vd01_qf16);
2091    HVX_Vector vd23_hf   = Q6_Vhf_equals_Vqf16(vd23_qf16);
2092
2093    hvx_vec_store_u(y_d + 0, 2, vd01_hf);
2094    HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);
2095    hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);
2096
2097    hvx_vec_store_u(y_d + 4, 2, vd23_hf);
2098    rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);
2099    hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
2100
2101    // Divide input by the scale
2102    HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
2103    HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
2104    vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
2105    vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
2106
2107    // Convert to int8
2108    HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
2109    HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
2110    HVX_Vector vx_i8    = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
2111
2112    *(HVX_Vector *) y_q = vx_i8;
2113}
2114
2115static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2116    assert((unsigned long) x % 128 == 0);
2117    assert((unsigned long) y_q % 128 == 0);
2118
2119    HVX_Vector * vx = (HVX_Vector *) x;
2120
2121    // Load and convert into QF32
2122    HVX_Vector zero   = Q6_V_vsplat_R(0);
2123    HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements
2124    HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements
2125    HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements
2126    HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements
2127
2128    // Convert into fp16
2129    HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
2130    HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
2131
2132    // Compute max and scale
2133    HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
2134    HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
2135
2136    // Replicate first fp16 scale across all lanes
2137    HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
2138    vmax01_hf         = Q6_V_vdelta_VV(vmax01_hf, ctrl);
2139    vmax23_hf         = Q6_V_vdelta_VV(vmax23_hf, ctrl);
2140
2141    HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
2142    HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
2143    HVX_Vector vd01_hf   = Q6_Vhf_equals_Vqf16(vd01_qf16);
2144    HVX_Vector vd23_hf   = Q6_Vhf_equals_Vqf16(vd23_qf16);
2145
2146    hvx_vec_store_u(y_d + 0, 4, vd01_hf);
2147    hvx_vec_store_u(y_d + 4, 4, vd23_hf);
2148
2149    // Divide input by the scale
2150    HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
2151    HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
2152    vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
2153    vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
2154
2155    // Convert to int8
2156    HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
2157    HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
2158    HVX_Vector vx_i8    = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
2159
2160    *(HVX_Vector *) y_q = vx_i8;
2161}
2162
2163static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2164    assert((unsigned long) x % 128 == 0);
2165    assert((unsigned long) y_q % 128 == 0);
2166
2167    HVX_Vector * vx = (HVX_Vector *) x;
2168
2169    // Load and convert into QF32
2170    HVX_Vector zero   = Q6_V_vsplat_R(0);
2171    HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements
2172    HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements
2173    HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements
2174    HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements
2175
2176    // Convert into fp16
2177    HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
2178    HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
2179
2180    // Compute max and scale
2181    HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
2182    vmax_hf            = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
2183
2184    // Replicate first fp16 scale across all lanes
2185    HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
2186    vmax_hf         = Q6_V_vdelta_VV(vmax_hf, ctrl);
2187
2188    HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
2189    HVX_Vector vd_hf   = Q6_Vhf_equals_Vqf16(vd_qf16);
2190
2191    *(HVX_UVector *) y_d = vd_hf;
2192
2193    // Divide input by the scale
2194    HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);
2195    vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
2196    vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
2197
2198    // Convert to int8
2199    HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
2200    HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
2201    HVX_Vector vx_i8    = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
2202
2203    *(HVX_Vector *) y_q = vx_i8;
2204}
2205
2206// Overrides input x
2207static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
2208    assert(k % 32 == 0);
2209    const uint32_t qk = QK_Q8_0x4x2;
2210    const uint32_t nb = (k + qk - 1) / qk;
2211
2212    const uint32_t qrow_size = k;              // int8
2213
2214    const uint32_t dblk_size = 8 * 2;          // 8x __fp16
2215    const uint32_t qblk_size = QK_Q8_0x4x2;    // int8
2216
2217    uint8_t * restrict y_q = (y + 0);          // quants first
2218    uint8_t * restrict y_d = (y + qrow_size);  // then scales
2219
2220    // Temp scales override input since we're working off of the aligned temp buffer in VTCM
2221    uint8_t * restrict t_d = (uint8_t *) x;
2222
2223    for (uint32_t i = 0; i < nb; i++) {
2224#if FP32_QUANTIZE_GROUP_SIZE == 32
2225        quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2226        quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2227#elif FP32_QUANTIZE_GROUP_SIZE == 64
2228        quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2229        quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2230#elif FP32_QUANTIZE_GROUP_SIZE == 128
2231        quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2232        quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2233#else
2234#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
2235#endif
2236    }
2237
2238    // now copy the scales into final location
2239    hvx_copy_f16_ua(y_d, t_d, nb * 8);
2240}
2241
2242static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
2243    struct htp_matmul_context * mmctx = data;
2244    struct htp_ops_context * octx = mmctx->octx;
2245
2246    const struct htp_tensor * src = &octx->src1;
2247    uint8_t * restrict dst = octx->src1_spad.data;
2248    struct htp_spad * spad = &octx->src0_spad;
2249    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2250
2251    uint64_t t1 = HAP_perf_get_qtimer_count();
2252
2253    const uint32_t ne0 = src->ne[0];
2254    const uint32_t ne1 = src->ne[1];
2255    const uint32_t ne2 = src->ne[2];
2256    const uint32_t ne3 = src->ne[3];
2257
2258    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows
2259
2260    const uint32_t ir_first = nrows_per_thread * ith;                   // first row
2261    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row
2262
2263    const size_t src_row_size = src->nb[1];
2264    const size_t dst_row_size = q8x4x2_row_size(ne0);
2265
2266    uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);
2267    uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
2268    uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
2269
2270    const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
2271    memset(tmp_data, 0, src_row_size_padded);  // zero-out temp row data for padding
2272
2273    for (uint32_t i = ir_first; i < ir_last; ++i) {
2274        hex_l2fetch(src_data, src_row_size, src_row_size, 2);
2275        hvx_copy_f32_aa(tmp_data, src_data, ne0);
2276
2277        // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
2278        quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0);
2279        dst_data += dst_row_size;
2280        src_data += src_row_size;
2281    }
2282
2283    uint64_t t2 = HAP_perf_get_qtimer_count();
2284
2285    FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
2286         ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2287}
2288
2289static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
2290    struct htp_matmul_context * mmctx = data;
2291    struct htp_ops_context * octx = mmctx->octx;
2292
2293    const struct htp_tensor * src = &octx->src1;
2294    uint8_t * restrict dst = octx->src1_spad.data;
2295    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2296    uint32_t dst_stride = octx->src1_spad.stride;
2297
2298    uint64_t t1 = HAP_perf_get_qtimer_count();
2299
2300    const uint32_t ne0 = src->ne[0];
2301    const uint32_t ne1 = src->ne[1];
2302    const uint32_t ne2 = src->ne[2];
2303    const uint32_t ne3 = src->ne[3];
2304
2305    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows
2306
2307    const uint32_t ir_first = nrows_per_thread * ith;                   // first row
2308    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row
2309
2310    const size_t src_row_size = ne0 * sizeof(float);
2311    const size_t src_stride   = src->nb[1];
2312
2313    uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
2314    uint8_t * restrict dst_data = (uint8_t *) dst       + (dst_stride * ir_first);
2315
2316    for (uint32_t i = ir_first; i < ir_last; ++i) {
2317        hex_l2fetch(src_data, src_row_size, src_stride, 2);
2318        hvx_copy_f16_f32_au(dst_data, src_data, ne0);
2319
2320        dst_data += dst_stride;
2321        src_data += src_stride;
2322    }
2323
2324    uint64_t t2 = HAP_perf_get_qtimer_count();
2325
2326    FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
2327        ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2328}
2329
2330// TODO just a plain copy that should be done via the DMA during the Op setup
2331static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
2332    struct htp_matmul_context * mmctx = data;
2333    struct htp_ops_context * octx = mmctx->octx;
2334
2335    const struct htp_tensor * src = &octx->src1;
2336    uint8_t * restrict dst = octx->src1_spad.data;
2337    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2338    uint32_t dst_stride = octx->src1_spad.stride;
2339
2340    uint64_t t1 = HAP_perf_get_qtimer_count();
2341
2342    const uint32_t ne0 = src->ne[0];
2343    const uint32_t ne1 = src->ne[1];
2344    const uint32_t ne2 = src->ne[2];
2345    const uint32_t ne3 = src->ne[3];
2346
2347    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows
2348
2349    const uint32_t ir_first = nrows_per_thread * ith;                   // first row
2350    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row
2351
2352    const size_t src_row_size = ne0 * sizeof(float);
2353    const size_t src_stride   = src->nb[1];
2354
2355    uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
2356    uint8_t * restrict dst_data = (uint8_t *) dst       + (dst_stride * ir_first);
2357
2358    for (uint32_t i = ir_first; i < ir_last; ++i) {
2359        hex_l2fetch(src_data, src_row_size, src_stride, 2);
2360        hvx_copy_f16_au(dst_data, src_data, ne0);
2361
2362        dst_data += dst_stride;
2363        src_data += src_stride;
2364    }
2365
2366    uint64_t t2 = HAP_perf_get_qtimer_count();
2367
2368    FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
2369        ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2370}
2371
2372
2373static inline bool htp_is_permuted(const struct htp_tensor * t) {
2374    return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
2375}
2376
2377static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {
2378    switch (type) {
2379        case HTP_TYPE_Q4_0:
2380            mmctx->type        = "q4x4x2-f32";
2381            mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
2382            mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
2383            mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
2384            return 0;
2385        case HTP_TYPE_Q8_0:
2386            mmctx->type        = "q8x4x2-f32";
2387            mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
2388            mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
2389            mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
2390            return 0;
2391        case HTP_TYPE_MXFP4:
2392            mmctx->type        = "mxfp4x4x2-f32";
2393            mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
2394            mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
2395            mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
2396            return 0;
2397        default:
2398            return -1;
2399    }
2400}
2401
2402static void htp_mminit_spad(struct htp_ops_context * octx,
2403                                 size_t dst_row_size,
2404                                 size_t src0_row_size_padded,
2405                                 size_t src1_row_size,
2406                                 uint32_t src1_nrows,
2407                                 size_t src2_spad_size_per_thread) {
2408    octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2409    octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2410    octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
2411
2412    if (src2_spad_size_per_thread > 0) {
2413        octx->src2_spad.size_per_thread = src2_spad_size_per_thread;
2414        octx->src2_spad.size            = octx->src2_spad.size_per_thread;
2415    }
2416
2417    // src0 spad is also used in dynamic quantizer to store padded src1 rows
2418    size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2419    if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2420        octx->src0_spad.size_per_thread = src1_row_size_padded;
2421    }
2422
2423    octx->src1_spad.size = octx->src1_spad.size_per_thread;
2424    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2425    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
2426}
2427
2428int op_matmul(struct htp_ops_context * octx) {
2429    htp_matmul_tensors_preamble;
2430
2431    struct htp_matmul_context mmctx_struct = {0};
2432    struct htp_matmul_context * mmctx = &mmctx_struct;
2433    mmctx->octx = octx;
2434
2435    const uint32_t src0_nrows = ne01 * ne02 * ne03;
2436    const uint32_t src1_nrows = ne11 * ne12 * ne13;
2437
2438    // Compute src0_nrows_per_thread
2439    mmctx->src0_nrows_per_thread  = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2440    mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
2441
2442    const size_t src0_row_size = nb01;
2443    const size_t dst_row_size  = nb1;
2444    size_t       src1_row_size = nb11;
2445
2446    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
2447    size_t       src1_row_size_padded;
2448
2449    worker_callback_t quant_job_func;
2450    worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
2451
2452    bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
2453
2454    if (src0->type == HTP_TYPE_F16) {
2455        // Try optimized f16-f16 path first (src1 in VTCM)
2456        const size_t f16_src1_row_size  = hex_round_up(ne10 * 2, 128);
2457        const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
2458        const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
2459        const size_t f16_dst_spad_size  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
2460
2461        const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
2462
2463        // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
2464        // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
2465        const bool is_batched  = (ne02 > 1) || (ne03 > 1);
2466        const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
2467
2468        if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
2469            // Optimized path
2470            quant_job_func     = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;
2471            mmctx->type        = "f16-f16";
2472            mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;
2473            mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;
2474            mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;
2475
2476            src1_row_size = f16_src1_row_size;  // row size post quantization
2477
2478            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2479            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2480            octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
2481
2482            octx->src1_spad.size = octx->src1_spad.size_per_thread;
2483            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2484            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
2485        } else {
2486            // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
2487            quant_job_func = NULL;
2488            if (src1->type == HTP_TYPE_F32) {
2489                mmctx->type        = "f16-f32";
2490                mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;
2491                matmul_job_func    = matmul_4d;
2492            } else {
2493                mmctx->type        = "f16-f16";
2494                mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;
2495                matmul_job_func    = matmul_4d;
2496            }
2497
2498            src1_row_size = nb11;  // original row size in DDR
2499
2500            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2501            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
2502            octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
2503
2504            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2505            octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
2506            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
2507
2508            // Init fastdiv for matmul_4d (supports broadcasting)
2509            mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
2510            mmctx->mm_div_ne1      = init_fastdiv_values(dst->ne[1]);
2511            mmctx->mm_div_r2       = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
2512            mmctx->mm_div_r3       = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
2513
2514            need_quant = false;
2515        }
2516    } else {
2517        if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
2518            return HTP_STATUS_NO_SUPPORT;
2519        }
2520
2521        quant_job_func = quantize_f32_q8x4x2;
2522        src1_row_size  = q8x4x2_row_size(ne10);
2523        htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
2524    }
2525
2526    // VTCM scratchpads for all tensors
2527    size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2528
2529    FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
2530         octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
2531
2532    FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0],
2533         src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
2534         dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
2535
2536    // Make sure the reserved vtcm size is sufficient
2537    if (octx->ctx->vtcm_size < spad_size) {
2538        FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type,
2539             octx->ctx->vtcm_size, spad_size);
2540        return HTP_STATUS_VTCM_TOO_SMALL;
2541    }
2542
2543    octx->src0_spad.data = octx->ctx->vtcm_base;
2544    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2545    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
2546
2547    octx->src0_spad.stride = src0_row_size_padded;
2548    octx->src1_spad.stride = src1_row_size;
2549
2550    if (need_quant) {
2551        const uint32_t n_quant_jobs  = MIN(src1_nrows, octx->n_threads);
2552        mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2553        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
2554    }
2555
2556    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2557        const uint32_t n_matmul_jobs = octx->n_threads;
2558        worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
2559    }
2560
2561    return HTP_STATUS_OK;
2562}
2563
2564int op_matmul_id(struct htp_ops_context * octx) {
2565    htp_matmul_tensors_preamble;
2566
2567    struct htp_matmul_context mmctx_struct = {0};
2568    struct htp_matmul_context * mmctx = &mmctx_struct;
2569    mmctx->octx = octx;
2570
2571    struct htp_tensor * restrict ids = &octx->src2;
2572
2573    const size_t src0_row_size = nb01;
2574    const size_t dst_row_size  = nb1;
2575
2576    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
2577
2578    const uint32_t src0_nrows = ne01;  // per expert
2579    const uint32_t src1_nrows = ne11 * ne12 * ne13;
2580
2581    worker_callback_t quant_job_func;
2582    worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;
2583
2584    // Compute src0_nrows_per_thread
2585    mmctx->src0_nrows_per_thread  = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2586    mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
2587
2588    size_t src1_row_size;
2589    size_t src1_row_size_padded;
2590
2591    // row groups
2592    const int n_ids = ids->ne[0];  // n_expert_used
2593    const int n_as  = ne02;        // n_expert
2594
2595    size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
2596    size_t matrix_row_map_size    = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
2597
2598    if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
2599        return HTP_STATUS_NO_SUPPORT;
2600    }
2601
2602    quant_job_func = quantize_f32_q8x4x2;
2603    src1_row_size  = q8x4x2_row_size(ne10);
2604
2605    const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2606    htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
2607
2608    size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2609
2610    FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
2611         octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
2612
2613    FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type,
2614         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
2615         ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
2616         src1->data, dst->data);
2617
2618    // Make sure the reserved vtcm size is sufficient
2619    if (octx->ctx->vtcm_size < spad_size) {
2620        FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
2621        return HTP_STATUS_VTCM_TOO_SMALL;
2622    }
2623
2624    octx->src0_spad.data = octx->ctx->vtcm_base;
2625    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2626    octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2627    octx->dst_spad.data  = octx->src2_spad.data + octx->src2_spad.size;
2628
2629    octx->src0_spad.stride = src0_row_size_padded;
2630    octx->src1_spad.stride = src1_row_size;
2631
2632    if (src1_nrows > 1) {
2633        // initialize matrix_row_counts and map
2634        uint32_t *                matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
2635        struct mmid_row_mapping * matrix_rows       = (void *) octx->src2_spad.data + matrix_row_counts_size;
2636
2637        memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
2638
2639        // group rows by src0 matrix
2640        for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {  // token idx
2641            for (uint32_t id = 0; id < n_ids; ++id) {         // expert idx
2642                const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
2643
2644                assert(i02 >= 0 && i02 < n_as);
2645
2646                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
2647                matrix_row_counts[i02] += 1;
2648            }
2649        }
2650    }
2651
2652    // Setup worker pool callbacks
2653    if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
2654        const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2655        mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2656        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
2657    }
2658
2659    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2660        const uint32_t n_matmul_jobs = octx->n_threads;
2661        worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
2662    }
2663
2664    return HTP_STATUS_OK;
2665}