diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cpu/arch/wasm')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c | 1221 |
1 files changed, 1221 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c b/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c new file mode 100644 index 0000000..74a359e --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c | |||
| @@ -0,0 +1,1221 @@ | |||
| 1 | #define GGML_COMMON_IMPL_C | ||
| 2 | #include "ggml-common.h" | ||
| 3 | #include "ggml-quants.h" | ||
| 4 | #include "ggml-impl.h" | ||
| 5 | #include "ggml-cpu.h" | ||
| 6 | #include "simd-mappings.h" | ||
| 7 | |||
| 8 | #include "../../quants.h" | ||
| 9 | #include "../../ggml-cpu-impl.h" | ||
| 10 | |||
| 11 | #include <math.h> | ||
| 12 | #include <string.h> | ||
| 13 | #include <assert.h> | ||
| 14 | #include <float.h> | ||
| 15 | #include <stdlib.h> // for qsort | ||
| 16 | #include <stdio.h> // for GGML_ASSERT | ||
| 17 | |||
| 18 | #define GROUP_MAX_EPS 1e-15f | ||
| 19 | #define GROUP_MAX_EPS_IQ3_XXS 1e-8f | ||
| 20 | #define GROUP_MAX_EPS_IQ2_S 1e-8f | ||
| 21 | #define GROUP_MAX_EPS_IQ1_M 1e-7f | ||
| 22 | #define GROUP_MAX_EPS_IQ1_S 1e-12f | ||
| 23 | |||
| 24 | #define UNUSED GGML_UNUSED | ||
| 25 | |||
| 26 | #if defined(__wasm_simd128__) | ||
| 27 | #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s | ||
| 28 | #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) | ||
| 29 | #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) | ||
| 30 | #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) | ||
| 31 | #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) | ||
| 32 | #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) | ||
| 33 | #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) | ||
| 34 | #define B8(c,s ) B7(c,s, c), B7(c,s, s) | ||
| 35 | |||
| 36 | // precomputed tables for expanding 8bits to 8 bytes: | ||
| 37 | static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 | ||
| 38 | static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 | ||
| 39 | #endif | ||
| 40 | |||
| 41 | void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { | ||
| 42 | assert(QK8_0 == 32); | ||
| 43 | assert(k % QK8_0 == 0); | ||
| 44 | const int nb = k / QK8_0; | ||
| 45 | |||
| 46 | block_q8_0 * GGML_RESTRICT y = vy; | ||
| 47 | |||
| 48 | #if defined __wasm_simd128__ | ||
| 49 | for (int i = 0; i < nb; i++) { | ||
| 50 | v128_t srcv [8]; | ||
| 51 | v128_t asrcv[8]; | ||
| 52 | v128_t amaxv[8]; | ||
| 53 | |||
| 54 | for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); | ||
| 55 | for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); | ||
| 56 | |||
| 57 | for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); | ||
| 58 | for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); | ||
| 59 | for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); | ||
| 60 | |||
| 61 | const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), | ||
| 62 | wasm_f32x4_extract_lane(amaxv[0], 1)), | ||
| 63 | MAX(wasm_f32x4_extract_lane(amaxv[0], 2), | ||
| 64 | wasm_f32x4_extract_lane(amaxv[0], 3))); | ||
| 65 | |||
| 66 | const float d = amax / ((1 << 7) - 1); | ||
| 67 | const float id = d ? 1.0f/d : 0.0f; | ||
| 68 | |||
| 69 | y[i].d = GGML_CPU_FP32_TO_FP16(d); | ||
| 70 | |||
| 71 | for (int j = 0; j < 8; j++) { | ||
| 72 | const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); | ||
| 73 | const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); | ||
| 74 | |||
| 75 | y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); | ||
| 76 | y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); | ||
| 77 | y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); | ||
| 78 | y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); | ||
| 79 | } | ||
| 80 | } | ||
| 81 | #else | ||
| 82 | GGML_UNUSED(nb); | ||
| 83 | // scalar | ||
| 84 | quantize_row_q8_0_ref(x, y, k); | ||
| 85 | #endif | ||
| 86 | } | ||
| 87 | |||
| 88 | void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { | ||
| 89 | assert(k % QK8_1 == 0); | ||
| 90 | const int nb = k / QK8_1; | ||
| 91 | |||
| 92 | block_q8_1 * GGML_RESTRICT y = vy; | ||
| 93 | #if defined __wasm_simd128__ | ||
| 94 | for (int i = 0; i < nb; i++) { | ||
| 95 | v128_t srcv [8]; | ||
| 96 | v128_t asrcv[8]; | ||
| 97 | v128_t amaxv[8]; | ||
| 98 | |||
| 99 | for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); | ||
| 100 | for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); | ||
| 101 | |||
| 102 | for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); | ||
| 103 | for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); | ||
| 104 | for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); | ||
| 105 | |||
| 106 | const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), | ||
| 107 | wasm_f32x4_extract_lane(amaxv[0], 1)), | ||
| 108 | MAX(wasm_f32x4_extract_lane(amaxv[0], 2), | ||
| 109 | wasm_f32x4_extract_lane(amaxv[0], 3))); | ||
| 110 | |||
| 111 | const float d = amax / ((1 << 7) - 1); | ||
| 112 | const float id = d ? 1.0f/d : 0.0f; | ||
| 113 | |||
| 114 | y[i].d = GGML_CPU_FP32_TO_FP16(d); | ||
| 115 | |||
| 116 | v128_t accv = wasm_i32x4_splat(0); | ||
| 117 | |||
| 118 | for (int j = 0; j < 8; j++) { | ||
| 119 | const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); | ||
| 120 | const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); | ||
| 121 | |||
| 122 | y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); | ||
| 123 | y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); | ||
| 124 | y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); | ||
| 125 | y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); | ||
| 126 | |||
| 127 | accv = wasm_i32x4_add(accv, vi); | ||
| 128 | } | ||
| 129 | |||
| 130 | y[i].s = GGML_CPU_FP32_TO_FP16( | ||
| 131 | d * (wasm_i32x4_extract_lane(accv, 0) + | ||
| 132 | wasm_i32x4_extract_lane(accv, 1) + | ||
| 133 | wasm_i32x4_extract_lane(accv, 2) + | ||
| 134 | wasm_i32x4_extract_lane(accv, 3))); | ||
| 135 | } | ||
| 136 | #else | ||
| 137 | GGML_UNUSED(nb); | ||
| 138 | // scalar | ||
| 139 | quantize_row_q8_1_ref(x, y, k); | ||
| 140 | #endif | ||
| 141 | } | ||
| 142 | |||
| 143 | //===================================== Q8_K ============================================== | ||
| 144 | |||
| 145 | void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { | ||
| 146 | #ifdef __wasm_simd128__ | ||
| 147 | assert(k % QK_K == 0); | ||
| 148 | const int64_t nb = k / QK_K; | ||
| 149 | block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type | ||
| 150 | |||
| 151 | for (int i = 0; i < nb; i++) { | ||
| 152 | const float * x_block = x + i * QK_K; | ||
| 153 | |||
| 154 | v128_t min_vec = wasm_v128_load(x_block); | ||
| 155 | v128_t max_vec = min_vec; | ||
| 156 | |||
| 157 | for (int j = 4; j < QK_K; j += 4) { | ||
| 158 | v128_t x_vec = wasm_v128_load(x_block + j); | ||
| 159 | max_vec = wasm_f32x4_pmax(max_vec, x_vec); | ||
| 160 | min_vec = wasm_f32x4_pmin(min_vec, x_vec); | ||
| 161 | } | ||
| 162 | max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1)); | ||
| 163 | max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2)); | ||
| 164 | min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1)); | ||
| 165 | min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2)); | ||
| 166 | float max = wasm_f32x4_extract_lane(max_vec, 0); | ||
| 167 | float min = wasm_f32x4_extract_lane(min_vec, 0); | ||
| 168 | float amax = -min > max ? min : max; | ||
| 169 | |||
| 170 | if (amax == 0.0f) { | ||
| 171 | yc[i].d = 0.0f; | ||
| 172 | const v128_t zero = wasm_i8x16_splat(0); | ||
| 173 | for (int j = 0; j < QK_K; j += 16) { | ||
| 174 | wasm_v128_store(yc[i].qs + j, zero); | ||
| 175 | } | ||
| 176 | continue; | ||
| 177 | } | ||
| 178 | |||
| 179 | const float iscale = -127.0f / amax; | ||
| 180 | const v128_t scale_vec = wasm_f32x4_splat(iscale); | ||
| 181 | |||
| 182 | // Process 16 elements per iteration | ||
| 183 | for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) { | ||
| 184 | // Load and quantize 16 floats | ||
| 185 | v128_t x0 = wasm_v128_load(x_block + j); | ||
| 186 | v128_t x1 = wasm_v128_load(x_block + j + 4); | ||
| 187 | v128_t x2 = wasm_v128_load(x_block + j + 8); | ||
| 188 | v128_t x3 = wasm_v128_load(x_block + j + 12); | ||
| 189 | |||
| 190 | v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec)); | ||
| 191 | v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec)); | ||
| 192 | v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec)); | ||
| 193 | v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec)); | ||
| 194 | |||
| 195 | // Convert to i32 with saturation | ||
| 196 | v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0); | ||
| 197 | v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1); | ||
| 198 | v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2); | ||
| 199 | v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3); | ||
| 200 | |||
| 201 | // Pack into 16 i8 values | ||
| 202 | v128_t i8 = wasm_i8x16_narrow_i16x8( | ||
| 203 | wasm_i16x8_narrow_i32x4(i0, i1), | ||
| 204 | wasm_i16x8_narrow_i32x4(i2, i3) | ||
| 205 | ); | ||
| 206 | wasm_v128_store(yc[i].qs + j, i8); | ||
| 207 | |||
| 208 | // Calculate bsums using SIMD | ||
| 209 | v128_t sum16 = wasm_i16x8_add( | ||
| 210 | wasm_i16x8_extend_low_i8x16(i8), | ||
| 211 | wasm_i16x8_extend_high_i8x16(i8) | ||
| 212 | ); | ||
| 213 | v128_t sum32 = wasm_i32x4_add( | ||
| 214 | wasm_i32x4_extend_low_i16x8(sum16), | ||
| 215 | wasm_i32x4_extend_high_i16x8(sum16) | ||
| 216 | ); | ||
| 217 | sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1)); | ||
| 218 | sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2)); | ||
| 219 | yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0); | ||
| 220 | } | ||
| 221 | |||
| 222 | yc[i].d = 1.0f / iscale; | ||
| 223 | } | ||
| 224 | #else | ||
| 225 | quantize_row_q8_K_ref(x, y, k); | ||
| 226 | #endif | ||
| 227 | } | ||
| 228 | |||
| 229 | |||
| 230 | //===================================== Dot products ================================= | ||
| 231 | |||
| 232 | void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 233 | const int qk = QK8_0; | ||
| 234 | const int nb = n / qk; | ||
| 235 | |||
| 236 | assert(n % qk == 0); | ||
| 237 | assert(nrc == 1); | ||
| 238 | UNUSED(nrc); | ||
| 239 | UNUSED(bx); | ||
| 240 | UNUSED(by); | ||
| 241 | UNUSED(bs); | ||
| 242 | |||
| 243 | const block_q4_0 * GGML_RESTRICT x = vx; | ||
| 244 | const block_q8_0 * GGML_RESTRICT y = vy; | ||
| 245 | |||
| 246 | int ib = 0; | ||
| 247 | float sumf = 0; | ||
| 248 | |||
| 249 | #if defined __wasm_simd128__ | ||
| 250 | v128_t sumv = wasm_f32x4_splat(0.0f); | ||
| 251 | |||
| 252 | const v128_t m4b = wasm_i8x16_splat(0x0F); | ||
| 253 | const v128_t s8b = wasm_i8x16_splat(0x8); | ||
| 254 | |||
| 255 | for (; ib + 1 < nb; ib += 2) { | ||
| 256 | const block_q4_0 * GGML_RESTRICT x0 = &x[ib]; | ||
| 257 | const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; | ||
| 258 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; | ||
| 259 | const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; | ||
| 260 | |||
| 261 | // Load and process x0 | ||
| 262 | v128_t v0_0 = wasm_v128_load(x0->qs); | ||
| 263 | v128_t v0_0l = wasm_v128_and(v0_0, m4b); | ||
| 264 | v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); | ||
| 265 | v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); | ||
| 266 | v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); | ||
| 267 | |||
| 268 | // Load y0 vectors | ||
| 269 | v128_t y0_l = wasm_v128_load(y0->qs); | ||
| 270 | v128_t y0_h = wasm_v128_load(y0->qs + 16); | ||
| 271 | |||
| 272 | // Extend to i16x8 and compute dot products | ||
| 273 | v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls); | ||
| 274 | v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls); | ||
| 275 | v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs); | ||
| 276 | v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs); | ||
| 277 | |||
| 278 | v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l); | ||
| 279 | v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l); | ||
| 280 | v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h); | ||
| 281 | v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h); | ||
| 282 | |||
| 283 | v128_t dp0 = wasm_i32x4_add( | ||
| 284 | wasm_i32x4_add( | ||
| 285 | wasm_i32x4_dot_i16x8(dx0l, dy0ll), | ||
| 286 | wasm_i32x4_dot_i16x8(dx0h, dy0lh) | ||
| 287 | ), | ||
| 288 | wasm_i32x4_add( | ||
| 289 | wasm_i32x4_dot_i16x8(dx0hl, dy0hl), | ||
| 290 | wasm_i32x4_dot_i16x8(dx0hh, dy0hh) | ||
| 291 | ) | ||
| 292 | ); | ||
| 293 | |||
| 294 | // Load and process x1 | ||
| 295 | v128_t v0_1 = wasm_v128_load(x1->qs); | ||
| 296 | v128_t v0_1l = wasm_v128_and(v0_1, m4b); | ||
| 297 | v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); | ||
| 298 | v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); | ||
| 299 | v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); | ||
| 300 | |||
| 301 | // Load y1 vectors | ||
| 302 | v128_t y1_l = wasm_v128_load(y1->qs); | ||
| 303 | v128_t y1_h = wasm_v128_load(y1->qs + 16); | ||
| 304 | |||
| 305 | // Extend to i16x8 and compute dot products | ||
| 306 | v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls); | ||
| 307 | v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls); | ||
| 308 | v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs); | ||
| 309 | v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs); | ||
| 310 | |||
| 311 | v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l); | ||
| 312 | v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l); | ||
| 313 | v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h); | ||
| 314 | v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h); | ||
| 315 | |||
| 316 | v128_t dp1 = wasm_i32x4_add( | ||
| 317 | wasm_i32x4_add( | ||
| 318 | wasm_i32x4_dot_i16x8(dx1l, dy1ll), | ||
| 319 | wasm_i32x4_dot_i16x8(dx1h, dy1lh) | ||
| 320 | ), | ||
| 321 | wasm_i32x4_add( | ||
| 322 | wasm_i32x4_dot_i16x8(dx1hl, dy1hl), | ||
| 323 | wasm_i32x4_dot_i16x8(dx1hh, dy1hh) | ||
| 324 | ) | ||
| 325 | ); | ||
| 326 | |||
| 327 | // Accumulate results with scaling | ||
| 328 | float scale0 = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d); | ||
| 329 | float scale1 = GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d); | ||
| 330 | |||
| 331 | sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0))); | ||
| 332 | sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1))); | ||
| 333 | } | ||
| 334 | |||
| 335 | sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + | ||
| 336 | wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); | ||
| 337 | |||
| 338 | #endif | ||
| 339 | for (; ib < nb; ++ib) { | ||
| 340 | int sumi0 = 0; | ||
| 341 | int sumi1 = 0; | ||
| 342 | |||
| 343 | for (int j = 0; j < qk/2; ++j) { | ||
| 344 | const int v0 = (x[ib].qs[j] & 0x0F) - 8; | ||
| 345 | const int v1 = (x[ib].qs[j] >> 4) - 8; | ||
| 346 | |||
| 347 | sumi0 += (v0 * y[ib].qs[j]); | ||
| 348 | sumi1 += (v1 * y[ib].qs[j + qk/2]); | ||
| 349 | } | ||
| 350 | |||
| 351 | int sumi = sumi0 + sumi1; | ||
| 352 | sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d); | ||
| 353 | } | ||
| 354 | |||
| 355 | *s = sumf; | ||
| 356 | } | ||
| 357 | |||
| 358 | void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 359 | const int qk = QK8_0; | ||
| 360 | const int nb = n / qk; | ||
| 361 | |||
| 362 | int ib = 0; | ||
| 363 | float sumf = 0; | ||
| 364 | |||
| 365 | assert(n % qk == 0); | ||
| 366 | assert(qk == QK5_0); | ||
| 367 | assert(nrc == 1); | ||
| 368 | UNUSED(nrc); | ||
| 369 | UNUSED(bx); | ||
| 370 | UNUSED(by); | ||
| 371 | UNUSED(bs); | ||
| 372 | |||
| 373 | const block_q5_0 * GGML_RESTRICT x = vx; | ||
| 374 | const block_q8_0 * GGML_RESTRICT y = vy; | ||
| 375 | |||
| 376 | #if defined __wasm_simd128__ | ||
| 377 | v128_t sumv = wasm_f32x4_splat(0.0f); | ||
| 378 | |||
| 379 | uint32_t qh_; | ||
| 380 | uint64_t tmp[4]; | ||
| 381 | |||
| 382 | // TODO: check if unrolling this is better | ||
| 383 | for (; ib < nb; ++ib) { | ||
| 384 | const block_q5_0 * GGML_RESTRICT x0 = &x[ib]; | ||
| 385 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; | ||
| 386 | |||
| 387 | const v128_t m4b = wasm_i8x16_splat(0x0F); | ||
| 388 | |||
| 389 | // extract the 5th bit | ||
| 390 | memcpy(&qh_, x0->qh, sizeof(qh_)); | ||
| 391 | |||
| 392 | tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF]; | ||
| 393 | tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF]; | ||
| 394 | tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF]; | ||
| 395 | tmp[3] = table_b2b_1[(qh_ >> 24) ]; | ||
| 396 | |||
| 397 | const v128_t qhl = wasm_v128_load(tmp + 0); | ||
| 398 | const v128_t qhh = wasm_v128_load(tmp + 2); | ||
| 399 | |||
| 400 | const v128_t v0 = wasm_v128_load(x0->qs); | ||
| 401 | |||
| 402 | // 4-bit -> 8-bit | ||
| 403 | const v128_t v0l = wasm_v128_and (v0, m4b); | ||
| 404 | const v128_t v0h = wasm_u8x16_shr(v0, 4); | ||
| 405 | |||
| 406 | // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) | ||
| 407 | const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); | ||
| 408 | const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); | ||
| 409 | |||
| 410 | // load y | ||
| 411 | const v128_t v1l = wasm_v128_load(y0->qs); | ||
| 412 | const v128_t v1h = wasm_v128_load(y0->qs + 16); | ||
| 413 | |||
| 414 | // int8x16 -> int16x8 | ||
| 415 | const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); | ||
| 416 | const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); | ||
| 417 | const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); | ||
| 418 | const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); | ||
| 419 | |||
| 420 | const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); | ||
| 421 | const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); | ||
| 422 | const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); | ||
| 423 | const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); | ||
| 424 | |||
| 425 | // dot product | ||
| 426 | sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( | ||
| 427 | wasm_i32x4_add( | ||
| 428 | wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), | ||
| 429 | wasm_i32x4_dot_i16x8(v0lfh, v1lh)), | ||
| 430 | wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), | ||
| 431 | wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), | ||
| 432 | wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)))); | ||
| 433 | } | ||
| 434 | |||
| 435 | sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + | ||
| 436 | wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); | ||
| 437 | |||
| 438 | *s = sumf; | ||
| 439 | #else | ||
| 440 | UNUSED(nb); | ||
| 441 | UNUSED(ib); | ||
| 442 | UNUSED(sumf); | ||
| 443 | UNUSED(x); | ||
| 444 | UNUSED(y); | ||
| 445 | ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 446 | #endif | ||
| 447 | } | ||
| 448 | |||
| 449 | void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 450 | const int qk = QK8_1; | ||
| 451 | const int nb = n / qk; | ||
| 452 | |||
| 453 | int ib = 0; | ||
| 454 | float sumf = 0; | ||
| 455 | |||
| 456 | assert(n % qk == 0); | ||
| 457 | assert(qk == QK5_1); | ||
| 458 | assert(nrc == 1); | ||
| 459 | UNUSED(nrc); | ||
| 460 | UNUSED(bx); | ||
| 461 | UNUSED(by); | ||
| 462 | UNUSED(bs); | ||
| 463 | |||
| 464 | const block_q5_1 * GGML_RESTRICT x = vx; | ||
| 465 | const block_q8_1 * GGML_RESTRICT y = vy; | ||
| 466 | |||
| 467 | #if defined __wasm_simd128__ | ||
| 468 | v128_t sumv = wasm_f32x4_splat(0.0f); | ||
| 469 | |||
| 470 | float summs = 0.0f; | ||
| 471 | |||
| 472 | uint32_t qh_; | ||
| 473 | uint64_t tmp[4]; | ||
| 474 | |||
| 475 | // TODO: check if unrolling this is better | ||
| 476 | for (; ib < nb; ++ib) { | ||
| 477 | const block_q5_1 * GGML_RESTRICT x0 = &x[ib]; | ||
| 478 | const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; | ||
| 479 | |||
| 480 | summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); | ||
| 481 | |||
| 482 | const v128_t m4b = wasm_i8x16_splat(0x0F); | ||
| 483 | |||
| 484 | // extract the 5th bit | ||
| 485 | memcpy(&qh_, x0->qh, sizeof(qh_)); | ||
| 486 | |||
| 487 | tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF]; | ||
| 488 | tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF]; | ||
| 489 | tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF]; | ||
| 490 | tmp[3] = table_b2b_0[(qh_ >> 24) ]; | ||
| 491 | |||
| 492 | const v128_t qhl = wasm_v128_load(tmp + 0); | ||
| 493 | const v128_t qhh = wasm_v128_load(tmp + 2); | ||
| 494 | |||
| 495 | const v128_t v0 = wasm_v128_load(x0->qs); | ||
| 496 | |||
| 497 | // 4-bit -> 8-bit | ||
| 498 | const v128_t v0l = wasm_v128_and (v0, m4b); | ||
| 499 | const v128_t v0h = wasm_u8x16_shr(v0, 4); | ||
| 500 | |||
| 501 | // add high bit | ||
| 502 | const v128_t v0lf = wasm_v128_or(v0l, qhl); | ||
| 503 | const v128_t v0hf = wasm_v128_or(v0h, qhh); | ||
| 504 | |||
| 505 | // load y | ||
| 506 | const v128_t v1l = wasm_v128_load(y0->qs); | ||
| 507 | const v128_t v1h = wasm_v128_load(y0->qs + 16); | ||
| 508 | |||
| 509 | // int8x16 -> int16x8 | ||
| 510 | const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); | ||
| 511 | const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); | ||
| 512 | const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); | ||
| 513 | const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); | ||
| 514 | |||
| 515 | const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); | ||
| 516 | const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); | ||
| 517 | const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); | ||
| 518 | const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); | ||
| 519 | |||
| 520 | // dot product | ||
| 521 | sumv = wasm_f32x4_add(sumv, | ||
| 522 | wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( | ||
| 523 | wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), | ||
| 524 | wasm_i32x4_dot_i16x8(v0lfh, v1lh)), | ||
| 525 | wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), | ||
| 526 | wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), | ||
| 527 | wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)))); | ||
| 528 | } | ||
| 529 | |||
| 530 | sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + | ||
| 531 | wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; | ||
| 532 | |||
| 533 | *s = sumf; | ||
| 534 | #else | ||
| 535 | UNUSED(nb); | ||
| 536 | UNUSED(ib); | ||
| 537 | UNUSED(sumf); | ||
| 538 | UNUSED(x); | ||
| 539 | UNUSED(y); | ||
| 540 | ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 541 | #endif | ||
| 542 | } | ||
| 543 | |||
| 544 | void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 545 | const int qk = QK8_0; | ||
| 546 | const int nb = n / qk; | ||
| 547 | |||
| 548 | assert(n % qk == 0); | ||
| 549 | assert(nrc == 1); | ||
| 550 | UNUSED(nrc); | ||
| 551 | UNUSED(bx); | ||
| 552 | UNUSED(by); | ||
| 553 | UNUSED(bs); | ||
| 554 | |||
| 555 | const block_q8_0 * GGML_RESTRICT x = vx; | ||
| 556 | const block_q8_0 * GGML_RESTRICT y = vy; | ||
| 557 | |||
| 558 | int ib = 0; | ||
| 559 | float sumf = 0; | ||
| 560 | |||
| 561 | #if defined __wasm_simd128__ | ||
| 562 | v128_t sumv = wasm_f32x4_splat(0.0f); | ||
| 563 | |||
| 564 | for (; ib < nb; ++ib) { | ||
| 565 | const block_q8_0 * GGML_RESTRICT x0 = &x[ib]; | ||
| 566 | const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; | ||
| 567 | |||
| 568 | const v128_t x0_0 = wasm_v128_load(x0->qs); | ||
| 569 | const v128_t x0_1 = wasm_v128_load(x0->qs + 16); | ||
| 570 | const v128_t y0_0 = wasm_v128_load(y0->qs); | ||
| 571 | const v128_t y0_1 = wasm_v128_load(y0->qs + 16); | ||
| 572 | |||
| 573 | // Extend 8-bit to 16-bit | ||
| 574 | const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0); | ||
| 575 | const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0); | ||
| 576 | const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1); | ||
| 577 | const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1); | ||
| 578 | |||
| 579 | const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0); | ||
| 580 | const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0); | ||
| 581 | const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1); | ||
| 582 | const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1); | ||
| 583 | |||
| 584 | // Compute dot products | ||
| 585 | const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l); | ||
| 586 | const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h); | ||
| 587 | const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l); | ||
| 588 | const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h); | ||
| 589 | |||
| 590 | // Sum all dot products | ||
| 591 | const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1)); | ||
| 592 | |||
| 593 | // Convert to float and accumulate | ||
| 594 | const float scale = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d); | ||
| 595 | sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale))); | ||
| 596 | } | ||
| 597 | |||
| 598 | sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + | ||
| 599 | wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); | ||
| 600 | |||
| 601 | *s = sumf; | ||
| 602 | #else | ||
| 603 | UNUSED(nb); | ||
| 604 | UNUSED(x); | ||
| 605 | UNUSED(y); | ||
| 606 | UNUSED(ib); | ||
| 607 | UNUSED(sumf); | ||
| 608 | ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 609 | #endif | ||
| 610 | } | ||
| 611 | |||
| 612 | void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 613 | assert(nrc == 1); | ||
| 614 | UNUSED(nrc); | ||
| 615 | UNUSED(bx); | ||
| 616 | UNUSED(by); | ||
| 617 | UNUSED(bs); | ||
| 618 | |||
| 619 | const block_q2_K * GGML_RESTRICT x = vx; | ||
| 620 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 621 | |||
| 622 | const int nb = n / QK_K; | ||
| 623 | |||
| 624 | #if defined __wasm_simd128__ | ||
| 625 | float sumf = 0; | ||
| 626 | |||
| 627 | for (int i = 0; i < nb; ++i) { | ||
| 628 | const uint8_t * q2 = x[i].qs; | ||
| 629 | const int8_t * q8 = y[i].qs; | ||
| 630 | const uint8_t * sc = x[i].scales; | ||
| 631 | |||
| 632 | // Vectorized summs calculation | ||
| 633 | v128_t summs_vec = wasm_i32x4_splat(0); | ||
| 634 | { | ||
| 635 | v128_t sc_vec = wasm_v128_load(sc); | ||
| 636 | v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4); | ||
| 637 | |||
| 638 | v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper); | ||
| 639 | v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper); | ||
| 640 | |||
| 641 | v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]); | ||
| 642 | v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]); | ||
| 643 | |||
| 644 | summs_vec = wasm_i32x4_add( | ||
| 645 | wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1), | ||
| 646 | wasm_i32x4_dot_i16x8(sc_high, bsums2)), | ||
| 647 | summs_vec | ||
| 648 | ); | ||
| 649 | |||
| 650 | summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1)); | ||
| 651 | summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2)); | ||
| 652 | } | ||
| 653 | int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0); | ||
| 654 | |||
| 655 | // Vectorized isum calculation | ||
| 656 | int32_t isum = 0; | ||
| 657 | const uint8_t * sc_ptr = sc; | ||
| 658 | const int k_iters = QK_K/128; | ||
| 659 | |||
| 660 | for (int k = 0; k < k_iters; ++k) { | ||
| 661 | v128_t isum_vec = wasm_i32x4_splat(0); | ||
| 662 | int shift = 0; | ||
| 663 | |||
| 664 | for (int j = 0; j < 4; ++j) { | ||
| 665 | const int d0 = (sc_ptr[0] & 0xF); | ||
| 666 | const int d1 = (sc_ptr[1] & 0xF); | ||
| 667 | sc_ptr += 2; | ||
| 668 | |||
| 669 | // Process first 16 elements | ||
| 670 | v128_t q2_0 = wasm_v128_load(q2); | ||
| 671 | v128_t q8_0 = wasm_v128_load(q8); | ||
| 672 | v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift); | ||
| 673 | v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03)); | ||
| 674 | |||
| 675 | // Process next 16 elements | ||
| 676 | v128_t q2_1 = wasm_v128_load(q2 + 16); | ||
| 677 | v128_t q8_1 = wasm_v128_load(q8 + 16); | ||
| 678 | v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift); | ||
| 679 | v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03)); | ||
| 680 | |||
| 681 | // Calculate dot products | ||
| 682 | v128_t p0 = wasm_i32x4_dot_i16x8( | ||
| 683 | wasm_i16x8_extend_low_i8x16(q8_0), | ||
| 684 | wasm_i16x8_extend_low_i8x16(q2_bits_0) | ||
| 685 | ); | ||
| 686 | v128_t p1 = wasm_i32x4_dot_i16x8( | ||
| 687 | wasm_i16x8_extend_high_i8x16(q8_0), | ||
| 688 | wasm_i16x8_extend_high_i8x16(q2_bits_0) | ||
| 689 | ); | ||
| 690 | v128_t p2 = wasm_i32x4_dot_i16x8( | ||
| 691 | wasm_i16x8_extend_low_i8x16(q8_1), | ||
| 692 | wasm_i16x8_extend_low_i8x16(q2_bits_1) | ||
| 693 | ); | ||
| 694 | v128_t p3 = wasm_i32x4_dot_i16x8( | ||
| 695 | wasm_i16x8_extend_high_i8x16(q8_1), | ||
| 696 | wasm_i16x8_extend_high_i8x16(q2_bits_1) | ||
| 697 | ); | ||
| 698 | |||
| 699 | // Accumulate scaled results | ||
| 700 | v128_t scaled = wasm_i32x4_add( | ||
| 701 | wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)), | ||
| 702 | wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1)) | ||
| 703 | ); | ||
| 704 | |||
| 705 | isum_vec = wasm_i32x4_add(isum_vec, scaled); | ||
| 706 | q8 += 32; | ||
| 707 | shift += 2; | ||
| 708 | } | ||
| 709 | q2 += 32; | ||
| 710 | |||
| 711 | // Horizontal sum of isum_vec | ||
| 712 | isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1)); | ||
| 713 | isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2)); | ||
| 714 | isum += wasm_i32x4_extract_lane(isum_vec, 0); | ||
| 715 | } | ||
| 716 | |||
| 717 | const float dall = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 718 | const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d; | ||
| 719 | sumf += dall * isum - dmin * summs; | ||
| 720 | } | ||
| 721 | |||
| 722 | *s = sumf; | ||
| 723 | |||
| 724 | #else | ||
| 725 | UNUSED(x); | ||
| 726 | UNUSED(y); | ||
| 727 | UNUSED(nb); | ||
| 728 | ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 729 | #endif | ||
| 730 | } | ||
| 731 | |||
| 732 | void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 733 | assert(n % QK_K == 0); | ||
| 734 | assert(nrc == 1); | ||
| 735 | UNUSED(nrc); | ||
| 736 | UNUSED(bx); | ||
| 737 | UNUSED(by); | ||
| 738 | UNUSED(bs); | ||
| 739 | |||
| 740 | const uint32_t kmask1 = 0x03030303; | ||
| 741 | const uint32_t kmask2 = 0x0f0f0f0f; | ||
| 742 | |||
| 743 | const block_q3_K * GGML_RESTRICT x = vx; | ||
| 744 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 745 | |||
| 746 | const int nb = n / QK_K; | ||
| 747 | |||
| 748 | #if defined __wasm_simd128__ | ||
| 749 | int8_t aux8[QK_K]; | ||
| 750 | float sums[8] = {0}; | ||
| 751 | uint32_t auxs[4]; | ||
| 752 | |||
| 753 | float sumf = 0; | ||
| 754 | for (int i = 0; i < nb; ++i) { | ||
| 755 | const uint8_t * GGML_RESTRICT q3 = x[i].qs; | ||
| 756 | const uint8_t * GGML_RESTRICT hm = x[i].hmask; | ||
| 757 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 758 | |||
| 759 | // Process blocks with SIMD | ||
| 760 | int8_t * a = aux8; | ||
| 761 | uint8_t m = 1; | ||
| 762 | for (int j = 0; j < QK_K; j += 128) { | ||
| 763 | for (int shift = 0; shift <= 6; shift += 2) { | ||
| 764 | v128_t v_m = wasm_i8x16_splat(m); | ||
| 765 | for (int l = 0; l < 32; l += 16) { | ||
| 766 | v128_t v_q3 = wasm_v128_load(q3 + l); | ||
| 767 | v128_t v_shift = wasm_i8x16_shr(v_q3, shift); | ||
| 768 | v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03)); | ||
| 769 | |||
| 770 | v128_t v_hm = wasm_v128_load(hm + l); | ||
| 771 | v128_t v_mask = wasm_v128_and(v_hm, v_m); | ||
| 772 | v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0)); | ||
| 773 | |||
| 774 | v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask))); | ||
| 775 | wasm_v128_store(a + l, v_low2); | ||
| 776 | } | ||
| 777 | a += 32; | ||
| 778 | m <<= 1; | ||
| 779 | } | ||
| 780 | q3 += 32; | ||
| 781 | } | ||
| 782 | |||
| 783 | // Extract scales | ||
| 784 | memcpy(auxs, x[i].scales, 12); | ||
| 785 | uint32_t tmp = auxs[2]; | ||
| 786 | auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); | ||
| 787 | auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); | ||
| 788 | auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); | ||
| 789 | auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); | ||
| 790 | const int8_t * scales = (const int8_t *)auxs; | ||
| 791 | |||
| 792 | // SIMD dot product with register accumulators | ||
| 793 | v128_t v_acc0 = wasm_i32x4_splat(0); | ||
| 794 | v128_t v_acc1 = wasm_i32x4_splat(0); | ||
| 795 | a = aux8; | ||
| 796 | for (int j = 0; j < QK_K/16; ++j) { | ||
| 797 | const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32); | ||
| 798 | |||
| 799 | // Process 16 elements per iteration | ||
| 800 | for (int k = 0; k < 2; ++k) { | ||
| 801 | const v128_t v_q8 = wasm_i16x8_load8x8(q8); | ||
| 802 | const v128_t v_a = wasm_i16x8_load8x8(a); | ||
| 803 | |||
| 804 | v128_t v_prod = wasm_i16x8_mul(v_q8, v_a); | ||
| 805 | v_prod = wasm_i16x8_mul(v_prod, v_scale); | ||
| 806 | |||
| 807 | v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod)); | ||
| 808 | v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod)); | ||
| 809 | |||
| 810 | q8 += 8; | ||
| 811 | a += 8; | ||
| 812 | } | ||
| 813 | } | ||
| 814 | |||
| 815 | // Accumulate results | ||
| 816 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 817 | const v128_t v_d = wasm_f32x4_splat(d); | ||
| 818 | v128_t v_sum = wasm_f32x4_add( | ||
| 819 | wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d), | ||
| 820 | wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d) | ||
| 821 | ); | ||
| 822 | |||
| 823 | // Accumulate into sums vector | ||
| 824 | wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum)); | ||
| 825 | } | ||
| 826 | |||
| 827 | // Horizontal sum | ||
| 828 | v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4)); | ||
| 829 | sumf = wasm_f32x4_extract_lane(v_sum, 0) + | ||
| 830 | wasm_f32x4_extract_lane(v_sum, 1) + | ||
| 831 | wasm_f32x4_extract_lane(v_sum, 2) + | ||
| 832 | wasm_f32x4_extract_lane(v_sum, 3); | ||
| 833 | |||
| 834 | *s = sumf; | ||
| 835 | |||
| 836 | #else | ||
| 837 | UNUSED(kmask1); | ||
| 838 | UNUSED(kmask2); | ||
| 839 | UNUSED(x); | ||
| 840 | UNUSED(y); | ||
| 841 | UNUSED(nb); | ||
| 842 | ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 843 | #endif | ||
| 844 | |||
| 845 | } | ||
| 846 | |||
| 847 | void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 848 | assert(n % QK_K == 0); | ||
| 849 | assert(nrc == 1); | ||
| 850 | UNUSED(nrc); | ||
| 851 | UNUSED(bx); | ||
| 852 | UNUSED(by); | ||
| 853 | UNUSED(bs); | ||
| 854 | |||
| 855 | const block_q4_K * GGML_RESTRICT x = vx; | ||
| 856 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 857 | |||
| 858 | const int nb = n / QK_K; | ||
| 859 | |||
| 860 | static const uint32_t kmask1 = 0x3f3f3f3f; | ||
| 861 | static const uint32_t kmask2 = 0x0f0f0f0f; | ||
| 862 | static const uint32_t kmask3 = 0x03030303; | ||
| 863 | |||
| 864 | uint32_t utmp[4]; | ||
| 865 | |||
| 866 | #if defined __wasm_simd128__ | ||
| 867 | const uint8_t * scales = (const uint8_t*)&utmp[0]; | ||
| 868 | float sumf = 0; | ||
| 869 | |||
| 870 | for (int i = 0; i < nb; ++i) { | ||
| 871 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 872 | const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Corrected sign | ||
| 873 | |||
| 874 | const uint8_t * GGML_RESTRICT q4 = x[i].qs; | ||
| 875 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 876 | |||
| 877 | // Process scales and mins | ||
| 878 | memcpy(utmp, x[i].scales, 12); | ||
| 879 | utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); | ||
| 880 | const uint32_t uaux = utmp[1] & kmask1; | ||
| 881 | utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); | ||
| 882 | utmp[2] = uaux; | ||
| 883 | utmp[0] &= kmask1; | ||
| 884 | |||
| 885 | // Sum mins * q8sums | ||
| 886 | int32_t sumi = 0; | ||
| 887 | const int16_t * GGML_RESTRICT q8sums = y[i].bsums; | ||
| 888 | const uint8_t * m = (const uint8_t *)&utmp[2]; | ||
| 889 | for (int j = 0; j < 16; j += 2) { | ||
| 890 | sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; | ||
| 891 | } | ||
| 892 | sumf -= dmin * sumi; | ||
| 893 | |||
| 894 | int32_t sumi1 = 0; | ||
| 895 | int32_t sumi2 = 0; | ||
| 896 | |||
| 897 | for (int j = 0; j < QK_K/64; ++j) { | ||
| 898 | // Load 64 4-bit weights (32 bytes) | ||
| 899 | const v128_t q4x0 = wasm_v128_load(q4); | ||
| 900 | const v128_t q4x1 = wasm_v128_load(q4 + 16); | ||
| 901 | q4 += 32; | ||
| 902 | |||
| 903 | // Split into low/high nibbles | ||
| 904 | const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); | ||
| 905 | const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); | ||
| 906 | const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); | ||
| 907 | const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); | ||
| 908 | |||
| 909 | // Load 64 8-bit values (64 bytes) | ||
| 910 | const v128_t q8x0 = wasm_v128_load(q8); | ||
| 911 | const v128_t q8x1 = wasm_v128_load(q8 + 16); | ||
| 912 | const v128_t q8x2 = wasm_v128_load(q8 + 32); | ||
| 913 | const v128_t q8x3 = wasm_v128_load(q8 + 48); | ||
| 914 | q8 += 64; | ||
| 915 | |||
| 916 | // Low nibble products | ||
| 917 | v128_t vacc1 = wasm_i32x4_dot_i16x8( | ||
| 918 | wasm_i16x8_extend_low_i8x16(q4l0), | ||
| 919 | wasm_i16x8_extend_low_i8x16(q8x0) | ||
| 920 | ); | ||
| 921 | vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( | ||
| 922 | wasm_i16x8_extend_high_i8x16(q4l0), | ||
| 923 | wasm_i16x8_extend_high_i8x16(q8x0) | ||
| 924 | )); | ||
| 925 | vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( | ||
| 926 | wasm_i16x8_extend_low_i8x16(q4l1), | ||
| 927 | wasm_i16x8_extend_low_i8x16(q8x1) | ||
| 928 | )); | ||
| 929 | vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( | ||
| 930 | wasm_i16x8_extend_high_i8x16(q4l1), | ||
| 931 | wasm_i16x8_extend_high_i8x16(q8x1) | ||
| 932 | )); | ||
| 933 | |||
| 934 | // High nibble products | ||
| 935 | v128_t vacc2 = wasm_i32x4_dot_i16x8( | ||
| 936 | wasm_i16x8_extend_low_i8x16(q4h0), | ||
| 937 | wasm_i16x8_extend_low_i8x16(q8x2) | ||
| 938 | ); | ||
| 939 | vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( | ||
| 940 | wasm_i16x8_extend_high_i8x16(q4h0), | ||
| 941 | wasm_i16x8_extend_high_i8x16(q8x2) | ||
| 942 | )); | ||
| 943 | vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( | ||
| 944 | wasm_i16x8_extend_low_i8x16(q4h1), | ||
| 945 | wasm_i16x8_extend_low_i8x16(q8x3) | ||
| 946 | )); | ||
| 947 | vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( | ||
| 948 | wasm_i16x8_extend_high_i8x16(q4h1), | ||
| 949 | wasm_i16x8_extend_high_i8x16(q8x3) | ||
| 950 | )); | ||
| 951 | |||
| 952 | // Accumulate scaled results | ||
| 953 | int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + | ||
| 954 | wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); | ||
| 955 | sumi1 += vacc1_sum * scales[2*j]; | ||
| 956 | |||
| 957 | int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + | ||
| 958 | wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); | ||
| 959 | sumi2 += vacc2_sum * scales[2*j+1]; | ||
| 960 | } | ||
| 961 | |||
| 962 | sumf += d * (sumi1 + sumi2); | ||
| 963 | } | ||
| 964 | |||
| 965 | *s = sumf; | ||
| 966 | |||
| 967 | #else | ||
| 968 | UNUSED(x); | ||
| 969 | UNUSED(y); | ||
| 970 | UNUSED(nb); | ||
| 971 | UNUSED(kmask1); | ||
| 972 | UNUSED(kmask2); | ||
| 973 | UNUSED(kmask3); | ||
| 974 | UNUSED(utmp); | ||
| 975 | ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 976 | #endif | ||
| 977 | } | ||
| 978 | |||
| 979 | void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 980 | assert(n % QK_K == 0); | ||
| 981 | assert(nrc == 1); | ||
| 982 | UNUSED(nrc); | ||
| 983 | UNUSED(bx); | ||
| 984 | UNUSED(by); | ||
| 985 | UNUSED(bs); | ||
| 986 | |||
| 987 | const block_q5_K * GGML_RESTRICT x = vx; | ||
| 988 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 989 | |||
| 990 | const int nb = n / QK_K; | ||
| 991 | |||
| 992 | static const uint32_t kmask1 = 0x3f3f3f3f; | ||
| 993 | static const uint32_t kmask2 = 0x0f0f0f0f; | ||
| 994 | static const uint32_t kmask3 = 0x03030303; | ||
| 995 | |||
| 996 | uint32_t utmp[4]; | ||
| 997 | |||
| 998 | #if defined __wasm_simd128__ | ||
| 999 | //const uint8_t * scales = (const uint8_t*)&utmp[0]; | ||
| 1000 | float sumf = 0; | ||
| 1001 | |||
| 1002 | for (int i = 0; i < nb; ++i) { | ||
| 1003 | const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); | ||
| 1004 | const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Fixed sign | ||
| 1005 | |||
| 1006 | const uint8_t * GGML_RESTRICT q5 = x[i].qs; | ||
| 1007 | const uint8_t * GGML_RESTRICT qh = x[i].qh; | ||
| 1008 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 1009 | |||
| 1010 | // Process scales and mins | ||
| 1011 | memcpy(utmp, x[i].scales, 12); | ||
| 1012 | utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); | ||
| 1013 | const uint32_t uaux = utmp[1] & kmask1; | ||
| 1014 | utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); | ||
| 1015 | utmp[2] = uaux; | ||
| 1016 | utmp[0] &= kmask1; | ||
| 1017 | |||
| 1018 | // Sum mins * q8sums | ||
| 1019 | int32_t sumi_mins = 0; | ||
| 1020 | const int16_t * GGML_RESTRICT q8sums = y[i].bsums; | ||
| 1021 | const uint8_t * m = (const uint8_t *)&utmp[2]; | ||
| 1022 | for (int j = 0; j < 16; j += 2) { | ||
| 1023 | sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2]; | ||
| 1024 | } | ||
| 1025 | sumf -= dmin * sumi_mins; // Correct subtraction | ||
| 1026 | |||
| 1027 | v128_t qh0 = wasm_v128_load(qh); | ||
| 1028 | v128_t qh1 = wasm_v128_load(qh + 16); | ||
| 1029 | const uint8_t * sc = (const uint8_t *)utmp; | ||
| 1030 | |||
| 1031 | int32_t sumi = 0; | ||
| 1032 | |||
| 1033 | for (int j = 0; j < QK_K/64; ++j) { | ||
| 1034 | const int shift = j * 2; | ||
| 1035 | v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift); | ||
| 1036 | v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift); | ||
| 1037 | |||
| 1038 | v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4); | ||
| 1039 | v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3); | ||
| 1040 | v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4); | ||
| 1041 | v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3); | ||
| 1042 | |||
| 1043 | v128_t q5_0 = wasm_v128_load(q5); | ||
| 1044 | v128_t q5_1 = wasm_v128_load(q5 + 16); | ||
| 1045 | q5 += 32; | ||
| 1046 | |||
| 1047 | v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0); | ||
| 1048 | v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0); | ||
| 1049 | v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1); | ||
| 1050 | v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1); | ||
| 1051 | |||
| 1052 | v128_t q8_0 = wasm_v128_load(q8); | ||
| 1053 | v128_t q8_1 = wasm_v128_load(q8 + 16); | ||
| 1054 | v128_t q8_2 = wasm_v128_load(q8 + 32); | ||
| 1055 | v128_t q8_3 = wasm_v128_load(q8 + 48); | ||
| 1056 | q8 += 64; | ||
| 1057 | |||
| 1058 | // Process low quants | ||
| 1059 | v128_t pl0 = wasm_i32x4_dot_i16x8( | ||
| 1060 | wasm_i16x8_extend_low_i8x16(q5l_0), | ||
| 1061 | wasm_i16x8_extend_low_i8x16(q8_0) | ||
| 1062 | ); | ||
| 1063 | pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8( | ||
| 1064 | wasm_i16x8_extend_high_i8x16(q5l_0), | ||
| 1065 | wasm_i16x8_extend_high_i8x16(q8_0) | ||
| 1066 | )); | ||
| 1067 | v128_t pl1 = wasm_i32x4_dot_i16x8( | ||
| 1068 | wasm_i16x8_extend_low_i8x16(q5l_1), | ||
| 1069 | wasm_i16x8_extend_low_i8x16(q8_1) | ||
| 1070 | ); | ||
| 1071 | pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8( | ||
| 1072 | wasm_i16x8_extend_high_i8x16(q5l_1), | ||
| 1073 | wasm_i16x8_extend_high_i8x16(q8_1) | ||
| 1074 | )); | ||
| 1075 | v128_t sum_low = wasm_i32x4_add(pl0, pl1); | ||
| 1076 | |||
| 1077 | // Process high quants | ||
| 1078 | v128_t ph0 = wasm_i32x4_dot_i16x8( | ||
| 1079 | wasm_i16x8_extend_low_i8x16(q5h_0), | ||
| 1080 | wasm_i16x8_extend_low_i8x16(q8_2) | ||
| 1081 | ); | ||
| 1082 | ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8( | ||
| 1083 | wasm_i16x8_extend_high_i8x16(q5h_0), | ||
| 1084 | wasm_i16x8_extend_high_i8x16(q8_2) | ||
| 1085 | )); | ||
| 1086 | v128_t ph1 = wasm_i32x4_dot_i16x8( | ||
| 1087 | wasm_i16x8_extend_low_i8x16(q5h_1), | ||
| 1088 | wasm_i16x8_extend_low_i8x16(q8_3) | ||
| 1089 | ); | ||
| 1090 | ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8( | ||
| 1091 | wasm_i16x8_extend_high_i8x16(q5h_1), | ||
| 1092 | wasm_i16x8_extend_high_i8x16(q8_3) | ||
| 1093 | )); | ||
| 1094 | v128_t sum_high = wasm_i32x4_add(ph0, ph1); | ||
| 1095 | |||
| 1096 | // Accumulate with scale factors | ||
| 1097 | int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) + | ||
| 1098 | wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3); | ||
| 1099 | int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) + | ||
| 1100 | wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3); | ||
| 1101 | |||
| 1102 | sumi += sl * sc[2*j] + sh * sc[2*j+1]; | ||
| 1103 | } | ||
| 1104 | |||
| 1105 | sumf += d * sumi; | ||
| 1106 | } | ||
| 1107 | |||
| 1108 | *s = sumf; | ||
| 1109 | |||
| 1110 | #else | ||
| 1111 | UNUSED(x); | ||
| 1112 | UNUSED(y); | ||
| 1113 | UNUSED(nb); | ||
| 1114 | UNUSED(kmask1); | ||
| 1115 | UNUSED(kmask2); | ||
| 1116 | UNUSED(kmask3); | ||
| 1117 | UNUSED(utmp); | ||
| 1118 | ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 1119 | #endif | ||
| 1120 | } | ||
| 1121 | |||
| 1122 | void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||
| 1123 | assert(n % QK_K == 0); | ||
| 1124 | assert(nrc == 1); | ||
| 1125 | UNUSED(nrc); | ||
| 1126 | UNUSED(bx); | ||
| 1127 | UNUSED(by); | ||
| 1128 | UNUSED(bs); | ||
| 1129 | |||
| 1130 | const block_q6_K * GGML_RESTRICT x = vx; | ||
| 1131 | const block_q8_K * GGML_RESTRICT y = vy; | ||
| 1132 | |||
| 1133 | const int nb = n / QK_K; | ||
| 1134 | |||
| 1135 | #if defined __wasm_simd128__ | ||
| 1136 | int8_t aux8[QK_K] __attribute__((aligned(16))); | ||
| 1137 | int32_t aux32[8] __attribute__((aligned(16))) = {0}; | ||
| 1138 | float sums[8] __attribute__((aligned(16))) = {0}; | ||
| 1139 | |||
| 1140 | for (int i = 0; i < nb; ++i) { | ||
| 1141 | // Unpack 6-bit quantized data into aux8 (unchanged) | ||
| 1142 | const uint8_t * GGML_RESTRICT q4 = x[i].ql; | ||
| 1143 | const uint8_t * GGML_RESTRICT qh = x[i].qh; | ||
| 1144 | int8_t * a = aux8; | ||
| 1145 | for (int j = 0; j < QK_K; j += 128) { | ||
| 1146 | for (int l = 0; l < 32; ++l) { | ||
| 1147 | a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; | ||
| 1148 | a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; | ||
| 1149 | a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; | ||
| 1150 | a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; | ||
| 1151 | } | ||
| 1152 | a += 128; | ||
| 1153 | q4 += 64; | ||
| 1154 | qh += 32; | ||
| 1155 | } | ||
| 1156 | |||
| 1157 | const int8_t * GGML_RESTRICT a_ptr = aux8; | ||
| 1158 | const int8_t * GGML_RESTRICT q8 = y[i].qs; | ||
| 1159 | v128_t acc0 = wasm_i32x4_splat(0); | ||
| 1160 | v128_t acc1 = wasm_i32x4_splat(0); | ||
| 1161 | |||
| 1162 | for (int j = 0; j < QK_K/16; ++j) { | ||
| 1163 | const int scale = x[i].scales[j]; | ||
| 1164 | const v128_t vscale = wasm_i32x4_splat(scale); | ||
| 1165 | |||
| 1166 | // Load 16 elements from a and q8 | ||
| 1167 | const v128_t a_vec = wasm_v128_load(a_ptr); | ||
| 1168 | const v128_t q8_vec = wasm_v128_load(q8); | ||
| 1169 | |||
| 1170 | // Process low 8 elements | ||
| 1171 | v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec); | ||
| 1172 | v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec); | ||
| 1173 | v128_t prod_low = wasm_i16x8_mul(a_low, q8_low); | ||
| 1174 | v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low); | ||
| 1175 | v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low); | ||
| 1176 | |||
| 1177 | // Process high 8 elements | ||
| 1178 | v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec); | ||
| 1179 | v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec); | ||
| 1180 | v128_t prod_high = wasm_i16x8_mul(a_high, q8_high); | ||
| 1181 | v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high); | ||
| 1182 | v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high); | ||
| 1183 | |||
| 1184 | // Scale and accumulate | ||
| 1185 | prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale); | ||
| 1186 | prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale); | ||
| 1187 | prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale); | ||
| 1188 | prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale); | ||
| 1189 | |||
| 1190 | acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo)); | ||
| 1191 | acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi)); | ||
| 1192 | |||
| 1193 | a_ptr += 16; | ||
| 1194 | q8 += 16; | ||
| 1195 | } | ||
| 1196 | |||
| 1197 | // Store accumulated results | ||
| 1198 | wasm_v128_store(&aux32[0], acc0); | ||
| 1199 | wasm_v128_store(&aux32[4], acc1); | ||
| 1200 | |||
| 1201 | const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; | ||
| 1202 | for (int l = 0; l < 8; ++l) { | ||
| 1203 | sums[l] += d * aux32[l]; | ||
| 1204 | } | ||
| 1205 | } | ||
| 1206 | |||
| 1207 | // Sum final results | ||
| 1208 | float sumf = 0; | ||
| 1209 | for (int l = 0; l < 8; ++l) { | ||
| 1210 | sumf += sums[l]; | ||
| 1211 | } | ||
| 1212 | *s = sumf; | ||
| 1213 | |||
| 1214 | #else | ||
| 1215 | UNUSED(x); | ||
| 1216 | UNUSED(y); | ||
| 1217 | UNUSED(nb); | ||
| 1218 | ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); | ||
| 1219 | #endif | ||
| 1220 | } | ||
| 1221 | |||
