diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c | 4052 |
1 files changed, 4052 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c b/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c new file mode 100644 index 0000000..b390ab6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c | |||
| @@ -0,0 +1,4052 @@ | |||
| 1 | #define GGML_COMMON_IMPL_C | ||
| 2 | #include "ggml-common.h" | ||
| 3 | #include "ggml-quants.h" | ||
| 4 | #include "ggml-impl.h" | ||
| 5 | #include "ggml-cpu.h" | ||
| 6 | #include "simd-mappings.h" | ||
| 7 | |||
| 8 | #include "../../quants.h" | ||
| 9 | #include "../../ggml-cpu-impl.h" | ||
| 10 | |||
| 11 | #include <math.h> | ||
| 12 | #include <string.h> | ||
| 13 | #include <assert.h> | ||
| 14 | #include <float.h> | ||
| 15 | #include <stdlib.h> // for qsort | ||
| 16 | #include <stdio.h> // for GGML_ASSERT | ||
| 17 | |||
| 18 | #define GROUP_MAX_EPS 1e-15f | ||
| 19 | #define GROUP_MAX_EPS_IQ3_XXS 1e-8f | ||
| 20 | #define GROUP_MAX_EPS_IQ2_S 1e-8f | ||
| 21 | #define GROUP_MAX_EPS_IQ1_M 1e-7f | ||
| 22 | #define GROUP_MAX_EPS_IQ1_S 1e-12f | ||
| 23 | |||
| 24 | #define UNUSED GGML_UNUSED | ||
| 25 | |||
| 26 | #if defined(__ARM_NEON) | ||
| 27 | #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s | ||
| 28 | #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) | ||
| 29 | #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) | ||
| 30 | #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) | ||
| 31 | #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) | ||
| 32 | #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) | ||
| 33 | #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) | ||
| 34 | #define B8(c,s ) B7(c,s, c), B7(c,s, s) | ||
| 35 | |||
| 36 | // precomputed tables for expanding 8bits to 8 bytes: | ||
| 37 | static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 | ||
| 38 | static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 | ||
| 39 | #endif | ||
| 40 | |||
| 41 | void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { | ||
| 42 | assert(QK8_0 == 32); | ||
| 43 | assert(k % QK8_0 == 0); | ||
| 44 | const int nb = k / QK8_0; | ||
| 45 | |||
| 46 | block_q8_0 * GGML_RESTRICT y = vy; | ||
| 47 | |||
| 48 | #if defined(__ARM_NEON) | ||
| 49 | for (int i = 0; i < nb; i++) { | ||
| 50 | float32x4_t srcv [8]; | ||
| 51 | float32x4_t asrcv[8]; | ||
| 52 | float32x4_t amaxv[8]; | ||
| 53 | |||
| 54 | for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); | ||
| 55 | for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); | ||
| 56 | |||
| 57 | for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); | ||
| 58 | for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); | ||
| 59 | for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); | ||
| 60 | |||
| 61 | const float amax = vmaxvq_f32(amaxv[0]); | ||
| 62 | |||
| 63 | const float d = amax / ((1 << 7) - 1); | ||
| 64 | const float id = d ? 1.0f/d : 0.0f; | ||
| 65 | |||
| 66 | y[i].d = GGML_CPU_FP32_TO_FP16(d); | ||
| 67 | |||
| 68 | for (int j = 0; j < 8; j++) { | ||
| 69 | const float32x4_t v = vmulq_n_f32(srcv[j], id); | ||
| 70 | const int32x4_t vi = vcvtnq_s32_f32(v); | ||
| 71 | |||
| 72 | y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); | ||
| 73 | y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); | ||
| 74 | y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); | ||
| 75 | y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); | ||
| 76 | } | ||
| 77 | } | ||
| 78 | #else | ||
| 79 | GGML_UNUSED(nb); | ||
| 80 | // scalar | ||
| 81 | quantize_row_q8_0_ref(x, y, k); | ||
| 82 | #endif | ||
| 83 | } | ||
| 84 | |||
| 85 | void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { | ||
| 86 | assert(k % QK8_1 == 0); | ||
| 87 | const int nb = k / QK8_1; | ||
| 88 | |||
| 89 | block_q8_1 * GGML_RESTRICT y = vy; | ||
| 90 | #if defined(__ARM_NEON) | ||
| 91 | for (int i = 0; i < nb; i++) { | ||
| 92 | float32x4_t srcv [8]; | ||
| 93 | float32x4_t asrcv[8]; | ||
| 94 | float32x4_t amaxv[8]; | ||
| 95 | |||
| 96 | for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); | ||
| 97 | for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); | ||
| 98 | |||
| 99 | for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); | ||
| 100 | for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); | ||
| 101 | for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); | ||
| 102 | |||
| 103 | const float amax = vmaxvq_f32(amaxv[0]); | ||
| 104 | |||
| 105 | const float d = amax / ((1 << 7) - 1); | ||
| 106 | const float id = d ? 1.0f/d : 0.0f; | ||
| 107 | |||
| 108 | y[i].d = GGML_CPU_FP32_TO_FP16(d); | ||
| 109 | |||
| 110 | int32x4_t accv = vdupq_n_s32(0); | ||
| 111 | |||
| 112 | for (int j = 0; j < 8; j++) { | ||
| 113 | const float32x4_t v = vmulq_n_f32(srcv[j], id); | ||
| 114 | const int32x4_t vi = vcvtnq_s32_f32(v); | ||
| 115 | |||
| 116 | y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); | ||
| 117 | y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); | ||
| 118 | y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); | ||
| 119 | y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); | ||
| 120 | |||
| 121 | accv = vaddq_s32(accv, vi); | ||
| 122 | } | ||
| 123 | |||
| 124 | y[i].s = GGML_CPU_FP32_TO_FP16(d * vaddvq_s32(accv)); | ||
| 125 | } | ||
| 126 | #else | ||
| 127 | GGML_UNUSED(nb); | ||
| 128 | // scalar | ||
| 129 | quantize_row_q8_1_ref(x, y, k); | ||
| 130 | #endif | ||
| 131 | } | ||
| 132 | |||
| 133 | // placeholder implementation for Apple targets | ||
| 134 | void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { | ||
| 135 | quantize_row_q8_K_ref(x, y, k); | ||
| 136 | } | ||
| 137 | |||
| 138 | //===================================== Dot products ================================= | ||
| 139 | |||
| 140 | void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 141 | const int qk = QK8_0; | ||
| 142 | const int nb = n / qk; | ||
| 143 | |||
| 144 | assert(n % qk == 0); | ||
| 145 | #if defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 146 | assert((nrc == 2) || (nrc == 1)); | ||
| 147 | #else | ||
| 148 | assert(nrc == 1); | ||
| 149 | #endif | ||
| 150 | UNUSED(nrc); | ||
| 151 | UNUSED(bx); | ||
| 152 | UNUSED(by); | ||
| 153 | UNUSED(bs); | ||
| 154 | |||
| 155 | const block_q4_0 * GGML_RESTRICT x = vx; | ||
| 156 | const block_q8_0 * GGML_RESTRICT y = vy; | ||
| 157 | |||
| 158 | #if defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 159 | if (nrc == 2) { | ||
| 160 | const block_q4_0 * GGML_RESTRICT vx0 = vx; | ||
| 161 | const block_q4_0 * GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx); | ||
| 162 | const block_q8_0 * GGML_RESTRICT vy0 = vy; | ||
| 163 | const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); | ||
| 164 | |||
| 165 | float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||
| 166 | |||
| 167 | for (int i = 0; i < nb; i++) { | ||
| 168 | const block_q4_0 * GGML_RESTRICT b_x0 = &vx0[i]; | ||
| 169 | const block_q4_0 * GGML_RESTRICT b_x1 = &vx1[i]; | ||
| 170 | const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i]; | ||
| 171 | const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i]; | ||
| 172 | |||
| 173 | const uint8x16_t m4b = vdupq_n_u8(0x0F); | ||
| 174 | const int8x16_t s8b = vdupq_n_s8(0x8); | ||
| 175 | |||
| 176 | const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); | ||
| 177 | const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); | ||
| 178 | |||
| 179 | // 4-bit -> 8-bit | ||
| 180 | const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); | ||
| 181 | const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); | ||
| 182 | const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); | ||
| 183 | const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); | ||
| 184 | |||
| 185 | // sub 8 | ||
| 186 | const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); | ||
| 187 | const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); | ||
| 188 | const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); | ||
| 189 | const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); | ||
| 190 | |||
| 191 | // load y | ||
| 192 | const int8x16_t y0_l = vld1q_s8(b_y0->qs); | ||
| 193 | const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); | ||
| 194 | const int8x16_t y1_l = vld1q_s8(b_y1->qs); | ||
| 195 | const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); | ||
| 196 | |||
| 197 | float32_t _scale[4] = { | ||
| 198 | GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d), | ||
| 199 | GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d), | ||
| 200 | GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d), | ||
| 201 | GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d) | ||
| 202 | }; | ||
| 203 | float32x4_t scale = vld1q_f32(_scale); | ||
| 204 | |||
| 205 | int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); | ||
| 206 | int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); | ||
| 207 | |||
| 208 | int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); | ||
| 209 | int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); | ||
| 210 | |||
| 211 | int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); | ||
| 212 | int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); | ||
| 213 | |||
| 214 | int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); | ||
| 215 | int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); | ||
| 216 | |||
| 217 | sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), | ||
| 218 | l1, r1)), l2, r2)), l3, r3))), scale); | ||
| 219 | } | ||
| 220 | |||
| 221 | float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); | ||
| 222 | float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); | ||
| 223 | |||
| 224 | vst1_f32(s, vget_low_f32 (sumv2)); | ||
| 225 | vst1_f32(s + bs, vget_high_f32(sumv2)); | ||
| 226 | |||
| 227 | return; | ||
| 228 | } | ||
| 229 | #endif | ||
| 230 | |||
| 231 | int ib = 0; | ||
| 232 | float sumf = 0; | ||
| 233 | |||
| 234 | #if defined(__ARM_FEATURE_SVE) | ||
| 235 | svfloat32_t sumv0 = svdup_n_f32(0.0f); | ||
| 236 | svfloat32_t sumv1 = svdup_n_f32(0.0f); | ||
| 237 | |||
| 238 | const int vector_length = ggml_cpu_get_sve_cnt()*8; | ||
| 239 | |||
| 240 | // VLA Implementation using switch case | ||
| 241 | switch (vector_length) { | ||
| 242 | case 128: | ||
| 243 | { | ||
| 244 | // predicate for activating higher lanes for 4 float32 elements | ||
| 245 | const svbool_t ph4 = svptrue_pat_b32(SV_VL4); | ||
| 246 | |||
| 247 | for (; ib + 1 < nb; ib += 2) { | ||
| 248 | const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 249 | const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 250 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 251 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 252 | |||
| 253 | // load x | ||
| 254 | const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); | ||
| 255 | const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); | ||
| 256 | |||
| 257 | // 4-bit -> 8-bit | ||
| 258 | const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F)); | ||
| 259 | const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04)); | ||
| 260 | const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F)); | ||
| 261 | const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); | ||
| 262 | |||
| 263 | // sub 8 | ||
| 264 | const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); | ||
| 265 | const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); | ||
| 266 | const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); | ||
| 267 | const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); | ||
| 268 | |||
| 269 | // load y | ||
| 270 | const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); | ||
| 271 | const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16); | ||
| 272 | const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); | ||
| 273 | const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16); | ||
| 274 | |||
| 275 | // dot product | ||
| 276 | sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4, | ||
| 277 | svdot_s32(svdup_n_s32(0), qx0ls, qy0l), | ||
| 278 | svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 279 | sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4, | ||
| 280 | svdot_s32(svdup_n_s32(0), qx1ls, qy1l), | ||
| 281 | svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 282 | } | ||
| 283 | |||
| 284 | sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); | ||
| 285 | } break; | ||
| 286 | case 256: | ||
| 287 | { | ||
| 288 | // predicate for activating higher lanes for 16 int8 elements | ||
| 289 | const svbool_t ph16 = svptrue_pat_b8(SV_VL16); | ||
| 290 | // predicate for activating lower lanes for 16 int8 elements | ||
| 291 | const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); | ||
| 292 | |||
| 293 | for (; ib + 1 < nb; ib += 2) { | ||
| 294 | const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 295 | const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 296 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 297 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 298 | |||
| 299 | // load x | ||
| 300 | const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); | ||
| 301 | const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); | ||
| 302 | |||
| 303 | // 4-bit -> 8-bit | ||
| 304 | const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); | ||
| 305 | const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); | ||
| 306 | |||
| 307 | // sub 8 | ||
| 308 | const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); | ||
| 309 | const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); | ||
| 310 | |||
| 311 | // load y | ||
| 312 | const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); | ||
| 313 | const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); | ||
| 314 | |||
| 315 | // dot product | ||
| 316 | sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), | ||
| 317 | svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 318 | sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), | ||
| 319 | svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 320 | } | ||
| 321 | |||
| 322 | sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); | ||
| 323 | } break; | ||
| 324 | case 512: | ||
| 325 | { | ||
| 326 | // predicate for activating higher lanes for 32 int8 elements | ||
| 327 | const svbool_t ph32 = svptrue_pat_b8(SV_VL32); | ||
| 328 | |||
| 329 | // predicate for activating higher lanes for 16 int8 elements | ||
| 330 | const svbool_t ph16 = svptrue_pat_b8(SV_VL16); | ||
| 331 | // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes | ||
| 332 | const svbool_t pl16 = svnot_b_z(ph32, ph16); | ||
| 333 | |||
| 334 | for (; ib + 1 < nb; ib += 2) { | ||
| 335 | const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 336 | const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 337 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 338 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 339 | |||
| 340 | // load x | ||
| 341 | const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); | ||
| 342 | const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); | ||
| 343 | |||
| 344 | // 4-bit -> 8-bit | ||
| 345 | const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); | ||
| 346 | const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); | ||
| 347 | |||
| 348 | // sub 8 | ||
| 349 | const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8); | ||
| 350 | const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8); | ||
| 351 | |||
| 352 | // load y | ||
| 353 | const svint8_t qy0 = svld1_s8(ph32, y0->qs); | ||
| 354 | const svint8_t qy1 = svld1_s8(ph32, y1->qs); | ||
| 355 | |||
| 356 | // dot product | ||
| 357 | sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32, | ||
| 358 | svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 359 | sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32, | ||
| 360 | svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 361 | } | ||
| 362 | |||
| 363 | sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); | ||
| 364 | } break; | ||
| 365 | default: | ||
| 366 | assert(false && "Unsupported vector length"); | ||
| 367 | break; | ||
| 368 | } | ||
| 369 | |||
| 370 | #elif defined(__ARM_NEON) | ||
| 371 | float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||
| 372 | float32x4_t sumv1 = vdupq_n_f32(0.0f); | ||
| 373 | |||
| 374 | for (; ib + 1 < nb; ib += 2) { | ||
| 375 | const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 376 | const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 377 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 378 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 379 | |||
| 380 | const uint8x16_t m4b = vdupq_n_u8(0x0F); | ||
| 381 | const int8x16_t s8b = vdupq_n_s8(0x8); | ||
| 382 | |||
| 383 | const uint8x16_t v0_0 = vld1q_u8(x0->qs); | ||
| 384 | const uint8x16_t v0_1 = vld1q_u8(x1->qs); | ||
| 385 | |||
| 386 | // 4-bit -> 8-bit | ||
| 387 | const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); | ||
| 388 | const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); | ||
| 389 | const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); | ||
| 390 | const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); | ||
| 391 | |||
| 392 | // sub 8 | ||
| 393 | const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); | ||
| 394 | const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); | ||
| 395 | const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); | ||
| 396 | const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); | ||
| 397 | |||
| 398 | // load y | ||
| 399 | const int8x16_t v1_0l = vld1q_s8(y0->qs); | ||
| 400 | const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); | ||
| 401 | const int8x16_t v1_1l = vld1q_s8(y1->qs); | ||
| 402 | const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); | ||
| 403 | |||
| 404 | // dot product into int32x4_t | ||
| 405 | const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); | ||
| 406 | const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); | ||
| 407 | |||
| 408 | sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 409 | sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 410 | } | ||
| 411 | |||
| 412 | sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); | ||
| 413 | #endif | ||
| 414 | for (; ib < nb; ++ib) { | ||
| 415 | int sumi0 = 0; | ||
| 416 | int sumi1 = 0; | ||
| 417 | |||
| 418 | for (int j = 0; j < qk/2; ++j) { | ||
| 419 | const int v0 = (x[ib].qs[j] & 0x0F) - 8; | ||
| 420 | const int v1 = (x[ib].qs[j] >> 4) - 8; | ||
| 421 | |||
| 422 | sumi0 += (v0 * y[ib].qs[j]); | ||
| 423 | sumi1 += (v1 * y[ib].qs[j + qk/2]); | ||
| 424 | } | ||
| 425 | |||
| 426 | int sumi = sumi0 + sumi1; | ||
| 427 | sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d); | ||
| 428 | } | ||
| 429 | |||
| 430 | *s = sumf; | ||
| 431 | } | ||
| 432 | |||
| 433 | void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 434 | const int qk = QK8_1; | ||
| 435 | const int nb = n / qk; | ||
| 436 | |||
| 437 | assert(n % qk == 0); | ||
| 438 | #if defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 439 | assert((nrc == 2) || (nrc == 1)); | ||
| 440 | #else | ||
| 441 | assert(nrc == 1); | ||
| 442 | #endif | ||
| 443 | UNUSED(nrc); | ||
| 444 | UNUSED(bx); | ||
| 445 | UNUSED(by); | ||
| 446 | UNUSED(bs); | ||
| 447 | |||
| 448 | const block_q4_1 * GGML_RESTRICT x = vx; | ||
| 449 | const block_q8_1 * GGML_RESTRICT y = vy; | ||
| 450 | |||
| 451 | #if defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 452 | if (nrc == 2) { | ||
| 453 | const block_q4_1 * GGML_RESTRICT vx0 = vx; | ||
| 454 | const block_q4_1 * GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx); | ||
| 455 | const block_q8_1 * GGML_RESTRICT vy0 = vy; | ||
| 456 | const block_q8_1 * GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by); | ||
| 457 | |||
| 458 | float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||
| 459 | float32x4_t summs0 = vdupq_n_f32(0.0f); | ||
| 460 | |||
| 461 | for (int i = 0; i < nb; i++) { | ||
| 462 | const block_q4_1 * GGML_RESTRICT b_x0 = &vx0[i]; | ||
| 463 | const block_q4_1 * GGML_RESTRICT b_x1 = &vx1[i]; | ||
| 464 | const block_q8_1 * GGML_RESTRICT b_y0 = &vy0[i]; | ||
| 465 | const block_q8_1 * GGML_RESTRICT b_y1 = &vy1[i]; | ||
| 466 | |||
| 467 | float32_t summs_t[4] = { | ||
| 468 | GGML_CPU_FP16_TO_FP32(b_x0->m) * GGML_CPU_FP16_TO_FP32(b_y0->s), | ||
| 469 | GGML_CPU_FP16_TO_FP32(b_x1->m) * GGML_CPU_FP16_TO_FP32(b_y0->s), | ||
| 470 | GGML_CPU_FP16_TO_FP32(b_x0->m) * GGML_CPU_FP16_TO_FP32(b_y1->s), | ||
| 471 | GGML_CPU_FP16_TO_FP32(b_x1->m) * GGML_CPU_FP16_TO_FP32(b_y1->s) | ||
| 472 | }; | ||
| 473 | summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); | ||
| 474 | |||
| 475 | const uint8x16_t m4b = vdupq_n_u8(0x0F); | ||
| 476 | |||
| 477 | const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); | ||
| 478 | const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); | ||
| 479 | |||
| 480 | // 4-bit -> 8-bit | ||
| 481 | const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); | ||
| 482 | const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); | ||
| 483 | const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); | ||
| 484 | const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); | ||
| 485 | |||
| 486 | // load y | ||
| 487 | const int8x16_t y0_l = vld1q_s8(b_y0->qs); | ||
| 488 | const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); | ||
| 489 | const int8x16_t y1_l = vld1q_s8(b_y1->qs); | ||
| 490 | const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); | ||
| 491 | |||
| 492 | // mmla into int32x4_t | ||
| 493 | float32_t _scale[4] = { | ||
| 494 | GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d), | ||
| 495 | GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d), | ||
| 496 | GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d), | ||
| 497 | GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d) | ||
| 498 | }; | ||
| 499 | float32x4_t scale = vld1q_f32(_scale); | ||
| 500 | |||
| 501 | int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); | ||
| 502 | int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); | ||
| 503 | |||
| 504 | int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); | ||
| 505 | int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); | ||
| 506 | |||
| 507 | int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); | ||
| 508 | int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); | ||
| 509 | |||
| 510 | int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); | ||
| 511 | int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); | ||
| 512 | sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), | ||
| 513 | l1, r1)), l2, r2)), l3, r3))), scale); | ||
| 514 | } | ||
| 515 | |||
| 516 | float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); | ||
| 517 | float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); | ||
| 518 | |||
| 519 | sumv2 = vaddq_f32(sumv2, summs0); | ||
| 520 | |||
| 521 | vst1_f32(s, vget_low_f32 (sumv2)); | ||
| 522 | vst1_f32(s + bs, vget_high_f32(sumv2)); | ||
| 523 | |||
| 524 | return; | ||
| 525 | } | ||
| 526 | #endif | ||
| 527 | |||
| 528 | int ib = 0; | ||
| 529 | float sumf = 0; | ||
| 530 | |||
| 531 | #if defined(__ARM_NEON) | ||
| 532 | float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||
| 533 | float32x4_t sumv1 = vdupq_n_f32(0.0f); | ||
| 534 | |||
| 535 | float summs = 0; | ||
| 536 | |||
| 537 | for (; ib + 1 < nb; ib += 2) { | ||
| 538 | const block_q4_1 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 539 | const block_q4_1 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 540 | const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 541 | const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 542 | |||
| 543 | summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s) + GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s); | ||
| 544 | |||
| 545 | const uint8x16_t m4b = vdupq_n_u8(0x0F); | ||
| 546 | |||
| 547 | const uint8x16_t v0_0 = vld1q_u8(x0->qs); | ||
| 548 | const uint8x16_t v0_1 = vld1q_u8(x1->qs); | ||
| 549 | |||
| 550 | // 4-bit -> 8-bit | ||
| 551 | const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); | ||
| 552 | const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); | ||
| 553 | const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); | ||
| 554 | const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); | ||
| 555 | |||
| 556 | // load y | ||
| 557 | const int8x16_t v1_0l = vld1q_s8(y0->qs); | ||
| 558 | const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); | ||
| 559 | const int8x16_t v1_1l = vld1q_s8(y1->qs); | ||
| 560 | const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); | ||
| 561 | |||
| 562 | // dot product into int32x4_t | ||
| 563 | const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); | ||
| 564 | const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); | ||
| 565 | |||
| 566 | sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 567 | sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 568 | } | ||
| 569 | |||
| 570 | sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; | ||
| 571 | |||
| 572 | #endif | ||
| 573 | for (; ib < nb; ++ib) { | ||
| 574 | int sumi0 = 0; | ||
| 575 | int sumi1 = 0; | ||
| 576 | |||
| 577 | for (int j = 0; j < qk/2; ++j) { | ||
| 578 | const int v0 = (x[ib].qs[j] & 0x0F); | ||
| 579 | const int v1 = (x[ib].qs[j] >> 4); | ||
| 580 | |||
| 581 | sumi0 += (v0 * y[ib].qs[j]); | ||
| 582 | sumi1 += (v1 * y[ib].qs[j + qk/2]); | ||
| 583 | } | ||
| 584 | |||
| 585 | int sumi = sumi0 + sumi1; | ||
| 586 | sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s); | ||
| 587 | } | ||
| 588 | |||
| 589 | *s = sumf; | ||
| 590 | } | ||
| 591 | |||
| 592 | void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 593 | assert(nrc == 1); | ||
| 594 | UNUSED(nrc); | ||
| 595 | UNUSED(bx); | ||
| 596 | UNUSED(by); | ||
| 597 | UNUSED(bs); | ||
| 598 | assert(n % QK_MXFP4 == 0); | ||
| 599 | static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); | ||
| 600 | |||
| 601 | const block_mxfp4 * GGML_RESTRICT x = vx; | ||
| 602 | const block_q8_0 * GGML_RESTRICT y = vy; | ||
| 603 | |||
| 604 | const int nb = n / QK_MXFP4; | ||
| 605 | |||
| 606 | int ib = 0; | ||
| 607 | float sumf = 0; | ||
| 608 | |||
| 609 | #if defined __ARM_NEON | ||
| 610 | const int8x16_t values = vld1q_s8(kvalues_mxfp4); | ||
| 611 | const uint8x16_t m4b = vdupq_n_u8(0x0f); | ||
| 612 | uint8x16x2_t q4bits; | ||
| 613 | int8x16x4_t q4b; | ||
| 614 | int8x16x4_t q8b; | ||
| 615 | int32x4_t prod_1; | ||
| 616 | int32x4_t prod_2; | ||
| 617 | |||
| 618 | for (; ib + 1 < nb; ib += 2) { | ||
| 619 | q4bits.val[0] = vld1q_u8(x[ib + 0].qs); | ||
| 620 | q4bits.val[1] = vld1q_u8(x[ib + 1].qs); | ||
| 621 | q8b.val[0] = vld1q_s8(y[ib + 0].qs); | ||
| 622 | q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); | ||
| 623 | q8b.val[2] = vld1q_s8(y[ib + 1].qs); | ||
| 624 | q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); | ||
| 625 | |||
| 626 | q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); | ||
| 627 | q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); | ||
| 628 | q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); | ||
| 629 | q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); | ||
| 630 | |||
| 631 | prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); | ||
| 632 | prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); | ||
| 633 | |||
| 634 | sumf += | ||
| 635 | GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + | ||
| 636 | GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); | ||
| 637 | } | ||
| 638 | |||
| 639 | #endif | ||
| 640 | for (; ib < nb; ++ib) { | ||
| 641 | const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); | ||
| 642 | int sumi1 = 0; | ||
| 643 | int sumi2 = 0; | ||
| 644 | for (int j = 0; j < QK_MXFP4/2; ++j) { | ||
| 645 | sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf]; | ||
| 646 | sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4]; | ||
| 647 | } | ||
| 648 | sumf += d * (sumi1 + sumi2); | ||
| 649 | } | ||
| 650 | *s = sumf; | ||
| 651 | } | ||
| 652 | |||
| 653 | void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 654 | const int qk = QK8_0; | ||
| 655 | const int nb = n / qk; | ||
| 656 | |||
| 657 | int ib = 0; | ||
| 658 | float sumf = 0; | ||
| 659 | |||
| 660 | assert(n % qk == 0); | ||
| 661 | assert(qk == QK5_0); | ||
| 662 | assert(nrc == 1); | ||
| 663 | UNUSED(nrc); | ||
| 664 | UNUSED(bx); | ||
| 665 | UNUSED(by); | ||
| 666 | UNUSED(bs); | ||
| 667 | |||
| 668 | const block_q5_0 * GGML_RESTRICT x = vx; | ||
| 669 | const block_q8_0 * GGML_RESTRICT y = vy; | ||
| 670 | |||
| 671 | #if defined(__ARM_NEON) | ||
| 672 | float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||
| 673 | float32x4_t sumv1 = vdupq_n_f32(0.0f); | ||
| 674 | |||
| 675 | uint32_t qh0; | ||
| 676 | uint32_t qh1; | ||
| 677 | |||
| 678 | uint64_t tmp0[4]; | ||
| 679 | uint64_t tmp1[4]; | ||
| 680 | |||
| 681 | for (; ib + 1 < nb; ib += 2) { | ||
| 682 | const block_q5_0 * GGML_RESTRICT x0 = &x[ib]; | ||
| 683 | const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 684 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; | ||
| 685 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 686 | |||
| 687 | const uint8x16_t m4b = vdupq_n_u8(0x0F); | ||
| 688 | |||
| 689 | // extract the 5th bit via lookup table ((!b) << 4) | ||
| 690 | memcpy(&qh0, x0->qh, sizeof(qh0)); | ||
| 691 | memcpy(&qh1, x1->qh, sizeof(qh1)); | ||
| 692 | |||
| 693 | tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; | ||
| 694 | tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; | ||
| 695 | tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; | ||
| 696 | tmp0[3] = table_b2b_1[(qh0 >> 24) ]; | ||
| 697 | |||
| 698 | tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; | ||
| 699 | tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; | ||
| 700 | tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; | ||
| 701 | tmp1[3] = table_b2b_1[(qh1 >> 24) ]; | ||
| 702 | |||
| 703 | const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); | ||
| 704 | const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); | ||
| 705 | const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); | ||
| 706 | const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); | ||
| 707 | |||
| 708 | const uint8x16_t v0_0 = vld1q_u8(x0->qs); | ||
| 709 | const uint8x16_t v0_1 = vld1q_u8(x1->qs); | ||
| 710 | |||
| 711 | // 4-bit -> 8-bit | ||
| 712 | int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); | ||
| 713 | int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); | ||
| 714 | int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); | ||
| 715 | int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); | ||
| 716 | |||
| 717 | // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) | ||
| 718 | const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); | ||
| 719 | const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); | ||
| 720 | const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); | ||
| 721 | const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); | ||
| 722 | |||
| 723 | // load y | ||
| 724 | const int8x16_t v1_0l = vld1q_s8(y0->qs); | ||
| 725 | const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); | ||
| 726 | const int8x16_t v1_1l = vld1q_s8(y1->qs); | ||
| 727 | const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); | ||
| 728 | |||
| 729 | sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( | ||
| 730 | ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), | ||
| 731 | ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 732 | sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( | ||
| 733 | ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), | ||
| 734 | ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 735 | } | ||
| 736 | |||
| 737 | sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); | ||
| 738 | |||
| 739 | #endif | ||
| 740 | for (; ib < nb; ++ib) { | ||
| 741 | uint32_t qh; | ||
| 742 | memcpy(&qh, x[ib].qh, sizeof(qh)); | ||
| 743 | |||
| 744 | int sumi0 = 0; | ||
| 745 | int sumi1 = 0; | ||
| 746 | |||
| 747 | for (int j = 0; j < qk/2; ++j) { | ||
| 748 | const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; | ||
| 749 | const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); | ||
| 750 | |||
| 751 | const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); | ||
| 752 | const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); | ||
| 753 | |||
| 754 | sumi0 += (x0 * y[ib].qs[j]); | ||
| 755 | sumi1 += (x1 * y[ib].qs[j + qk/2]); | ||
| 756 | } | ||
| 757 | |||
| 758 | int sumi = sumi0 + sumi1; | ||
| 759 | sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi; | ||
| 760 | } | ||
| 761 | |||
| 762 | *s = sumf; | ||
| 763 | } | ||
| 764 | |||
| 765 | void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 766 | const int qk = QK8_1; | ||
| 767 | const int nb = n / qk; | ||
| 768 | |||
| 769 | int ib = 0; | ||
| 770 | float sumf = 0; | ||
| 771 | |||
| 772 | assert(n % qk == 0); | ||
| 773 | assert(qk == QK5_1); | ||
| 774 | assert(nrc == 1); | ||
| 775 | UNUSED(nrc); | ||
| 776 | UNUSED(bx); | ||
| 777 | UNUSED(by); | ||
| 778 | UNUSED(bs); | ||
| 779 | |||
| 780 | const block_q5_1 * GGML_RESTRICT x = vx; | ||
| 781 | const block_q8_1 * GGML_RESTRICT y = vy; | ||
| 782 | |||
| 783 | #if defined(__ARM_NEON) | ||
| 784 | float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||
| 785 | float32x4_t sumv1 = vdupq_n_f32(0.0f); | ||
| 786 | |||
| 787 | float summs0 = 0.0f; | ||
| 788 | float summs1 = 0.0f; | ||
| 789 | |||
| 790 | uint32_t qh0; | ||
| 791 | uint32_t qh1; | ||
| 792 | |||
| 793 | uint64_t tmp0[4]; | ||
| 794 | uint64_t tmp1[4]; | ||
| 795 | |||
| 796 | for (; ib + 1 < nb; ib += 2) { | ||
| 797 | const block_q5_1 * GGML_RESTRICT x0 = &x[ib]; | ||
| 798 | const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 799 | const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; | ||
| 800 | const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 801 | |||
| 802 | const uint8x16_t m4b = vdupq_n_u8(0x0F); | ||
| 803 | |||
| 804 | summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); | ||
| 805 | summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s); | ||
| 806 | |||
| 807 | // extract the 5th bit via lookup table ((b) << 4) | ||
| 808 | memcpy(&qh0, x0->qh, sizeof(qh0)); | ||
| 809 | memcpy(&qh1, x1->qh, sizeof(qh1)); | ||
| 810 | |||
| 811 | tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; | ||
| 812 | tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; | ||
| 813 | tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; | ||
| 814 | tmp0[3] = table_b2b_0[(qh0 >> 24) ]; | ||
| 815 | |||
| 816 | tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; | ||
| 817 | tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; | ||
| 818 | tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; | ||
| 819 | tmp1[3] = table_b2b_0[(qh1 >> 24) ]; | ||
| 820 | |||
| 821 | const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); | ||
| 822 | const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); | ||
| 823 | const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); | ||
| 824 | const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); | ||
| 825 | |||
| 826 | const uint8x16_t v0_0 = vld1q_u8(x0->qs); | ||
| 827 | const uint8x16_t v0_1 = vld1q_u8(x1->qs); | ||
| 828 | |||
| 829 | // 4-bit -> 8-bit | ||
| 830 | const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); | ||
| 831 | const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); | ||
| 832 | const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); | ||
| 833 | const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); | ||
| 834 | |||
| 835 | // add high bit | ||
| 836 | const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); | ||
| 837 | const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); | ||
| 838 | const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); | ||
| 839 | const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); | ||
| 840 | |||
| 841 | // load y | ||
| 842 | const int8x16_t v1_0l = vld1q_s8(y0->qs); | ||
| 843 | const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); | ||
| 844 | const int8x16_t v1_1l = vld1q_s8(y1->qs); | ||
| 845 | const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); | ||
| 846 | |||
| 847 | sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( | ||
| 848 | ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), | ||
| 849 | ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 850 | sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( | ||
| 851 | ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), | ||
| 852 | ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 853 | } | ||
| 854 | |||
| 855 | sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; | ||
| 856 | |||
| 857 | #endif | ||
| 858 | for (; ib < nb; ++ib) { | ||
| 859 | uint32_t qh; | ||
| 860 | memcpy(&qh, x[ib].qh, sizeof(qh)); | ||
| 861 | |||
| 862 | int sumi0 = 0; | ||
| 863 | int sumi1 = 0; | ||
| 864 | |||
| 865 | for (int j = 0; j < qk/2; ++j) { | ||
| 866 | const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; | ||
| 867 | const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; | ||
| 868 | |||
| 869 | const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; | ||
| 870 | const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; | ||
| 871 | |||
| 872 | sumi0 += (x0 * y[ib].qs[j]); | ||
| 873 | sumi1 += (x1 * y[ib].qs[j + qk/2]); | ||
| 874 | } | ||
| 875 | |||
| 876 | int sumi = sumi0 + sumi1; | ||
| 877 | sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s); | ||
| 878 | } | ||
| 879 | |||
| 880 | *s = sumf; | ||
| 881 | } | ||
| 882 | |||
| 883 | void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 884 | const int qk = QK8_0; | ||
| 885 | const int nb = n / qk; | ||
| 886 | |||
| 887 | assert(n % qk == 0); | ||
| 888 | #if defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 889 | assert((nrc == 2) || (nrc == 1)); | ||
| 890 | #else | ||
| 891 | assert(nrc == 1); | ||
| 892 | #endif | ||
| 893 | UNUSED(nrc); | ||
| 894 | UNUSED(bx); | ||
| 895 | UNUSED(by); | ||
| 896 | UNUSED(bs); | ||
| 897 | |||
| 898 | const block_q8_0 * GGML_RESTRICT x = vx; | ||
| 899 | const block_q8_0 * GGML_RESTRICT y = vy; | ||
| 900 | |||
| 901 | #if defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 902 | if (nrc == 2) { | ||
| 903 | const block_q8_0 * GGML_RESTRICT vx0 = vx; | ||
| 904 | const block_q8_0 * GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx); | ||
| 905 | const block_q8_0 * GGML_RESTRICT vy0 = vy; | ||
| 906 | const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); | ||
| 907 | |||
| 908 | float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||
| 909 | |||
| 910 | for (int i = 0; i < nb; i++) { | ||
| 911 | const block_q8_0 * GGML_RESTRICT b_x0 = &vx0[i]; | ||
| 912 | const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i]; | ||
| 913 | |||
| 914 | const block_q8_0 * GGML_RESTRICT b_x1 = &vx1[i]; | ||
| 915 | const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i]; | ||
| 916 | |||
| 917 | const int8x16_t x0_l = vld1q_s8(b_x0->qs); | ||
| 918 | const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); | ||
| 919 | const int8x16_t x1_l = vld1q_s8(b_x1->qs); | ||
| 920 | const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); | ||
| 921 | |||
| 922 | // load y | ||
| 923 | const int8x16_t y0_l = vld1q_s8(b_y0->qs); | ||
| 924 | const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); | ||
| 925 | const int8x16_t y1_l = vld1q_s8(b_y1->qs); | ||
| 926 | const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); | ||
| 927 | |||
| 928 | float32_t _scale[4] = { | ||
| 929 | GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d), | ||
| 930 | GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d), | ||
| 931 | GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d), | ||
| 932 | GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d) | ||
| 933 | }; | ||
| 934 | float32x4_t scale = vld1q_f32(_scale); | ||
| 935 | |||
| 936 | int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); | ||
| 937 | int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); | ||
| 938 | |||
| 939 | int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); | ||
| 940 | int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); | ||
| 941 | |||
| 942 | int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); | ||
| 943 | int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); | ||
| 944 | |||
| 945 | int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); | ||
| 946 | int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); | ||
| 947 | |||
| 948 | sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), | ||
| 949 | l1, r1)), l2, r2)), l3, r3))), scale); | ||
| 950 | } | ||
| 951 | |||
| 952 | float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); | ||
| 953 | float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); | ||
| 954 | |||
| 955 | vst1_f32(s, vget_low_f32 (sumv2)); | ||
| 956 | vst1_f32(s + bs, vget_high_f32(sumv2)); | ||
| 957 | |||
| 958 | return; | ||
| 959 | } | ||
| 960 | #endif | ||
| 961 | |||
| 962 | int ib = 0; | ||
| 963 | float sumf = 0; | ||
| 964 | |||
| 965 | #if defined(__ARM_FEATURE_SVE) | ||
| 966 | svfloat32_t sumv0 = svdup_n_f32(0.0f); | ||
| 967 | svfloat32_t sumv1 = svdup_n_f32(0.0f); | ||
| 968 | |||
| 969 | const int vector_length = ggml_cpu_get_sve_cnt()*8; | ||
| 970 | |||
| 971 | //VLA Implemenation for SVE | ||
| 972 | switch (vector_length) { | ||
| 973 | case 128: | ||
| 974 | { | ||
| 975 | // predicate for activating lanes for 16 Int8 elements | ||
| 976 | const svbool_t ph16 = svptrue_pat_b8 (SV_VL16); | ||
| 977 | const svbool_t pl16 = svptrue_pat_b32(SV_VL4); | ||
| 978 | |||
| 979 | for (; ib + 1 < nb; ib += 2) { | ||
| 980 | const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 981 | const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 982 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 983 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 984 | |||
| 985 | // load x | ||
| 986 | const svint8_t qx0_0 = svld1_s8(ph16, x0->qs); | ||
| 987 | const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16); | ||
| 988 | const svint8_t qx1_0 = svld1_s8(ph16, x1->qs); | ||
| 989 | const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16); | ||
| 990 | |||
| 991 | // load y | ||
| 992 | const svint8_t qy0_0 = svld1_s8(ph16, y0->qs); | ||
| 993 | const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16); | ||
| 994 | const svint8_t qy1_0 = svld1_s8(ph16, y1->qs); | ||
| 995 | const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16); | ||
| 996 | |||
| 997 | sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16, | ||
| 998 | svdot_s32(svdup_n_s32(0), qx0_0, qy0_0), | ||
| 999 | svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 1000 | sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16, | ||
| 1001 | svdot_s32(svdup_n_s32(0), qx1_0, qy1_0), | ||
| 1002 | svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 1003 | } | ||
| 1004 | |||
| 1005 | sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1)); | ||
| 1006 | } break; | ||
| 1007 | case 256: | ||
| 1008 | { | ||
| 1009 | //printf("sve256"); | ||
| 1010 | for (; ib + 1 < nb; ib += 2) { | ||
| 1011 | const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 1012 | const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 1013 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 1014 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 1015 | |||
| 1016 | // load x | ||
| 1017 | const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); | ||
| 1018 | const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); | ||
| 1019 | |||
| 1020 | // load y | ||
| 1021 | const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); | ||
| 1022 | const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); | ||
| 1023 | |||
| 1024 | sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), | ||
| 1025 | svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 1026 | sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), | ||
| 1027 | svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 1028 | } | ||
| 1029 | |||
| 1030 | sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); | ||
| 1031 | } break; | ||
| 1032 | case 512: | ||
| 1033 | { | ||
| 1034 | // predicate for activating high 256 bit | ||
| 1035 | const svbool_t ph32 = svptrue_pat_b8(SV_VL32); | ||
| 1036 | // predicate for activating low 256 bit | ||
| 1037 | const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32); | ||
| 1038 | |||
| 1039 | // predicate for activating high lanes for 8 float32 elements | ||
| 1040 | const svbool_t ph8 = svptrue_pat_b32(SV_VL8); | ||
| 1041 | // predicate for activating low lanes for 8 float32 elements | ||
| 1042 | const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8); | ||
| 1043 | |||
| 1044 | svfloat32_t sumv00 = svdup_n_f32(0.0f); | ||
| 1045 | |||
| 1046 | for (; ib + 1 < nb; ib += 2) { | ||
| 1047 | const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 1048 | const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 1049 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 1050 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 1051 | |||
| 1052 | //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits | ||
| 1053 | // and add them to make one 64 element vector | ||
| 1054 | // load x | ||
| 1055 | const svint8_t qx_32 = svld1_s8(ph32, x0->qs); | ||
| 1056 | svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2); | ||
| 1057 | |||
| 1058 | qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); | ||
| 1059 | |||
| 1060 | // load y | ||
| 1061 | const svint8_t qy_32 = svld1_s8(ph32, y0->qs); | ||
| 1062 | svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2); | ||
| 1063 | |||
| 1064 | qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); | ||
| 1065 | |||
| 1066 | // scale creation | ||
| 1067 | const float32_t deq1 = GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d); | ||
| 1068 | const float32_t deq2 = GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d); | ||
| 1069 | |||
| 1070 | // duplicate deq1 in first half of vector and deq2 in second half of vector | ||
| 1071 | const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2); | ||
| 1072 | |||
| 1073 | const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); | ||
| 1074 | |||
| 1075 | sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp); | ||
| 1076 | } | ||
| 1077 | |||
| 1078 | sumf = svaddv_f32(svptrue_b32(), sumv00); | ||
| 1079 | break; | ||
| 1080 | } | ||
| 1081 | default: | ||
| 1082 | assert(false && "Unsupported vector length"); | ||
| 1083 | break; | ||
| 1084 | } | ||
| 1085 | #elif defined(__ARM_NEON) | ||
| 1086 | float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||
| 1087 | float32x4_t sumv1 = vdupq_n_f32(0.0f); | ||
| 1088 | |||
| 1089 | for (; ib + 1 < nb; ib += 2) { | ||
| 1090 | const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; | ||
| 1091 | const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 1092 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; | ||
| 1093 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 1094 | |||
| 1095 | const int8x16_t x0_0 = vld1q_s8(x0->qs); | ||
| 1096 | const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); | ||
| 1097 | const int8x16_t x1_0 = vld1q_s8(x1->qs); | ||
| 1098 | const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); | ||
| 1099 | |||
| 1100 | // load y | ||
| 1101 | const int8x16_t y0_0 = vld1q_s8(y0->qs); | ||
| 1102 | const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); | ||
| 1103 | const int8x16_t y1_0 = vld1q_s8(y1->qs); | ||
| 1104 | const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); | ||
| 1105 | |||
| 1106 | sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( | ||
| 1107 | ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), | ||
| 1108 | ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d)); | ||
| 1109 | |||
| 1110 | sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( | ||
| 1111 | ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), | ||
| 1112 | ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d)); | ||
| 1113 | } | ||
| 1114 | |||
| 1115 | sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); | ||
| 1116 | #endif | ||
| 1117 | for (; ib < nb; ++ib) { | ||
| 1118 | int sumi = 0; | ||
| 1119 | |||
| 1120 | for (int j = 0; j < qk; j++) { | ||
| 1121 | sumi += x[ib].qs[j]*y[ib].qs[j]; | ||
| 1122 | } | ||
| 1123 | |||
| 1124 | sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)); | ||
| 1125 | } | ||
| 1126 | |||
| 1127 | *s = sumf; | ||
| 1128 | } | ||
| 1129 | |||
| 1130 | void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 1131 | assert(nrc == 1); | ||
| 1132 | UNUSED(nrc); | ||
| 1133 | UNUSED(bx); | ||
| 1134 | UNUSED(by); | ||
| 1135 | UNUSED(bs); | ||
| 1136 | |||
| 1137 | const block_tq1_0 * GGML_RESTRICT x = vx; | ||
| 1138 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 1139 | |||
| 1140 | const int nb = n / QK_K; | ||
| 1141 | |||
| 1142 | #if defined(__ARM_NEON) | ||
| 1143 | float sumf = 0.0f; | ||
| 1144 | |||
| 1145 | uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; | ||
| 1146 | |||
| 1147 | const uint8x16_t shift = vld1q_u8(k_shift); | ||
| 1148 | |||
| 1149 | for (int i = 0; i < nb; ++i) { | ||
| 1150 | #if defined(__ARM_FEATURE_DOTPROD) | ||
| 1151 | int32x4_t sumi0 = vdupq_n_s32(0); | ||
| 1152 | int32x4_t sumi1 = vdupq_n_s32(0); | ||
| 1153 | #else | ||
| 1154 | int16x8_t sumi0 = vdupq_n_s16(0); | ||
| 1155 | int16x8_t sumi1 = vdupq_n_s16(0); | ||
| 1156 | #endif | ||
| 1157 | |||
| 1158 | // first 32 bytes of 5 elements | ||
| 1159 | { | ||
| 1160 | uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); | ||
| 1161 | uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); | ||
| 1162 | uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); | ||
| 1163 | uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); | ||
| 1164 | uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); | ||
| 1165 | uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); | ||
| 1166 | uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); | ||
| 1167 | uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); | ||
| 1168 | uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); | ||
| 1169 | uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); | ||
| 1170 | |||
| 1171 | // multiply by 3 and keep the 2 bits above 8 bits | ||
| 1172 | int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); | ||
| 1173 | int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); | ||
| 1174 | int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); | ||
| 1175 | int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); | ||
| 1176 | int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); | ||
| 1177 | int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); | ||
| 1178 | int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); | ||
| 1179 | int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); | ||
| 1180 | int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); | ||
| 1181 | int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); | ||
| 1182 | |||
| 1183 | const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); | ||
| 1184 | const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); | ||
| 1185 | const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); | ||
| 1186 | const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); | ||
| 1187 | const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); | ||
| 1188 | const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); | ||
| 1189 | const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); | ||
| 1190 | const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); | ||
| 1191 | const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); | ||
| 1192 | const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); | ||
| 1193 | |||
| 1194 | #if defined(__ARM_FEATURE_DOTPROD) | ||
| 1195 | sumi0 = vdotq_s32(sumi0, sqx0, qy0); | ||
| 1196 | sumi1 = vdotq_s32(sumi1, sqx1, qy1); | ||
| 1197 | sumi0 = vdotq_s32(sumi0, sqx2, qy2); | ||
| 1198 | sumi1 = vdotq_s32(sumi1, sqx3, qy3); | ||
| 1199 | sumi0 = vdotq_s32(sumi0, sqx4, qy4); | ||
| 1200 | sumi1 = vdotq_s32(sumi1, sqx5, qy5); | ||
| 1201 | sumi0 = vdotq_s32(sumi0, sqx6, qy6); | ||
| 1202 | sumi1 = vdotq_s32(sumi1, sqx7, qy7); | ||
| 1203 | sumi0 = vdotq_s32(sumi0, sqx8, qy8); | ||
| 1204 | sumi1 = vdotq_s32(sumi1, sqx9, qy9); | ||
| 1205 | #else | ||
| 1206 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); | ||
| 1207 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); | ||
| 1208 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); | ||
| 1209 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); | ||
| 1210 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); | ||
| 1211 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); | ||
| 1212 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); | ||
| 1213 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); | ||
| 1214 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); | ||
| 1215 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); | ||
| 1216 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); | ||
| 1217 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); | ||
| 1218 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); | ||
| 1219 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); | ||
| 1220 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); | ||
| 1221 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); | ||
| 1222 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); | ||
| 1223 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); | ||
| 1224 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); | ||
| 1225 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); | ||
| 1226 | #endif | ||
| 1227 | } | ||
| 1228 | |||
| 1229 | // last 16 bytes of 5-element, along with the 4 bytes of 4 elements | ||
| 1230 | { | ||
| 1231 | uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); | ||
| 1232 | uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); | ||
| 1233 | uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); | ||
| 1234 | uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); | ||
| 1235 | uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); | ||
| 1236 | uint32_t qh; | ||
| 1237 | memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned | ||
| 1238 | uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); | ||
| 1239 | qx5 = vmulq_u8(qx5, shift); | ||
| 1240 | |||
| 1241 | // multiply by 3 and keep the 2 bits above 8 bits | ||
| 1242 | int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); | ||
| 1243 | int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); | ||
| 1244 | int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); | ||
| 1245 | int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); | ||
| 1246 | int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); | ||
| 1247 | int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); | ||
| 1248 | |||
| 1249 | const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); | ||
| 1250 | const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); | ||
| 1251 | const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); | ||
| 1252 | const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); | ||
| 1253 | const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); | ||
| 1254 | const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); | ||
| 1255 | |||
| 1256 | #if defined(__ARM_FEATURE_DOTPROD) | ||
| 1257 | sumi0 = vdotq_s32(sumi0, sqx0, qy0); | ||
| 1258 | sumi1 = vdotq_s32(sumi1, sqx1, qy1); | ||
| 1259 | sumi0 = vdotq_s32(sumi0, sqx2, qy2); | ||
| 1260 | sumi1 = vdotq_s32(sumi1, sqx3, qy3); | ||
| 1261 | sumi0 = vdotq_s32(sumi0, sqx4, qy4); | ||
| 1262 | sumi1 = vdotq_s32(sumi1, sqx5, qy5); | ||
| 1263 | #else | ||
| 1264 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); | ||
| 1265 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); | ||
| 1266 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); | ||
| 1267 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); | ||
| 1268 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); | ||
| 1269 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); | ||
| 1270 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); | ||
| 1271 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); | ||
| 1272 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); | ||
| 1273 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); | ||
| 1274 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); | ||
| 1275 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); | ||
| 1276 | #endif | ||
| 1277 | } | ||
| 1278 | |||
| 1279 | const int16x8_t ysum0 = vld1q_s16(y[i].bsums); | ||
| 1280 | const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); | ||
| 1281 | |||
| 1282 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 1283 | |||
| 1284 | #if defined(__ARM_FEATURE_DOTPROD) | ||
| 1285 | sumi0 = vaddq_s32(sumi0, sumi1); | ||
| 1286 | sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); | ||
| 1287 | |||
| 1288 | sumf += d * (float) vaddvq_s32(sumi0); | ||
| 1289 | #else | ||
| 1290 | sumi0 = vaddq_s16(sumi0, sumi1); | ||
| 1291 | sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); | ||
| 1292 | |||
| 1293 | sumf += d * (float) vaddlvq_s16(sumi0); | ||
| 1294 | #endif | ||
| 1295 | } | ||
| 1296 | |||
| 1297 | *s = sumf; | ||
| 1298 | |||
| 1299 | #else | ||
| 1300 | UNUSED(x); | ||
| 1301 | UNUSED(y); | ||
| 1302 | UNUSED(nb); | ||
| 1303 | ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 1304 | #endif | ||
| 1305 | } | ||
| 1306 | |||
| 1307 | void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 1308 | assert(nrc == 1); | ||
| 1309 | UNUSED(nrc); | ||
| 1310 | UNUSED(bx); | ||
| 1311 | UNUSED(by); | ||
| 1312 | UNUSED(bs); | ||
| 1313 | |||
| 1314 | const block_tq2_0 * GGML_RESTRICT x = vx; | ||
| 1315 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 1316 | |||
| 1317 | const int nb = n / QK_K; | ||
| 1318 | |||
| 1319 | #if defined(__ARM_NEON) | ||
| 1320 | float sumf = 0.0f; | ||
| 1321 | |||
| 1322 | const uint8x16_t m3 = vdupq_n_u8(3); | ||
| 1323 | |||
| 1324 | for (int i = 0; i < nb; ++i) { | ||
| 1325 | #if defined(__ARM_FEATURE_DOTPROD) | ||
| 1326 | int32x4_t sumi0 = vdupq_n_s32(0); | ||
| 1327 | int32x4_t sumi1 = vdupq_n_s32(0); | ||
| 1328 | #else | ||
| 1329 | int16x8_t sumi0 = vdupq_n_s16(0); | ||
| 1330 | int16x8_t sumi1 = vdupq_n_s16(0); | ||
| 1331 | #endif | ||
| 1332 | |||
| 1333 | for (size_t j = 0; j < sizeof(x->qs); j += 32) { | ||
| 1334 | uint8x16_t qx0 = vld1q_u8(x[i].qs + j); | ||
| 1335 | uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); | ||
| 1336 | uint8x16_t qx2 = vshrq_n_u8(qx0, 2); | ||
| 1337 | uint8x16_t qx3 = vshrq_n_u8(qx1, 2); | ||
| 1338 | uint8x16_t qx4 = vshrq_n_u8(qx0, 4); | ||
| 1339 | uint8x16_t qx5 = vshrq_n_u8(qx1, 4); | ||
| 1340 | uint8x16_t qx6 = vshrq_n_u8(qx0, 6); | ||
| 1341 | uint8x16_t qx7 = vshrq_n_u8(qx1, 6); | ||
| 1342 | |||
| 1343 | int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); | ||
| 1344 | int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); | ||
| 1345 | int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); | ||
| 1346 | int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); | ||
| 1347 | int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); | ||
| 1348 | int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); | ||
| 1349 | int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); | ||
| 1350 | int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); | ||
| 1351 | |||
| 1352 | const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); | ||
| 1353 | const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); | ||
| 1354 | const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); | ||
| 1355 | const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); | ||
| 1356 | const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); | ||
| 1357 | const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); | ||
| 1358 | const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); | ||
| 1359 | const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); | ||
| 1360 | |||
| 1361 | #if defined(__ARM_FEATURE_DOTPROD) | ||
| 1362 | sumi0 = vdotq_s32(sumi0, sqx0, qy0); | ||
| 1363 | sumi1 = vdotq_s32(sumi1, sqx1, qy1); | ||
| 1364 | sumi0 = vdotq_s32(sumi0, sqx2, qy2); | ||
| 1365 | sumi1 = vdotq_s32(sumi1, sqx3, qy3); | ||
| 1366 | sumi0 = vdotq_s32(sumi0, sqx4, qy4); | ||
| 1367 | sumi1 = vdotq_s32(sumi1, sqx5, qy5); | ||
| 1368 | sumi0 = vdotq_s32(sumi0, sqx6, qy6); | ||
| 1369 | sumi1 = vdotq_s32(sumi1, sqx7, qy7); | ||
| 1370 | #else | ||
| 1371 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); | ||
| 1372 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); | ||
| 1373 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); | ||
| 1374 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); | ||
| 1375 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); | ||
| 1376 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); | ||
| 1377 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); | ||
| 1378 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); | ||
| 1379 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); | ||
| 1380 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); | ||
| 1381 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); | ||
| 1382 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); | ||
| 1383 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); | ||
| 1384 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); | ||
| 1385 | sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); | ||
| 1386 | sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); | ||
| 1387 | #endif | ||
| 1388 | } | ||
| 1389 | |||
| 1390 | const int16x8_t ysum0 = vld1q_s16(y[i].bsums); | ||
| 1391 | const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); | ||
| 1392 | |||
| 1393 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 1394 | |||
| 1395 | #if defined(__ARM_FEATURE_DOTPROD) | ||
| 1396 | sumi0 = vaddq_s32(sumi0, sumi1); | ||
| 1397 | sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); | ||
| 1398 | |||
| 1399 | sumf += d * (float) vaddvq_s32(sumi0); | ||
| 1400 | #else | ||
| 1401 | sumi0 = vaddq_s16(sumi0, sumi1); | ||
| 1402 | sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); | ||
| 1403 | |||
| 1404 | sumf += d * (float) vaddlvq_s16(sumi0); | ||
| 1405 | #endif | ||
| 1406 | } | ||
| 1407 | |||
| 1408 | *s = sumf; | ||
| 1409 | |||
| 1410 | #else | ||
| 1411 | UNUSED(x); | ||
| 1412 | UNUSED(y); | ||
| 1413 | UNUSED(nb); | ||
| 1414 | ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 1415 | #endif | ||
| 1416 | } | ||
| 1417 | |||
| 1418 | void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 1419 | assert(nrc == 1); | ||
| 1420 | UNUSED(nrc); | ||
| 1421 | UNUSED(bx); | ||
| 1422 | UNUSED(by); | ||
| 1423 | UNUSED(bs); | ||
| 1424 | |||
| 1425 | const block_q2_K * GGML_RESTRICT x = vx; | ||
| 1426 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 1427 | |||
| 1428 | const int nb = n / QK_K; | ||
| 1429 | |||
| 1430 | #ifdef __ARM_FEATURE_SVE | ||
| 1431 | const int vector_length = svcntb()*8; | ||
| 1432 | const svuint8_t m3s = svdup_n_u8(0x3); | ||
| 1433 | const svuint32_t m4s = svdup_n_u32(0xF); | ||
| 1434 | const svint32_t vzero_sv = svdup_n_s32(0); | ||
| 1435 | svfloat32_t acc_sum = svdup_n_f32(0); | ||
| 1436 | svbool_t pred_s32 = svptrue_pat_b32(SV_VL4); | ||
| 1437 | |||
| 1438 | switch (vector_length) { | ||
| 1439 | case 128: | ||
| 1440 | for (int i = 0; i < nb; ++i) { | ||
| 1441 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 1442 | svfloat32_t d_broad = svdup_n_f32((float32_t)d); | ||
| 1443 | const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); | ||
| 1444 | svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin); | ||
| 1445 | |||
| 1446 | const uint8_t * GGML_RESTRICT q2 = x[i].qs; | ||
| 1447 | const int8_t * GGML_RESTRICT q8_sv = y[i].qs; | ||
| 1448 | const uint8_t * GGML_RESTRICT sc = x[i].scales; | ||
| 1449 | |||
| 1450 | svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc); | ||
| 1451 | const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); | ||
| 1452 | |||
| 1453 | mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4); | ||
| 1454 | const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); | ||
| 1455 | |||
| 1456 | svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums); | ||
| 1457 | svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4); | ||
| 1458 | |||
| 1459 | const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2)); | ||
| 1460 | |||
| 1461 | mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8); | ||
| 1462 | const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); | ||
| 1463 | |||
| 1464 | mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12); | ||
| 1465 | const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); | ||
| 1466 | |||
| 1467 | q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8); | ||
| 1468 | q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12); | ||
| 1469 | |||
| 1470 | svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2)); | ||
| 1471 | |||
| 1472 | svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1)); | ||
| 1473 | |||
| 1474 | acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad); | ||
| 1475 | |||
| 1476 | svint32_t sumi1 = svdup_n_s32(0); | ||
| 1477 | |||
| 1478 | { | ||
| 1479 | const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2); | ||
| 1480 | svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s)); | ||
| 1481 | svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1482 | const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s)); | ||
| 1483 | |||
| 1484 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0)); | ||
| 1485 | |||
| 1486 | const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16); | ||
| 1487 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s)); | ||
| 1488 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1489 | |||
| 1490 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1)); | ||
| 1491 | |||
| 1492 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s)); | ||
| 1493 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1494 | |||
| 1495 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2)); | ||
| 1496 | |||
| 1497 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s)); | ||
| 1498 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1499 | |||
| 1500 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3)); | ||
| 1501 | |||
| 1502 | |||
| 1503 | const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s)); | ||
| 1504 | |||
| 1505 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s)); | ||
| 1506 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1507 | |||
| 1508 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0)); | ||
| 1509 | |||
| 1510 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s)); | ||
| 1511 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1512 | |||
| 1513 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1)); | ||
| 1514 | |||
| 1515 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s)); | ||
| 1516 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1517 | |||
| 1518 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2)); | ||
| 1519 | |||
| 1520 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s)); | ||
| 1521 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1522 | |||
| 1523 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3)); | ||
| 1524 | |||
| 1525 | //------------------------------- | ||
| 1526 | |||
| 1527 | q2 += 32; | ||
| 1528 | const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s)); | ||
| 1529 | const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2); | ||
| 1530 | |||
| 1531 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s)); | ||
| 1532 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1533 | |||
| 1534 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0)); | ||
| 1535 | |||
| 1536 | const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16); | ||
| 1537 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s)); | ||
| 1538 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1539 | |||
| 1540 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1)); | ||
| 1541 | |||
| 1542 | |||
| 1543 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s)); | ||
| 1544 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1545 | |||
| 1546 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2)); | ||
| 1547 | |||
| 1548 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s)); | ||
| 1549 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1550 | |||
| 1551 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3)); | ||
| 1552 | |||
| 1553 | |||
| 1554 | const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s)); | ||
| 1555 | |||
| 1556 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s)); | ||
| 1557 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1558 | |||
| 1559 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0)); | ||
| 1560 | |||
| 1561 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s)); | ||
| 1562 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1563 | |||
| 1564 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1)); | ||
| 1565 | |||
| 1566 | |||
| 1567 | |||
| 1568 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s)); | ||
| 1569 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1570 | |||
| 1571 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2)); | ||
| 1572 | |||
| 1573 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s)); | ||
| 1574 | q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1575 | |||
| 1576 | sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3)); | ||
| 1577 | } | ||
| 1578 | acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad); | ||
| 1579 | } | ||
| 1580 | *s = svaddv_f32(svptrue_b32(), acc_sum); | ||
| 1581 | break; | ||
| 1582 | |||
| 1583 | case 256: | ||
| 1584 | case 512: | ||
| 1585 | for (int i = 0; i < nb; ++i) { | ||
| 1586 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 1587 | svfloat32_t d_broad = svdup_n_f32((float32_t)d); | ||
| 1588 | const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); | ||
| 1589 | svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin); | ||
| 1590 | |||
| 1591 | const uint8_t * GGML_RESTRICT q2 = x[i].qs; | ||
| 1592 | const int8_t * GGML_RESTRICT q8_sv = y[i].qs; | ||
| 1593 | const uint8_t * GGML_RESTRICT sc = x[i].scales; | ||
| 1594 | |||
| 1595 | const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8; | ||
| 1596 | const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s)); | ||
| 1597 | const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4)); | ||
| 1598 | svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums); | ||
| 1599 | |||
| 1600 | const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); | ||
| 1601 | const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s)); | ||
| 1602 | const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4)); | ||
| 1603 | |||
| 1604 | svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8); | ||
| 1605 | |||
| 1606 | svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2))); | ||
| 1607 | |||
| 1608 | acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad); | ||
| 1609 | |||
| 1610 | svint32_t sumi1 = svdup_n_s32(0); | ||
| 1611 | |||
| 1612 | { | ||
| 1613 | const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2); | ||
| 1614 | svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s)); | ||
| 1615 | svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1616 | |||
| 1617 | svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1)); | ||
| 1618 | sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); | ||
| 1619 | |||
| 1620 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s)); | ||
| 1621 | q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1622 | |||
| 1623 | svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3)); | ||
| 1624 | sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2); | ||
| 1625 | |||
| 1626 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s)); | ||
| 1627 | q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1628 | |||
| 1629 | scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5)); | ||
| 1630 | sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); | ||
| 1631 | |||
| 1632 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s)); | ||
| 1633 | q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1634 | |||
| 1635 | scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7)); | ||
| 1636 | sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); | ||
| 1637 | |||
| 1638 | q2 += 32; | ||
| 1639 | |||
| 1640 | const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2); | ||
| 1641 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s)); | ||
| 1642 | q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1643 | |||
| 1644 | scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1)); | ||
| 1645 | sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); | ||
| 1646 | |||
| 1647 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s)); | ||
| 1648 | q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1649 | |||
| 1650 | scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3)); | ||
| 1651 | sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); | ||
| 1652 | |||
| 1653 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s)); | ||
| 1654 | q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1655 | |||
| 1656 | scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5)); | ||
| 1657 | sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); | ||
| 1658 | |||
| 1659 | q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s)); | ||
| 1660 | q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1661 | |||
| 1662 | scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7)); | ||
| 1663 | sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); | ||
| 1664 | } | ||
| 1665 | acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad); | ||
| 1666 | } | ||
| 1667 | *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum); | ||
| 1668 | break; | ||
| 1669 | |||
| 1670 | default: | ||
| 1671 | assert(false && "Unsupported vector length"); | ||
| 1672 | break; | ||
| 1673 | } | ||
| 1674 | |||
| 1675 | #elif __ARM_NEON | ||
| 1676 | const uint8x16_t m3 = vdupq_n_u8(0x3); | ||
| 1677 | const uint8x16_t m4 = vdupq_n_u8(0xF); | ||
| 1678 | |||
| 1679 | const int32x4_t vzero = vdupq_n_s32(0); | ||
| 1680 | |||
| 1681 | ggml_int8x16x2_t q2bytes; | ||
| 1682 | uint8_t aux[16]; | ||
| 1683 | |||
| 1684 | float sum = 0; | ||
| 1685 | |||
| 1686 | for (int i = 0; i < nb; ++i) { | ||
| 1687 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 1688 | const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); | ||
| 1689 | |||
| 1690 | const uint8_t * GGML_RESTRICT q2 = x[i].qs; | ||
| 1691 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 1692 | const uint8_t * GGML_RESTRICT sc = x[i].scales; | ||
| 1693 | |||
| 1694 | const uint8x16_t mins_and_scales = vld1q_u8(sc); | ||
| 1695 | const uint8x16_t scales = vandq_u8(mins_and_scales, m4); | ||
| 1696 | vst1q_u8(aux, scales); | ||
| 1697 | |||
| 1698 | const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); | ||
| 1699 | const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); | ||
| 1700 | const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; | ||
| 1701 | const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), | ||
| 1702 | vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); | ||
| 1703 | const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), | ||
| 1704 | vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); | ||
| 1705 | sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); | ||
| 1706 | |||
| 1707 | int isum = 0; | ||
| 1708 | int is = 0; | ||
| 1709 | |||
| 1710 | // We use this macro instead of a function call because for some reason | ||
| 1711 | // the code runs 2-3% slower, even if the function is declared inline | ||
| 1712 | #define MULTIPLY_ACCUM_WITH_SCALE(index)\ | ||
| 1713 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ | ||
| 1714 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; | ||
| 1715 | |||
| 1716 | #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ | ||
| 1717 | q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ | ||
| 1718 | q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ | ||
| 1719 | q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ | ||
| 1720 | MULTIPLY_ACCUM_WITH_SCALE((index)); | ||
| 1721 | |||
| 1722 | for (int j = 0; j < QK_K/128; ++j) { | ||
| 1723 | const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; | ||
| 1724 | |||
| 1725 | ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; | ||
| 1726 | q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); | ||
| 1727 | q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); | ||
| 1728 | |||
| 1729 | MULTIPLY_ACCUM_WITH_SCALE(0); | ||
| 1730 | |||
| 1731 | SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); | ||
| 1732 | SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); | ||
| 1733 | SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); | ||
| 1734 | |||
| 1735 | is += 8; | ||
| 1736 | } | ||
| 1737 | |||
| 1738 | sum += d * isum; | ||
| 1739 | } | ||
| 1740 | |||
| 1741 | *s = sum; | ||
| 1742 | |||
| 1743 | #else | ||
| 1744 | UNUSED(x); | ||
| 1745 | UNUSED(y); | ||
| 1746 | UNUSED(nb); | ||
| 1747 | ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 1748 | #endif | ||
| 1749 | } | ||
| 1750 | |||
| 1751 | void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 1752 | assert(n % QK_K == 0); | ||
| 1753 | assert(nrc == 1); | ||
| 1754 | UNUSED(nrc); | ||
| 1755 | UNUSED(bx); | ||
| 1756 | UNUSED(by); | ||
| 1757 | UNUSED(bs); | ||
| 1758 | |||
| 1759 | const uint32_t kmask1 = 0x03030303; | ||
| 1760 | const uint32_t kmask2 = 0x0f0f0f0f; | ||
| 1761 | |||
| 1762 | const block_q3_K * GGML_RESTRICT x = vx; | ||
| 1763 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 1764 | |||
| 1765 | const int nb = n / QK_K; | ||
| 1766 | |||
| 1767 | #if defined(__ARM_FEATURE_SVE) | ||
| 1768 | |||
| 1769 | uint32_t aux[3]; | ||
| 1770 | uint32_t utmp[4]; | ||
| 1771 | |||
| 1772 | const int8_t m32 = 32; | ||
| 1773 | const int vector_length = svcntb()*8; | ||
| 1774 | const svuint8_t m3b_sv = svdup_n_u8(0x3); | ||
| 1775 | const svint32_t vzero_sv = svdup_n_s32(0); | ||
| 1776 | |||
| 1777 | const svuint8_t m0_sv = svdup_n_u8(1); | ||
| 1778 | const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1); | ||
| 1779 | const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2); | ||
| 1780 | const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3); | ||
| 1781 | |||
| 1782 | float sum = 0; | ||
| 1783 | |||
| 1784 | for (int i = 0; i < nb; ++i) { | ||
| 1785 | |||
| 1786 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 1787 | |||
| 1788 | const uint8_t * GGML_RESTRICT q3_sv = x[i].qs; | ||
| 1789 | const uint8_t * GGML_RESTRICT qh_sv = x[i].hmask; | ||
| 1790 | const int8_t * GGML_RESTRICT q8_sv = y[i].qs; | ||
| 1791 | |||
| 1792 | // Set up scales | ||
| 1793 | memcpy(aux, x[i].scales, 12); | ||
| 1794 | utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); | ||
| 1795 | utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); | ||
| 1796 | utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); | ||
| 1797 | utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); | ||
| 1798 | |||
| 1799 | int8_t * scale = (int8_t *)utmp; | ||
| 1800 | |||
| 1801 | for (int j = 0; j < 16; ++j) scale[j] -= m32; | ||
| 1802 | |||
| 1803 | switch (vector_length) { | ||
| 1804 | case 128: | ||
| 1805 | { | ||
| 1806 | svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv); | ||
| 1807 | svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16); | ||
| 1808 | svuint8_t q3h_sv; | ||
| 1809 | |||
| 1810 | svint32_t sumi1_1 = svdup_n_s32(0); | ||
| 1811 | svint8_t q3bytes_sv; | ||
| 1812 | |||
| 1813 | for (int j = 0; j < QK_K/128; ++j) { | ||
| 1814 | |||
| 1815 | const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; | ||
| 1816 | const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; | ||
| 1817 | svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1818 | svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1819 | |||
| 1820 | q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2); | ||
| 1821 | q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1822 | |||
| 1823 | sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); | ||
| 1824 | |||
| 1825 | q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2); | ||
| 1826 | q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1827 | |||
| 1828 | sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); | ||
| 1829 | |||
| 1830 | q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1831 | q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1832 | |||
| 1833 | q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1); | ||
| 1834 | q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1835 | |||
| 1836 | sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); | ||
| 1837 | |||
| 1838 | q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1); | ||
| 1839 | q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1840 | |||
| 1841 | sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); | ||
| 1842 | |||
| 1843 | |||
| 1844 | scale += 4; | ||
| 1845 | q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1846 | q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1847 | |||
| 1848 | q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1); | ||
| 1849 | q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1850 | |||
| 1851 | sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); | ||
| 1852 | |||
| 1853 | q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2); | ||
| 1854 | q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1855 | |||
| 1856 | sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); | ||
| 1857 | |||
| 1858 | |||
| 1859 | q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1860 | q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; | ||
| 1861 | |||
| 1862 | q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1); | ||
| 1863 | q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1864 | |||
| 1865 | sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); | ||
| 1866 | |||
| 1867 | q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1); | ||
| 1868 | q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1869 | |||
| 1870 | sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); | ||
| 1871 | |||
| 1872 | if (j == 0) { | ||
| 1873 | qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4); | ||
| 1874 | qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4); | ||
| 1875 | } | ||
| 1876 | |||
| 1877 | scale += 4; | ||
| 1878 | } | ||
| 1879 | |||
| 1880 | sum += d * (svaddv_s32(svptrue_b32(), sumi1_1)); | ||
| 1881 | } break; | ||
| 1882 | case 256: | ||
| 1883 | case 512: | ||
| 1884 | { | ||
| 1885 | svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv); | ||
| 1886 | svuint8_t q3h_sv; | ||
| 1887 | |||
| 1888 | svint32_t sumi1_1 = svdup_n_s32(0); | ||
| 1889 | svint8_t q3bytes_sv; | ||
| 1890 | |||
| 1891 | for (int j = 0; j < QK_K/128; ++j) { | ||
| 1892 | |||
| 1893 | const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32; | ||
| 1894 | svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1895 | svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1896 | |||
| 1897 | q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2); | ||
| 1898 | q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1899 | |||
| 1900 | |||
| 1901 | svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); | ||
| 1902 | sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); | ||
| 1903 | |||
| 1904 | q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1); | ||
| 1905 | q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1906 | |||
| 1907 | scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); | ||
| 1908 | sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); | ||
| 1909 | |||
| 1910 | scale += 4; | ||
| 1911 | q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1912 | q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; | ||
| 1913 | |||
| 1914 | q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv); | ||
| 1915 | q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1916 | |||
| 1917 | scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); | ||
| 1918 | sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); | ||
| 1919 | |||
| 1920 | q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1); | ||
| 1921 | q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); | ||
| 1922 | |||
| 1923 | scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); | ||
| 1924 | sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); | ||
| 1925 | |||
| 1926 | if (j == 0) { | ||
| 1927 | qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4); | ||
| 1928 | } | ||
| 1929 | |||
| 1930 | scale += 4; | ||
| 1931 | } | ||
| 1932 | |||
| 1933 | sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1)); | ||
| 1934 | } break; | ||
| 1935 | default: | ||
| 1936 | assert(false && "Unsupported vector length"); | ||
| 1937 | break; | ||
| 1938 | } | ||
| 1939 | } | ||
| 1940 | *s = sum; | ||
| 1941 | |||
| 1942 | #elif __ARM_NEON | ||
| 1943 | |||
| 1944 | uint32_t aux[3]; | ||
| 1945 | uint32_t utmp[4]; | ||
| 1946 | |||
| 1947 | const uint8x16_t m3b = vdupq_n_u8(0x3); | ||
| 1948 | const int32x4_t vzero = vdupq_n_s32(0); | ||
| 1949 | |||
| 1950 | const uint8x16_t m0 = vdupq_n_u8(1); | ||
| 1951 | const uint8x16_t m1 = vshlq_n_u8(m0, 1); | ||
| 1952 | const uint8x16_t m2 = vshlq_n_u8(m0, 2); | ||
| 1953 | const uint8x16_t m3 = vshlq_n_u8(m0, 3); | ||
| 1954 | const int8_t m32 = 32; | ||
| 1955 | |||
| 1956 | ggml_int8x16x4_t q3bytes; | ||
| 1957 | |||
| 1958 | float sum = 0; | ||
| 1959 | |||
| 1960 | for (int i = 0; i < nb; ++i) { | ||
| 1961 | |||
| 1962 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 1963 | |||
| 1964 | const uint8_t * GGML_RESTRICT q3 = x[i].qs; | ||
| 1965 | const uint8_t * GGML_RESTRICT qh = x[i].hmask; | ||
| 1966 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 1967 | |||
| 1968 | ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); | ||
| 1969 | |||
| 1970 | ggml_uint8x16x4_t q3h; | ||
| 1971 | |||
| 1972 | int32_t isum = 0; | ||
| 1973 | |||
| 1974 | // Set up scales | ||
| 1975 | memcpy(aux, x[i].scales, 12); | ||
| 1976 | utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); | ||
| 1977 | utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); | ||
| 1978 | utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); | ||
| 1979 | utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); | ||
| 1980 | |||
| 1981 | int8_t * scale = (int8_t *)utmp; | ||
| 1982 | for (int j = 0; j < 16; ++j) scale[j] -= m32; | ||
| 1983 | |||
| 1984 | for (int j = 0; j < QK_K/128; ++j) { | ||
| 1985 | |||
| 1986 | const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; | ||
| 1987 | const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 1988 | const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 1989 | |||
| 1990 | q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); | ||
| 1991 | q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); | ||
| 1992 | q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); | ||
| 1993 | q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); | ||
| 1994 | |||
| 1995 | q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); | ||
| 1996 | q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); | ||
| 1997 | q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); | ||
| 1998 | q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); | ||
| 1999 | |||
| 2000 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; | ||
| 2001 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; | ||
| 2002 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; | ||
| 2003 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; | ||
| 2004 | |||
| 2005 | scale += 4; | ||
| 2006 | |||
| 2007 | q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); | ||
| 2008 | q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); | ||
| 2009 | q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); | ||
| 2010 | q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); | ||
| 2011 | |||
| 2012 | q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); | ||
| 2013 | q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); | ||
| 2014 | q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); | ||
| 2015 | q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); | ||
| 2016 | |||
| 2017 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; | ||
| 2018 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; | ||
| 2019 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; | ||
| 2020 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; | ||
| 2021 | |||
| 2022 | scale += 4; | ||
| 2023 | |||
| 2024 | if (j == 0) { | ||
| 2025 | qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); | ||
| 2026 | qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); | ||
| 2027 | } | ||
| 2028 | |||
| 2029 | } | ||
| 2030 | sum += d * isum; | ||
| 2031 | |||
| 2032 | } | ||
| 2033 | |||
| 2034 | *s = sum; | ||
| 2035 | |||
| 2036 | #else | ||
| 2037 | UNUSED(kmask1); | ||
| 2038 | UNUSED(kmask2); | ||
| 2039 | UNUSED(x); | ||
| 2040 | UNUSED(y); | ||
| 2041 | UNUSED(nb); | ||
| 2042 | ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 2043 | #endif | ||
| 2044 | |||
| 2045 | } | ||
| 2046 | |||
| 2047 | #ifdef __ARM_FEATURE_SVE | ||
| 2048 | static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) { | ||
| 2049 | const svbool_t pg_all = svptrue_pat_b32(SV_VL4); | ||
| 2050 | const svbool_t pg_false = svpfalse_b(); // 0x0000 | ||
| 2051 | const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff | ||
| 2052 | const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8); | ||
| 2053 | |||
| 2054 | svuint32_t vutmp_hi, vutmp_lo; | ||
| 2055 | svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales); | ||
| 2056 | vutmp_hi = svzip1_u32(vx01, vx01); | ||
| 2057 | vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2); | ||
| 2058 | vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f))); | ||
| 2059 | const svuint32_t vx2 = svdup_u32(vx_scales[2]); | ||
| 2060 | vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2))); | ||
| 2061 | vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f)); | ||
| 2062 | svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo); | ||
| 2063 | return vutmp; | ||
| 2064 | } | ||
| 2065 | #endif | ||
| 2066 | |||
| 2067 | void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 2068 | assert(n % QK_K == 0); | ||
| 2069 | #ifdef __ARM_FEATURE_MATMUL_INT8 | ||
| 2070 | assert((nrc == 2) || (nrc == 1)); | ||
| 2071 | #else | ||
| 2072 | assert(nrc == 1); | ||
| 2073 | #endif | ||
| 2074 | UNUSED(nrc); | ||
| 2075 | UNUSED(bx); | ||
| 2076 | UNUSED(by); | ||
| 2077 | UNUSED(bs); | ||
| 2078 | |||
| 2079 | const block_q4_K * GGML_RESTRICT x = vx; | ||
| 2080 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 2081 | |||
| 2082 | const int nb = n / QK_K; | ||
| 2083 | |||
| 2084 | static const uint32_t kmask1 = 0x3f3f3f3f; | ||
| 2085 | static const uint32_t kmask2 = 0x0f0f0f0f; | ||
| 2086 | static const uint32_t kmask3 = 0x03030303; | ||
| 2087 | |||
| 2088 | uint32_t utmp[4]; | ||
| 2089 | #ifdef __ARM_FEATURE_SVE | ||
| 2090 | const int vector_length = ggml_cpu_get_sve_cnt()*8; | ||
| 2091 | #endif | ||
| 2092 | |||
| 2093 | #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 2094 | if (nrc == 2) { | ||
| 2095 | svbool_t pg32_2 = svptrue_pat_b32(SV_VL2); | ||
| 2096 | |||
| 2097 | const block_q4_K * GGML_RESTRICT vx0 = vx; | ||
| 2098 | const block_q8_K * GGML_RESTRICT vy0 = vy; | ||
| 2099 | const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx); | ||
| 2100 | const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by); | ||
| 2101 | |||
| 2102 | union { | ||
| 2103 | uint32_t u32[8]; | ||
| 2104 | uint64_t u64[4]; | ||
| 2105 | } new_utmp; | ||
| 2106 | |||
| 2107 | svfloat32_t sumf1 = svdup_n_f32(0); | ||
| 2108 | |||
| 2109 | switch (vector_length) { | ||
| 2110 | case 128: | ||
| 2111 | { | ||
| 2112 | svbool_t pg_false = svpfalse_b(); | ||
| 2113 | svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); | ||
| 2114 | svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false); | ||
| 2115 | svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8); | ||
| 2116 | svbool_t pg128_all = svptrue_pat_b8(SV_VL16); | ||
| 2117 | for (int i = 0; i < nb; ++i) { | ||
| 2118 | svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)); | ||
| 2119 | svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d))); | ||
| 2120 | svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d); | ||
| 2121 | svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin))); | ||
| 2122 | svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)); | ||
| 2123 | svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1); | ||
| 2124 | const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs; | ||
| 2125 | const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs; | ||
| 2126 | const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs; | ||
| 2127 | const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs; | ||
| 2128 | svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0); | ||
| 2129 | svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8); | ||
| 2130 | svint16_t sum_tmp1 = svuzp1_s16(lo, hi); | ||
| 2131 | svint16_t sum_tmp2 = svuzp2_s16(lo, hi); | ||
| 2132 | svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2); | ||
| 2133 | lo = svld1_s16(pg128_all, vy1[i].bsums + 0); | ||
| 2134 | hi = svld1_s16(pg128_all, vy1[i].bsums + 8); | ||
| 2135 | sum_tmp1 = svuzp1(lo, hi); | ||
| 2136 | sum_tmp2 = svuzp2(lo, hi); | ||
| 2137 | svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2); | ||
| 2138 | svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales); | ||
| 2139 | svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales); | ||
| 2140 | svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1); | ||
| 2141 | svst2_u32(pg128_all, new_utmp.u32, decoded_scales); | ||
| 2142 | svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0))))); | ||
| 2143 | svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0))))); | ||
| 2144 | svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0)); | ||
| 2145 | svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1)); | ||
| 2146 | svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2); | ||
| 2147 | svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0)); | ||
| 2148 | svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1)); | ||
| 2149 | svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5); | ||
| 2150 | svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6))); | ||
| 2151 | svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6))); | ||
| 2152 | svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8); | ||
| 2153 | svint32_t svscales, sumi1, sumi2; | ||
| 2154 | svint32_t acc_sumif1 = svdup_n_s32(0); | ||
| 2155 | svint32_t acc_sumif2 = svdup_n_s32(0); | ||
| 2156 | svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3, | ||
| 2157 | q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3; | ||
| 2158 | #pragma GCC unroll 1 | ||
| 2159 | for (int j = 0; j < QK_K/64; ++j) { | ||
| 2160 | q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf)); | ||
| 2161 | q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf)); | ||
| 2162 | q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf)); | ||
| 2163 | q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf)); | ||
| 2164 | l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l))); | ||
| 2165 | l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l))); | ||
| 2166 | l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h))); | ||
| 2167 | l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h))); | ||
| 2168 | q8bytes_0_h = svld1_s8(pg128_all, q8_0); | ||
| 2169 | q8bytes_1_h = svld1_s8(pg128_all, q8_1); | ||
| 2170 | q8bytes_0_l = svld1_s8(pg128_all, q8_0+16); | ||
| 2171 | q8bytes_1_l = svld1_s8(pg128_all, q8_1+16); | ||
| 2172 | r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h))); | ||
| 2173 | r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h))); | ||
| 2174 | r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l))); | ||
| 2175 | r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l))); | ||
| 2176 | sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3); | ||
| 2177 | svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24)); | ||
| 2178 | acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1); | ||
| 2179 | |||
| 2180 | q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4)); | ||
| 2181 | q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4)); | ||
| 2182 | q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4)); | ||
| 2183 | q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4)); | ||
| 2184 | l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l))); | ||
| 2185 | l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l))); | ||
| 2186 | l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h))); | ||
| 2187 | l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h))); | ||
| 2188 | q8bytes_0_h = svld1_s8(pg128_all, q8_0+32); | ||
| 2189 | q8bytes_1_h = svld1_s8(pg128_all, q8_1+32); | ||
| 2190 | q8bytes_0_l = svld1_s8(pg128_all, q8_0+48); | ||
| 2191 | q8bytes_1_l = svld1_s8(pg128_all, q8_1+48); | ||
| 2192 | r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h))); | ||
| 2193 | r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h))); | ||
| 2194 | r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l))); | ||
| 2195 | r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l))); | ||
| 2196 | sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3); | ||
| 2197 | svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24)); | ||
| 2198 | acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2); | ||
| 2199 | q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64; | ||
| 2200 | } | ||
| 2201 | sumf1 = svmla_f32_x(pg128_all, | ||
| 2202 | svmla_f32_x(pg128_all, | ||
| 2203 | sumf1, | ||
| 2204 | svcvt_f32_x(pg128_all, | ||
| 2205 | svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)), | ||
| 2206 | svsuper_block_scales), | ||
| 2207 | svdmins, | ||
| 2208 | svcvt_f32_s32_x(pg128_all, svsumfs_tmp)); | ||
| 2209 | } //end of for nb | ||
| 2210 | } // end of case 128 | ||
| 2211 | break; | ||
| 2212 | case 256: | ||
| 2213 | case 512: | ||
| 2214 | { | ||
| 2215 | const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4); | ||
| 2216 | const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16); | ||
| 2217 | const svbool_t pg256_all = svptrue_pat_b8(SV_ALL); | ||
| 2218 | for (int i = 0; i < nb; ++i) { | ||
| 2219 | const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs; | ||
| 2220 | const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs; | ||
| 2221 | const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs; | ||
| 2222 | const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs; | ||
| 2223 | svint32_t svscales, sumi1, sumi2; | ||
| 2224 | svint32_t acc_sumif1 = svdup_n_s32(0); | ||
| 2225 | svint32_t acc_sumif2 = svdup_n_s32(0); | ||
| 2226 | svint8_t l0, l1, l2, l3, r0, r1, r2, r3; | ||
| 2227 | svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d))); | ||
| 2228 | svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d))); | ||
| 2229 | svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp)); | ||
| 2230 | svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d); | ||
| 2231 | svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin))); | ||
| 2232 | svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d))); | ||
| 2233 | svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp)); | ||
| 2234 | svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1); | ||
| 2235 | svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums)); | ||
| 2236 | svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums)); | ||
| 2237 | svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2); | ||
| 2238 | svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales); | ||
| 2239 | svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales); | ||
| 2240 | svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1); | ||
| 2241 | svst2_u32(pg8_16, new_utmp.u32, decoded_scales); | ||
| 2242 | svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums))); | ||
| 2243 | svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums))); | ||
| 2244 | svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]); | ||
| 2245 | svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]); | ||
| 2246 | svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0))); | ||
| 2247 | svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1))); | ||
| 2248 | svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0); | ||
| 2249 | svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1); | ||
| 2250 | svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1); | ||
| 2251 | svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1); | ||
| 2252 | |||
| 2253 | #pragma GCC unroll 1 | ||
| 2254 | for (int j = 0; j < QK_K/64; ++j) { | ||
| 2255 | svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf); | ||
| 2256 | svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf); | ||
| 2257 | svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4); | ||
| 2258 | svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4); | ||
| 2259 | l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1))); | ||
| 2260 | l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1))); | ||
| 2261 | l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3))); | ||
| 2262 | l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3))); | ||
| 2263 | svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0); | ||
| 2264 | svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1); | ||
| 2265 | svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32); | ||
| 2266 | svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32); | ||
| 2267 | r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); | ||
| 2268 | r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); | ||
| 2269 | r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3))); | ||
| 2270 | r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3))); | ||
| 2271 | sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1); | ||
| 2272 | svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24)); | ||
| 2273 | acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1); | ||
| 2274 | sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3); | ||
| 2275 | svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24)); | ||
| 2276 | acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2); | ||
| 2277 | q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64; | ||
| 2278 | } | ||
| 2279 | svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2); | ||
| 2280 | svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4); | ||
| 2281 | acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif); | ||
| 2282 | sumf1 = svmla_f32_x(pg32_4, | ||
| 2283 | svmla_f32_x(pg32_4, | ||
| 2284 | sumf1, | ||
| 2285 | svcvt_f32_x(pg32_4, acc_sumif), | ||
| 2286 | svsuper_block_scales), | ||
| 2287 | svdmins, | ||
| 2288 | svsumfs_tmp); | ||
| 2289 | } // end of for nb | ||
| 2290 | } // end of case 256-512 | ||
| 2291 | break; | ||
| 2292 | default: | ||
| 2293 | assert(false && "Unsupported vector length"); | ||
| 2294 | break; | ||
| 2295 | } | ||
| 2296 | |||
| 2297 | svst1_f32(pg32_2, s, sumf1); | ||
| 2298 | svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8))); | ||
| 2299 | |||
| 2300 | return; | ||
| 2301 | } | ||
| 2302 | #elif defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 2303 | if (nrc == 2) { | ||
| 2304 | const block_q4_K * GGML_RESTRICT x0 = x; | ||
| 2305 | const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx); | ||
| 2306 | const block_q8_K * GGML_RESTRICT y0 = y; | ||
| 2307 | const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by); | ||
| 2308 | |||
| 2309 | const uint8x16_t m4b = vdupq_n_u8(0x0f); | ||
| 2310 | |||
| 2311 | float32x4_t vfsum = vdupq_n_f32(0.0f); | ||
| 2312 | |||
| 2313 | for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) { | ||
| 2314 | const uint8_t * GGML_RESTRICT qx0 = x0->qs; | ||
| 2315 | const uint8_t * GGML_RESTRICT qx1 = x1->qs; | ||
| 2316 | const int8_t * GGML_RESTRICT qy0 = y0->qs; | ||
| 2317 | const int8_t * GGML_RESTRICT qy1 = y1->qs; | ||
| 2318 | |||
| 2319 | // decode scales and mins | ||
| 2320 | int8_t x0_scales[8], x1_scales[8]; | ||
| 2321 | int16x8_t x0_mins, x1_mins; | ||
| 2322 | { | ||
| 2323 | uint32_t scales_mins[3]; | ||
| 2324 | memcpy(scales_mins, x0->scales, 12); | ||
| 2325 | const uint32_t mins_0_3 = scales_mins[1] & kmask1; | ||
| 2326 | const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4); | ||
| 2327 | const uint32x2_t mins = {mins_0_3, mins_4_7}; | ||
| 2328 | x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins))); | ||
| 2329 | uint32_t scales[2]; | ||
| 2330 | scales[0] = scales_mins[0] & kmask1; // scales 0~3 | ||
| 2331 | scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7 | ||
| 2332 | memcpy(x0_scales, scales, 8); | ||
| 2333 | } | ||
| 2334 | { | ||
| 2335 | uint32_t scales_mins[3]; | ||
| 2336 | memcpy(scales_mins, x1->scales, 12); | ||
| 2337 | const uint32_t mins_0_3 = scales_mins[1] & kmask1; | ||
| 2338 | const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4); | ||
| 2339 | const uint32x2_t mins = {mins_0_3, mins_4_7}; | ||
| 2340 | x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins))); | ||
| 2341 | uint32_t scales[2]; | ||
| 2342 | scales[0] = scales_mins[0] & kmask1; // scales 0~3 | ||
| 2343 | scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7 | ||
| 2344 | memcpy(x1_scales, scales, 8); | ||
| 2345 | } | ||
| 2346 | |||
| 2347 | int32x4_t visum = {0}; | ||
| 2348 | |||
| 2349 | // process 64 data points per iteration, totally 256 data points | ||
| 2350 | for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) { | ||
| 2351 | const int8x16x4_t vy0 = vld1q_s8_x4(qy0); | ||
| 2352 | const int8x16x4_t vy1 = vld1q_s8_x4(qy1); | ||
| 2353 | |||
| 2354 | int8x16_t vx0[4], vx1[4]; | ||
| 2355 | { | ||
| 2356 | const uint8x16x2_t vv = vld1q_u8_x2(qx0); | ||
| 2357 | vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b)); | ||
| 2358 | vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b)); | ||
| 2359 | vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4)); | ||
| 2360 | vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4)); | ||
| 2361 | } | ||
| 2362 | { | ||
| 2363 | const uint8x16x2_t vv = vld1q_u8_x2(qx1); | ||
| 2364 | vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b)); | ||
| 2365 | vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b)); | ||
| 2366 | vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4)); | ||
| 2367 | vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4)); | ||
| 2368 | } | ||
| 2369 | |||
| 2370 | // process 32 data points (share same block scale) per iteration | ||
| 2371 | for (int k = 0; k < 2; ++k) { | ||
| 2372 | const int blk = j * 2 + k; | ||
| 2373 | const int32x4_t block_scale = { | ||
| 2374 | x0_scales[blk], | ||
| 2375 | x0_scales[blk], | ||
| 2376 | x1_scales[blk], | ||
| 2377 | x1_scales[blk], | ||
| 2378 | }; | ||
| 2379 | |||
| 2380 | int32x4_t vr = {0}; | ||
| 2381 | for (int l = 0; l < 2; ++l) { | ||
| 2382 | const int idx = k * 2 + l; | ||
| 2383 | const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]); | ||
| 2384 | const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]); | ||
| 2385 | const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]); | ||
| 2386 | const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]); | ||
| 2387 | const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64)); | ||
| 2388 | const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64)); | ||
| 2389 | const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64)); | ||
| 2390 | const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64)); | ||
| 2391 | vr = vmmlaq_s32(vr, vx_l, vy_l); | ||
| 2392 | vr = vmmlaq_s32(vr, vx_h, vy_h); | ||
| 2393 | } | ||
| 2394 | // apply block scale, will NOT overflow | ||
| 2395 | // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits | ||
| 2396 | visum = vmlaq_s32(visum, vr, block_scale); | ||
| 2397 | } | ||
| 2398 | } | ||
| 2399 | |||
| 2400 | // adjust bias, apply superblock scale | ||
| 2401 | { | ||
| 2402 | int32_t bias[4]; | ||
| 2403 | // no obvious uplift from sve sdot-16, just use neon mul add | ||
| 2404 | const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8)); | ||
| 2405 | const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8)); | ||
| 2406 | bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)), | ||
| 2407 | vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins)))); | ||
| 2408 | bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)), | ||
| 2409 | vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins)))); | ||
| 2410 | bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)), | ||
| 2411 | vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins)))); | ||
| 2412 | bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)), | ||
| 2413 | vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins)))); | ||
| 2414 | const float32x4_t dmins = { | ||
| 2415 | GGML_CPU_FP16_TO_FP32(x0->dmin) * y0->d, | ||
| 2416 | GGML_CPU_FP16_TO_FP32(x0->dmin) * y1->d, | ||
| 2417 | GGML_CPU_FP16_TO_FP32(x1->dmin) * y0->d, | ||
| 2418 | GGML_CPU_FP16_TO_FP32(x1->dmin) * y1->d, | ||
| 2419 | }; | ||
| 2420 | vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins); | ||
| 2421 | |||
| 2422 | const float32x4_t superblock_scale = { | ||
| 2423 | GGML_CPU_FP16_TO_FP32(x0->d) * y0->d, | ||
| 2424 | GGML_CPU_FP16_TO_FP32(x0->d) * y1->d, | ||
| 2425 | GGML_CPU_FP16_TO_FP32(x1->d) * y0->d, | ||
| 2426 | GGML_CPU_FP16_TO_FP32(x1->d) * y1->d, | ||
| 2427 | }; | ||
| 2428 | vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale); | ||
| 2429 | } | ||
| 2430 | } | ||
| 2431 | |||
| 2432 | // vfsum = ABCD -> ACBD | ||
| 2433 | // AC -> s, BD -> (s+bs) | ||
| 2434 | vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2)); | ||
| 2435 | vst1_f32(s, vget_low_f32 (vfsum)); | ||
| 2436 | vst1_f32(s + bs, vget_high_f32(vfsum)); | ||
| 2437 | |||
| 2438 | return; | ||
| 2439 | } | ||
| 2440 | #endif | ||
| 2441 | |||
| 2442 | #ifdef __ARM_FEATURE_SVE | ||
| 2443 | float sumf = 0; | ||
| 2444 | for (int i = 0; i < nb; ++i) { | ||
| 2445 | |||
| 2446 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 2447 | const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); | ||
| 2448 | |||
| 2449 | const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); | ||
| 2450 | |||
| 2451 | memcpy(utmp, x[i].scales, K_SCALE_SIZE); | ||
| 2452 | |||
| 2453 | uint32x2_t mins8 = { 0 }; | ||
| 2454 | mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); | ||
| 2455 | mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); | ||
| 2456 | |||
| 2457 | utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); | ||
| 2458 | utmp[0] &= kmask1; | ||
| 2459 | |||
| 2460 | const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); | ||
| 2461 | const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), | ||
| 2462 | vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); | ||
| 2463 | sumf -= dmin * vaddvq_s32(prod); | ||
| 2464 | |||
| 2465 | const uint8_t * scales = (const uint8_t *)utmp; | ||
| 2466 | |||
| 2467 | const uint8_t * GGML_RESTRICT q4 = x[i].qs; | ||
| 2468 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 2469 | |||
| 2470 | const svuint8_t m4b = svdup_n_u8(0xf); | ||
| 2471 | const svint32_t mzero = svdup_n_s32(0); | ||
| 2472 | svint32_t sumi1 = svdup_n_s32(0); | ||
| 2473 | svint32_t sumi1_1 = svdup_n_s32(0); | ||
| 2474 | svint32_t sumi1_2 = svdup_n_s32(0); | ||
| 2475 | svint32_t sumi2 = svdup_n_s32(0); | ||
| 2476 | svint32_t sumi2_1 = svdup_n_s32(0); | ||
| 2477 | svint32_t sumi2_2 = svdup_n_s32(0); | ||
| 2478 | switch (vector_length) { | ||
| 2479 | case 128: | ||
| 2480 | { | ||
| 2481 | for (int j = 0; j < QK_K/64; ++j) { | ||
| 2482 | svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b)); | ||
| 2483 | svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; | ||
| 2484 | sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); | ||
| 2485 | q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b)); | ||
| 2486 | q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; | ||
| 2487 | sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); | ||
| 2488 | |||
| 2489 | q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4)); | ||
| 2490 | q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; | ||
| 2491 | sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); | ||
| 2492 | q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4)); | ||
| 2493 | q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; | ||
| 2494 | sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); | ||
| 2495 | q4 += 32; | ||
| 2496 | } | ||
| 2497 | sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2); | ||
| 2498 | sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2); | ||
| 2499 | sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2))); | ||
| 2500 | } break; | ||
| 2501 | case 256: | ||
| 2502 | case 512: | ||
| 2503 | { | ||
| 2504 | for (int j = 0; j < QK_K/64; ++j) { | ||
| 2505 | const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32; | ||
| 2506 | svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b)); | ||
| 2507 | svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; | ||
| 2508 | sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); | ||
| 2509 | |||
| 2510 | q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4)); | ||
| 2511 | q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; | ||
| 2512 | sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); | ||
| 2513 | } | ||
| 2514 | sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2))); | ||
| 2515 | } break; | ||
| 2516 | default: | ||
| 2517 | assert(false && "Unsupported vector length"); | ||
| 2518 | break; | ||
| 2519 | } | ||
| 2520 | } | ||
| 2521 | *s = sumf; | ||
| 2522 | #elif defined __ARM_NEON | ||
| 2523 | const uint8x16_t m4b = vdupq_n_u8(0xf); | ||
| 2524 | const int32x4_t mzero = vdupq_n_s32(0); | ||
| 2525 | |||
| 2526 | ggml_int8x16x2_t q4bytes; | ||
| 2527 | ggml_int8x16x2_t q8bytes; | ||
| 2528 | |||
| 2529 | float sumf = 0; | ||
| 2530 | |||
| 2531 | for (int i = 0; i < nb; ++i) { | ||
| 2532 | |||
| 2533 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 2534 | const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); | ||
| 2535 | |||
| 2536 | const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); | ||
| 2537 | |||
| 2538 | memcpy(utmp, x[i].scales, 12); | ||
| 2539 | |||
| 2540 | uint32x2_t mins8 = { 0 }; | ||
| 2541 | mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); | ||
| 2542 | mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); | ||
| 2543 | |||
| 2544 | utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); | ||
| 2545 | utmp[0] &= kmask1; | ||
| 2546 | |||
| 2547 | const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); | ||
| 2548 | const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), | ||
| 2549 | vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); | ||
| 2550 | sumf -= dmin * vaddvq_s32(prod); | ||
| 2551 | |||
| 2552 | const uint8_t * scales = (const uint8_t *)utmp; | ||
| 2553 | |||
| 2554 | const uint8_t * GGML_RESTRICT q4 = x[i].qs; | ||
| 2555 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 2556 | |||
| 2557 | int32_t sumi1 = 0; | ||
| 2558 | int32_t sumi2 = 0; | ||
| 2559 | |||
| 2560 | for (int j = 0; j < QK_K/64; ++j) { | ||
| 2561 | const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; | ||
| 2562 | |||
| 2563 | q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; | ||
| 2564 | q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); | ||
| 2565 | q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); | ||
| 2566 | |||
| 2567 | const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); | ||
| 2568 | sumi1 += vaddvq_s32(p1) * scales[2*j+0]; | ||
| 2569 | |||
| 2570 | q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; | ||
| 2571 | q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); | ||
| 2572 | q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); | ||
| 2573 | |||
| 2574 | const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); | ||
| 2575 | |||
| 2576 | sumi2 += vaddvq_s32(p2) * scales[2*j+1]; | ||
| 2577 | } | ||
| 2578 | |||
| 2579 | sumf += d * (sumi1 + sumi2); | ||
| 2580 | |||
| 2581 | } | ||
| 2582 | |||
| 2583 | *s = sumf; | ||
| 2584 | |||
| 2585 | #else | ||
| 2586 | UNUSED(x); | ||
| 2587 | UNUSED(y); | ||
| 2588 | UNUSED(nb); | ||
| 2589 | UNUSED(kmask1); | ||
| 2590 | UNUSED(kmask2); | ||
| 2591 | UNUSED(kmask3); | ||
| 2592 | UNUSED(utmp); | ||
| 2593 | ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 2594 | #endif | ||
| 2595 | } | ||
| 2596 | |||
| 2597 | void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 2598 | assert(n % QK_K == 0); | ||
| 2599 | assert(nrc == 1); | ||
| 2600 | UNUSED(nrc); | ||
| 2601 | UNUSED(bx); | ||
| 2602 | UNUSED(by); | ||
| 2603 | UNUSED(bs); | ||
| 2604 | |||
| 2605 | const block_q5_K * GGML_RESTRICT x = vx; | ||
| 2606 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 2607 | |||
| 2608 | const int nb = n / QK_K; | ||
| 2609 | |||
| 2610 | static const uint32_t kmask1 = 0x3f3f3f3f; | ||
| 2611 | static const uint32_t kmask2 = 0x0f0f0f0f; | ||
| 2612 | static const uint32_t kmask3 = 0x03030303; | ||
| 2613 | |||
| 2614 | uint32_t utmp[4]; | ||
| 2615 | |||
| 2616 | |||
| 2617 | #ifdef __ARM_NEON | ||
| 2618 | const uint8x16_t m4b = vdupq_n_u8(0xf); | ||
| 2619 | const uint8x16_t mone = vdupq_n_u8(1); | ||
| 2620 | const uint8x16_t mtwo = vdupq_n_u8(2); | ||
| 2621 | const int32x4_t mzero = vdupq_n_s32(0); | ||
| 2622 | |||
| 2623 | ggml_int8x16x4_t q5bytes; | ||
| 2624 | |||
| 2625 | float sumf = 0; | ||
| 2626 | |||
| 2627 | for (int i = 0; i < nb; ++i) { | ||
| 2628 | |||
| 2629 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 2630 | const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); | ||
| 2631 | |||
| 2632 | const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); | ||
| 2633 | |||
| 2634 | memcpy(utmp, x[i].scales, 12); | ||
| 2635 | utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); | ||
| 2636 | const uint32_t uaux = utmp[1] & kmask1; | ||
| 2637 | utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); | ||
| 2638 | utmp[2] = uaux; | ||
| 2639 | utmp[0] &= kmask1; | ||
| 2640 | |||
| 2641 | const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); | ||
| 2642 | const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); | ||
| 2643 | const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), | ||
| 2644 | vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); | ||
| 2645 | int32_t sumi_mins = vaddvq_s32(prod); | ||
| 2646 | |||
| 2647 | const uint8_t * scales = (const uint8_t *)utmp; | ||
| 2648 | |||
| 2649 | const uint8_t * GGML_RESTRICT q5 = x[i].qs; | ||
| 2650 | const uint8_t * GGML_RESTRICT qh = x[i].qh; | ||
| 2651 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 2652 | |||
| 2653 | ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); | ||
| 2654 | |||
| 2655 | ggml_uint8x16x4_t q5h; | ||
| 2656 | |||
| 2657 | int32_t sumi = 0; | ||
| 2658 | |||
| 2659 | for (int j = 0; j < QK_K/64; ++j) { | ||
| 2660 | |||
| 2661 | const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; | ||
| 2662 | const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 2663 | |||
| 2664 | q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); | ||
| 2665 | q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); | ||
| 2666 | q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); | ||
| 2667 | q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); | ||
| 2668 | qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); | ||
| 2669 | qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); | ||
| 2670 | |||
| 2671 | q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); | ||
| 2672 | q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); | ||
| 2673 | q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); | ||
| 2674 | q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); | ||
| 2675 | |||
| 2676 | sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; | ||
| 2677 | sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; | ||
| 2678 | } | ||
| 2679 | |||
| 2680 | sumf += d * sumi - dmin * sumi_mins; | ||
| 2681 | } | ||
| 2682 | |||
| 2683 | *s = sumf; | ||
| 2684 | |||
| 2685 | #else | ||
| 2686 | UNUSED(x); | ||
| 2687 | UNUSED(y); | ||
| 2688 | UNUSED(nb); | ||
| 2689 | UNUSED(kmask1); | ||
| 2690 | UNUSED(kmask2); | ||
| 2691 | UNUSED(kmask3); | ||
| 2692 | UNUSED(utmp); | ||
| 2693 | ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 2694 | #endif | ||
| 2695 | } | ||
| 2696 | |||
| 2697 | void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 2698 | assert(n % QK_K == 0); | ||
| 2699 | #ifdef __ARM_FEATURE_MATMUL_INT8 | ||
| 2700 | assert((nrc == 2) || (nrc == 1)); | ||
| 2701 | #else | ||
| 2702 | assert(nrc == 1); | ||
| 2703 | #endif | ||
| 2704 | UNUSED(nrc); | ||
| 2705 | UNUSED(bx); | ||
| 2706 | UNUSED(by); | ||
| 2707 | UNUSED(bs); | ||
| 2708 | |||
| 2709 | const block_q6_K * GGML_RESTRICT x = vx; | ||
| 2710 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 2711 | |||
| 2712 | const int nb = n / QK_K; | ||
| 2713 | |||
| 2714 | #ifdef __ARM_FEATURE_SVE | ||
| 2715 | const int vector_length = ggml_cpu_get_sve_cnt()*8; | ||
| 2716 | #endif | ||
| 2717 | #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 2718 | if (nrc == 2) { | ||
| 2719 | const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2); | ||
| 2720 | |||
| 2721 | svfloat32_t sum = svdup_n_f32(0); | ||
| 2722 | |||
| 2723 | const block_q6_K * GGML_RESTRICT vx0 = vx; | ||
| 2724 | const block_q8_K * GGML_RESTRICT vy0 = vy; | ||
| 2725 | const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx); | ||
| 2726 | const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by); | ||
| 2727 | |||
| 2728 | switch (vector_length) { | ||
| 2729 | case 128: | ||
| 2730 | { | ||
| 2731 | const svbool_t pg128_all = svptrue_pat_b8(SV_ALL); | ||
| 2732 | for (int i = 0; i < nb; ++i) { | ||
| 2733 | const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql; | ||
| 2734 | const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh; | ||
| 2735 | const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql; | ||
| 2736 | const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh; | ||
| 2737 | const int8_t * GGML_RESTRICT q80 = vy0[i].qs; | ||
| 2738 | const int8_t * GGML_RESTRICT q81 = vy1[i].qs; | ||
| 2739 | |||
| 2740 | const int8_t * GGML_RESTRICT scale0 = vx0[i].scales; | ||
| 2741 | const int8_t * GGML_RESTRICT scale1 = vx1[i].scales; | ||
| 2742 | |||
| 2743 | svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)); | ||
| 2744 | svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d))); | ||
| 2745 | svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d); | ||
| 2746 | // process q8sum summation 128 bit route | ||
| 2747 | const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums); | ||
| 2748 | const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8); | ||
| 2749 | const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums); | ||
| 2750 | const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8); | ||
| 2751 | const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0); | ||
| 2752 | const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0))); | ||
| 2753 | const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1))); | ||
| 2754 | const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1); | ||
| 2755 | const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0))); | ||
| 2756 | const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1))); | ||
| 2757 | const svint64_t prod = svdup_n_s64(0); | ||
| 2758 | |||
| 2759 | svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02)); | ||
| 2760 | svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12)); | ||
| 2761 | svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2); | ||
| 2762 | svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02)); | ||
| 2763 | svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12)); | ||
| 2764 | svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5); | ||
| 2765 | svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6))); | ||
| 2766 | svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6))); | ||
| 2767 | svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8); | ||
| 2768 | |||
| 2769 | // process mmla | ||
| 2770 | svint8_t l0, l1, r0, r1; | ||
| 2771 | svint32_t isum_tmp = svdup_n_s32(0); | ||
| 2772 | for (int j = 0; j < QK_K/128; ++j) { | ||
| 2773 | for (int k = 0; k < 8; ++k) { | ||
| 2774 | svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2)); | ||
| 2775 | svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2)); | ||
| 2776 | svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4)); | ||
| 2777 | svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4)); | ||
| 2778 | const int ql_pos = (k/4)*4; | ||
| 2779 | svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4); | ||
| 2780 | svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4); | ||
| 2781 | const int qh_pos = (k/2)*2; | ||
| 2782 | svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos); | ||
| 2783 | svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos); | ||
| 2784 | svint8_t q6bytes_0, q6bytes_1; | ||
| 2785 | if (qh_pos <= 4) { | ||
| 2786 | q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos))); | ||
| 2787 | q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos))); | ||
| 2788 | } else { | ||
| 2789 | q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4)))); | ||
| 2790 | q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4)))); | ||
| 2791 | } | ||
| 2792 | svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8)); | ||
| 2793 | svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8)); | ||
| 2794 | l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1))); | ||
| 2795 | l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1))); | ||
| 2796 | r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); | ||
| 2797 | r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); | ||
| 2798 | svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k])); | ||
| 2799 | isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale); | ||
| 2800 | } | ||
| 2801 | qh0 += 32; qh1 += 32; | ||
| 2802 | ql0 += 64; ql1 += 64; | ||
| 2803 | q80 += 128; q81 += 128; | ||
| 2804 | scale0 += 8; scale1 += 8; | ||
| 2805 | } | ||
| 2806 | sum = svmla_f32_x(pg128_all, sum, | ||
| 2807 | svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp, | ||
| 2808 | svisum_mins, svdup_n_s32(-32))), | ||
| 2809 | svsuper_block_scales); | ||
| 2810 | } | ||
| 2811 | } // end of case 128 | ||
| 2812 | break; | ||
| 2813 | case 256: | ||
| 2814 | case 512: | ||
| 2815 | { | ||
| 2816 | const svbool_t pg256_all = svptrue_pat_b8(SV_ALL); | ||
| 2817 | const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4); | ||
| 2818 | for (int i = 0; i < nb; ++i) { | ||
| 2819 | const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql; | ||
| 2820 | const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh; | ||
| 2821 | const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql; | ||
| 2822 | const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh; | ||
| 2823 | const int8_t * GGML_RESTRICT q80 = vy0[i].qs; | ||
| 2824 | const int8_t * GGML_RESTRICT q81 = vy1[i].qs; | ||
| 2825 | |||
| 2826 | const int8_t * GGML_RESTRICT scale0 = vx0[i].scales; | ||
| 2827 | const int8_t * GGML_RESTRICT scale1 = vx1[i].scales; | ||
| 2828 | svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d))); | ||
| 2829 | svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d))); | ||
| 2830 | svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp)); | ||
| 2831 | svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d); | ||
| 2832 | // process q8sum summation 256 bit route | ||
| 2833 | const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums); | ||
| 2834 | const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums); | ||
| 2835 | const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0)); | ||
| 2836 | const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1)); | ||
| 2837 | const svint64_t prod = svdup_n_s64(0); | ||
| 2838 | svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0)); | ||
| 2839 | svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1)); | ||
| 2840 | svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0)); | ||
| 2841 | svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1)); | ||
| 2842 | svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2); | ||
| 2843 | svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4); | ||
| 2844 | svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6))); | ||
| 2845 | svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6))); | ||
| 2846 | svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8); | ||
| 2847 | svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16)); | ||
| 2848 | svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10); | ||
| 2849 | |||
| 2850 | // process mmla | ||
| 2851 | svint8_t l0, l1, r0, r1; | ||
| 2852 | svint32_t isum_tmp = svdup_n_s32(0); | ||
| 2853 | for (int j = 0; j < QK_K/128; ++j) { | ||
| 2854 | for (int k = 0; k < 8; k+=2) { // process 2 block | ||
| 2855 | svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0); | ||
| 2856 | svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1); | ||
| 2857 | svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2)); | ||
| 2858 | svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2)); | ||
| 2859 | const int ql_pos = (k/4)*4; | ||
| 2860 | svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4); | ||
| 2861 | svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4); | ||
| 2862 | const int qh_pos = (k/2)*2; | ||
| 2863 | svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos); | ||
| 2864 | svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos); | ||
| 2865 | svint8_t q6bytes_0, q6bytes_1; | ||
| 2866 | if (qh_pos <= 4) { | ||
| 2867 | q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos))); | ||
| 2868 | q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos))); | ||
| 2869 | } else { | ||
| 2870 | q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4)))); | ||
| 2871 | q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4)))); | ||
| 2872 | } | ||
| 2873 | svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2)); | ||
| 2874 | svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2)); | ||
| 2875 | l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1))); | ||
| 2876 | l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1))); | ||
| 2877 | r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); | ||
| 2878 | r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); | ||
| 2879 | svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k])); | ||
| 2880 | svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1])); | ||
| 2881 | isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0); | ||
| 2882 | isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1); | ||
| 2883 | } | ||
| 2884 | qh0 += 32; qh1 += 32; | ||
| 2885 | ql0 += 64; ql1 += 64; | ||
| 2886 | q80 += 128; q81 += 128; | ||
| 2887 | scale0 += 8; scale1 += 8; | ||
| 2888 | } // end of for | ||
| 2889 | svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4); | ||
| 2890 | isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp); | ||
| 2891 | sum = svmla_f32_x(pg32_4, sum, | ||
| 2892 | svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp, | ||
| 2893 | svisum_mins, svdup_n_s32(-32))), | ||
| 2894 | svsuper_block_scales); | ||
| 2895 | } | ||
| 2896 | } // end of case 256 | ||
| 2897 | break; | ||
| 2898 | default: | ||
| 2899 | assert(false && "Unsupported vector length"); | ||
| 2900 | break; | ||
| 2901 | } // end of switch | ||
| 2902 | |||
| 2903 | svst1_f32(pg32_2, s, sum); | ||
| 2904 | svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8))); | ||
| 2905 | |||
| 2906 | return; | ||
| 2907 | } | ||
| 2908 | #elif defined(__ARM_FEATURE_MATMUL_INT8) | ||
| 2909 | if (nrc == 2) { | ||
| 2910 | const block_q6_K * GGML_RESTRICT x0 = x; | ||
| 2911 | const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx); | ||
| 2912 | const block_q8_K * GGML_RESTRICT y0 = y; | ||
| 2913 | const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by); | ||
| 2914 | |||
| 2915 | float32x4_t vfsum = vdupq_n_f32(0.0f); | ||
| 2916 | |||
| 2917 | for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) { | ||
| 2918 | const uint8_t * GGML_RESTRICT ql0 = x0->ql; | ||
| 2919 | const uint8_t * GGML_RESTRICT ql1 = x1->ql; | ||
| 2920 | const uint8_t * GGML_RESTRICT qh0 = x0->qh; | ||
| 2921 | const uint8_t * GGML_RESTRICT qh1 = x1->qh; | ||
| 2922 | const int8_t * GGML_RESTRICT qy0 = y0->qs; | ||
| 2923 | const int8_t * GGML_RESTRICT qy1 = y1->qs; | ||
| 2924 | |||
| 2925 | const uint8x16_t mone = vdupq_n_u8(0x30); | ||
| 2926 | const uint8x16_t m4b = vdupq_n_u8(0x0f); | ||
| 2927 | |||
| 2928 | int32x4_t visum = vdupq_n_s32(0); | ||
| 2929 | |||
| 2930 | // process 8 blocks per iteration, totally 16 blocks | ||
| 2931 | for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) { | ||
| 2932 | int8x16_t vx0[8], vx1[8]; | ||
| 2933 | |||
| 2934 | // de-quantize vx0[8] | ||
| 2935 | { | ||
| 2936 | const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0); | ||
| 2937 | const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0); | ||
| 2938 | |||
| 2939 | uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); | ||
| 2940 | uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); | ||
| 2941 | uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); | ||
| 2942 | uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); | ||
| 2943 | |||
| 2944 | vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); | ||
| 2945 | vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); | ||
| 2946 | vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); | ||
| 2947 | vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); | ||
| 2948 | |||
| 2949 | q6h_0 = vandq_u8(mone, qh_bits.val[0]); | ||
| 2950 | q6h_1 = vandq_u8(mone, qh_bits.val[1]); | ||
| 2951 | q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); | ||
| 2952 | q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); | ||
| 2953 | |||
| 2954 | vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); | ||
| 2955 | vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); | ||
| 2956 | vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); | ||
| 2957 | vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); | ||
| 2958 | } | ||
| 2959 | |||
| 2960 | // de-quantize vx1[8] | ||
| 2961 | { | ||
| 2962 | const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1); | ||
| 2963 | const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1); | ||
| 2964 | |||
| 2965 | uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); | ||
| 2966 | uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); | ||
| 2967 | uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); | ||
| 2968 | uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); | ||
| 2969 | |||
| 2970 | vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); | ||
| 2971 | vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); | ||
| 2972 | vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); | ||
| 2973 | vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); | ||
| 2974 | |||
| 2975 | q6h_0 = vandq_u8(mone, qh_bits.val[0]); | ||
| 2976 | q6h_1 = vandq_u8(mone, qh_bits.val[1]); | ||
| 2977 | q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); | ||
| 2978 | q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); | ||
| 2979 | |||
| 2980 | vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); | ||
| 2981 | vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); | ||
| 2982 | vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); | ||
| 2983 | vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); | ||
| 2984 | } | ||
| 2985 | |||
| 2986 | // process 16 elements (one block with same scale) per iteration | ||
| 2987 | // - vx = concat(ql, qh) - 32 | ||
| 2988 | // - r1,r2,r3,r4 = smmla(vx, vy) | ||
| 2989 | for (int k = 0; k < 8; ++k) { | ||
| 2990 | const int blk = j * 8 + k; | ||
| 2991 | |||
| 2992 | const int8x16_t vy0 = vld1q_s8(qy0); | ||
| 2993 | const int8x16_t vy1 = vld1q_s8(qy1); | ||
| 2994 | qy0 += 16; | ||
| 2995 | qy1 += 16; | ||
| 2996 | |||
| 2997 | const int32x4_t block_scale = { | ||
| 2998 | x0->scales[blk], | ||
| 2999 | x0->scales[blk], | ||
| 3000 | x1->scales[blk], | ||
| 3001 | x1->scales[blk], | ||
| 3002 | }; | ||
| 3003 | |||
| 3004 | // calculate four results at once with outer product | ||
| 3005 | const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); | ||
| 3006 | const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); | ||
| 3007 | const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); | ||
| 3008 | const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); | ||
| 3009 | int32x4_t vr = vdupq_n_s32(0); | ||
| 3010 | vr = vmmlaq_s32(vr, vx_l, vy_l); | ||
| 3011 | vr = vmmlaq_s32(vr, vx_h, vy_h); | ||
| 3012 | |||
| 3013 | // apply block scale, will NOT overflow | ||
| 3014 | // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits | ||
| 3015 | visum = vmlaq_s32(visum, vr, block_scale); | ||
| 3016 | } | ||
| 3017 | } | ||
| 3018 | |||
| 3019 | // adjust bias, apply superblock scale | ||
| 3020 | { | ||
| 3021 | int32_t bias[4]; | ||
| 3022 | // NEON doesn't support int16 dot product, fallback to separated mul and add | ||
| 3023 | const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums); | ||
| 3024 | const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums); | ||
| 3025 | |||
| 3026 | int8x16_t scales_s8 = vld1q_s8(x0->scales); | ||
| 3027 | const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; | ||
| 3028 | scales_s8 = vld1q_s8(x1->scales); | ||
| 3029 | const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; | ||
| 3030 | |||
| 3031 | int32x4_t prod; | ||
| 3032 | prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])), | ||
| 3033 | vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))), | ||
| 3034 | vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])), | ||
| 3035 | vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1])))); | ||
| 3036 | bias[0] = vaddvq_s32(prod); | ||
| 3037 | prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])), | ||
| 3038 | vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))), | ||
| 3039 | vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])), | ||
| 3040 | vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1])))); | ||
| 3041 | bias[1] = vaddvq_s32(prod); | ||
| 3042 | prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])), | ||
| 3043 | vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))), | ||
| 3044 | vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])), | ||
| 3045 | vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1])))); | ||
| 3046 | bias[2] = vaddvq_s32(prod); | ||
| 3047 | prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])), | ||
| 3048 | vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))), | ||
| 3049 | vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])), | ||
| 3050 | vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1])))); | ||
| 3051 | bias[3] = vaddvq_s32(prod); | ||
| 3052 | |||
| 3053 | const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32); | ||
| 3054 | |||
| 3055 | const float32x4_t superblock_scale = { | ||
| 3056 | GGML_CPU_FP16_TO_FP32(x0->d) * y0->d, | ||
| 3057 | GGML_CPU_FP16_TO_FP32(x0->d) * y1->d, | ||
| 3058 | GGML_CPU_FP16_TO_FP32(x1->d) * y0->d, | ||
| 3059 | GGML_CPU_FP16_TO_FP32(x1->d) * y1->d, | ||
| 3060 | }; | ||
| 3061 | |||
| 3062 | visum = vsubq_s32(visum, vibias); | ||
| 3063 | vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale); | ||
| 3064 | } | ||
| 3065 | } | ||
| 3066 | |||
| 3067 | // vfsum = ABCD -> ACBD | ||
| 3068 | // AC -> s, BD -> (s+bs) | ||
| 3069 | vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2)); | ||
| 3070 | vst1_f32(s, vget_low_f32 (vfsum)); | ||
| 3071 | vst1_f32(s + bs, vget_high_f32(vfsum)); | ||
| 3072 | |||
| 3073 | return; | ||
| 3074 | } | ||
| 3075 | #endif | ||
| 3076 | |||
| 3077 | #ifdef __ARM_FEATURE_SVE | ||
| 3078 | float sum = 0; | ||
| 3079 | svuint8_t m4b = svdup_n_u8(0xf); | ||
| 3080 | svint32_t vzero = svdup_n_s32(0); | ||
| 3081 | svuint8_t mone = svdup_n_u8(0x30); | ||
| 3082 | svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4; | ||
| 3083 | svuint8_t q6h_1, q6h_2, q6h_3, q6h_4; | ||
| 3084 | |||
| 3085 | for (int i = 0; i < nb; ++i) { | ||
| 3086 | const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 3087 | |||
| 3088 | const uint8_t * GGML_RESTRICT q6 = x[i].ql; | ||
| 3089 | const uint8_t * GGML_RESTRICT qh = x[i].qh; | ||
| 3090 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 3091 | |||
| 3092 | const int8_t * GGML_RESTRICT scale = x[i].scales; | ||
| 3093 | |||
| 3094 | const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8); | ||
| 3095 | const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums); | ||
| 3096 | const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8); | ||
| 3097 | const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale)); | ||
| 3098 | const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8)); | ||
| 3099 | const svint64_t prod = svdup_n_s64(0); | ||
| 3100 | int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1), | ||
| 3101 | svdot_s64(prod, q8sums_2, q6scales_2))); | ||
| 3102 | int32_t isum = 0; | ||
| 3103 | |||
| 3104 | switch (vector_length) { | ||
| 3105 | case 128: | ||
| 3106 | { | ||
| 3107 | const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4); | ||
| 3108 | const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16); | ||
| 3109 | svint32_t isum_tmp = svdup_n_s32(0); | ||
| 3110 | for (int j = 0; j < QK_K/128; ++j) { | ||
| 3111 | svuint8_t qhbits_1 = svld1_u8(pg8_16, qh); | ||
| 3112 | svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16); | ||
| 3113 | qh += 32; | ||
| 3114 | svuint8_t q6bits_1 = svld1_u8(pg8_16, q6); | ||
| 3115 | svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16); | ||
| 3116 | svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32); | ||
| 3117 | svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48); | ||
| 3118 | q6 += 64; | ||
| 3119 | svint8_t q8bytes_1 = svld1_s8(pg8_16, q8); | ||
| 3120 | svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16); | ||
| 3121 | svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32); | ||
| 3122 | svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48); | ||
| 3123 | q8 += 64; | ||
| 3124 | |||
| 3125 | q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4)); | ||
| 3126 | q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4)); | ||
| 3127 | q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2)); | ||
| 3128 | q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2)); | ||
| 3129 | q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1)); | ||
| 3130 | q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2)); | ||
| 3131 | q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3)); | ||
| 3132 | q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4)); | ||
| 3133 | isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]); | ||
| 3134 | isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]); | ||
| 3135 | isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]); | ||
| 3136 | isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]); | ||
| 3137 | |||
| 3138 | scale += 4; | ||
| 3139 | q8bytes_1 = svld1_s8(pg8_16, q8); | ||
| 3140 | q8bytes_2 = svld1_s8(pg8_16, q8+16); | ||
| 3141 | q8bytes_3 = svld1_s8(pg8_16, q8+32); | ||
| 3142 | q8bytes_4 = svld1_s8(pg8_16, q8+48); | ||
| 3143 | q8 += 64; | ||
| 3144 | |||
| 3145 | q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1); | ||
| 3146 | q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2); | ||
| 3147 | q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2)); | ||
| 3148 | q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2)); | ||
| 3149 | q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1)); | ||
| 3150 | q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2)); | ||
| 3151 | q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3)); | ||
| 3152 | q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4)); | ||
| 3153 | isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]); | ||
| 3154 | isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]); | ||
| 3155 | isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]); | ||
| 3156 | isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]); | ||
| 3157 | scale += 4; | ||
| 3158 | } | ||
| 3159 | isum += svaddv_s32(pg32_4, isum_tmp); | ||
| 3160 | sum += d_all * y[i].d * (isum - 32 * isum_mins); | ||
| 3161 | } | ||
| 3162 | break; | ||
| 3163 | case 256: | ||
| 3164 | case 512: | ||
| 3165 | { | ||
| 3166 | const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2); | ||
| 3167 | const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8); | ||
| 3168 | const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32); | ||
| 3169 | svint32_t isum_tmp = svdup_n_s32(0); | ||
| 3170 | for (int j = 0; j < QK_K/128; j++) { | ||
| 3171 | svuint8_t qhbits_1 = svld1_u8(pg8_32, qh); | ||
| 3172 | qh += 32; | ||
| 3173 | svuint8_t q6bits_1 = svld1_u8(pg8_32, q6); | ||
| 3174 | svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32); | ||
| 3175 | q6 += 64; | ||
| 3176 | svint8_t q8bytes_1 = svld1_s8(pg8_32, q8); | ||
| 3177 | svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32); | ||
| 3178 | svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64); | ||
| 3179 | svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96); | ||
| 3180 | q8 += 128; | ||
| 3181 | q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4)); | ||
| 3182 | q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2)); | ||
| 3183 | q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1); | ||
| 3184 | q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2)); | ||
| 3185 | q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1)); | ||
| 3186 | q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2)); | ||
| 3187 | q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3)); | ||
| 3188 | q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4)); | ||
| 3189 | |||
| 3190 | svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale); | ||
| 3191 | scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp); | ||
| 3192 | scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp); | ||
| 3193 | svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2); | ||
| 3194 | scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp); | ||
| 3195 | scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp); | ||
| 3196 | svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4); | ||
| 3197 | scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp); | ||
| 3198 | scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp); | ||
| 3199 | svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6); | ||
| 3200 | scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp); | ||
| 3201 | scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp); | ||
| 3202 | svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp)); | ||
| 3203 | svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp)); | ||
| 3204 | svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp)); | ||
| 3205 | svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp)); | ||
| 3206 | |||
| 3207 | isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1); | ||
| 3208 | isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2); | ||
| 3209 | isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3); | ||
| 3210 | isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4); | ||
| 3211 | scale += 8; | ||
| 3212 | } | ||
| 3213 | isum += svaddv_s32(pg32_8, isum_tmp); | ||
| 3214 | sum += d_all * y[i].d * (isum - 32 * isum_mins); | ||
| 3215 | } | ||
| 3216 | break; | ||
| 3217 | default: | ||
| 3218 | assert(false && "Unsupported vector length"); | ||
| 3219 | break; | ||
| 3220 | } | ||
| 3221 | } | ||
| 3222 | |||
| 3223 | *s = sum; | ||
| 3224 | |||
| 3225 | #elif __ARM_NEON | ||
| 3226 | float sum = 0; | ||
| 3227 | |||
| 3228 | const uint8x16_t m4b = vdupq_n_u8(0xF); | ||
| 3229 | const int32x4_t vzero = vdupq_n_s32(0); | ||
| 3230 | //const int8x16_t m32s = vdupq_n_s8(32); | ||
| 3231 | |||
| 3232 | const uint8x16_t mone = vdupq_n_u8(3); | ||
| 3233 | |||
| 3234 | ggml_int8x16x4_t q6bytes; | ||
| 3235 | ggml_uint8x16x4_t q6h; | ||
| 3236 | |||
| 3237 | for (int i = 0; i < nb; ++i) { | ||
| 3238 | |||
| 3239 | const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 3240 | |||
| 3241 | const uint8_t * GGML_RESTRICT q6 = x[i].ql; | ||
| 3242 | const uint8_t * GGML_RESTRICT qh = x[i].qh; | ||
| 3243 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 3244 | |||
| 3245 | const int8_t * GGML_RESTRICT scale = x[i].scales; | ||
| 3246 | |||
| 3247 | const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); | ||
| 3248 | const int8x16_t scales = vld1q_s8(scale); | ||
| 3249 | const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; | ||
| 3250 | |||
| 3251 | const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), | ||
| 3252 | vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), | ||
| 3253 | vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), | ||
| 3254 | vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); | ||
| 3255 | int32_t isum_mins = vaddvq_s32(prod); | ||
| 3256 | |||
| 3257 | int32_t isum = 0; | ||
| 3258 | |||
| 3259 | for (int j = 0; j < QK_K/128; ++j) { | ||
| 3260 | |||
| 3261 | ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; | ||
| 3262 | ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; | ||
| 3263 | ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3264 | |||
| 3265 | q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); | ||
| 3266 | q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); | ||
| 3267 | uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); | ||
| 3268 | q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); | ||
| 3269 | shifted = vshrq_n_u8(qhbits.val[1], 2); | ||
| 3270 | q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); | ||
| 3271 | |||
| 3272 | //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); | ||
| 3273 | //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); | ||
| 3274 | //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); | ||
| 3275 | //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); | ||
| 3276 | q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); | ||
| 3277 | q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); | ||
| 3278 | q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); | ||
| 3279 | q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); | ||
| 3280 | |||
| 3281 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + | ||
| 3282 | vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + | ||
| 3283 | vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + | ||
| 3284 | vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; | ||
| 3285 | |||
| 3286 | scale += 4; | ||
| 3287 | |||
| 3288 | q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3289 | |||
| 3290 | shifted = vshrq_n_u8(qhbits.val[0], 4); | ||
| 3291 | q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); | ||
| 3292 | shifted = vshrq_n_u8(qhbits.val[1], 4); | ||
| 3293 | q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); | ||
| 3294 | shifted = vshrq_n_u8(qhbits.val[0], 6); | ||
| 3295 | q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); | ||
| 3296 | shifted = vshrq_n_u8(qhbits.val[1], 6); | ||
| 3297 | q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); | ||
| 3298 | |||
| 3299 | //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); | ||
| 3300 | //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); | ||
| 3301 | //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); | ||
| 3302 | //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); | ||
| 3303 | q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); | ||
| 3304 | q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); | ||
| 3305 | q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); | ||
| 3306 | q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); | ||
| 3307 | |||
| 3308 | isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + | ||
| 3309 | vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + | ||
| 3310 | vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + | ||
| 3311 | vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; | ||
| 3312 | scale += 4; | ||
| 3313 | } | ||
| 3314 | //sum += isum * d_all * y[i].d; | ||
| 3315 | sum += d_all * y[i].d * (isum - 32 * isum_mins); | ||
| 3316 | |||
| 3317 | } | ||
| 3318 | *s = sum; | ||
| 3319 | #else | ||
| 3320 | UNUSED(x); | ||
| 3321 | UNUSED(y); | ||
| 3322 | UNUSED(nb); | ||
| 3323 | ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 3324 | #endif | ||
| 3325 | } | ||
| 3326 | |||
| 3327 | #if defined (__ARM_NEON) | ||
| 3328 | static const int8_t keven_signs_q2xs[1024] = { | ||
| 3329 | 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, | ||
| 3330 | 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, | ||
| 3331 | 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, | ||
| 3332 | 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, | ||
| 3333 | 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, | ||
| 3334 | 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, | ||
| 3335 | 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, | ||
| 3336 | 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, | ||
| 3337 | 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, | ||
| 3338 | 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, | ||
| 3339 | 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, | ||
| 3340 | 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, | ||
| 3341 | 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, | ||
| 3342 | 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, | ||
| 3343 | 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, | ||
| 3344 | 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, | ||
| 3345 | 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, | ||
| 3346 | 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, | ||
| 3347 | 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, | ||
| 3348 | 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, | ||
| 3349 | 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, | ||
| 3350 | 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, | ||
| 3351 | 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, | ||
| 3352 | 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, | ||
| 3353 | 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, | ||
| 3354 | 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, | ||
| 3355 | 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, | ||
| 3356 | 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, | ||
| 3357 | 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, | ||
| 3358 | 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, | ||
| 3359 | 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, | ||
| 3360 | 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, | ||
| 3361 | }; | ||
| 3362 | #endif | ||
| 3363 | |||
| 3364 | void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3365 | assert(n % QK_K == 0); | ||
| 3366 | assert(nrc == 1); | ||
| 3367 | UNUSED(nrc); | ||
| 3368 | UNUSED(bx); | ||
| 3369 | UNUSED(by); | ||
| 3370 | UNUSED(bs); | ||
| 3371 | |||
| 3372 | const block_iq2_xxs * GGML_RESTRICT x = vx; | ||
| 3373 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 3374 | |||
| 3375 | const int nb = n / QK_K; | ||
| 3376 | |||
| 3377 | #if defined(__ARM_NEON) | ||
| 3378 | |||
| 3379 | const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; | ||
| 3380 | |||
| 3381 | uint32_t aux32[4]; | ||
| 3382 | const uint8_t * aux8 = (const uint8_t *)aux32; | ||
| 3383 | |||
| 3384 | ggml_int8x16x4_t q2u; | ||
| 3385 | ggml_int8x16x4_t q2s; | ||
| 3386 | ggml_int8x16x4_t q8b; | ||
| 3387 | |||
| 3388 | float sumf = 0; | ||
| 3389 | for (int i = 0; i < nb; ++i) { | ||
| 3390 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 3391 | const uint16_t * GGML_RESTRICT q2 = x[i].qs; | ||
| 3392 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 3393 | float sumf1 = 0, sumf2 = 0; | ||
| 3394 | for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { | ||
| 3395 | q8b = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3396 | memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; | ||
| 3397 | q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); | ||
| 3398 | q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); | ||
| 3399 | q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); | ||
| 3400 | q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); | ||
| 3401 | q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); | ||
| 3402 | q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); | ||
| 3403 | q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); | ||
| 3404 | q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); | ||
| 3405 | q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); | ||
| 3406 | q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); | ||
| 3407 | q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); | ||
| 3408 | q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); | ||
| 3409 | const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); | ||
| 3410 | const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); | ||
| 3411 | sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); | ||
| 3412 | sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); | ||
| 3413 | } | ||
| 3414 | sumf += d*(sumf1 + sumf2); | ||
| 3415 | } | ||
| 3416 | *s = 0.25f * sumf; | ||
| 3417 | |||
| 3418 | #else | ||
| 3419 | UNUSED(x); | ||
| 3420 | UNUSED(y); | ||
| 3421 | UNUSED(nb); | ||
| 3422 | ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 3423 | #endif | ||
| 3424 | } | ||
| 3425 | |||
| 3426 | void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3427 | assert(n % QK_K == 0); | ||
| 3428 | assert(nrc == 1); | ||
| 3429 | UNUSED(nrc); | ||
| 3430 | UNUSED(bx); | ||
| 3431 | UNUSED(by); | ||
| 3432 | UNUSED(bs); | ||
| 3433 | |||
| 3434 | const block_iq2_xs * GGML_RESTRICT x = vx; | ||
| 3435 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 3436 | |||
| 3437 | const int nb = n / QK_K; | ||
| 3438 | |||
| 3439 | #if defined(__ARM_NEON) | ||
| 3440 | |||
| 3441 | const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; | ||
| 3442 | |||
| 3443 | ggml_int8x16x4_t q2u; | ||
| 3444 | ggml_int8x16x4_t q2s; | ||
| 3445 | ggml_int8x16x4_t q8b; | ||
| 3446 | |||
| 3447 | int32x4x4_t scales32; | ||
| 3448 | |||
| 3449 | float sumf = 0; | ||
| 3450 | for (int i = 0; i < nb; ++i) { | ||
| 3451 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 3452 | const uint16_t * GGML_RESTRICT q2 = x[i].qs; | ||
| 3453 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 3454 | const uint8x8_t scales8 = vld1_u8(x[i].scales); | ||
| 3455 | const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); | ||
| 3456 | const uint8x8_t scales_h = vshr_n_u8(scales8, 4); | ||
| 3457 | uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); | ||
| 3458 | scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); | ||
| 3459 | const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); | ||
| 3460 | const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); | ||
| 3461 | scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); | ||
| 3462 | scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); | ||
| 3463 | scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); | ||
| 3464 | scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); | ||
| 3465 | int32x4_t sumi = vdupq_n_s32(0); | ||
| 3466 | for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { | ||
| 3467 | q8b = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3468 | q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); | ||
| 3469 | q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); | ||
| 3470 | q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); | ||
| 3471 | q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); | ||
| 3472 | q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); | ||
| 3473 | q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); | ||
| 3474 | q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); | ||
| 3475 | q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); | ||
| 3476 | q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); | ||
| 3477 | q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); | ||
| 3478 | q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); | ||
| 3479 | q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); | ||
| 3480 | const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); | ||
| 3481 | const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); | ||
| 3482 | const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); | ||
| 3483 | const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); | ||
| 3484 | const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); | ||
| 3485 | sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); | ||
| 3486 | q2 += 8; | ||
| 3487 | } | ||
| 3488 | sumf += d*vaddvq_s32(sumi); | ||
| 3489 | } | ||
| 3490 | *s = 0.125f * sumf; | ||
| 3491 | |||
| 3492 | #else | ||
| 3493 | UNUSED(x); | ||
| 3494 | UNUSED(y); | ||
| 3495 | UNUSED(nb); | ||
| 3496 | ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 3497 | #endif | ||
| 3498 | } | ||
| 3499 | |||
| 3500 | void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3501 | assert(n % QK_K == 0); | ||
| 3502 | assert(nrc == 1); | ||
| 3503 | UNUSED(nrc); | ||
| 3504 | UNUSED(bx); | ||
| 3505 | UNUSED(by); | ||
| 3506 | UNUSED(bs); | ||
| 3507 | |||
| 3508 | const block_iq2_s * GGML_RESTRICT x = vx; | ||
| 3509 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 3510 | |||
| 3511 | const int nb = n / QK_K; | ||
| 3512 | |||
| 3513 | #if defined(__ARM_NEON) | ||
| 3514 | |||
| 3515 | static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, | ||
| 3516 | 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 | ||
| 3517 | }; | ||
| 3518 | |||
| 3519 | static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; | ||
| 3520 | |||
| 3521 | const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); | ||
| 3522 | const uint8x16_t mask2 = vld1q_u8(k_mask2); | ||
| 3523 | const uint8x16_t m1 = vdupq_n_u8(1); | ||
| 3524 | const int32x4_t vzero = vdupq_n_s32(0); | ||
| 3525 | |||
| 3526 | uint8x16x2_t vs; | ||
| 3527 | ggml_int8x16x4_t q2s; | ||
| 3528 | ggml_int8x16x4_t q8b; | ||
| 3529 | |||
| 3530 | float sumf = 0; | ||
| 3531 | for (int i = 0; i < nb; ++i) { | ||
| 3532 | |||
| 3533 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 3534 | |||
| 3535 | const uint8_t * GGML_RESTRICT qs = x[i].qs; | ||
| 3536 | const uint8_t * GGML_RESTRICT qh = x[i].qh; | ||
| 3537 | const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); | ||
| 3538 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 3539 | |||
| 3540 | int sumi1 = 0, sumi2 = 0; | ||
| 3541 | for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { | ||
| 3542 | q8b = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3543 | q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), | ||
| 3544 | vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); | ||
| 3545 | q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), | ||
| 3546 | vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); | ||
| 3547 | q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), | ||
| 3548 | vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); | ||
| 3549 | q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), | ||
| 3550 | vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); | ||
| 3551 | qs += 8; | ||
| 3552 | |||
| 3553 | vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); | ||
| 3554 | vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); | ||
| 3555 | vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); | ||
| 3556 | vs.val[0] = vceqq_u8(vs.val[0], mask2); | ||
| 3557 | vs.val[1] = vceqq_u8(vs.val[1], mask2); | ||
| 3558 | |||
| 3559 | q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); | ||
| 3560 | q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); | ||
| 3561 | |||
| 3562 | vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); | ||
| 3563 | vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); | ||
| 3564 | vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); | ||
| 3565 | vs.val[0] = vceqq_u8(vs.val[0], mask2); | ||
| 3566 | vs.val[1] = vceqq_u8(vs.val[1], mask2); | ||
| 3567 | |||
| 3568 | signs += 4; | ||
| 3569 | |||
| 3570 | q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); | ||
| 3571 | q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); | ||
| 3572 | |||
| 3573 | const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); | ||
| 3574 | const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); | ||
| 3575 | const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); | ||
| 3576 | const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); | ||
| 3577 | |||
| 3578 | sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); | ||
| 3579 | sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); | ||
| 3580 | sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); | ||
| 3581 | sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); | ||
| 3582 | } | ||
| 3583 | sumf += d*(sumi1 + sumi2); | ||
| 3584 | } | ||
| 3585 | |||
| 3586 | *s = 0.125f * sumf; | ||
| 3587 | |||
| 3588 | #else | ||
| 3589 | UNUSED(x); | ||
| 3590 | UNUSED(y); | ||
| 3591 | UNUSED(nb); | ||
| 3592 | ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 3593 | #endif | ||
| 3594 | |||
| 3595 | } | ||
| 3596 | |||
| 3597 | void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3598 | assert(n % QK_K == 0); | ||
| 3599 | assert(nrc == 1); | ||
| 3600 | UNUSED(nrc); | ||
| 3601 | UNUSED(bx); | ||
| 3602 | UNUSED(by); | ||
| 3603 | UNUSED(bs); | ||
| 3604 | |||
| 3605 | const block_iq3_xxs * GGML_RESTRICT x = vx; | ||
| 3606 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 3607 | |||
| 3608 | const int nb = n / QK_K; | ||
| 3609 | |||
| 3610 | #if defined(__ARM_NEON) | ||
| 3611 | |||
| 3612 | const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; | ||
| 3613 | |||
| 3614 | uint32_t aux32[2]; | ||
| 3615 | |||
| 3616 | ggml_int8x16x4_t q3s; | ||
| 3617 | ggml_int8x16x4_t q8b; | ||
| 3618 | |||
| 3619 | float sumf = 0; | ||
| 3620 | for (int i = 0; i < nb; ++i) { | ||
| 3621 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 3622 | const uint8_t * GGML_RESTRICT q3 = x[i].qs; | ||
| 3623 | const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; | ||
| 3624 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 3625 | float sumf1 = 0, sumf2 = 0; | ||
| 3626 | for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { | ||
| 3627 | q8b = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3628 | memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); | ||
| 3629 | const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); | ||
| 3630 | const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); | ||
| 3631 | const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); | ||
| 3632 | const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); | ||
| 3633 | q3 += 16; | ||
| 3634 | q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); | ||
| 3635 | q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); | ||
| 3636 | q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); | ||
| 3637 | q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); | ||
| 3638 | q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); | ||
| 3639 | q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); | ||
| 3640 | q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); | ||
| 3641 | q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); | ||
| 3642 | const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); | ||
| 3643 | const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); | ||
| 3644 | sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); | ||
| 3645 | sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); | ||
| 3646 | } | ||
| 3647 | sumf += d*(sumf1 + sumf2); | ||
| 3648 | } | ||
| 3649 | *s = 0.5f * sumf; | ||
| 3650 | |||
| 3651 | #else | ||
| 3652 | UNUSED(x); | ||
| 3653 | UNUSED(y); | ||
| 3654 | UNUSED(nb); | ||
| 3655 | ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 3656 | #endif | ||
| 3657 | } | ||
| 3658 | |||
| 3659 | void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3660 | assert(n % QK_K == 0); | ||
| 3661 | assert(nrc == 1); | ||
| 3662 | UNUSED(nrc); | ||
| 3663 | UNUSED(bx); | ||
| 3664 | UNUSED(by); | ||
| 3665 | UNUSED(bs); | ||
| 3666 | |||
| 3667 | const block_iq3_s * GGML_RESTRICT x = vx; | ||
| 3668 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 3669 | |||
| 3670 | const int nb = n / QK_K; | ||
| 3671 | |||
| 3672 | #if defined(__ARM_NEON) | ||
| 3673 | |||
| 3674 | typedef union { | ||
| 3675 | uint16x8_t vec_index; | ||
| 3676 | uint16_t index[8]; | ||
| 3677 | } vec_index_t; | ||
| 3678 | |||
| 3679 | static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, | ||
| 3680 | 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 | ||
| 3681 | }; | ||
| 3682 | |||
| 3683 | static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; | ||
| 3684 | |||
| 3685 | static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; | ||
| 3686 | |||
| 3687 | const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); | ||
| 3688 | const uint8x16_t mask2 = vld1q_u8(k_mask2); | ||
| 3689 | |||
| 3690 | const int16x8_t hshift = vld1q_s16(k_shift); | ||
| 3691 | const uint16x8_t m256 = vdupq_n_u16(256); | ||
| 3692 | const uint8x16_t m1 = vdupq_n_u8(1); | ||
| 3693 | |||
| 3694 | uint8x16x2_t vs; | ||
| 3695 | ggml_int8x16x4_t q3s; | ||
| 3696 | ggml_int8x16x4_t q8b; | ||
| 3697 | vec_index_t idx; | ||
| 3698 | |||
| 3699 | uint32_t scales32[2]; | ||
| 3700 | const uint8_t * scales8 = (const uint8_t *)scales32; | ||
| 3701 | |||
| 3702 | float sumf = 0; | ||
| 3703 | for (int i = 0; i < nb; ++i) { | ||
| 3704 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 3705 | const uint8_t * GGML_RESTRICT qs = x[i].qs; | ||
| 3706 | const uint8_t * GGML_RESTRICT qh = x[i].qh; | ||
| 3707 | const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; | ||
| 3708 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 3709 | |||
| 3710 | memcpy(scales32, x[i].scales, 4); | ||
| 3711 | scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; | ||
| 3712 | scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; | ||
| 3713 | |||
| 3714 | int sumi1 = 0, sumi2 = 0; | ||
| 3715 | for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { | ||
| 3716 | q8b = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3717 | |||
| 3718 | const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; | ||
| 3719 | idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); | ||
| 3720 | const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], | ||
| 3721 | iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); | ||
| 3722 | const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], | ||
| 3723 | iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); | ||
| 3724 | idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); | ||
| 3725 | const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], | ||
| 3726 | iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); | ||
| 3727 | const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], | ||
| 3728 | iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); | ||
| 3729 | |||
| 3730 | |||
| 3731 | vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); | ||
| 3732 | vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); | ||
| 3733 | vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); | ||
| 3734 | vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); | ||
| 3735 | vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); | ||
| 3736 | |||
| 3737 | q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); | ||
| 3738 | q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); | ||
| 3739 | |||
| 3740 | vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); | ||
| 3741 | vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); | ||
| 3742 | vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); | ||
| 3743 | vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); | ||
| 3744 | vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); | ||
| 3745 | |||
| 3746 | signs += 4; | ||
| 3747 | |||
| 3748 | q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); | ||
| 3749 | q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); | ||
| 3750 | |||
| 3751 | const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); | ||
| 3752 | const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); | ||
| 3753 | |||
| 3754 | sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; | ||
| 3755 | sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; | ||
| 3756 | } | ||
| 3757 | sumf += d*(sumi1 + sumi2); | ||
| 3758 | } | ||
| 3759 | *s = sumf; | ||
| 3760 | |||
| 3761 | #else | ||
| 3762 | UNUSED(x); | ||
| 3763 | UNUSED(y); | ||
| 3764 | UNUSED(nb); | ||
| 3765 | ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 3766 | #endif | ||
| 3767 | } | ||
| 3768 | |||
| 3769 | void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3770 | assert(n % QK_K == 0); | ||
| 3771 | assert(nrc == 1); | ||
| 3772 | UNUSED(nrc); | ||
| 3773 | UNUSED(bx); | ||
| 3774 | UNUSED(by); | ||
| 3775 | UNUSED(bs); | ||
| 3776 | |||
| 3777 | const block_iq1_s * GGML_RESTRICT x = vx; | ||
| 3778 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 3779 | |||
| 3780 | const int nb = n / QK_K; | ||
| 3781 | |||
| 3782 | #if defined __ARM_NEON | ||
| 3783 | |||
| 3784 | ggml_int8x16x4_t q1b; | ||
| 3785 | ggml_int8x16x4_t q8b; | ||
| 3786 | |||
| 3787 | float sumf = 0; | ||
| 3788 | for (int i = 0; i < nb; ++i) { | ||
| 3789 | |||
| 3790 | const int8_t * q8 = y[i].qs; | ||
| 3791 | const uint8_t * qs = x[i].qs; | ||
| 3792 | const uint16_t * qh = x[i].qh; | ||
| 3793 | |||
| 3794 | int sumi1 = 0, sumi2 = 0, sumi3 = 0; | ||
| 3795 | |||
| 3796 | for (int ib = 0; ib < QK_K/32; ib += 2) { | ||
| 3797 | |||
| 3798 | q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), | ||
| 3799 | vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); | ||
| 3800 | q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), | ||
| 3801 | vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); | ||
| 3802 | q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), | ||
| 3803 | vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); | ||
| 3804 | q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), | ||
| 3805 | vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); | ||
| 3806 | qs += 8; | ||
| 3807 | |||
| 3808 | q8b = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3809 | |||
| 3810 | const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); | ||
| 3811 | const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); | ||
| 3812 | |||
| 3813 | const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; | ||
| 3814 | const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; | ||
| 3815 | sumi1 += vaddvq_s32(p1) * ls1; | ||
| 3816 | sumi2 += vaddvq_s32(p2) * ls2; | ||
| 3817 | sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) | ||
| 3818 | + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); | ||
| 3819 | |||
| 3820 | } | ||
| 3821 | |||
| 3822 | sumf += y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); | ||
| 3823 | } | ||
| 3824 | |||
| 3825 | *s = sumf; | ||
| 3826 | |||
| 3827 | #else | ||
| 3828 | UNUSED(x); | ||
| 3829 | UNUSED(y); | ||
| 3830 | UNUSED(nb); | ||
| 3831 | ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 3832 | #endif | ||
| 3833 | } | ||
| 3834 | |||
| 3835 | void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3836 | assert(n % QK_K == 0); | ||
| 3837 | assert(nrc == 1); | ||
| 3838 | UNUSED(nrc); | ||
| 3839 | UNUSED(bx); | ||
| 3840 | UNUSED(by); | ||
| 3841 | UNUSED(bs); | ||
| 3842 | |||
| 3843 | const block_iq1_m * GGML_RESTRICT x = vx; | ||
| 3844 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 3845 | |||
| 3846 | const int nb = n / QK_K; | ||
| 3847 | |||
| 3848 | iq1m_scale_t scale; | ||
| 3849 | |||
| 3850 | #if defined __ARM_NEON | ||
| 3851 | const int32x4_t mask = vdupq_n_s32(0x7); | ||
| 3852 | const int32x4_t mone = vdupq_n_s32(1); | ||
| 3853 | const int32x4_t mzero = vdupq_n_s32(0); | ||
| 3854 | |||
| 3855 | ggml_int8x16x4_t deltas; | ||
| 3856 | deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); | ||
| 3857 | deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); | ||
| 3858 | deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); | ||
| 3859 | deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); | ||
| 3860 | |||
| 3861 | ggml_int8x16x4_t q1b; | ||
| 3862 | ggml_int8x16x4_t q8b; | ||
| 3863 | |||
| 3864 | uint32_t aux32; | ||
| 3865 | const uint8_t * aux8 = (const uint8_t *)&aux32; | ||
| 3866 | |||
| 3867 | float sumf = 0; | ||
| 3868 | for (int i = 0; i < nb; ++i) { | ||
| 3869 | |||
| 3870 | const int8_t * q8 = y[i].qs; | ||
| 3871 | const uint8_t * qs = x[i].qs; | ||
| 3872 | const uint8_t * qh = x[i].qh; | ||
| 3873 | const uint16_t * sc = (const uint16_t *)x[i].scales; | ||
| 3874 | |||
| 3875 | scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); | ||
| 3876 | |||
| 3877 | int32x4_t sumi1 = mzero; | ||
| 3878 | int32x4_t sumi2 = mzero; | ||
| 3879 | |||
| 3880 | for (int ib = 0; ib < QK_K/32; ib += 2) { | ||
| 3881 | |||
| 3882 | q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), | ||
| 3883 | vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); | ||
| 3884 | q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), | ||
| 3885 | vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); | ||
| 3886 | q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), | ||
| 3887 | vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); | ||
| 3888 | q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), | ||
| 3889 | vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); | ||
| 3890 | |||
| 3891 | q8b = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 3892 | |||
| 3893 | const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); | ||
| 3894 | const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); | ||
| 3895 | const int32x4_t p12 = vpaddq_s32(p1, p2); | ||
| 3896 | |||
| 3897 | const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that | ||
| 3898 | aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); | ||
| 3899 | |||
| 3900 | const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); | ||
| 3901 | const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); | ||
| 3902 | const int32x4_t p34 = vpaddq_s32(p3, p4); | ||
| 3903 | |||
| 3904 | int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); | ||
| 3905 | |||
| 3906 | scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); | ||
| 3907 | |||
| 3908 | sumi1 = vmlaq_s32(sumi1, scales_4, p12); | ||
| 3909 | sumi2 = vmlaq_s32(sumi2, scales_4, p34); | ||
| 3910 | |||
| 3911 | qs += 8; qh += 4; | ||
| 3912 | |||
| 3913 | } | ||
| 3914 | |||
| 3915 | sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); | ||
| 3916 | } | ||
| 3917 | |||
| 3918 | *s = sumf; | ||
| 3919 | |||
| 3920 | #else | ||
| 3921 | UNUSED(x); | ||
| 3922 | UNUSED(y); | ||
| 3923 | UNUSED(nb); | ||
| 3924 | UNUSED(scale); | ||
| 3925 | ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 3926 | #endif | ||
| 3927 | } | ||
| 3928 | |||
| 3929 | void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3930 | assert(nrc == 1); | ||
| 3931 | UNUSED(nrc); | ||
| 3932 | UNUSED(bx); | ||
| 3933 | UNUSED(by); | ||
| 3934 | UNUSED(bs); | ||
| 3935 | assert(n % QK4_NL == 0); | ||
| 3936 | static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); | ||
| 3937 | |||
| 3938 | const block_iq4_nl * GGML_RESTRICT x = vx; | ||
| 3939 | const block_q8_0 * GGML_RESTRICT y = vy; | ||
| 3940 | |||
| 3941 | const int nb = n / QK4_NL; | ||
| 3942 | |||
| 3943 | int ib = 0; | ||
| 3944 | float sumf = 0; | ||
| 3945 | |||
| 3946 | #if defined __ARM_NEON | ||
| 3947 | const int8x16_t values = vld1q_s8(kvalues_iq4nl); | ||
| 3948 | const uint8x16_t m4b = vdupq_n_u8(0x0f); | ||
| 3949 | uint8x16x2_t q4bits; | ||
| 3950 | int8x16x4_t q4b; | ||
| 3951 | int8x16x4_t q8b; | ||
| 3952 | int32x4_t prod_1, prod_2; | ||
| 3953 | |||
| 3954 | for (; ib + 1 < nb; ib += 2) { | ||
| 3955 | |||
| 3956 | q4bits.val[0] = vld1q_u8(x[ib + 0].qs); | ||
| 3957 | q4bits.val[1] = vld1q_u8(x[ib + 1].qs); | ||
| 3958 | q8b.val[0] = vld1q_s8(y[ib + 0].qs); | ||
| 3959 | q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); | ||
| 3960 | q8b.val[2] = vld1q_s8(y[ib + 1].qs); | ||
| 3961 | q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); | ||
| 3962 | |||
| 3963 | q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); | ||
| 3964 | q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); | ||
| 3965 | q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); | ||
| 3966 | q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); | ||
| 3967 | |||
| 3968 | prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); | ||
| 3969 | prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); | ||
| 3970 | |||
| 3971 | sumf += | ||
| 3972 | GGML_CPU_FP16_TO_FP32(x[ib+0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + | ||
| 3973 | GGML_CPU_FP16_TO_FP32(x[ib+1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); | ||
| 3974 | } | ||
| 3975 | |||
| 3976 | #endif | ||
| 3977 | for (; ib < nb; ++ib) { | ||
| 3978 | const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d); | ||
| 3979 | int sumi1 = 0, sumi2 = 0; | ||
| 3980 | for (int j = 0; j < QK4_NL/2; ++j) { | ||
| 3981 | sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; | ||
| 3982 | sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; | ||
| 3983 | } | ||
| 3984 | sumf += d * (sumi1 + sumi2); | ||
| 3985 | } | ||
| 3986 | *s = sumf; | ||
| 3987 | } | ||
| 3988 | |||
| 3989 | void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 3990 | assert(nrc == 1); | ||
| 3991 | UNUSED(nrc); | ||
| 3992 | UNUSED(bx); | ||
| 3993 | UNUSED(by); | ||
| 3994 | UNUSED(bs); | ||
| 3995 | assert(n % QK_K == 0); | ||
| 3996 | |||
| 3997 | const block_iq4_xs * GGML_RESTRICT x = vx; | ||
| 3998 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 3999 | |||
| 4000 | const int nb = n / QK_K; | ||
| 4001 | |||
| 4002 | #if defined __ARM_NEON | ||
| 4003 | const int8x16_t values = vld1q_s8(kvalues_iq4nl); | ||
| 4004 | const uint8x16_t m4b = vdupq_n_u8(0x0f); | ||
| 4005 | ggml_uint8x16x2_t q4bits; | ||
| 4006 | ggml_int8x16x4_t q4b; | ||
| 4007 | ggml_int8x16x4_t q8b; | ||
| 4008 | int32x4_t prod_1, prod_2; | ||
| 4009 | |||
| 4010 | float sumf = 0; | ||
| 4011 | |||
| 4012 | for (int ibl = 0; ibl < nb; ++ibl) { | ||
| 4013 | |||
| 4014 | const int8_t * q8 = y[ibl].qs; | ||
| 4015 | const uint8_t * q4 = x[ibl].qs; | ||
| 4016 | uint16_t h = x[ibl].scales_h; | ||
| 4017 | |||
| 4018 | int sumi1 = 0, sumi2 = 0; | ||
| 4019 | for (int ib = 0; ib < QK_K/64; ++ib) { | ||
| 4020 | |||
| 4021 | q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; | ||
| 4022 | q8b = ggml_vld1q_s8_x4(q8); q8 += 64; | ||
| 4023 | |||
| 4024 | q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); | ||
| 4025 | q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); | ||
| 4026 | q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); | ||
| 4027 | q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); | ||
| 4028 | |||
| 4029 | prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); | ||
| 4030 | prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); | ||
| 4031 | |||
| 4032 | int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; | ||
| 4033 | int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; | ||
| 4034 | h >>= 4; | ||
| 4035 | sumi1 += vaddvq_s32(prod_1) * ls1; | ||
| 4036 | sumi2 += vaddvq_s32(prod_2) * ls2; | ||
| 4037 | |||
| 4038 | } | ||
| 4039 | |||
| 4040 | sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); | ||
| 4041 | } | ||
| 4042 | |||
| 4043 | *s = sumf; | ||
| 4044 | |||
| 4045 | #else | ||
| 4046 | UNUSED(x); | ||
| 4047 | UNUSED(y); | ||
| 4048 | UNUSED(nb); | ||
| 4049 | ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 4050 | #endif | ||
| 4051 | } | ||
| 4052 | |||
