1#define GGML_COMMON_IMPL_C
   2#include "ggml-common.h"
   3
   4#include "ggml-quants.h"
   5#include "ggml-impl.h"
   6#include "ggml-cpu/ggml-cpu-impl.h"
   7#include "ggml-cpu.h"
   8
   9#include <math.h>
  10#include <string.h>
  11#include <assert.h>
  12#include <float.h>
  13#include <stdlib.h> // for qsort
  14#include <stdio.h>  // for GGML_ASSERT
  15
  16#define GROUP_MAX_EPS 1e-15f
  17#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
  18#define GROUP_MAX_EPS_IQ2_S 1e-8f
  19#define GROUP_MAX_EPS_IQ1_M 1e-7f
  20#define GROUP_MAX_EPS_IQ1_S 1e-12f
  21
  22#define UNUSED GGML_UNUSED
  23
  24static inline int best_index_int8(int n, const int8_t * val, float x) {
  25    if (x <= val[0]) return 0;
  26    if (x >= val[n-1]) return n-1;
  27    int ml = 0, mu = n-1;
  28    while (mu-ml > 1) {
  29        int mav = (ml+mu)/2;
  30        if (x < val[mav]) mu = mav; else ml = mav;
  31    }
  32    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
  33}
  34
  35// reference implementation for deterministic creation of model files
  36void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
  37    static const int qk = QK4_0;
  38
  39    assert(k % qk == 0);
  40
  41    const int nb = k / qk;
  42
  43    for (int i = 0; i < nb; i++) {
  44        float amax = 0.0f; // absolute max
  45        float max  = 0.0f;
  46
  47        for (int j = 0; j < qk; j++) {
  48            const float v = x[i*qk + j];
  49            if (amax < fabsf(v)) {
  50                amax = fabsf(v);
  51                max  = v;
  52            }
  53        }
  54
  55        const float d  = max / -8;
  56        const float id = d ? 1.0f/d : 0.0f;
  57
  58        y[i].d = GGML_FP32_TO_FP16(d);
  59
  60        for (int j = 0; j < qk/2; ++j) {
  61            const float x0 = x[i*qk + 0    + j]*id;
  62            const float x1 = x[i*qk + qk/2 + j]*id;
  63
  64            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
  65            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
  66
  67            y[i].qs[j]  = xi0;
  68            y[i].qs[j] |= xi1 << 4;
  69        }
  70    }
  71}
  72
  73void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k) {
  74    const int qk = QK4_1;
  75
  76    assert(k % qk == 0);
  77
  78    const int nb = k / qk;
  79
  80    for (int i = 0; i < nb; i++) {
  81        float min = FLT_MAX;
  82        float max = -FLT_MAX;
  83
  84        for (int j = 0; j < qk; j++) {
  85            const float v = x[i*qk + j];
  86
  87            if (v < min) min = v;
  88            if (v > max) max = v;
  89        }
  90
  91        const float d  = (max - min) / ((1 << 4) - 1);
  92        const float id = d ? 1.0f/d : 0.0f;
  93
  94        y[i].d = GGML_FP32_TO_FP16(d);
  95        y[i].m = GGML_FP32_TO_FP16(min);
  96
  97        for (int j = 0; j < qk/2; ++j) {
  98            const float x0 = (x[i*qk + 0    + j] - min)*id;
  99            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
 100
 101            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
 102            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
 103
 104            y[i].qs[j]  = xi0;
 105            y[i].qs[j] |= xi1 << 4;
 106        }
 107    }
 108}
 109
 110void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k) {
 111    static const int qk = QK5_0;
 112
 113    assert(k % qk == 0);
 114
 115    const int nb = k / qk;
 116
 117    for (int i = 0; i < nb; i++) {
 118        float amax = 0.0f; // absolute max
 119        float max  = 0.0f;
 120
 121        for (int j = 0; j < qk; j++) {
 122            const float v = x[i*qk + j];
 123            if (amax < fabsf(v)) {
 124                amax = fabsf(v);
 125                max  = v;
 126            }
 127        }
 128
 129        const float d  = max / -16;
 130        const float id = d ? 1.0f/d : 0.0f;
 131
 132        y[i].d = GGML_FP32_TO_FP16(d);
 133
 134        uint32_t qh = 0;
 135
 136        for (int j = 0; j < qk/2; ++j) {
 137            const float x0 = x[i*qk + 0    + j]*id;
 138            const float x1 = x[i*qk + qk/2 + j]*id;
 139
 140            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
 141            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
 142
 143            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
 144
 145            // get the 5-th bit and store it in qh at the right position
 146            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
 147            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
 148        }
 149
 150        memcpy(&y[i].qh, &qh, sizeof(qh));
 151    }
 152}
 153
 154void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k) {
 155    const int qk = QK5_1;
 156
 157    assert(k % qk == 0);
 158
 159    const int nb = k / qk;
 160
 161    for (int i = 0; i < nb; i++) {
 162        float min = FLT_MAX;
 163        float max = -FLT_MAX;
 164
 165        for (int j = 0; j < qk; j++) {
 166            const float v = x[i*qk + j];
 167
 168            if (v < min) min = v;
 169            if (v > max) max = v;
 170        }
 171
 172        const float d  = (max - min) / ((1 << 5) - 1);
 173        const float id = d ? 1.0f/d : 0.0f;
 174
 175        y[i].d = GGML_FP32_TO_FP16(d);
 176        y[i].m = GGML_FP32_TO_FP16(min);
 177
 178        uint32_t qh = 0;
 179
 180        for (int j = 0; j < qk/2; ++j) {
 181            const float x0 = (x[i*qk + 0    + j] - min)*id;
 182            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
 183
 184            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
 185            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
 186
 187            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
 188
 189            // get the 5-th bit and store it in qh at the right position
 190            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
 191            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
 192        }
 193
 194        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
 195    }
 196}
 197
 198// reference implementation for deterministic creation of model files
 199void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k) {
 200    assert(k % QK8_0 == 0);
 201    const int nb = k / QK8_0;
 202
 203    for (int i = 0; i < nb; i++) {
 204        float amax = 0.0f; // absolute max
 205
 206        for (int j = 0; j < QK8_0; j++) {
 207            const float v = x[i*QK8_0 + j];
 208            amax = MAX(amax, fabsf(v));
 209        }
 210
 211        const float d = amax / ((1 << 7) - 1);
 212        const float id = d ? 1.0f/d : 0.0f;
 213
 214        y[i].d = GGML_FP32_TO_FP16(d);
 215
 216        for (int j = 0; j < QK8_0; ++j) {
 217            const float x0 = x[i*QK8_0 + j]*id;
 218
 219            y[i].qs[j] = roundf(x0);
 220        }
 221    }
 222}
 223
 224// reference implementation for deterministic creation of model files
 225void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k) {
 226    assert(QK8_1 == 32);
 227    assert(k % QK8_1 == 0);
 228    const int nb = k / QK8_1;
 229
 230    for (int i = 0; i < nb; i++) {
 231        float amax = 0.0f; // absolute max
 232
 233        for (int j = 0; j < QK8_1; j++) {
 234            const float v = x[i*QK8_1 + j];
 235            amax = MAX(amax, fabsf(v));
 236        }
 237
 238        const float d = amax / ((1 << 7) - 1);
 239        const float id = d ? 1.0f/d : 0.0f;
 240
 241        y[i].d = GGML_FP32_TO_FP16(d);
 242
 243        int sum = 0;
 244
 245        for (int j = 0; j < QK8_1/2; ++j) {
 246            const float v0 = x[i*QK8_1           + j]*id;
 247            const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
 248
 249            y[i].qs[          j] = roundf(v0);
 250            y[i].qs[QK8_1/2 + j] = roundf(v1);
 251
 252            sum += y[i].qs[          j];
 253            sum += y[i].qs[QK8_1/2 + j];
 254        }
 255
 256        y[i].s = GGML_FP32_TO_FP16(sum*d);
 257    }
 258}
 259
 260static inline int best_index_mxfp4(float x, float e) {
 261    int best_index = 0;
 262    float best_err = fabsf(kvalues_mxfp4[0]*e - x);
 263    for (int i = 1; i < 16; i++) {
 264        float err = fabsf(kvalues_mxfp4[i]*e - x);
 265        if (err < best_err) {
 266            best_index = i;
 267            best_err = err;
 268        }
 269    }
 270    return best_index;
 271}
 272
 273void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
 274    static const int qk = QK_MXFP4;
 275
 276    assert(k % qk == 0);
 277
 278    const int nb = k / qk;
 279
 280    for (int i = 0; i < nb; i++) {
 281        float amax = 0.0f; // absolute max
 282
 283        for (int j = 0; j < qk; j++) {
 284            const float v = x[i*qk + j];
 285
 286            if (amax < fabsf(v)) {
 287                amax = fabsf(v);
 288            }
 289        }
 290
 291        const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
 292
 293        const float d = GGML_E8M0_TO_FP32_HALF(e);
 294
 295        y[i].e = e;
 296
 297        for (int j = 0; j < qk/2; ++j) {
 298            const uint8_t x0 = best_index_mxfp4(x[i*qk + 0    + j], d);
 299            const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
 300
 301            y[i].qs[j]  = x0;
 302            y[i].qs[j] |= x1 << 4;
 303        }
 304    }
 305}
 306
 307void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
 308    static const int qk = QK4_0;
 309
 310    assert(k % qk == 0);
 311
 312    const int nb = k / qk;
 313
 314    for (int i = 0; i < nb; i++) {
 315        const float d = GGML_FP16_TO_FP32(x[i].d);
 316
 317        for (int j = 0; j < qk/2; ++j) {
 318            const int x0 = (x[i].qs[j] & 0x0F) - 8;
 319            const int x1 = (x[i].qs[j] >>   4) - 8;
 320
 321            y[i*qk + j + 0   ] = x0*d;
 322            y[i*qk + j + qk/2] = x1*d;
 323        }
 324    }
 325}
 326
 327void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
 328    static const int qk = QK4_1;
 329
 330    assert(k % qk == 0);
 331
 332    const int nb = k / qk;
 333
 334    for (int i = 0; i < nb; i++) {
 335        const float d = GGML_FP16_TO_FP32(x[i].d);
 336        const float m = GGML_FP16_TO_FP32(x[i].m);
 337
 338        for (int j = 0; j < qk/2; ++j) {
 339            const int x0 = (x[i].qs[j] & 0x0F);
 340            const int x1 = (x[i].qs[j] >>   4);
 341
 342            y[i*qk + j + 0   ] = x0*d + m;
 343            y[i*qk + j + qk/2] = x1*d + m;
 344        }
 345    }
 346}
 347
 348void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
 349    static const int qk = QK5_0;
 350
 351    assert(k % qk == 0);
 352
 353    const int nb = k / qk;
 354
 355    for (int i = 0; i < nb; i++) {
 356        const float d = GGML_FP16_TO_FP32(x[i].d);
 357
 358        uint32_t qh;
 359        memcpy(&qh, x[i].qh, sizeof(qh));
 360
 361        for (int j = 0; j < qk/2; ++j) {
 362            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
 363            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
 364
 365            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
 366            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
 367
 368            y[i*qk + j + 0   ] = x0*d;
 369            y[i*qk + j + qk/2] = x1*d;
 370        }
 371    }
 372}
 373
 374void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
 375    static const int qk = QK5_1;
 376
 377    assert(k % qk == 0);
 378
 379    const int nb = k / qk;
 380
 381    for (int i = 0; i < nb; i++) {
 382        const float d = GGML_FP16_TO_FP32(x[i].d);
 383        const float m = GGML_FP16_TO_FP32(x[i].m);
 384
 385        uint32_t qh;
 386        memcpy(&qh, x[i].qh, sizeof(qh));
 387
 388        for (int j = 0; j < qk/2; ++j) {
 389            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
 390            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
 391
 392            const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
 393            const int x1 = (x[i].qs[j] >>   4) | xh_1;
 394
 395            y[i*qk + j + 0   ] = x0*d + m;
 396            y[i*qk + j + qk/2] = x1*d + m;
 397        }
 398    }
 399}
 400
 401void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
 402    static const int qk = QK8_0;
 403
 404    assert(k % qk == 0);
 405
 406    const int nb = k / qk;
 407
 408    for (int i = 0; i < nb; i++) {
 409        const float d = GGML_FP16_TO_FP32(x[i].d);
 410
 411        for (int j = 0; j < qk; ++j) {
 412            y[i*qk + j] = x[i].qs[j]*d;
 413        }
 414    }
 415}
 416
 417void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
 418    static const int qk = QK_MXFP4;
 419
 420    assert(k % qk == 0);
 421
 422    const int nb = k / qk;
 423
 424    for (int i = 0; i < nb; i++) {
 425        const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
 426
 427        for (int j = 0; j < qk/2; ++j) {
 428            const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
 429            const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >>   4];
 430
 431            y[i*qk + j + 0   ] = x0*d;
 432            y[i*qk + j + qk/2] = x1*d;
 433        }
 434    }
 435}
 436
 437//
 438// 2-6 bit quantization in super-blocks
 439//
 440
 441//
 442// ===================== Helper functions
 443//
 444static inline int nearest_int(float fval) {
 445    assert(fabsf(fval) <= 4194303.f);
 446    float val = fval + 12582912.f;
 447    int i; memcpy(&i, &val, sizeof(int));
 448    return (i & 0x007fffff) - 0x00400000;
 449}
 450
 451static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type,
 452        const float * GGML_RESTRICT qw) {
 453    float max = 0;
 454    float amax = 0;
 455    for (int i = 0; i < n; ++i) {
 456        float ax = fabsf(x[i]);
 457        if (ax > amax) { amax = ax; max = x[i]; }
 458    }
 459    if (amax < GROUP_MAX_EPS) { // all zero
 460        for (int i = 0; i < n; ++i) {
 461            L[i] = 0;
 462        }
 463        return 0.f;
 464    }
 465    float iscale = -nmax / max;
 466    if (rmse_type == 0) {
 467        for (int i = 0; i < n; ++i) {
 468            int l = nearest_int(iscale * x[i]);
 469            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
 470        }
 471        return 1/iscale;
 472    }
 473    bool return_early = false;
 474    if (rmse_type < 0) {
 475        rmse_type = -rmse_type;
 476        return_early = true;
 477    }
 478    float sumlx = 0;
 479    float suml2 = 0;
 480#ifdef HAVE_BUGGY_APPLE_LINKER
 481    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
 482    for (volatile int i = 0; i < n; ++i) {
 483#else
 484    for (int i = 0; i < n; ++i) {
 485#endif
 486        int l = nearest_int(iscale * x[i]);
 487        l = MAX(-nmax, MIN(nmax-1, l));
 488        L[i] = l + nmax;
 489        float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
 490        sumlx += w*x[i]*l;
 491        suml2 += w*l*l;
 492    }
 493    float scale = suml2 ? sumlx/suml2 : 0.0f;
 494    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
 495    float best = scale * sumlx;
 496    for (int is = -9; is <= 9; ++is) {
 497        if (is == 0) {
 498            continue;
 499        }
 500        iscale = -(nmax + 0.1f*is) / max;
 501        sumlx = suml2 = 0;
 502        for (int i = 0; i < n; ++i) {
 503            int l = nearest_int(iscale * x[i]);
 504            l = MAX(-nmax, MIN(nmax-1, l));
 505            float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
 506            sumlx += w*x[i]*l;
 507            suml2 += w*l*l;
 508        }
 509        if (suml2 > 0 && sumlx*sumlx > best*suml2) {
 510            for (int i = 0; i < n; ++i) {
 511                int l = nearest_int(iscale * x[i]);
 512                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
 513            }
 514            scale = sumlx/suml2; best = scale*sumlx;
 515        }
 516    }
 517    return scale;
 518}
 519
 520static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, bool do_rmse) {
 521    float max = 0;
 522    float amax = 0;
 523    for (int i = 0; i < n; ++i) {
 524        float ax = fabsf(x[i]);
 525        if (ax > amax) { amax = ax; max = x[i]; }
 526    }
 527    if (amax < GROUP_MAX_EPS) { // all zero
 528        for (int i = 0; i < n; ++i) { L[i] = 0; }
 529        return 0.f;
 530    }
 531    float iscale = -nmax / max;
 532    if (do_rmse) {
 533        float sumlx = 0;
 534        float suml2 = 0;
 535        for (int i = 0; i < n; ++i) {
 536            int l = nearest_int(iscale * x[i]);
 537            l = MAX(-nmax, MIN(nmax-1, l));
 538            L[i] = l;
 539            float w = x[i]*x[i];
 540            sumlx += w*x[i]*l;
 541            suml2 += w*l*l;
 542        }
 543        for (int itry = 0; itry < 5; ++itry) {
 544            int n_changed = 0;
 545            for (int i = 0; i < n; ++i) {
 546                float w = x[i]*x[i];
 547                float slx = sumlx - w*x[i]*L[i];
 548                if (slx > 0) {
 549                    float sl2 = suml2 - w*L[i]*L[i];
 550                    int new_l = nearest_int(x[i] * sl2 / slx);
 551                    new_l = MAX(-nmax, MIN(nmax-1, new_l));
 552                    if (new_l != L[i]) {
 553                        slx += w*x[i]*new_l;
 554                        sl2 += w*new_l*new_l;
 555                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
 556                            L[i] = new_l; sumlx = slx; suml2 = sl2;
 557                            ++n_changed;
 558                        }
 559                    }
 560                }
 561            }
 562            if (!n_changed) {
 563                break;
 564            }
 565        }
 566        for (int i = 0; i < n; ++i) {
 567            L[i] += nmax;
 568        }
 569        return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
 570    }
 571    for (int i = 0; i < n; ++i) {
 572        int l = nearest_int(iscale * x[i]);
 573        l = MAX(-nmax, MIN(nmax-1, l));
 574        L[i] = l + nmax;
 575    }
 576    return 1/iscale;
 577}
 578
 579static float make_qkx1_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min,
 580        int ntry, float alpha) {
 581    float min = x[0];
 582    float max = x[0];
 583    for (int i = 1; i < n; ++i) {
 584        if (x[i] < min) min = x[i];
 585        if (x[i] > max) max = x[i];
 586    }
 587    if (max == min) {
 588        for (int i = 0; i < n; ++i) L[i] = 0;
 589        *the_min = 0;
 590        return 0.f;
 591    }
 592    if (min > 0) min = 0;
 593    float iscale = nmax/(max - min);
 594    float scale = 1/iscale;
 595    for (int itry = 0; itry < ntry; ++itry) {
 596        float sumlx = 0; int suml2 = 0;
 597        bool did_change = false;
 598        for (int i = 0; i < n; ++i) {
 599            int l = nearest_int(iscale*(x[i] - min));
 600            l = MAX(0, MIN(nmax, l));
 601            if (l != L[i]) {
 602                L[i] = l;
 603                did_change = true;
 604            }
 605            sumlx += (x[i] - min)*l;
 606            suml2 += l*l;
 607        }
 608        scale = sumlx/suml2;
 609        float sum = 0;
 610        for (int i = 0; i < n; ++i) {
 611            sum += x[i] - scale*L[i];
 612        }
 613        min = alpha*min + (1 - alpha)*sum/n;
 614        if (min > 0) min = 0;
 615        iscale = 1/scale;
 616        if (!did_change) break;
 617    }
 618    *the_min = -min;
 619    return scale;
 620}
 621
 622static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights,
 623        uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux,
 624        float rmin, float rdelta, int nstep, bool use_mad) {
 625    float min = x[0];
 626    float max = x[0];
 627    float sum_w = weights[0];
 628    float sum_x = sum_w * x[0];
 629#ifdef HAVE_BUGGY_APPLE_LINKER
 630    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
 631    for (volatile int i = 1; i < n; ++i) {
 632#else
 633    for (int i = 1; i < n; ++i) {
 634#endif
 635        if (x[i] < min) min = x[i];
 636        if (x[i] > max) max = x[i];
 637        float w = weights[i];
 638        sum_w += w;
 639        sum_x += w * x[i];
 640    }
 641    if (min > 0) min = 0;
 642    if (max == min) {
 643        for (int i = 0; i < n; ++i) L[i] = 0;
 644        *the_min = -min;
 645        return 0.f;
 646    }
 647    float iscale = nmax/(max - min);
 648    float scale = 1/iscale;
 649    float best_error = 0;
 650    for (int i = 0; i < n; ++i) {
 651        int l = nearest_int(iscale*(x[i] - min));
 652        L[i] = MAX(0, MIN(nmax, l));
 653        float diff = scale * L[i] + min - x[i];
 654        diff = use_mad ? fabsf(diff) : diff * diff;
 655        float w = weights[i];
 656        best_error += w * diff;
 657    }
 658    if (nstep < 1) {
 659        *the_min = -min;
 660        return scale;
 661    }
 662    for (int is = 0; is <= nstep; ++is) {
 663        iscale = (rmin + rdelta*is + nmax)/(max - min);
 664        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
 665        for (int i = 0; i < n; ++i) {
 666            int l = nearest_int(iscale*(x[i] - min));
 667            l = MAX(0, MIN(nmax, l));
 668            Laux[i] = l;
 669            float w = weights[i];
 670            sum_l += w*l;
 671            sum_l2 += w*l*l;
 672            sum_xl += w*l*x[i];
 673        }
 674        float D = sum_w * sum_l2 - sum_l * sum_l;
 675        if (D > 0) {
 676            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
 677            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
 678            if (this_min > 0) {
 679                this_min = 0;
 680                this_scale = sum_xl / sum_l2;
 681            }
 682            float cur_error = 0;
 683            for (int i = 0; i < n; ++i) {
 684                float diff = this_scale * Laux[i] + this_min - x[i];
 685                diff = use_mad ? fabsf(diff) : diff * diff;
 686                float w = weights[i];
 687                cur_error += w * diff;
 688            }
 689            if (cur_error < best_error) {
 690                for (int i = 0; i < n; ++i) {
 691                    L[i] = Laux[i];
 692                }
 693                best_error = cur_error;
 694                scale = this_scale;
 695                min = this_min;
 696            }
 697        }
 698    }
 699    *the_min = -min;
 700    return scale;
 701}
 702
 703static inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) {
 704    if (j < 4) {
 705        *d = q[j] & 63; *m = q[j + 4] & 63;
 706    } else {
 707        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
 708        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
 709    }
 710}
 711
 712//========================- 2-bit (de)-quantization
 713
 714void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k) {
 715    assert(k % QK_K == 0);
 716    const int nb = k / QK_K;
 717
 718    uint8_t L[QK_K];
 719    uint8_t Laux[16];
 720    float   weights[16];
 721    float mins[QK_K/16];
 722    float scales[QK_K/16];
 723
 724    const float q4scale = 15.f;
 725
 726    for (int i = 0; i < nb; i++) {
 727        float max_scale = 0; // as we are deducting the min, scales are always positive
 728        float max_min = 0;
 729        for (int j = 0; j < QK_K/16; ++j) {
 730            for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
 731            scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
 732            float scale = scales[j];
 733            if (scale > max_scale) {
 734                max_scale = scale;
 735            }
 736            float min = mins[j];
 737            if (min > max_min) {
 738                max_min = min;
 739            }
 740        }
 741
 742        if (max_scale > 0) {
 743            float iscale = q4scale/max_scale;
 744            for (int j = 0; j < QK_K/16; ++j) {
 745                int l = nearest_int(iscale*scales[j]);
 746                y[i].scales[j] = l;
 747            }
 748            y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale);
 749        } else {
 750            for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
 751            y[i].d = GGML_FP32_TO_FP16(0.f);
 752        }
 753        if (max_min > 0) {
 754            float iscale = q4scale/max_min;
 755            for (int j = 0; j < QK_K/16; ++j) {
 756                int l = nearest_int(iscale*mins[j]);
 757                y[i].scales[j] |= (l << 4);
 758            }
 759            y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale);
 760        } else {
 761            y[i].dmin = GGML_FP32_TO_FP16(0.f);
 762        }
 763        for (int j = 0; j < QK_K/16; ++j) {
 764            const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);
 765            if (!d) continue;
 766            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);
 767            for (int ii = 0; ii < 16; ++ii) {
 768                int l = nearest_int((x[16*j + ii] + dm)/d);
 769                l = MAX(0, MIN(3, l));
 770                L[16*j + ii] = l;
 771            }
 772        }
 773
 774        for (int j = 0; j < QK_K; j += 128) {
 775            for (int l = 0; l < 32; ++l) {
 776                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
 777            }
 778        }
 779
 780        x += QK_K;
 781    }
 782}
 783
 784void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
 785    assert(k % QK_K == 0);
 786    const int nb = k / QK_K;
 787
 788    for (int i = 0; i < nb; i++) {
 789
 790        const float d = GGML_FP16_TO_FP32(x[i].d);
 791        const float min = GGML_FP16_TO_FP32(x[i].dmin);
 792
 793        const uint8_t * q = x[i].qs;
 794
 795        int is = 0;
 796        float dl, ml;
 797        for (int n = 0; n < QK_K; n += 128) {
 798            int shift = 0;
 799            for (int j = 0; j < 4; ++j) {
 800
 801                uint8_t sc = x[i].scales[is++];
 802                dl = d * (sc & 0xF); ml = min * (sc >> 4);
 803                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
 804
 805                sc = x[i].scales[is++];
 806                dl = d * (sc & 0xF); ml = min * (sc >> 4);
 807                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
 808
 809                shift += 2;
 810            }
 811            q += 32;
 812        }
 813    }
 814}
 815
 816static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights,
 817        uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux,
 818        float rmin, float rdelta, int nstep, bool use_mad) {
 819    float min = x[0];
 820    float max = x[0];
 821    float sum_w = weights ? weights[0] : x[0]*x[0];
 822    float sum_x = sum_w * x[0];
 823#ifdef HAVE_BUGGY_APPLE_LINKER
 824    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
 825    for (volatile int i = 1; i < n; ++i) {
 826#else
 827    for (int i = 1; i < n; ++i) {
 828#endif
 829        if (x[i] < min) min = x[i];
 830        if (x[i] > max) max = x[i];
 831        float w = weights ? weights[i] : x[i]*x[i];
 832        sum_w += w;
 833        sum_x += w * x[i];
 834    }
 835    if (min > 0) {
 836        min = 0;
 837    }
 838    if (max <= min) {
 839        memset(L, 0, n);
 840        *the_min = -min;
 841        return 0.f;
 842    }
 843    float iscale = nmax/(max - min);
 844    float scale = 1/iscale;
 845    float best_mad = 0;
 846    for (int i = 0; i < n; ++i) {
 847        int l = nearest_int(iscale*(x[i] - min));
 848        L[i] = MAX(0, MIN(nmax, l));
 849        float diff = scale * L[i] + min - x[i];
 850        diff = use_mad ? fabsf(diff) : diff*diff;
 851        float w = weights ? weights[i] : x[i]*x[i];
 852        best_mad += w * diff;
 853    }
 854    if (nstep < 1) {
 855        *the_min = -min;
 856        return scale;
 857    }
 858    for (int is = 0; is <= nstep; ++is) {
 859        iscale = (rmin + rdelta*is + nmax)/(max - min);
 860        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
 861        for (int i = 0; i < n; ++i) {
 862            int l = nearest_int(iscale*(x[i] - min));
 863            l = MAX(0, MIN(nmax, l));
 864            Laux[i] = l;
 865            float w = weights ? weights[i] : x[i]*x[i];
 866            sum_l  += w*l;
 867            sum_l2 += w*l*l;
 868            sum_xl += w*l*x[i];
 869        }
 870        float D = sum_w * sum_l2 - sum_l * sum_l;
 871        if (D > 0) {
 872            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
 873            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
 874            if (this_min > 0) {
 875                this_min = 0;
 876                this_scale = sum_xl / sum_l2;
 877            }
 878            float mad = 0;
 879            for (int i = 0; i < n; ++i) {
 880                float diff = this_scale * Laux[i] + this_min - x[i];
 881                diff = use_mad ? fabsf(diff) : diff*diff;
 882                float w = weights ? weights[i] : x[i]*x[i];
 883                mad += w * diff;
 884            }
 885            if (mad < best_mad) {
 886                for (int i = 0; i < n; ++i) {
 887                    L[i] = Laux[i];
 888                }
 889                best_mad = mad;
 890                scale = this_scale;
 891                min = this_min;
 892            }
 893        }
 894    }
 895    *the_min = -min;
 896    return scale;
 897}
 898
 899static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, const float * quant_weights) {
 900    float max = 0;
 901    for (int i = 0; i < n; ++i) {
 902        max = MAX(max, x[i]);
 903    }
 904    if (max < GROUP_MAX_EPS) { // all zero
 905        for (int i = 0; i < n; ++i) { L[i] = 0; }
 906        return 0.f;
 907    }
 908    float iscale = nmax / max;
 909    for (int i = 0; i < n; ++i) {
 910        L[i] = nearest_int(iscale * x[i]);
 911    }
 912    float scale = 1/iscale;
 913    float best_mse = 0;
 914    for (int i = 0; i < n; ++i) {
 915        float diff = x[i] - scale*L[i];
 916        float w = quant_weights[i];
 917        best_mse += w*diff*diff;
 918    }
 919    for (int is = -4; is <= 4; ++is) {
 920        if (is == 0) continue;
 921        float iscale_is = (0.1f*is + nmax)/max;
 922        float scale_is = 1/iscale_is;
 923        float mse = 0;
 924        for (int i = 0; i < n; ++i) {
 925            int l = nearest_int(iscale_is*x[i]);
 926            l = MIN(nmax, l);
 927            float diff = x[i] - scale_is*l;
 928            float w = quant_weights[i];
 929            mse += w*diff*diff;
 930        }
 931        if (mse < best_mse) {
 932            best_mse = mse;
 933            iscale = iscale_is;
 934        }
 935    }
 936    float sumlx = 0;
 937    float suml2 = 0;
 938    for (int i = 0; i < n; ++i) {
 939        int l = nearest_int(iscale * x[i]);
 940        l = MIN(nmax, l);
 941        L[i] = l;
 942        float w = quant_weights[i];
 943        sumlx += w*x[i]*l;
 944        suml2 += w*l*l;
 945    }
 946    for (int itry = 0; itry < 5; ++itry) {
 947        int n_changed = 0;
 948        for (int i = 0; i < n; ++i) {
 949            float w = quant_weights[i];
 950            float slx = sumlx - w*x[i]*L[i];
 951            float sl2 = suml2 - w*L[i]*L[i];
 952            if (slx > 0 && sl2 > 0) {
 953                int new_l = nearest_int(x[i] * sl2 / slx);
 954                new_l = MIN(nmax, new_l);
 955                if (new_l != L[i]) {
 956                    slx += w*x[i]*new_l;
 957                    sl2 += w*new_l*new_l;
 958                    if (slx*slx*suml2 > sumlx*sumlx*sl2) {
 959                        L[i] = new_l; sumlx = slx; suml2 = sl2;
 960                        ++n_changed;
 961                    }
 962                }
 963            }
 964        }
 965        if (!n_changed) {
 966            break;
 967        }
 968    }
 969    return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
 970}
 971
 972static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k, const float * GGML_RESTRICT quant_weights) {
 973    GGML_ASSERT(quant_weights);
 974    assert(k % QK_K == 0);
 975    const int nb = k / QK_K;
 976    const bool requantize = true;
 977
 978    uint8_t L[QK_K];
 979    uint8_t Laux[16];
 980    float mins[QK_K/16];
 981    float scales[QK_K/16];
 982    float sw[QK_K/16];
 983    float weight[16];
 984    uint8_t Ls[QK_K/16], Lm[QK_K/16];
 985
 986    for (int i = 0; i < nb; i++) {
 987        memset(sw, 0, QK_K/16*sizeof(float));
 988        float sumx2 = 0;
 989        for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
 990        float sigma2 = sumx2/QK_K;
 991        for (int j = 0; j < QK_K/16; ++j) {
 992            const float * GGML_RESTRICT qw = quant_weights + QK_K * i + 16*j;
 993            for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
 994            for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
 995            scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
 996        }
 997
 998        float dm, mm;
 999        dm  = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1000        mm  = make_qp_quants(QK_K/16, 15, mins,   Lm, sw);
1001
1002        y[i].d    = GGML_FP32_TO_FP16(dm);
1003        y[i].dmin = GGML_FP32_TO_FP16(mm);
1004        dm        = GGML_FP16_TO_FP32(y[i].d);
1005        mm        = GGML_FP16_TO_FP32(y[i].dmin);
1006
1007        for (int j = 0; j < QK_K/16; ++j) {
1008            y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1009        }
1010
1011        if (requantize) {
1012            for (int j = 0; j < QK_K/16; ++j) {
1013                const float d = dm * (y[i].scales[j] & 0xF);
1014                if (!d) continue;
1015                const float m = mm * (y[i].scales[j] >> 4);
1016                for (int ii = 0; ii < 16; ++ii) {
1017                    int l = nearest_int((x[16*j + ii] + m)/d);
1018                    l = MAX(0, MIN(3, l));
1019                    L[16*j + ii] = l;
1020                }
1021            }
1022        }
1023
1024        for (int j = 0; j < QK_K; j += 128) {
1025            for (int l = 0; l < 32; ++l) {
1026                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1027            }
1028        }
1029
1030        x += QK_K;
1031    }
1032}
1033
1034size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
1035    size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
1036    if (!quant_weights) {
1037        quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
1038    }
1039    else {
1040        char * qrow = (char *)dst;
1041        for (int64_t row = 0; row < nrow; ++row) {
1042            quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
1043            src += n_per_row;
1044            qrow += row_size;
1045        }
1046    }
1047    return nrow * row_size;
1048}
1049
1050//========================= 3-bit (de)-quantization
1051
1052void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k) {
1053    assert(k % QK_K == 0);
1054    const int nb = k / QK_K;
1055
1056    int8_t L[QK_K];
1057    float scales[QK_K / 16];
1058
1059    for (int i = 0; i < nb; i++) {
1060
1061        float max_scale = 0;
1062        float amax = 0;
1063        for (int j = 0; j < QK_K/16; ++j) {
1064            scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
1065            float scale = fabsf(scales[j]);
1066            if (scale > amax) {
1067                amax = scale; max_scale = scales[j];
1068            }
1069        }
1070
1071        memset(y[i].scales, 0, 12);
1072        if (max_scale) {
1073            float iscale = -32.f/max_scale;
1074            for (int j = 0; j < QK_K/16; ++j) {
1075                int8_t l = nearest_int(iscale*scales[j]);
1076                l = MAX(-32, MIN(31, l)) + 32;
1077                if (j < 8) {
1078                    y[i].scales[j] = l & 0xF;
1079                } else {
1080                    y[i].scales[j-8] |= ((l & 0xF) << 4);
1081                }
1082                l >>= 4;
1083                y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
1084            }
1085            y[i].d = GGML_FP32_TO_FP16(1/iscale);
1086        } else {
1087            y[i].d = GGML_FP32_TO_FP16(0.f);
1088        }
1089
1090        int8_t sc;
1091        for (int j = 0; j < QK_K/16; ++j) {
1092            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
1093            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
1094            float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1095            if (!d) {
1096                continue;
1097            }
1098            for (int ii = 0; ii < 16; ++ii) {
1099                int l = nearest_int(x[16*j + ii]/d);
1100                l = MAX(-4, MIN(3, l));
1101                L[16*j + ii] = l + 4;
1102            }
1103        }
1104
1105        memset(y[i].hmask, 0, QK_K/8);
1106        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
1107        int m = 0;
1108        uint8_t hm = 1;
1109        for (int j = 0; j < QK_K; ++j) {
1110            if (L[j] > 3) {
1111                y[i].hmask[m] |= hm;
1112                L[j] -= 4;
1113            }
1114            if (++m == QK_K/8) {
1115                m = 0; hm <<= 1;
1116            }
1117        }
1118        for (int j = 0; j < QK_K; j += 128) {
1119            for (int l = 0; l < 32; ++l) {
1120                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1121            }
1122        }
1123
1124        x += QK_K;
1125    }
1126}
1127
1128void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
1129    assert(k % QK_K == 0);
1130    const int nb = k / QK_K;
1131
1132    const uint32_t kmask1 = 0x03030303;
1133    const uint32_t kmask2 = 0x0f0f0f0f;
1134
1135    uint32_t aux[4];
1136    const int8_t * scales = (const int8_t*)aux;
1137
1138    for (int i = 0; i < nb; i++) {
1139
1140        const float d_all = GGML_FP16_TO_FP32(x[i].d);
1141
1142        const uint8_t * GGML_RESTRICT q = x[i].qs;
1143        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
1144        uint8_t m = 1;
1145
1146        memcpy(aux, x[i].scales, 12);
1147        uint32_t tmp = aux[2];
1148        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
1149        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
1150        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
1151        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
1152
1153        int is = 0;
1154        float dl;
1155        for (int n = 0; n < QK_K; n += 128) {
1156            int shift = 0;
1157            for (int j = 0; j < 4; ++j) {
1158
1159                dl = d_all * (scales[is++] - 32);
1160                for (int l = 0; l < 16; ++l) {
1161                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
1162                }
1163
1164                dl = d_all * (scales[is++] - 32);
1165                for (int l = 0; l < 16; ++l) {
1166                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
1167                }
1168
1169                shift += 2;
1170                m <<= 1;
1171            }
1172            q += 32;
1173        }
1174
1175    }
1176}
1177
1178static void quantize_row_q3_K_impl(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t n_per_row, const float * GGML_RESTRICT quant_weights) {
1179    assert(n_per_row % QK_K == 0);
1180    const int nb = n_per_row / QK_K;
1181
1182    int8_t L[QK_K];
1183    float scales[QK_K / 16];
1184    float weight[16];
1185    float sw[QK_K / 16];
1186    int8_t Ls[QK_K / 16];
1187
1188    for (int i = 0; i < nb; i++) {
1189
1190        float sumx2 = 0;
1191        for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
1192        float sigma2 = 2*sumx2/QK_K;
1193
1194        for (int j = 0; j < QK_K/16; ++j) {
1195            if (quant_weights) {
1196                const float * qw = quant_weights + QK_K * i + 16*j;
1197                for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]);
1198            } else {
1199                for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l];
1200            }
1201            float sumw = 0;
1202            for (int l = 0; l < 16; ++l) sumw += weight[l];
1203            sw[j] = sumw;
1204
1205            scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
1206
1207        }
1208
1209        memset(y[i].scales, 0, 12);
1210
1211        float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
1212        for (int j = 0; j < QK_K/16; ++j) {
1213            int l = Ls[j];
1214            if (j < 8) {
1215                y[i].scales[j] = l & 0xF;
1216            } else {
1217                y[i].scales[j-8] |= ((l & 0xF) << 4);
1218            }
1219            l >>= 4;
1220            y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
1221        }
1222        y[i].d = GGML_FP32_TO_FP16(d_block);
1223
1224        int8_t sc;
1225        for (int j = 0; j < QK_K/16; ++j) {
1226            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
1227            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
1228            float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1229            if (!d) {
1230                continue;
1231            }
1232            for (int ii = 0; ii < 16; ++ii) {
1233                int l = nearest_int(x[16*j + ii]/d);
1234                l = MAX(-4, MIN(3, l));
1235                L[16*j + ii] = l + 4;
1236            }
1237        }
1238
1239        memset(y[i].hmask, 0, QK_K/8);
1240        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
1241        int m = 0;
1242        uint8_t hm = 1;
1243        for (int j = 0; j < QK_K; ++j) {
1244            if (L[j] > 3) {
1245                y[i].hmask[m] |= hm;
1246                L[j] -= 4;
1247            }
1248            if (++m == QK_K/8) {
1249                m = 0; hm <<= 1;
1250            }
1251        }
1252        for (int j = 0; j < QK_K; j += 128) {
1253            for (int l = 0; l < 32; ++l) {
1254                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1255            }
1256        }
1257
1258        x += QK_K;
1259    }
1260}
1261
1262size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
1263    size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
1264    if (!quant_weights) {
1265        quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
1266    }
1267    else {
1268        char * qrow = (char *)dst;
1269        for (int64_t row = 0; row < nrow; ++row) {
1270            quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
1271            src += n_per_row;
1272            qrow += row_size;
1273        }
1274    }
1275    return nrow * row_size;
1276}
1277
1278// ====================== 4-bit (de)-quantization
1279
1280void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k) {
1281    assert(k % QK_K == 0);
1282    const int nb = k / QK_K;
1283
1284    uint8_t L[QK_K];
1285    uint8_t Laux[32];
1286    float   weights[32];
1287    float mins[QK_K/32];
1288    float scales[QK_K/32];
1289
1290    for (int i = 0; i < nb; i++) {
1291        float max_scale = 0; // as we are deducting the min, scales are always positive
1292        float max_min = 0;
1293        for (int j = 0; j < QK_K/32; ++j) {
1294            //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
1295            float sum_x2 = 0;
1296            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
1297            float av_x = sqrtf(sum_x2/32);
1298            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
1299            scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
1300            float scale = scales[j];
1301            if (scale > max_scale) {
1302                max_scale = scale;
1303            }
1304            float min = mins[j];
1305            if (min > max_min) {
1306                max_min = min;
1307            }
1308        }
1309
1310        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
1311        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
1312        for (int j = 0; j < QK_K/32; ++j) {
1313            uint8_t ls = nearest_int(inv_scale*scales[j]);
1314            uint8_t lm = nearest_int(inv_min*mins[j]);
1315            ls = MIN(63, ls);
1316            lm = MIN(63, lm);
1317            if (j < 4) {
1318                y[i].scales[j] = ls;
1319                y[i].scales[j+4] = lm;
1320            } else {
1321                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
1322                y[i].scales[j-4] |= ((ls >> 4) << 6);
1323                y[i].scales[j-0] |= ((lm >> 4) << 6);
1324            }
1325        }
1326        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
1327        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
1328
1329        uint8_t sc, m;
1330        for (int j = 0; j < QK_K/32; ++j) {
1331            get_scale_min_k4(j, y[i].scales, &sc, &m);
1332            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1333            if (!d) continue;
1334            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
1335            for (int ii = 0; ii < 32; ++ii) {
1336                int l = nearest_int((x[32*j + ii] + dm)/d);
1337                l = MAX(0, MIN(15, l));
1338                L[32*j + ii] = l;
1339            }
1340        }
1341
1342        uint8_t * q = y[i].qs;
1343        for (int j = 0; j < QK_K; j += 64) {
1344            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
1345            q += 32;
1346        }
1347
1348        x += QK_K;
1349    }
1350}
1351
1352void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
1353    assert(k % QK_K == 0);
1354    const int nb = k / QK_K;
1355
1356    for (int i = 0; i < nb; i++) {
1357        const uint8_t * q = x[i].qs;
1358
1359        const float d   = GGML_FP16_TO_FP32(x[i].d);
1360        const float min = GGML_FP16_TO_FP32(x[i].dmin);
1361
1362        int is = 0;
1363        uint8_t sc, m;
1364        for (int j = 0; j < QK_K; j += 64) {
1365            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
1366            const float d1 = d * sc; const float m1 = min * m;
1367            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
1368            const float d2 = d * sc; const float m2 = min * m;
1369            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
1370            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
1371            q += 32; is += 2;
1372        }
1373    }
1374}
1375
1376static void quantize_row_q4_K_impl(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
1377    assert(n_per_row % QK_K == 0);
1378    const int64_t nb = n_per_row / QK_K;
1379
1380    uint8_t L[QK_K];
1381    uint8_t Laux[32];
1382    uint8_t Ls[QK_K/32];
1383    uint8_t Lm[QK_K/32];
1384    float   weights[32];
1385    float   sw[QK_K/32];
1386    float   mins[QK_K/32];
1387    float   scales[QK_K/32];
1388
1389    for (int i = 0; i < nb; i++) {
1390
1391        float sum_x2 = 0;
1392        for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
1393        float sigma2 = 2*sum_x2/QK_K;
1394        float av_x = sqrtf(sigma2);
1395
1396        for (int j = 0; j < QK_K/32; ++j) {
1397            if (quant_weights) {
1398                const float * qw = quant_weights + QK_K*i + 32*j;
1399                for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
1400            } else {
1401                for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
1402            }
1403            float sumw = 0;
1404            for (int l = 0; l < 32; ++l) sumw += weights[l];
1405            sw[j] = sumw;
1406            scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1407        }
1408
1409        float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
1410        float m_block = make_qp_quants(QK_K/32, 63, mins,   Lm, sw);
1411        for (int j = 0; j < QK_K/32; ++j) {
1412            uint8_t ls = Ls[j];
1413            uint8_t lm = Lm[j];
1414            if (j < 4) {
1415                y[i].scales[j] = ls;
1416                y[i].scales[j+4] = lm;
1417            } else {
1418                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
1419                y[i].scales[j-4] |= ((ls >> 4) << 6);
1420                y[i].scales[j-0] |= ((lm >> 4) << 6);
1421            }
1422        }
1423        y[i].d = GGML_FP32_TO_FP16(d_block);
1424        y[i].dmin = GGML_FP32_TO_FP16(m_block);
1425
1426        uint8_t sc, m;
1427        for (int j = 0; j < QK_K/32; ++j) {
1428            get_scale_min_k4(j, y[i].scales, &sc, &m);
1429            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1430            if (!d) continue;
1431            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
1432            for (int ii = 0; ii < 32; ++ii) {
1433                int l = nearest_int((x[32*j + ii] + dm)/d);
1434                l = MAX(0, MIN(15, l));
1435                L[32*j + ii] = l;
1436            }
1437        }
1438        uint8_t * q = y[i].qs;
1439        for (int j = 0; j < QK_K; j += 64) {
1440            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
1441            q += 32;
1442        }
1443
1444        x += QK_K;
1445
1446    }
1447}
1448
1449size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
1450    size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
1451    if (!quant_weights) {
1452        quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
1453    }
1454    else {
1455        char * qrow = (char *)dst;
1456        for (int64_t row = 0; row < nrow; ++row) {
1457            quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
1458            src += n_per_row;
1459            qrow += row_size;
1460        }
1461    }
1462    return nrow * row_size;
1463}
1464
1465// ====================== 5-bit (de)-quantization
1466
1467void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k) {
1468    assert(k % QK_K == 0);
1469    const int64_t nb = k / QK_K;
1470
1471    uint8_t L[QK_K];
1472    float mins[QK_K/32];
1473    float scales[QK_K/32];
1474    float weights[32];
1475    uint8_t Laux[32];
1476
1477    for (int i = 0; i < nb; i++) {
1478        float max_scale = 0; // as we are deducting the min, scales are always positive
1479        float max_min = 0;
1480        for (int j = 0; j < QK_K/32; ++j) {
1481            //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
1482            float sum_x2 = 0;
1483            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
1484            float av_x = sqrtf(sum_x2/32);
1485            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
1486            scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
1487            float scale = scales[j];
1488            if (scale > max_scale) {
1489                max_scale = scale;
1490            }
1491            float min = mins[j];
1492            if (min > max_min) {
1493                max_min = min;
1494            }
1495        }
1496
1497        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
1498        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
1499        for (int j = 0; j < QK_K/32; ++j) {
1500            uint8_t ls = nearest_int(inv_scale*scales[j]);
1501            uint8_t lm = nearest_int(inv_min*mins[j]);
1502            ls = MIN(63, ls);
1503            lm = MIN(63, lm);
1504            if (j < 4) {
1505                y[i].scales[j] = ls;
1506                y[i].scales[j+4] = lm;
1507            } else {
1508                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
1509                y[i].scales[j-4] |= ((ls >> 4) << 6);
1510                y[i].scales[j-0] |= ((lm >> 4) << 6);
1511            }
1512        }
1513        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
1514        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
1515
1516        uint8_t sc, m;
1517        for (int j = 0; j < QK_K/32; ++j) {
1518            get_scale_min_k4(j, y[i].scales, &sc, &m);
1519            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1520            if (!d) continue;
1521            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
1522            for (int ii = 0; ii < 32; ++ii) {
1523                int l = nearest_int((x[32*j + ii] + dm)/d);
1524                l = MAX(0, MIN(31, l));
1525                L[32*j + ii] = l;
1526            }
1527        }
1528
1529        uint8_t * GGML_RESTRICT qh = y[i].qh;
1530        uint8_t * GGML_RESTRICT ql = y[i].qs;
1531        memset(qh, 0, QK_K/8);
1532
1533        uint8_t m1 = 1, m2 = 2;
1534        for (int n = 0; n < QK_K; n += 64) {
1535            for (int j = 0; j < 32; ++j) {
1536                int l1 = L[n + j];
1537                if (l1 > 15) {
1538                    l1 -= 16; qh[j] |= m1;
1539                }
1540                int l2 = L[n + j + 32];
1541                if (l2 > 15) {
1542                    l2 -= 16; qh[j] |= m2;
1543                }
1544                ql[j] = l1 | (l2 << 4);
1545            }
1546            m1 <<= 2; m2 <<= 2;
1547            ql += 32;
1548        }
1549
1550        x += QK_K;
1551    }
1552}
1553
1554void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
1555    assert(k % QK_K == 0);
1556    const int64_t nb = k / QK_K;
1557
1558    for (int i = 0; i < nb; i++) {
1559        const uint8_t * ql = x[i].qs;
1560        const uint8_t * qh = x[i].qh;
1561
1562        const float d = GGML_FP16_TO_FP32(x[i].d);
1563        const float min = GGML_FP16_TO_FP32(x[i].dmin);
1564
1565        int is = 0;
1566        uint8_t sc, m;
1567        uint8_t u1 = 1, u2 = 2;
1568        for (int j = 0; j < QK_K; j += 64) {
1569            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
1570            const float d1 = d * sc; const float m1 = min * m;
1571            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
1572            const float d2 = d * sc; const float m2 = min * m;
1573            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
1574            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
1575            ql += 32; is += 2;
1576            u1 <<= 2; u2 <<= 2;
1577        }
1578    }
1579}
1580
1581static void quantize_row_q5_K_impl(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
1582    assert(n_per_row % QK_K == 0);
1583    const int64_t nb = n_per_row / QK_K;
1584
1585    uint8_t L[QK_K];
1586    uint8_t Laux[32];
1587    uint8_t Ls[QK_K/32];
1588    uint8_t Lm[QK_K/32];
1589    float   mins[QK_K/32];
1590    float   scales[QK_K/32];
1591    float   sw[QK_K/32];
1592    float   weights[32];
1593
1594    for (int i = 0; i < nb; i++) {
1595
1596        float sum_x2 = 0;
1597        for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
1598        float sigma2 = 2*sum_x2/QK_K;
1599        float av_x = sqrtf(sigma2);
1600
1601        for (int j = 0; j < QK_K/32; ++j) {
1602            if (quant_weights) {
1603                const float * qw = quant_weights + QK_K*i + 32*j;
1604                for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
1605            } else {
1606                for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
1607            }
1608            float sumw = 0;
1609            for (int l = 0; l < 32; ++l) sumw += weights[l];
1610            sw[j] = sumw;
1611
1612            scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1613        }
1614
1615        float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
1616        float m_block = make_qp_quants(QK_K/32, 63, mins,   Lm, sw);
1617
1618        for (int j = 0; j < QK_K/32; ++j) {
1619            uint8_t ls = Ls[j];
1620            uint8_t lm = Lm[j];
1621            ls = MIN(63, ls);
1622            lm = MIN(63, lm);
1623            if (j < 4) {
1624                y[i].scales[j] = ls;
1625                y[i].scales[j+4] = lm;
1626            } else {
1627                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
1628                y[i].scales[j-4] |= ((ls >> 4) << 6);
1629                y[i].scales[j-0] |= ((lm >> 4) << 6);
1630            }
1631        }
1632        y[i].d = GGML_FP32_TO_FP16(d_block);
1633        y[i].dmin = GGML_FP32_TO_FP16(m_block);
1634
1635        uint8_t sc, m;
1636        for (int j = 0; j < QK_K/32; ++j) {
1637            get_scale_min_k4(j, y[i].scales, &sc, &m);
1638            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1639            if (!d) continue;
1640            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
1641            for (int ii = 0; ii < 32; ++ii) {
1642                int l = nearest_int((x[32*j + ii] + dm)/d);
1643                l = MAX(0, MIN(31, l));
1644                L[32*j + ii] = l;
1645            }
1646        }
1647
1648        uint8_t * GGML_RESTRICT qh = y[i].qh;
1649        uint8_t * GGML_RESTRICT ql = y[i].qs;
1650        memset(qh, 0, QK_K/8);
1651
1652        uint8_t m1 = 1, m2 = 2;
1653        for (int n = 0; n < QK_K; n += 64) {
1654            for (int j = 0; j < 32; ++j) {
1655                int l1 = L[n + j];
1656                if (l1 > 15) {
1657                    l1 -= 16; qh[j] |= m1;
1658                }
1659                int l2 = L[n + j + 32];
1660                if (l2 > 15) {
1661                    l2 -= 16; qh[j] |= m2;
1662                }
1663                ql[j] = l1 | (l2 << 4);
1664            }
1665            m1 <<= 2; m2 <<= 2;
1666            ql += 32;
1667        }
1668
1669        x += QK_K;
1670
1671    }
1672}
1673
1674size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
1675    size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
1676    if (!quant_weights) {
1677        quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
1678    }
1679    else {
1680        char * qrow = (char *)dst;
1681        for (int64_t row = 0; row < nrow; ++row) {
1682            quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
1683            src += n_per_row;
1684            qrow += row_size;
1685        }
1686    }
1687    return nrow * row_size;
1688}
1689
1690// ====================== 6-bit (de)-quantization
1691
1692void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k) {
1693    assert(k % QK_K == 0);
1694    const int64_t nb = k / QK_K;
1695
1696    int8_t L[QK_K];
1697    float   scales[QK_K/16];
1698
1699    for (int i = 0; i < nb; i++) {
1700
1701        float max_scale = 0;
1702        float max_abs_scale = 0;
1703
1704        for (int ib = 0; ib < QK_K/16; ++ib) {
1705
1706            const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
1707            scales[ib] = scale;
1708
1709            const float abs_scale = fabsf(scale);
1710            if (abs_scale > max_abs_scale) {
1711                max_abs_scale = abs_scale;
1712                max_scale = scale;
1713            }
1714
1715        }
1716
1717        if (max_abs_scale < GROUP_MAX_EPS) {
1718            memset(&y[i], 0, sizeof(block_q6_K));
1719            y[i].d = GGML_FP32_TO_FP16(0.f);
1720            x += QK_K;
1721            continue;
1722        }
1723
1724        float iscale = -128.f/max_scale;
1725        y[i].d = GGML_FP32_TO_FP16(1/iscale);
1726        for (int ib = 0; ib < QK_K/16; ++ib) {
1727            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
1728        }
1729
1730        for (int j = 0; j < QK_K/16; ++j) {
1731            float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
1732            if (!d) {
1733                continue;
1734            }
1735            for (int ii = 0; ii < 16; ++ii) {
1736                int l = nearest_int(x[16*j + ii]/d);
1737                l = MAX(-32, MIN(31, l));
1738                L[16*j + ii] = l + 32;
1739            }
1740        }
1741
1742        uint8_t * GGML_RESTRICT ql = y[i].ql;
1743        uint8_t * GGML_RESTRICT qh = y[i].qh;
1744        for (int j = 0; j < QK_K; j += 128) {
1745            for (int l = 0; l < 32; ++l) {
1746                const uint8_t q1 = L[j + l +  0] & 0xF;
1747                const uint8_t q2 = L[j + l + 32] & 0xF;
1748                const uint8_t q3 = L[j + l + 64] & 0xF;
1749                const uint8_t q4 = L[j + l + 96] & 0xF;
1750                ql[l+ 0] = q1 | (q3 << 4);
1751                ql[l+32] = q2 | (q4 << 4);
1752                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
1753            }
1754            ql += 64;
1755            qh += 32;
1756        }
1757
1758        x += QK_K;
1759    }
1760}
1761
1762void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
1763    assert(k % QK_K == 0);
1764    const int64_t nb = k / QK_K;
1765
1766    for (int i = 0; i < nb; i++) {
1767        const float d = GGML_FP16_TO_FP32(x[i].d);
1768
1769        const uint8_t * GGML_RESTRICT ql = x[i].ql;
1770        const uint8_t * GGML_RESTRICT qh = x[i].qh;
1771        const int8_t  * GGML_RESTRICT sc = x[i].scales;
1772
1773        for (int n = 0; n < QK_K; n += 128) {
1774            for (int l = 0; l < 32; ++l) {
1775                int is = l/16;
1776                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1777                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1778                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1779                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1780                y[l +  0] = d * sc[is + 0] * q1;
1781                y[l + 32] = d * sc[is + 2] * q2;
1782                y[l + 64] = d * sc[is + 4] * q3;
1783                y[l + 96] = d * sc[is + 6] * q4;
1784            }
1785            y  += 128;
1786            ql += 64;
1787            qh += 32;
1788            sc += 8;
1789        }
1790    }
1791}
1792
1793static void quantize_row_q6_K_impl(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
1794    assert(n_per_row % QK_K == 0);
1795    const int64_t nb = n_per_row / QK_K;
1796
1797    int8_t L[QK_K];
1798    float   scales[QK_K/16];
1799    //float   weights[16];
1800
1801    for (int i = 0; i < nb; i++) {
1802
1803        //float sum_x2 = 0;
1804        //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j];
1805        //float sigma2 = sum_x2/QK_K;
1806
1807        float max_scale = 0;
1808        float max_abs_scale = 0;
1809
1810        for (int ib = 0; ib < QK_K/16; ++ib) {
1811
1812            float scale;
1813            if (quant_weights) {
1814                const float * qw = quant_weights + QK_K*i + 16*ib;
1815                //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]);
1816                //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights);
1817                scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw);
1818            } else {
1819                scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
1820            }
1821            scales[ib] = scale;
1822
1823            const float abs_scale = fabsf(scale);
1824            if (abs_scale > max_abs_scale) {
1825                max_abs_scale = abs_scale;
1826                max_scale = scale;
1827            }
1828
1829        }
1830
1831        if (max_abs_scale < GROUP_MAX_EPS) {
1832            memset(&y[i], 0, sizeof(block_q6_K));
1833            y[i].d = GGML_FP32_TO_FP16(0.f);
1834            x += QK_K;
1835            continue;
1836        }
1837
1838        float iscale = -128.f/max_scale;
1839        y[i].d = GGML_FP32_TO_FP16(1/iscale);
1840        for (int ib = 0; ib < QK_K/16; ++ib) {
1841            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
1842        }
1843
1844        for (int j = 0; j < QK_K/16; ++j) {
1845            float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
1846            if (!d) {
1847                continue;
1848            }
1849            for (int ii = 0; ii < 16; ++ii) {
1850                int l = nearest_int(x[16*j + ii]/d);
1851                l = MAX(-32, MIN(31, l));
1852                L[16*j + ii] = l + 32;
1853            }
1854        }
1855
1856        uint8_t * GGML_RESTRICT ql = y[i].ql;
1857        uint8_t * GGML_RESTRICT qh = y[i].qh;
1858        for (int j = 0; j < QK_K; j += 128) {
1859            for (int l = 0; l < 32; ++l) {
1860                const uint8_t q1 = L[j + l +  0] & 0xF;
1861                const uint8_t q2 = L[j + l + 32] & 0xF;
1862                const uint8_t q3 = L[j + l + 64] & 0xF;
1863                const uint8_t q4 = L[j + l + 96] & 0xF;
1864                ql[l+ 0] = q1 | (q3 << 4);
1865                ql[l+32] = q2 | (q4 << 4);
1866                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
1867            }
1868            ql += 64;
1869            qh += 32;
1870        }
1871
1872        x += QK_K;
1873
1874    }
1875}
1876
1877size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
1878    size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
1879    if (!quant_weights) {
1880        quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
1881    }
1882    else {
1883        char * qrow = (char *)dst;
1884        for (int64_t row = 0; row < nrow; ++row) {
1885            quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
1886            src += n_per_row;
1887            qrow += row_size;
1888        }
1889    }
1890    return nrow * row_size;
1891}
1892
1893static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
1894    static_assert(QK4_0 == 32, "QK4_0 must be 32");
1895
1896    if (!quant_weights) {
1897        quantize_row_q4_0_ref(x, y, n_per_row);
1898        return;
1899    }
1900
1901    float weight[QK4_0];
1902    int8_t L[QK4_0];
1903
1904    float sum_x2 = 0;
1905    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
1906    float sigma2 = sum_x2/n_per_row;
1907
1908    const int64_t nb = n_per_row/QK4_0;
1909    for (int ib = 0; ib < nb; ++ib) {
1910        const float * xb = x + QK4_0 * ib;
1911        const float * qw = quant_weights + QK4_0 * ib;
1912        for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
1913        float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
1914        y[ib].d = GGML_FP32_TO_FP16(d);
1915        for (int j = 0; j < 16; ++j) {
1916            y[ib].qs[j] = L[j] | (L[j+16] << 4);
1917        }
1918    }
1919}
1920
1921size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
1922    if (!quant_weights) {
1923        quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
1924        return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
1925    }
1926    size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
1927    char * qrow = (char *)dst;
1928    for (int64_t row = 0; row < nrow; ++row) {
1929        quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
1930        src += n_per_row;
1931        qrow += row_size;
1932    }
1933    return nrow * row_size;
1934}
1935
1936static void quantize_row_q4_1_impl(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
1937    static_assert(QK4_1 == 32, "QK4_1 must be 32");
1938
1939    if (!quant_weights) {
1940        quantize_row_q4_1_ref(x, y, n_per_row);
1941        return;
1942    }
1943
1944    float weight[QK4_1];
1945    uint8_t L[QK4_1], Laux[QK4_1];
1946
1947    float sum_x2 = 0;
1948    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
1949    float sigma2 = sum_x2/n_per_row;
1950
1951    const int64_t nb = n_per_row/QK4_1;
1952    for (int ib = 0; ib < nb; ++ib) {
1953        const float * xb = x + QK4_1 * ib;
1954        const float * qw = quant_weights + QK4_1 * ib;
1955        for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
1956        float min;
1957        float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
1958        y[ib].d = GGML_FP32_TO_FP16(d);
1959        y[ib].m = GGML_FP32_TO_FP16(-min);
1960        for (int j = 0; j < 16; ++j) {
1961            y[ib].qs[j] = L[j] | (L[j+16] << 4);
1962        }
1963    }
1964}
1965
1966size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
1967    if (!quant_weights) {
1968        quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
1969        return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
1970    }
1971    size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
1972    char * qrow = (char *)dst;
1973    for (int64_t row = 0; row < nrow; ++row) {
1974        quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
1975        src += n_per_row;
1976        qrow += row_size;
1977    }
1978    return nrow * row_size;
1979}
1980
1981static void quantize_row_q5_0_impl(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
1982    static_assert(QK5_0 == 32, "QK5_0 must be 32");
1983
1984    if (!quant_weights) {
1985        quantize_row_q5_0_ref(x, y, n_per_row);
1986        return;
1987    }
1988
1989    float weight[QK5_0];
1990    int8_t L[QK5_0];
1991
1992    float sum_x2 = 0;
1993    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
1994    float sigma2 = sum_x2/n_per_row;
1995
1996    const int64_t nb = n_per_row/QK5_0;
1997    for (int ib = 0; ib < nb; ++ib) {
1998        const float * xb = x + QK5_0 * ib;
1999        const float * qw = quant_weights + QK5_0 * ib;
2000        for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
2001        float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight);
2002        y[ib].d = GGML_FP32_TO_FP16(d);
2003
2004        uint32_t qh = 0;
2005
2006        for (int j = 0; j < 16; ++j) {
2007            const uint8_t xi0 = L[j];
2008            const uint8_t xi1 = L[j+16];
2009            y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
2010
2011            // get the 5-th bit and store it in qh at the right position
2012            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2013            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
2014        }
2015
2016        memcpy(&y[ib].qh, &qh, sizeof(qh));
2017    }
2018}
2019
2020size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2021    if (!quant_weights) {
2022        quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
2023        return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
2024    }
2025    size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
2026    char * qrow = (char *)dst;
2027    for (int64_t row = 0; row < nrow; ++row) {
2028        quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
2029        src += n_per_row;
2030        qrow += row_size;
2031    }
2032    return nrow * row_size;
2033}
2034
2035static void quantize_row_q5_1_impl(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
2036    static_assert(QK5_1 == 32, "QK5_1 must be 32");
2037
2038    if (!quant_weights) {
2039        quantize_row_q5_1_ref(x, y, n_per_row);
2040        return;
2041    }
2042
2043    float weight[QK5_1];
2044    uint8_t L[QK5_1], Laux[QK5_1];
2045
2046    float sum_x2 = 0;
2047    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
2048    float sigma2 = sum_x2/n_per_row;
2049
2050    const int64_t nb = n_per_row/QK5_1;
2051    for (int ib = 0; ib < nb; ++ib) {
2052        const float * xb = x + QK5_1 * ib;
2053        const float * qw = quant_weights + QK5_1 * ib;
2054        for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
2055        float min;
2056        float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
2057        y[ib].d = GGML_FP32_TO_FP16(d);
2058        y[ib].m = GGML_FP32_TO_FP16(-min);
2059
2060        uint32_t qh = 0;
2061        for (int j = 0; j < 16; ++j) {
2062            const uint8_t xi0 = L[j];
2063            const uint8_t xi1 = L[j+16];
2064            y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
2065            // get the 5-th bit and store it in qh at the right position
2066            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2067            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
2068        }
2069        memcpy(&y[ib].qh, &qh, sizeof(qh));
2070    }
2071}
2072
2073size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2074    if (!quant_weights) {
2075        quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
2076        return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
2077    }
2078    size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
2079    char * qrow = (char *)dst;
2080    for (int64_t row = 0; row < nrow; ++row) {
2081        quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
2082        src += n_per_row;
2083        qrow += row_size;
2084    }
2085    return nrow * row_size;
2086}
2087
2088size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2089    (void)quant_weights; // not used
2090    const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
2091    quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
2092    return nrow * row_size;
2093}
2094
2095size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2096    GGML_UNUSED(quant_weights);
2097    quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
2098    return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
2099}
2100
2101// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
2102
2103void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
2104    assert(k % QK_K == 0);
2105    const int64_t nb = k / QK_K;
2106
2107    for (int64_t i = 0; i < nb; i++) {
2108        float amax = 0.0f; // absolute max
2109
2110        for (int j = 0; j < QK_K; j++) {
2111            const float v = x[j];
2112            amax = MAX(amax, fabsf(v));
2113        }
2114
2115        const float d = amax;
2116        const float id = d ? 1.0f/d : 0.0f;
2117
2118        y[i].d = GGML_FP32_TO_FP16(d);
2119
2120        // 5 elements per byte, along 32 bytes
2121        for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
2122            for (size_t m = 0; m < 32; ++m) {
2123                uint8_t q = 0;
2124                for (size_t n = 0; n < 5; ++n) {
2125                    int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2
2126                    q *= 3;
2127                    q += xi;
2128                }
2129                // ceiling division (243 == pow(3, 5))
2130                q = ((uint16_t)q * 256 + (243 - 1)) / 243;
2131                y[i].qs[j + m] = q;
2132            }
2133            x += 5*32;
2134        }
2135        // along 16 bytes
2136        for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
2137            for (size_t m = 0; m < 16; ++m) {
2138                uint8_t q = 0;
2139                for (size_t n = 0; n < 5; ++n) {
2140                    int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2
2141                    q *= 3;
2142                    q += xi;
2143                }
2144                // ceiling division (243 == pow(3, 5))
2145                q = ((uint16_t)q * 256 + (243 - 1)) / 243;
2146                y[i].qs[j + m] = q;
2147            }
2148            x += 5*16;
2149        }
2150        // 4 elements per byte
2151        for (size_t j = 0; j < sizeof(y->qh); ++j) {
2152            uint8_t q = 0;
2153            for (size_t m = 0; m < 4; ++m) {
2154                // -1, 0, 1 -> 0, 1, 2
2155                int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;
2156                q *= 3;
2157                q += xi;
2158            }
2159            // shift the first value to the most significant trit
2160            q *= 3;
2161            // ceiling division (243 == pow(3, 5))
2162            q = ((uint16_t)q * 256 + (243 - 1)) / 243;
2163            y[i].qh[j] = q;
2164        }
2165        x += 4*sizeof(y->qh);
2166    }
2167}
2168
2169void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k) {
2170    assert(k % QK_K == 0);
2171    const int64_t nb = k / QK_K;
2172
2173    for (int64_t i = 0; i < nb; i++) {
2174        float amax = 0.0f; // absolute max
2175
2176        for (int j = 0; j < QK_K; j++) {
2177            const float v = x[j];
2178            amax = MAX(amax, fabsf(v));
2179        }
2180
2181        const float d = amax;
2182        const float id = d ? 1.0f/d : 0.0f;
2183
2184        y[i].d = GGML_FP32_TO_FP16(d);
2185
2186        for (size_t j = 0; j < sizeof(y->qs); j += 32) {
2187            for (size_t m = 0; m < 32; ++m) {
2188                uint8_t q = 0;
2189                for (size_t n = 0; n < 4; ++n) {
2190                    // -1, 0, 1 -> 0, 1, 2
2191                    int xi = lroundf(x[m + n*32] * id) + 1;
2192                    q += (xi & 3) << (2*n);
2193                }
2194                y[i].qs[j + m] = q;
2195            }
2196            x += 4*32;
2197        }
2198    }
2199}
2200
2201size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2202    (void)quant_weights; // not used
2203    const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
2204    quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row);
2205    return nrow * row_size;
2206}
2207
2208size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2209    (void)quant_weights; // not used
2210    const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
2211    quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row);
2212    return nrow * row_size;
2213}
2214
2215void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2216    assert(k % QK_K == 0);
2217    const int64_t nb = k / QK_K;
2218
2219    const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
2220
2221    for (int64_t i = 0; i < nb; ++i) {
2222
2223        const float d = GGML_FP16_TO_FP32(x[i].d);
2224
2225        for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
2226            for (size_t n = 0; n < 5; ++n) {
2227                for (size_t m = 0; m < 32; ++m) {
2228                    uint8_t q = x[i].qs[j + m] * pow3[n];
2229                    int16_t xi = ((uint16_t) q * 3) >> 8;
2230                    *y++ = (float) (xi - 1) * d;
2231                }
2232            }
2233        }
2234        for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
2235            for (size_t n = 0; n < 5; ++n) {
2236                for (size_t m = 0; m < 16; ++m) {
2237                    uint8_t q = x[i].qs[j + m] * pow3[n];
2238                    int16_t xi = ((uint16_t) q * 3) >> 8;
2239                    *y++ = (float) (xi - 1) * d;
2240                }
2241            }
2242        }
2243
2244        for (size_t n = 0; n < 4; ++n) {
2245            for (size_t j = 0; j < sizeof(x->qh); ++j) {
2246                uint8_t q = x[i].qh[j] * pow3[n];
2247                int16_t xi = ((uint16_t) q * 3) >> 8;
2248                *y++ = (float) (xi - 1) * d;
2249            }
2250        }
2251    }
2252}
2253
2254void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2255    assert(k % QK_K == 0);
2256    const int64_t nb = k / QK_K;
2257
2258    for (int64_t i = 0; i < nb; ++i) {
2259
2260        const float d = GGML_FP16_TO_FP32(x[i].d);
2261
2262        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
2263            for (size_t l = 0; l < 4; ++l) {
2264                for (size_t m = 0; m < 32; ++m) {
2265                    int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;
2266                    *y++ = (float) (q - 1) * d;
2267                }
2268            }
2269        }
2270    }
2271}
2272
2273// ====================== "True" 2-bit (de)-quantization
2274
2275void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2276    assert(k % QK_K == 0);
2277    const int64_t nb = k / QK_K;
2278
2279    uint32_t aux32[2];
2280    const uint8_t * aux8 = (const uint8_t *)aux32;
2281
2282    for (int i = 0; i < nb; i++) {
2283
2284        const float d = GGML_FP16_TO_FP32(x[i].d);
2285
2286        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
2287            memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
2288            const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
2289            for (int l = 0; l < 4; ++l) {
2290                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
2291                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
2292                for (int j = 0; j < 8; ++j) {
2293                    y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
2294                }
2295                y += 8;
2296            }
2297        }
2298    }
2299}
2300
2301// ====================== 2.3125 bpw (de)-quantization
2302
2303void dequantize_row_iq2_xs(const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2304    assert(k % QK_K == 0);
2305    const int64_t nb = k / QK_K;
2306
2307    float db[2];
2308
2309    for (int i = 0; i < nb; i++) {
2310
2311        const float d = GGML_FP16_TO_FP32(x[i].d);
2312
2313        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
2314            db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
2315            db[1] = d * (0.5f + (x[i].scales[ib32] >>  4)) * 0.25f;
2316            for (int l = 0; l < 4; ++l) {
2317                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
2318                const uint8_t  signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
2319                for (int j = 0; j < 8; ++j) {
2320                    y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
2321                }
2322                y += 8;
2323            }
2324        }
2325    }
2326}
2327
2328// ====================== 2.5625 bpw (de)-quantization
2329
2330void dequantize_row_iq2_s(const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2331    assert(k % QK_K == 0);
2332    const int64_t nb = k / QK_K;
2333
2334    float db[2];
2335
2336    for (int i = 0; i < nb; i++) {
2337
2338        const float d = GGML_FP16_TO_FP32(x[i].d);
2339        const uint8_t * qs = x[i].qs;
2340        const uint8_t * qh = x[i].qh;
2341        const uint8_t * signs = qs + QK_K/8;
2342
2343        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
2344            db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
2345            db[1] = d * (0.5f + (x[i].scales[ib32] >>  4)) * 0.25f;
2346            for (int l = 0; l < 4; ++l) {
2347                const float dl = db[l/2];
2348                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
2349                for (int j = 0; j < 8; ++j) {
2350                    y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
2351                }
2352                y += 8;
2353            }
2354            qs += 4;
2355            signs += 4;
2356        }
2357    }
2358}
2359
2360// ====================== 3.0625 bpw (de)-quantization
2361
2362void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2363    assert(k % QK_K == 0);
2364    const int64_t nb = k / QK_K;
2365
2366    uint32_t aux32;
2367
2368    for (int i = 0; i < nb; i++) {
2369
2370        const float d = GGML_FP16_TO_FP32(x[i].d);
2371        const uint8_t * qs = x[i].qs;
2372        const uint8_t * scales_and_signs = qs + QK_K/4;
2373
2374        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
2375            memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
2376            const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
2377            for (int l = 0; l < 4; ++l) {
2378                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
2379                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
2380                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
2381                for (int j = 0; j < 4; ++j) {
2382                    y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
2383                    y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
2384                }
2385                y += 8;
2386            }
2387            qs += 8;
2388        }
2389    }
2390}
2391
2392// ====================== 3.3125 bpw (de)-quantization
2393
2394void dequantize_row_iq3_s(const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2395    assert(k % QK_K == 0);
2396    const int64_t nb = k / QK_K;
2397
2398    for (int i = 0; i < nb; i++) {
2399
2400        const float d = GGML_FP16_TO_FP32(x[i].d);
2401        const uint8_t * qs = x[i].qs;
2402        const uint8_t * qh = x[i].qh;
2403        const uint8_t * signs = x[i].signs;
2404
2405        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2406            const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
2407            const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >>  4));
2408            for (int l = 0; l < 4; ++l) {
2409                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
2410                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
2411                for (int j = 0; j < 4; ++j) {
2412                    y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
2413                    y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
2414                }
2415                y += 8;
2416            }
2417            qs += 8;
2418            signs += 4;
2419            for (int l = 0; l < 4; ++l) {
2420                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
2421                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
2422                for (int j = 0; j < 4; ++j) {
2423                    y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
2424                    y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
2425                }
2426                y += 8;
2427            }
2428            qh += 2;
2429            qs += 8;
2430            signs += 4;
2431        }
2432    }
2433}
2434
2435// ====================== 1.5625 bpw (de)-quantization
2436
2437void dequantize_row_iq1_s(const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2438    assert(k % QK_K == 0);
2439    const int64_t nb = k / QK_K;
2440
2441    for (int i = 0; i < nb; i++) {
2442
2443        const float d = GGML_FP16_TO_FP32(x[i].d);
2444        const uint8_t  * qs = x[i].qs;
2445        const uint16_t * qh = x[i].qh;
2446
2447        for (int ib = 0; ib < QK_K/32; ++ib) {
2448            const float dl = d * (2*((qh[ib] >> 12) & 7) + 1);
2449            const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA;
2450            for (int l = 0; l < 4; ++l) {
2451                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
2452                for (int j = 0; j < 8; ++j) {
2453                    y[j] = dl * (grid[j] + delta);
2454                }
2455                y += 8;
2456            }
2457            qs += 4;
2458        }
2459    }
2460}
2461
2462void dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2463    assert(k % QK_K == 0);
2464    const int64_t nb = k / QK_K;
2465
2466    float delta[4];
2467    uint16_t idx[4];
2468
2469    iq1m_scale_t scale;
2470
2471    for (int i = 0; i < nb; i++) {
2472
2473        const uint16_t * sc = (const uint16_t *)x[i].scales;
2474        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
2475        const float d = GGML_FP16_TO_FP32(scale.f16);
2476
2477        const uint8_t * qs = x[i].qs;
2478        const uint8_t * qh = x[i].qh;
2479
2480        for (int ib = 0; ib < QK_K/32; ++ib) {
2481            const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
2482            const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
2483
2484            idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
2485            idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
2486            idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
2487            idx[3] = qs[3] | ((qh[1] << 4) & 0x700);
2488            delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
2489            delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
2490            delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
2491            delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
2492            for (int l = 0; l < 2; ++l) {
2493                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
2494                for (int j = 0; j < 8; ++j) {
2495                    y[j] = dl1 * (grid[j] + delta[l]);
2496                }
2497                y += 8;
2498            }
2499            for (int l = 2; l < 4; ++l) {
2500                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
2501                for (int j = 0; j < 8; ++j) {
2502                    y[j] = dl2 * (grid[j] + delta[l]);
2503                }
2504                y += 8;
2505            }
2506            qs += 4;
2507            qh += 2;
2508        }
2509    }
2510}
2511
2512void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2513    assert(k % QK4_NL == 0);
2514    const int64_t nb = k / QK4_NL;
2515
2516    for (int i = 0; i < nb; i++) {
2517
2518        const uint8_t * qs = x[i].qs;
2519
2520        const float d = GGML_FP16_TO_FP32(x[i].d);
2521        for (int j = 0; j < QK4_NL/2; ++j) {
2522            y[j+       0] = d * kvalues_iq4nl[qs[j] & 0xf];
2523            y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >>  4];
2524        }
2525        y  += QK4_NL;
2526        qs += QK4_NL/2;
2527    }
2528}
2529
2530void dequantize_row_iq4_xs(const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2531    assert(k % QK_K == 0);
2532    const int64_t nb = k / QK_K;
2533
2534    for (int i = 0; i < nb; i++) {
2535
2536        const uint8_t * qs = x[i].qs;
2537
2538        const float d = GGML_FP16_TO_FP32(x[i].d);
2539
2540        for (int ib = 0; ib < QK_K/32; ++ib) {
2541            const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4);
2542            const float dl = d * (ls - 32);
2543            for (int j = 0; j < 16; ++j) {
2544                y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
2545                y[j+16] = dl * kvalues_iq4nl[qs[j] >>  4];
2546            }
2547            y  += 32;
2548            qs += 16;
2549        }
2550    }
2551}
2552
2553//===================================== Q8_K ==============================================
2554
2555void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k) {
2556    assert(k % QK_K == 0);
2557    const int64_t nb = k / QK_K;
2558
2559    for (int i = 0; i < nb; i++) {
2560
2561        float max = 0;
2562        float amax = 0;
2563        for (int j = 0; j < QK_K; ++j) {
2564            float ax = fabsf(x[j]);
2565            if (ax > amax) {
2566                amax = ax; max = x[j];
2567            }
2568        }
2569        if (!amax) {
2570            y[i].d = 0;
2571            memset(y[i].qs, 0, QK_K);
2572            x += QK_K;
2573            continue;
2574        }
2575        //const float iscale = -128.f/max;
2576        // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
2577        const float iscale = -127.f/max;
2578        for (int j = 0; j < QK_K; ++j) {
2579            int v = nearest_int(iscale*x[j]);
2580            y[i].qs[j] = MIN(127, v);
2581        }
2582        for (int j = 0; j < QK_K/16; ++j) {
2583            int sum = 0;
2584            for (int ii = 0; ii < 16; ++ii) {
2585                sum += y[i].qs[j*16 + ii];
2586            }
2587            y[i].bsums[j] = sum;
2588        }
2589        y[i].d = 1/iscale;
2590        x += QK_K;
2591    }
2592}
2593
2594void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2595    assert(k % QK_K == 0);
2596    const int64_t nb = k / QK_K;
2597
2598    for (int i = 0; i < nb; i++) {
2599        for (int j = 0; j < QK_K; ++j) {
2600            *y++ = x[i].d * x[i].qs[j];
2601        }
2602    }
2603}
2604
2605// ================================ IQ2 quantization =============================================
2606
2607typedef struct {
2608    uint64_t * grid;
2609    int      * map;
2610    uint16_t * neighbours;
2611} iq2_entry_t;
2612
2613static iq2_entry_t iq2_data[4] = {
2614    {NULL, NULL, NULL},
2615    {NULL, NULL, NULL},
2616    {NULL, NULL, NULL},
2617    {NULL, NULL, NULL},
2618};
2619
2620static inline int iq2_data_index(enum ggml_type type) {
2621    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
2622    return type == GGML_TYPE_IQ2_XXS ? 0 :
2623           type == GGML_TYPE_IQ2_XS  ? 1 :
2624           type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3;
2625}
2626
2627static inline int iq2_grid_size(enum ggml_type type) {
2628    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
2629    return type == GGML_TYPE_IQ2_XXS ? 256 :
2630           type == GGML_TYPE_IQ2_XS  ? 512 :
2631           type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024;
2632}
2633
2634static int iq2_compare_func(const void * left, const void * right) {
2635    const int * l = (const int *)left;
2636    const int * r = (const int *)right;
2637    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
2638}
2639
2640void iq2xs_init_impl(enum ggml_type type) {
2641    const int gindex = iq2_data_index(type);
2642    const int grid_size = iq2_grid_size(type);
2643    if (iq2_data[gindex].grid) {
2644        return;
2645    }
2646    static const uint16_t kgrid_2bit_256[256] = {
2647            0,     2,     5,     8,    10,    17,    20,    32,    34,    40,    42,    65,    68,    80,    88,    97,
2648          100,   128,   130,   138,   162,   257,   260,   272,   277,   320,   388,   408,   512,   514,   546,   642,
2649         1025,  1028,  1040,  1057,  1060,  1088,  1090,  1096,  1120,  1153,  1156,  1168,  1188,  1280,  1282,  1288,
2650         1312,  1350,  1385,  1408,  1425,  1545,  1552,  1600,  1668,  1700,  2048,  2053,  2056,  2068,  2088,  2113,
2651         2116,  2128,  2130,  2184,  2308,  2368,  2562,  2580,  4097,  4100,  4112,  4129,  4160,  4192,  4228,  4240,
2652         4245,  4352,  4360,  4384,  4432,  4442,  4480,  4644,  4677,  5120,  5128,  5152,  5157,  5193,  5248,  5400,
2653         5474,  5632,  5654,  6145,  6148,  6160,  6208,  6273,  6400,  6405,  6560,  6737,  8192,  8194,  8202,  8260,
2654         8289,  8320,  8322,  8489,  8520,  8704,  8706,  9217,  9220,  9232,  9280,  9302,  9472,  9537,  9572,  9872,
2655        10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
2656        16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
2657        17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
2658        20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
2659        22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
2660        25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
2661        33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
2662        37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
2663    };
2664    static const uint16_t kgrid_2bit_512[512] = {
2665            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
2666           73,    80,    82,    85,    88,    97,   100,   128,   130,   133,   136,   145,   148,   153,   160,   257,
2667          260,   262,   265,   272,   274,   277,   280,   282,   289,   292,   320,   322,   325,   328,   337,   340,
2668          352,   360,   385,   388,   400,   512,   514,   517,   520,   529,   532,   544,   577,   580,   592,   597,
2669          640,   650,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1088,  1090,  1093,  1096,
2670         1105,  1108,  1110,  1120,  1153,  1156,  1168,  1280,  1282,  1285,  1288,  1297,  1300,  1312,  1345,  1348,
2671         1360,  1377,  1408,  1537,  1540,  1552,  1574,  1600,  1602,  1668,  2048,  2050,  2053,  2056,  2058,  2065,
2672         2068,  2080,  2085,  2113,  2116,  2128,  2136,  2176,  2208,  2218,  2305,  2308,  2320,  2368,  2433,  2441,
2673         2560,  2592,  2600,  2710,  2720,  4097,  4100,  4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4160,
2674         4162,  4165,  4168,  4177,  4180,  4192,  4202,  4225,  4228,  4240,  4352,  4354,  4357,  4360,  4369,  4372,
2675         4384,  4417,  4420,  4432,  4480,  4500,  4502,  4609,  4612,  4614,  4624,  4672,  4704,  5120,  5122,  5125,
2676         5128,  5137,  5140,  5152,  5185,  5188,  5193,  5200,  5220,  5248,  5377,  5380,  5392,  5440,  5632,  5652,
2677         5705,  6145,  6148,  6160,  6162,  6208,  6228,  6278,  6400,  6405,  6502,  6737,  6825,  8192,  8194,  8197,
2678         8200,  8202,  8209,  8212,  8224,  8257,  8260,  8272,  8320,  8352,  8449,  8452,  8464,  8512,  8520,  8549,
2679         8704,  8738,  8832,  8872,  9217,  9220,  9232,  9257,  9280,  9472,  9537,  9554,  9625,  9729,  9754,  9894,
2680        10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
2681        16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
2682        16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
2683        16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
2684        17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
2685        18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
2686        20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
2687        21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
2688        22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
2689        24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
2690        32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
2691        33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
2692        33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
2693        35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
2694        37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
2695        40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
2696        42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
2697    };
2698    static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = {
2699            0,     2,     5,     8,    10,    17,    21,    32,    34,    40,    42,    69,    81,    84,    86,   101,
2700          128,   130,   136,   138,   149,   160,   162,   168,   170,   260,   261,   273,   276,   278,   281,   282,
2701          293,   321,   326,   329,   338,   341,   346,   353,   356,   358,   360,   389,   401,   404,   406,   421,
2702          512,   514,   520,   522,   533,   544,   546,   552,   554,   581,   593,   601,   612,   617,   640,   642,
2703          648,   650,   657,   661,   665,   672,   674,   680,   682,  1041,  1044,  1046,  1061,  1089,  1097,  1109,
2704         1114,  1124,  1125,  1169,  1177,  1189,  1281,  1284,  1285,  1286,  1301,  1304,  1306,  1321,  1344,  1349,
2705         1354,  1360,  1361,  1364,  1365,  1366,  1369,  1376,  1378,  1381,  1384,  1386,  1409,  1425,  1429,  1432,
2706         1434,  1441,  1444,  1445,  1446,  1449,  1556,  1561,  1601,  1604,  1616,  1618,  1621,  1624,  1632,  1633,
2707         1638,  1641,  1669,  1681,  1684,  1689,  2048,  2050,  2056,  2058,  2069,  2080,  2082,  2088,  2090,  2117,
2708         2129,  2134,  2149,  2176,  2178,  2184,  2186,  2197,  2208,  2210,  2216,  2218,  2309,  2321,  2324,  2329,
2709         2340,  2341,  2369,  2384,  2385,  2389,  2401,  2404,  2409,  2449,  2452,  2454,  2457,  2469,  2560,  2562,
2710         2568,  2570,  2581,  2592,  2594,  2600,  2602,  2629,  2641,  2649,  2657,  2661,  2688,  2690,  2693,  2696,
2711         2698,  2709,  2720,  2722,  2728,  2730,  4112,  4113,  4116,  4121,  4132,  4133,  4161,  4164,  4176,  4181,
2712         4184,  4193,  4196,  4197,  4201,  4241,  4244,  4246,  4257,  4261,  4353,  4356,  4358,  4361,  4368,  4370,
2713         4373,  4376,  4385,  4388,  4393,  4421,  4426,  4432,  4433,  4434,  4436,  4437,  4438,  4441,  4448,  4453,
2714         4484,  4498,  4501,  4513,  4516,  4625,  4628,  4630,  4645,  4672,  4678,  4681,  4690,  4693,  4696,  4698,
2715         4708,  4710,  4741,  4753,  4756,  4758,  4773,  5121,  5126,  5129,  5140,  5141,  5144,  5145,  5153,  5158,
2716         5185,  5189,  5190,  5192,  5194,  5201,  5204,  5205,  5206,  5209,  5218,  5221,  5224,  5252,  5257,  5264,
2717         5268,  5269,  5272,  5273,  5274,  5281,  5284,  5285,  5289,  5378,  5381,  5386,  5393,  5396,  5397,  5398,
2718         5401,  5408,  5410,  5413,  5416,  5418,  5441,  5444,  5445,  5446,  5457,  5458,  5460,  5461,  5462,  5465,
2719         5466,  5473,  5476,  5477,  5478,  5481,  5504,  5506,  5508,  5509,  5512,  5514,  5520,  5521,  5524,  5525,
2720         5526,  5529,  5530,  5536,  5538,  5541,  5633,  5636,  5637,  5638,  5653,  5654,  5656,  5658,  5665,  5670,
2721         5696,  5698,  5700,  5701,  5704,  5706,  5713,  5717,  5718,  5720,  5721,  5729,  5732,  5733,  5736,  5737,
2722         5738,  5766,  5770,  5778,  5781,  5796,  5801,  6161,  6166,  6181,  6209,  6212,  6214,  6217,  6224,  6229,
2723         6232,  6234,  6240,  6241,  6244,  6246,  6249,  6277,  6289,  6292,  6309,  6416,  6418,  6421,  6426,  6433,
2724         6437,  6466,  6468,  6469,  6472,  6481,  6484,  6485,  6486,  6489,  6490,  6496,  6501,  6506,  6537,  6545,
2725         6546,  6549,  6552,  6561,  6566,  6569,  6665,  6678,  6692,  6694,  6724,  6726,  6729,  6736,  6738,  6741,
2726         6744,  6753,  6758,  6761,  6789,  6801,  6806,  6810,  8192,  8194,  8200,  8202,  8213,  8224,  8226,  8229,
2727         8232,  8234,  8261,  8273,  8281,  8289,  8293,  8320,  8322,  8328,  8330,  8341,  8352,  8354,  8357,  8360,
2728         8362,  8453,  8465,  8468,  8473,  8485,  8514,  8516,  8521,  8533,  8536,  8538,  8545,  8548,  8549,  8550,
2729         8581,  8592,  8598,  8601,  8613,  8705,  8712,  8714,  8721,  8725,  8736,  8738,  8744,  8746,  8773,  8785,
2730         8790,  8793,  8805,  8833,  8840,  8842,  8849,  8853,  8864,  8866,  8872,  8874,  9221,  9236,  9238,  9241,
2731         9253,  9284,  9285,  9286,  9289,  9298,  9301,  9304,  9306,  9318,  9349,  9361,  9364,  9369,  9377,  9381,
2732         9481,  9493,  9505,  9513,  9536,  9541,  9544,  9553,  9556,  9557,  9561,  9570,  9573,  9576,  9609,  9616,
2733         9620,  9621,  9624,  9626,  9633,  9636,  9638,  9641,  9733,  9744,  9746,  9753,  9765,  9793,  9801,  9813,
2734         9824,  9825,  9833,  9860,  9862,  9872,  9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282,
2735        10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521,
2736        10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752,
2737        10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890,
2738        10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484,
2739        16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673,
2740        16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772,
2741        16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986,
2742        16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494,
2743        17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666,
2744        17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744,
2745        17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809,
2746        17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953,
2747        17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049,
2748        18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517,
2749        18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704,
2750        18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784,
2751        18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012,
2752        19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501,
2753        20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617,
2754        20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761,
2755        20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822,
2756        20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896,
2757        20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078,
2758        21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526,
2759        21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589,
2760        21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653,
2761        21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780,
2762        21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832,
2763        21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864,
2764        21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924,
2765        21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048,
2766        22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098,
2767        22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154,
2768        22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561,
2769        22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665,
2770        22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821,
2771        22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884,
2772        22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061,
2773        23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144,
2774        23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656,
2775        24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850,
2776        24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970,
2777        24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221,
2778        25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674,
2779        25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749,
2780        25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926,
2781        25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001,
2782        26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176,
2783        26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250,
2784        26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721,
2785        26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949,
2786        26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044,
2787        27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270,
2788        27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852,
2789        32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046,
2790        33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161,
2791        33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369,
2792        33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877,
2793        33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117,
2794        34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192,
2795        34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394,
2796        34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858,
2797        34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986,
2798        35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172,
2799        35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412,
2800        35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901,
2801        36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124,
2802        37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205,
2803        37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396,
2804        37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889,
2805        37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985,
2806        37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161,
2807        38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226,
2808        38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290,
2809        38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432,
2810        38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538,
2811        38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998,
2812        39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194,
2813        39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269,
2814        39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497,
2815        39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994,
2816        41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130,
2817        41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349,
2818        41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561,
2819        41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068,
2820        42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278,
2821        42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386,
2822        42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592,
2823        42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048,
2824        43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284,
2825        43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530,
2826        43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690,
2827    };
2828    static const uint16_t kgrid_2bit_1024[1024] = {
2829            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
2830           73,    80,    82,    85,    88,    97,   100,   102,   105,   128,   130,   133,   136,   145,   148,   160,
2831          165,   170,   257,   260,   262,   265,   272,   274,   277,   280,   289,   292,   320,   322,   325,   328,
2832          337,   340,   342,   345,   352,   357,   360,   385,   388,   400,   402,   405,   417,   420,   512,   514,
2833          517,   520,   529,   532,   544,   554,   577,   580,   582,   585,   592,   597,   640,   645,   650,   660,
2834          674,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1062,  1065,  1088,  1090,  1093,
2835         1096,  1098,  1105,  1108,  1110,  1113,  1120,  1122,  1125,  1153,  1156,  1158,  1161,  1168,  1173,  1176,
2836         1185,  1188,  1280,  1282,  1285,  1288,  1290,  1297,  1300,  1302,  1305,  1312,  1317,  1320,  1345,  1348,
2837         1350,  1353,  1360,  1362,  1365,  1368,  1377,  1380,  1408,  1410,  1413,  1416,  1425,  1428,  1440,  1537,
2838         1540,  1542,  1545,  1552,  1557,  1600,  1605,  1608,  1617,  1620,  1632,  1665,  1668,  1680,  2048,  2050,
2839         2053,  2056,  2065,  2068,  2070,  2073,  2080,  2085,  2090,  2113,  2116,  2118,  2121,  2128,  2130,  2133,
2840         2136,  2145,  2148,  2176,  2181,  2196,  2218,  2305,  2308,  2320,  2322,  2325,  2328,  2337,  2368,  2373,
2841         2376,  2385,  2388,  2400,  2433,  2448,  2560,  2577,  2580,  2594,  2600,  2602,  2640,  2713,  4097,  4100,
2842         4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4134,  4160,  4162,  4165,  4168,  4177,  4180,  4182,
2843         4185,  4192,  4194,  4197,  4200,  4225,  4228,  4230,  4240,  4245,  4248,  4257,  4260,  4352,  4354,  4357,
2844         4360,  4362,  4369,  4372,  4374,  4377,  4384,  4386,  4389,  4392,  4417,  4420,  4422,  4425,  4432,  4434,
2845         4437,  4440,  4449,  4452,  4480,  4482,  4485,  4488,  4497,  4500,  4609,  4612,  4617,  4624,  4629,  4641,
2846         4644,  4672,  4677,  4689,  4692,  4737,  4740,  4752,  5120,  5122,  5125,  5128,  5137,  5140,  5142,  5145,
2847         5152,  5157,  5160,  5185,  5188,  5190,  5193,  5200,  5202,  5205,  5208,  5217,  5220,  5248,  5250,  5253,
2848         5256,  5265,  5268,  5280,  5377,  5380,  5382,  5385,  5392,  5394,  5397,  5400,  5409,  5412,  5440,  5442,
2849         5445,  5448,  5457,  5460,  5472,  5505,  5508,  5520,  5632,  5637,  5640,  5649,  5652,  5664,  5697,  5700,
2850         5712,  5760,  5802,  6145,  6148,  6150,  6153,  6160,  6165,  6168,  6177,  6208,  6210,  6213,  6216,  6225,
2851         6228,  6240,  6273,  6276,  6400,  6402,  6405,  6408,  6417,  6420,  6432,  6465,  6468,  6480,  6505,  6562,
2852         6660,  6672,  6720,  6742,  8192,  8194,  8197,  8200,  8209,  8212,  8214,  8217,  8224,  8229,  8234,  8257,
2853         8260,  8272,  8274,  8277,  8292,  8320,  8330,  8340,  8362,  8449,  8452,  8464,  8466,  8469,  8481,  8512,
2854         8514,  8517,  8529,  8532,  8544,  8577,  8580,  8592,  8704,  8714,  8738,  8744,  8746,  8772,  8784,  8840,
2855         8842,  8872,  9217,  9220,  9222,  9225,  9232,  9237,  9240,  9249,  9252,  9280,  9282,  9285,  9288,  9297,
2856         9300,  9312,  9345,  9348,  9360,  9472,  9477,  9480,  9489,  9492,  9504,  9537,  9540,  9552,  9574,  9600,
2857         9729,  9732,  9744,  9792,  9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500,
2858        10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410,
2859        16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513,
2860        16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674,
2861        16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785,
2862        16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025,
2863        17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476,
2864        17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665,
2865        17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760,
2866        17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085,
2867        18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528,
2868        18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948,
2869        18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548,
2870        20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740,
2871        20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865,
2872        20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510,
2873        21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636,
2874        21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054,
2875        22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800,
2876        22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645,
2877        24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912,
2878        24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680,
2879        25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880,
2880        26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850,
2881        32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060,
2882        33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345,
2883        33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873,
2884        33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176,
2885        34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076,
2886        35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928,
2887        36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200,
2888        37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968,
2889        38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976,
2890        39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130,
2891        41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121,
2892        42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690,
2893    };
2894
2895    const int kmap_size = 43692;
2896    //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
2897    const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
2898    const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
2899                             type == GGML_TYPE_IQ2_XS  ? kgrid_2bit_512 :
2900                             type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024;
2901    uint64_t * kgrid_q2xs;
2902    int      * kmap_q2xs;
2903    uint16_t * kneighbors_q2xs;
2904
2905    //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
2906    uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
2907    for (int k = 0; k < grid_size; ++k) {
2908        int8_t * pos = (int8_t *)(the_grid + k);
2909        for (int i = 0; i < 8; ++i) {
2910            int l = (kgrid[k] >> 2*i) & 0x3;
2911            pos[i] = 2*l + 1;
2912        }
2913    }
2914    kgrid_q2xs = the_grid;
2915    iq2_data[gindex].grid = the_grid;
2916    kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
2917    iq2_data[gindex].map = kmap_q2xs;
2918    for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
2919    uint64_t aux64;
2920    uint8_t * aux8 = (uint8_t *)&aux64;
2921    for (int i = 0; i < grid_size; ++i) {
2922        aux64 = kgrid_q2xs[i];
2923        uint16_t index = 0;
2924        for (int k=0; k<8; ++k) {
2925            uint16_t q = (aux8[k] - 1)/2;
2926            index |= (q << 2*k);
2927        }
2928        kmap_q2xs[index] = i;
2929    }
2930    int8_t pos[8];
2931    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
2932    int num_neighbors = 0, num_not_in_map = 0;
2933    for (int i = 0; i < kmap_size; ++i) {
2934        if (kmap_q2xs[i] >= 0) continue;
2935        ++num_not_in_map;
2936        for (int k = 0; k < 8; ++k) {
2937            int l = (i >> 2*k) & 0x3;
2938            pos[k] = 2*l + 1;
2939        }
2940        for (int j = 0; j < grid_size; ++j) {
2941            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
2942            int d2 = 0;
2943            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
2944            dist2[2*j+0] = d2;
2945            dist2[2*j+1] = j;
2946        }
2947        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
2948        int n = 0; int d2 = dist2[0];
2949        int nhave = 1;
2950        for (int j = 0; j < grid_size; ++j) {
2951            if (dist2[2*j] > d2) {
2952                if (nhave == nwant) break;
2953                d2 = dist2[2*j];
2954                ++nhave;
2955            }
2956            ++n;
2957        }
2958        num_neighbors += n;
2959    }
2960    //printf("%s: %d neighbours in total\n", __func__, num_neighbors);
2961    kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
2962    iq2_data[gindex].neighbours = kneighbors_q2xs;
2963    int counter = 0;
2964    for (int i = 0; i < kmap_size; ++i) {
2965        if (kmap_q2xs[i] >= 0) continue;
2966        for (int k = 0; k < 8; ++k) {
2967            int l = (i >> 2*k) & 0x3;
2968            pos[k] = 2*l + 1;
2969        }
2970        for (int j = 0; j < grid_size; ++j) {
2971            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
2972            int d2 = 0;
2973            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
2974            dist2[2*j+0] = d2;
2975            dist2[2*j+1] = j;
2976        }
2977        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
2978        kmap_q2xs[i] = -(counter + 1);
2979        int d2 = dist2[0];
2980        uint16_t * start = &kneighbors_q2xs[counter++];
2981        int n = 0, nhave = 1;
2982        for (int j = 0; j < grid_size; ++j) {
2983            if (dist2[2*j] > d2) {
2984                if (nhave == nwant) break;
2985                d2 = dist2[2*j];
2986                ++nhave;
2987            }
2988            kneighbors_q2xs[counter++] = dist2[2*j+1];
2989            ++n;
2990        }
2991        *start = n;
2992    }
2993    free(dist2);
2994}
2995
2996void iq2xs_free_impl(enum ggml_type type) {
2997    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
2998    const int gindex = iq2_data_index(type);
2999    if (iq2_data[gindex].grid) {
3000        free(iq2_data[gindex].grid);       iq2_data[gindex].grid = NULL;
3001        free(iq2_data[gindex].map);        iq2_data[gindex].map  = NULL;
3002        free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
3003    }
3004}
3005
3006static int iq2_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,
3007        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, int8_t * GGML_RESTRICT L) {
3008    int num_neighbors = neighbours[0];
3009    GGML_ASSERT(num_neighbors > 0);
3010    float best_d2 = FLT_MAX;
3011    int grid_index = -1;
3012    for (int j = 1; j <= num_neighbors; ++j) {
3013        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
3014        float d2 = 0;
3015        for (int i = 0; i < 8; ++i) {
3016            float q = pg[i];
3017            float diff = scale*q - xval[i];
3018            d2 += weight[i]*diff*diff;
3019        }
3020        if (d2 < best_d2) {
3021            best_d2 = d2; grid_index = neighbours[j];
3022        }
3023    }
3024    GGML_ASSERT(grid_index >= 0);
3025    const int8_t * pg = (const int8_t *)(grid + grid_index);
3026    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
3027    return grid_index;
3028}
3029
3030static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {
3031
3032    const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
3033
3034    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
3035    const int      * kmap_q2xs       = iq2_data[gindex].map;
3036    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
3037
3038    GGML_ASSERT(quant_weights   && "missing quantization weights");
3039    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
3040    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
3041    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
3042    GGML_ASSERT(n%QK_K == 0);
3043
3044    const int kMaxQ = 3;
3045
3046    const int64_t nbl = n/QK_K;
3047
3048    block_iq2_xxs * y = vy;
3049
3050    float scales[QK_K/32];
3051    float weight[32];
3052    float xval[32];
3053    int8_t L[32];
3054    int8_t Laux[32];
3055    float  waux[32];
3056    uint8_t block_signs[4];
3057    uint32_t q2[2*(QK_K/32)];
3058
3059    for (int ibl = 0; ibl < nbl; ++ibl) {
3060
3061        y[ibl].d = GGML_FP32_TO_FP16(0.f);
3062        memset(q2, 0, QK_K/4);
3063
3064        float max_scale = 0;
3065
3066        const float * xbl = x + QK_K*ibl;
3067        float sumx2 = 0;
3068        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
3069        float sigma2 = sumx2/QK_K;
3070
3071        for (int ib = 0; ib < QK_K/32; ++ib) {
3072            const float * xb = xbl + 32*ib;
3073            const float * qw = quant_weights + QK_K*ibl + 32*ib;
3074            for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
3075            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
3076            for (int k = 0; k < 4; ++k) {
3077                int nflip = 0;
3078                uint8_t s = 0;
3079                for (int i = 0; i < 8; ++i) {
3080                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
3081                    else {
3082                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
3083                    }
3084                }
3085                if (nflip%2) {
3086                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
3087                    for (int i = 1; i < 8; ++i) {
3088                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
3089                        if (ax < min) {
3090                            min = ax; imin = i;
3091                        }
3092                    }
3093                    xval[8*k+imin] = -xval[8*k+imin];
3094                    s ^= (1 << imin);
3095                }
3096                block_signs[k] = s & 127;
3097            }
3098            float max = xval[0];
3099            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
3100            if (max < GROUP_MAX_EPS) {
3101                scales[ib] = 0;
3102                memset(L, 0, 32);
3103                continue;
3104            }
3105            float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
3106            float eff_max = scale*kMaxQ;
3107            float best = 0;
3108            for (int is = -6; is <= 6; ++is) {
3109                float id = (2*kMaxQ-1+is*0.1f)/eff_max;
3110                float this_scale = 1/id;
3111                for (int k = 0; k < 4; ++k) {
3112                    for (int i = 0; i < 8; ++i) {
3113                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
3114                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
3115                    }
3116                    uint16_t u = 0;
3117                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
3118                    int grid_index = kmap_q2xs[u];
3119                    if (grid_index < 0) {
3120                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
3121                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
3122                    }
3123                }
3124                float sumqx = 0, sumq2 = 0;
3125                for (int i = 0; i < 32; ++i) {
3126                    float w = weight[i];
3127                    float q = 2*Laux[i] + 1;
3128                    sumqx += w*xval[i]*q;
3129                    sumq2 += w*q*q;
3130                }
3131                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
3132                    scale = sumqx/sumq2; best = scale*sumqx;
3133                    memcpy(L, Laux, 32);
3134                }
3135            }
3136            if (scale > 0) {
3137                float id = 1/scale;
3138                for (int k = 0; k < 4; ++k) {
3139                    uint16_t u = 0;
3140                    for (int i = 0; i < 8; ++i) {
3141                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
3142                        l = MAX(0, MIN(kMaxQ-1, l));
3143                        u |= (l << 2*i);
3144                    }
3145                    int grid_index = kmap_q2xs[u];
3146                    if (grid_index < 0) {
3147                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
3148                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
3149                    }
3150                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
3151                    for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
3152                }
3153                float sumqx = 0, sumq2 = 0;
3154                for (int i = 0; i < 32; ++i) {
3155                    float w = weight[i];
3156                    float q = 2*L[i] + 1;
3157                    sumqx += w*xval[i]*q;
3158                    sumq2 += w*q*q;
3159                }
3160                if (sumq2 > 0) scale = sumqx/sumq2;
3161            }
3162            if (scale < 0) {
3163                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
3164                // and correspondingly flip quant signs.
3165                scale = -scale;
3166                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
3167            }
3168            for (int k = 0; k < 4; ++k) {
3169                uint16_t u = 0;
3170                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
3171                int grid_index = kmap_q2xs[u];
3172                if (grid_index < 0) {
3173                    printf("Oops: found point %u not on grid:", u);
3174                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
3175                    printf("\n");
3176                    GGML_ABORT("fatal error");
3177                }
3178                q2[2*ib+0] |= ((uint32_t) grid_index << 8*k);
3179                q2[2*ib+1] |= (block_signs[k] << 7*k);
3180            }
3181            GGML_ASSERT(scale >= 0);
3182            scales[ib] = scale;
3183            max_scale = MAX(max_scale, scale);
3184        }
3185
3186        if (!max_scale) {
3187            memset(y[ibl].qs, 0, QK_K/4);
3188            continue;
3189        }
3190
3191        float d = max_scale/31;
3192        y[ibl].d = GGML_FP32_TO_FP16(d);
3193        float id = 1/d;
3194        for (int ib = 0; ib < QK_K/32; ++ib) {
3195            int l = nearest_int(0.5f*(id*scales[ib]-1));
3196            l = MAX(0, MIN(15, l));
3197            q2[2*ib+1] |= ((uint32_t)l << 28);
3198        }
3199        memcpy(y[ibl].qs, q2, QK_K/4);
3200    }
3201}
3202
3203static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {
3204
3205    const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
3206
3207    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
3208    const int      * kmap_q2xs       = iq2_data[gindex].map;
3209    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
3210
3211    GGML_ASSERT(quant_weights   && "missing quantization weights");
3212    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
3213    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
3214    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
3215    GGML_ASSERT(n%QK_K == 0);
3216
3217    const int kMaxQ = 3;
3218
3219    const int64_t nbl = n/QK_K;
3220
3221    block_iq2_xs * y = vy;
3222
3223    float scales[QK_K/16];
3224    float weight[16];
3225    float xval[16];
3226    int8_t L[16];
3227    int8_t Laux[16];
3228    float  waux[16];
3229    bool   is_on_grid[2];
3230    bool   is_on_grid_aux[2];
3231    uint8_t block_signs[2];
3232    uint16_t q2[2*(QK_K/16)];
3233
3234    for (int ibl = 0; ibl < nbl; ++ibl) {
3235
3236        y[ibl].d = GGML_FP32_TO_FP16(0.f);
3237        memset(q2, 0, QK_K/4);
3238        memset(y[ibl].scales, 0, QK_K/32);
3239
3240        float max_scale = 0;
3241
3242        const float * xbl = x + QK_K*ibl;
3243        float sumx2 = 0;
3244        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
3245        float sigma2 = sumx2/QK_K;
3246
3247        for (int ib = 0; ib < QK_K/16; ++ib) {
3248            const float * xb = xbl + 16*ib;
3249            const float * qw = quant_weights + QK_K*ibl + 16*ib;
3250            for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
3251            for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
3252            for (int k = 0; k < 2; ++k) {
3253                int nflip = 0;
3254                uint8_t s = 0;
3255                for (int i = 0; i < 8; ++i) {
3256                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
3257                    else {
3258                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
3259                    }
3260                }
3261                if (nflip%2) {
3262                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
3263                    for (int i = 1; i < 8; ++i) {
3264                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
3265                        if (ax < min) {
3266                            min = ax; imin = i;
3267                        }
3268                    }
3269                    xval[8*k+imin] = -xval[8*k+imin];
3270                    s ^= (1 << imin);
3271                }
3272                block_signs[k] = s & 127;
3273            }
3274            float max = xval[0];
3275            for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
3276            if (max < GROUP_MAX_EPS) {
3277                scales[ib] = 0;
3278                memset(L, 0, 16);
3279                continue;
3280            }
3281            float best = 0;
3282            float scale = max/(2*kMaxQ-1);
3283            is_on_grid[0] = is_on_grid[1] = true;
3284            for (int is = -9; is <= 9; ++is) {
3285                float id = (2*kMaxQ-1+is*0.1f)/max;
3286                float this_scale = 1/id;
3287                for (int k = 0; k < 2; ++k) {
3288                    for (int i = 0; i < 8; ++i) {
3289                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
3290                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
3291                    }
3292                    uint16_t u = 0;
3293                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
3294                    int grid_index = kmap_q2xs[u];
3295                    is_on_grid_aux[k] = true;
3296                    if (grid_index < 0) {
3297                        is_on_grid_aux[k] = false;
3298                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
3299                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
3300                    }
3301                }
3302                float sumqx = 0, sumq2 = 0;
3303                for (int i = 0; i < 16; ++i) {
3304                    float w = weight[i];
3305                    float q = 2*Laux[i] + 1;
3306                    sumqx += w*xval[i]*q;
3307                    sumq2 += w*q*q;
3308                }
3309                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
3310                    scale = sumqx/sumq2; best = scale*sumqx;
3311                    for (int i = 0; i < 16; ++i) L[i] = Laux[i];
3312                    for (int k = 0; k <  2; ++k) is_on_grid[k] = is_on_grid_aux[k];
3313                }
3314            }
3315            int n_not_ongrid = 0;
3316            for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
3317            if (n_not_ongrid > 0 && scale > 0) {
3318                float id = 1/scale;
3319                for (int k = 0; k < 2; ++k) {
3320                    if (is_on_grid[k]) continue;
3321                    uint16_t u = 0;
3322                    for (int i = 0; i < 8; ++i) {
3323                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
3324                        l = MAX(0, MIN(kMaxQ-1, l));
3325                        u |= (l << 2*i);
3326                        L[8*k + i] = l;
3327                    }
3328                    int grid_index = kmap_q2xs[u];
3329                    if (grid_index < 0) {
3330                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
3331                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
3332                    }
3333                }
3334                float sumqx = 0, sumq2 = 0;
3335                for (int i = 0; i < 16; ++i) {
3336                    float w = weight[i];
3337                    float q = 2*L[i] + 1;
3338                    sumqx += w*xval[i]*q;
3339                    sumq2 += w*q*q;
3340                }
3341                if (sumq2 > 0) scale = sumqx/sumq2;
3342            }
3343            if (scale < 0) {
3344                scale = -scale;
3345                for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
3346            }
3347            for (int k = 0; k < 2; ++k) {
3348                uint16_t u = 0;
3349                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
3350                int grid_index = kmap_q2xs[u];
3351                if (grid_index < 0) {
3352                    printf("Oops: found point %u not on grid:", u);
3353                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
3354                    printf("\n");
3355                    GGML_ABORT("fatal error");
3356                }
3357                q2[2*ib+k] = grid_index | (block_signs[k] << 9);
3358            }
3359            GGML_ASSERT(scale >= 0);
3360            scales[ib] = scale;
3361            max_scale = MAX(max_scale, scale);
3362        }
3363
3364        if (!max_scale) {
3365            memset(y[ibl].qs, 0, QK_K/4);
3366            continue;
3367        }
3368
3369        float d = max_scale/31;
3370        y[ibl].d = GGML_FP32_TO_FP16(d);
3371        float id = 1/d;
3372        for (int ib = 0; ib < QK_K/16; ++ib) {
3373            int l = nearest_int(0.5f*(id*scales[ib]-1));
3374            l = MAX(0, MIN(15, l));
3375            if (ib%2 == 0) y[ibl].scales[ib/2] = l;
3376            else y[ibl].scales[ib/2] |= (l << 4);
3377        }
3378        memcpy(y[ibl].qs, q2, QK_K/4);
3379
3380    }
3381}
3382
3383size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3384    GGML_ASSERT(n_per_row%QK_K == 0);
3385    int64_t nblock = n_per_row/QK_K;
3386    char * qrow = (char *)dst;
3387    for (int64_t row = 0; row < nrow; ++row) {
3388        quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
3389        src += n_per_row;
3390        qrow += nblock*sizeof(block_iq2_xxs);
3391    }
3392    return nrow * nblock * sizeof(block_iq2_xxs);
3393}
3394
3395size_t quantize_iq2_xs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3396    GGML_ASSERT(n_per_row%QK_K == 0);
3397    int64_t nblock = n_per_row/QK_K;
3398    char * qrow = (char *)dst;
3399    for (int64_t row = 0; row < nrow; ++row) {
3400        quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
3401        src += n_per_row;
3402        qrow += nblock*sizeof(block_iq2_xs);
3403    }
3404    return nrow * nblock * sizeof(block_iq2_xs);
3405}
3406
3407//
3408// ============================================= 3-bit using D4 lattice
3409//
3410
3411typedef struct {
3412    uint32_t * grid;
3413    int      * map;
3414    uint16_t * neighbours;
3415} iq3_entry_t;
3416
3417static iq3_entry_t iq3_data[2] = {
3418    {NULL, NULL, NULL},
3419    {NULL, NULL, NULL},
3420};
3421
3422static inline int iq3_data_index(int grid_size) {
3423    (void)grid_size;
3424    GGML_ASSERT(grid_size == 256 || grid_size == 512);
3425    return grid_size == 256 ? 0 : 1;
3426}
3427
3428static int iq3_compare_func(const void * left, const void * right) {
3429    const int * l = (const int *)left;
3430    const int * r = (const int *)right;
3431    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
3432}
3433
3434void iq3xs_init_impl(int grid_size) {
3435    const int gindex = iq3_data_index(grid_size);
3436    if (iq3_data[gindex].grid) {
3437        return;
3438    }
3439    static const uint16_t kgrid_256[256] = {
3440            0,     2,     4,     9,    11,    15,    16,    18,    25,    34,    59,    61,    65,    67,    72,    74,
3441           81,    85,    88,    90,    97,   108,   120,   128,   130,   132,   137,   144,   146,   153,   155,   159,
3442          169,   175,   189,   193,   199,   200,   202,   213,   248,   267,   287,   292,   303,   315,   317,   321,
3443          327,   346,   362,   413,   436,   456,   460,   462,   483,   497,   513,   515,   520,   522,   529,   531,
3444          536,   538,   540,   551,   552,   576,   578,   585,   592,   594,   641,   643,   648,   650,   657,   664,
3445          698,   704,   706,   720,   729,   742,   758,   769,   773,   808,   848,   852,   870,   889,   901,   978,
3446          992,  1024,  1026,  1033,  1035,  1040,  1042,  1046,  1049,  1058,  1089,  1091,  1093,  1096,  1098,  1105,
3447         1112,  1139,  1143,  1144,  1152,  1154,  1161,  1167,  1168,  1170,  1183,  1184,  1197,  1217,  1224,  1228,
3448         1272,  1276,  1309,  1323,  1347,  1367,  1377,  1404,  1473,  1475,  1486,  1509,  1537,  1544,  1546,  1553,
3449         1555,  1576,  1589,  1594,  1600,  1602,  1616,  1625,  1636,  1638,  1665,  1667,  1672,  1685,  1706,  1722,
3450         1737,  1755,  1816,  1831,  1850,  1856,  1862,  1874,  1901,  1932,  1950,  1971,  2011,  2032,  2052,  2063,
3451         2077,  2079,  2091,  2095,  2172,  2192,  2207,  2208,  2224,  2230,  2247,  2277,  2308,  2345,  2356,  2389,
3452         2403,  2424,  2501,  2504,  2506,  2520,  2570,  2593,  2616,  2624,  2630,  2646,  2669,  2700,  2714,  2746,
3453         2754,  2795,  2824,  2835,  2839,  2874,  2882,  2905,  2984,  3028,  3042,  3092,  3108,  3110,  3124,  3153,
3454         3185,  3215,  3252,  3288,  3294,  3364,  3397,  3434,  3483,  3523,  3537,  3587,  3589,  3591,  3592,  3610,
3455         3626,  3670,  3680,  3722,  3749,  3754,  3776,  3789,  3803,  3824,  3857,  3873,  3904,  3906,  3924,  3992,
3456    };
3457    static const uint16_t kgrid_512[512] = {
3458            0,     1,     2,     5,     7,     8,     9,    10,    12,    14,    16,    17,    21,    27,    32,    34,
3459           37,    39,    41,    43,    48,    50,    57,    60,    63,    64,    65,    66,    68,    72,    73,    77,
3460           80,    83,    87,    89,    93,   100,   113,   117,   122,   128,   129,   133,   135,   136,   139,   142,
3461          145,   149,   152,   156,   162,   165,   167,   169,   171,   184,   187,   195,   201,   205,   208,   210,
3462          217,   219,   222,   228,   232,   234,   247,   249,   253,   256,   267,   271,   273,   276,   282,   288,
3463          291,   297,   312,   322,   324,   336,   338,   342,   347,   353,   357,   359,   374,   379,   390,   393,
3464          395,   409,   426,   441,   448,   450,   452,   464,   466,   470,   475,   488,   492,   512,   513,   514,
3465          516,   520,   521,   523,   525,   527,   528,   530,   537,   540,   542,   556,   558,   561,   570,   576,
3466          577,   579,   582,   584,   588,   593,   600,   603,   609,   616,   618,   632,   638,   640,   650,   653,
3467          655,   656,   660,   666,   672,   675,   685,   688,   698,   705,   708,   711,   712,   715,   721,   727,
3468          728,   732,   737,   754,   760,   771,   773,   778,   780,   793,   795,   802,   806,   808,   812,   833,
3469          840,   843,   849,   856,   858,   873,   912,   916,   919,   932,   934,   961,   963,   968,   970,   977,
3470          989,   993,  1010,  1016,  1024,  1025,  1027,  1029,  1031,  1032,  1034,  1036,  1038,  1041,  1043,  1047,
3471         1048,  1050,  1057,  1059,  1061,  1064,  1066,  1079,  1080,  1083,  1085,  1088,  1090,  1096,  1099,  1103,
3472         1106,  1109,  1113,  1116,  1122,  1129,  1153,  1156,  1159,  1169,  1171,  1176,  1183,  1185,  1195,  1199,
3473         1209,  1212,  1216,  1218,  1221,  1225,  1234,  1236,  1241,  1243,  1250,  1256,  1270,  1281,  1287,  1296,
3474         1299,  1306,  1309,  1313,  1338,  1341,  1348,  1353,  1362,  1375,  1376,  1387,  1400,  1408,  1410,  1415,
3475         1425,  1453,  1457,  1477,  1481,  1494,  1496,  1507,  1512,  1538,  1545,  1547,  1549,  1551,  1554,  1561,
3476         1563,  1565,  1570,  1572,  1575,  1577,  1587,  1593,  1601,  1603,  1605,  1612,  1617,  1619,  1632,  1648,
3477         1658,  1662,  1664,  1674,  1680,  1690,  1692,  1704,  1729,  1736,  1740,  1745,  1747,  1751,  1752,  1761,
3478         1763,  1767,  1773,  1787,  1795,  1801,  1806,  1810,  1817,  1834,  1840,  1844,  1857,  1864,  1866,  1877,
3479         1882,  1892,  1902,  1915,  1934,  1953,  1985,  1987,  2000,  2002,  2013,  2048,  2052,  2058,  2064,  2068,
3480         2071,  2074,  2081,  2088,  2104,  2114,  2119,  2121,  2123,  2130,  2136,  2141,  2147,  2153,  2157,  2177,
3481         2179,  2184,  2189,  2193,  2203,  2208,  2223,  2226,  2232,  2244,  2249,  2251,  2256,  2258,  2265,  2269,
3482         2304,  2306,  2324,  2335,  2336,  2361,  2373,  2375,  2385,  2418,  2443,  2460,  2480,  2504,  2509,  2520,
3483         2531,  2537,  2562,  2568,  2572,  2578,  2592,  2596,  2599,  2602,  2614,  2620,  2625,  2627,  2629,  2634,
3484         2641,  2650,  2682,  2688,  2697,  2707,  2712,  2718,  2731,  2754,  2759,  2760,  2775,  2788,  2793,  2805,
3485         2811,  2817,  2820,  2832,  2842,  2854,  2890,  2902,  2921,  2923,  2978,  3010,  3012,  3026,  3081,  3083,
3486         3085,  3097,  3099,  3120,  3136,  3152,  3159,  3188,  3210,  3228,  3234,  3245,  3250,  3256,  3264,  3276,
3487         3281,  3296,  3349,  3363,  3378,  3392,  3395,  3420,  3440,  3461,  3488,  3529,  3531,  3584,  3588,  3591,
3488         3600,  3602,  3614,  3616,  3628,  3634,  3650,  3657,  3668,  3683,  3685,  3713,  3716,  3720,  3726,  3729,
3489         3736,  3753,  3778,  3802,  3805,  3819,  3841,  3845,  3851,  3856,  3880,  3922,  3938,  3970,  3993,  4032,
3490    };
3491
3492    const int kmap_size = 4096;
3493    const int nwant = grid_size == 256 ? 2 : 3;
3494    const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
3495    uint32_t * kgrid_q3xs;
3496    int      * kmap_q3xs;
3497    uint16_t * kneighbors_q3xs;
3498
3499    //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
3500    uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
3501    for (int k = 0; k < grid_size; ++k) {
3502        int8_t * pos = (int8_t *)(the_grid + k);
3503        for (int i = 0; i < 4; ++i) {
3504            int l = (kgrid[k] >> 3*i) & 0x7;
3505            pos[i] = 2*l + 1;
3506        }
3507    }
3508    kgrid_q3xs = the_grid;
3509    iq3_data[gindex].grid = the_grid;
3510    kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
3511    iq3_data[gindex].map = kmap_q3xs;
3512    for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
3513    uint32_t aux32;
3514    uint8_t * aux8 = (uint8_t *)&aux32;
3515    for (int i = 0; i < grid_size; ++i) {
3516        aux32 = kgrid_q3xs[i];
3517        uint16_t index = 0;
3518        for (int k=0; k<4; ++k) {
3519            uint16_t q = (aux8[k] - 1)/2;
3520            index |= (q << 3*k);
3521        }
3522        kmap_q3xs[index] = i;
3523    }
3524    int8_t pos[4];
3525    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
3526    int num_neighbors = 0, num_not_in_map = 0;
3527    for (int i = 0; i < kmap_size; ++i) {
3528        if (kmap_q3xs[i] >= 0) continue;
3529        ++num_not_in_map;
3530        for (int k = 0; k < 4; ++k) {
3531            int l = (i >> 3*k) & 0x7;
3532            pos[k] = 2*l + 1;
3533        }
3534        for (int j = 0; j < grid_size; ++j) {
3535            const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
3536            int d2 = 0;
3537            for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
3538            dist2[2*j+0] = d2;
3539            dist2[2*j+1] = j;
3540        }
3541        qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
3542        int n = 0; int d2 = dist2[0];
3543        int nhave = 1;
3544        for (int j = 0; j < grid_size; ++j) {
3545            if (dist2[2*j] > d2) {
3546                if (nhave == nwant) break;
3547                d2 = dist2[2*j];
3548                ++nhave;
3549            }
3550            ++n;
3551        }
3552        num_neighbors += n;
3553    }
3554    //printf("%s: %d neighbours in total\n", __func__, num_neighbors);
3555    kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
3556    iq3_data[gindex].neighbours = kneighbors_q3xs;
3557    int counter = 0;
3558    for (int i = 0; i < kmap_size; ++i) {
3559        if (kmap_q3xs[i] >= 0) continue;
3560        for (int k = 0; k < 4; ++k) {
3561            int l = (i >> 3*k) & 0x7;
3562            pos[k] = 2*l + 1;
3563        }
3564        for (int j = 0; j < grid_size; ++j) {
3565            const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
3566            int d2 = 0;
3567            for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
3568            dist2[2*j+0] = d2;
3569            dist2[2*j+1] = j;
3570        }
3571        qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
3572        kmap_q3xs[i] = -(counter + 1);
3573        int d2 = dist2[0];
3574        uint16_t * start = &kneighbors_q3xs[counter++];
3575        int n = 0, nhave = 1;
3576        for (int j = 0; j < grid_size; ++j) {
3577            if (dist2[2*j] > d2) {
3578                if (nhave == nwant) break;
3579                d2 = dist2[2*j];
3580                ++nhave;
3581            }
3582            kneighbors_q3xs[counter++] = dist2[2*j+1];
3583            ++n;
3584        }
3585        *start = n;
3586    }
3587    free(dist2);
3588}
3589
3590void iq3xs_free_impl(int grid_size) {
3591    GGML_ASSERT(grid_size == 256 || grid_size == 512);
3592    const int gindex = iq3_data_index(grid_size);
3593    if (iq3_data[gindex].grid) {
3594        free(iq3_data[gindex].grid);       iq3_data[gindex].grid = NULL;
3595        free(iq3_data[gindex].map);        iq3_data[gindex].map  = NULL;
3596        free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
3597    }
3598}
3599
3600static int iq3_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint32_t * GGML_RESTRICT grid,
3601        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, int8_t * GGML_RESTRICT L) {
3602    int num_neighbors = neighbours[0];
3603    GGML_ASSERT(num_neighbors > 0);
3604    float best_d2 = FLT_MAX;
3605    int grid_index = -1;
3606    for (int j = 1; j <= num_neighbors; ++j) {
3607        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
3608        float d2 = 0;
3609        for (int i = 0; i < 4; ++i) {
3610            float q = pg[i];
3611            float diff = scale*q - xval[i];
3612            d2 += weight[i]*diff*diff;
3613        }
3614        if (d2 < best_d2) {
3615            best_d2 = d2; grid_index = neighbours[j];
3616        }
3617    }
3618    GGML_ASSERT(grid_index >= 0);
3619    const int8_t * pg = (const int8_t *)(grid + grid_index);
3620    for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
3621    return grid_index;
3622}
3623
3624static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n,
3625        const float * GGML_RESTRICT quant_weights) {
3626
3627    const int gindex = iq3_data_index(grid_size);
3628
3629    const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;
3630    const int      * kmap_q3xs       = iq3_data[gindex].map;
3631    const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
3632
3633    //GGML_ASSERT(quant_weights   && "missing quantization weights");
3634    GGML_ASSERT(kgrid_q3xs      && "forgot to call ggml_quantize_init()?");
3635    GGML_ASSERT(kmap_q3xs       && "forgot to call ggml_quantize_init()?");
3636    GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
3637    GGML_ASSERT(n%QK_K == 0);
3638
3639    const int kMaxQ = 8;
3640
3641    const int64_t nbl = n/QK_K;
3642
3643    ggml_fp16_t * dh;
3644    uint8_t * qs;
3645    int block_size;
3646    if (grid_size == 256) {
3647        block_iq3_xxs * y = vy;
3648        dh = &y->d;
3649        qs = y->qs;
3650        block_size = sizeof(block_iq3_xxs);
3651    } else {
3652        block_iq3_s * y = vy;
3653        dh = &y->d;
3654        qs = y->qs;
3655        block_size = sizeof(block_iq3_s);
3656    }
3657    int quant_size = block_size - sizeof(ggml_fp16_t);
3658
3659    float scales[QK_K/32];
3660    float weight[32];
3661    float xval[32];
3662    int8_t L[32];
3663    int8_t Laux[32];
3664    float  waux[32];
3665    bool   is_on_grid[8];
3666    bool   is_on_grid_aux[8];
3667    uint8_t block_signs[8];
3668    uint8_t q3[3*(QK_K/8)+QK_K/32];
3669    uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
3670    uint8_t  * qh = q3 + 3*(QK_K/8);
3671
3672    for (int ibl = 0; ibl < nbl; ++ibl) {
3673
3674        dh[0] = GGML_FP32_TO_FP16(0.f);
3675        memset(q3, 0, 3*QK_K/8+QK_K/32);
3676
3677        float max_scale = 0;
3678
3679        const float * xbl = x + QK_K*ibl;
3680        float sumx2 = 0;
3681        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
3682        float sigma2 = 2*sumx2/QK_K;
3683
3684        for (int ib = 0; ib < QK_K/32; ++ib) {
3685            const float * xb = xbl + 32*ib;
3686            if (quant_weights) {
3687                const float * qw = quant_weights + QK_K*ibl + 32*ib;
3688                for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
3689            } else {
3690                for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
3691            }
3692            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
3693            for (int k = 0; k < 4; ++k) {
3694                int nflip = 0;
3695                uint8_t s = 0;
3696                for (int i = 0; i < 8; ++i) {
3697                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
3698                    else {
3699                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
3700                    }
3701                }
3702                if (nflip%2) {
3703                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
3704                    for (int i = 1; i < 8; ++i) {
3705                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
3706                        if (ax < min) {
3707                            min = ax; imin = i;
3708                        }
3709                    }
3710                    xval[8*k+imin] = -xval[8*k+imin];
3711                    s ^= (1 << imin);
3712                }
3713                block_signs[k] = s & 127;
3714            }
3715            float max = xval[0];
3716            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
3717            if (max < GROUP_MAX_EPS_IQ3_XXS) {
3718                scales[ib] = 0;
3719                memset(L, 0, 32);
3720                continue;
3721            }
3722            float best = 0;
3723            float scale = max/(2*kMaxQ-1);
3724            for (int k = 0; k < 8; ++k) is_on_grid[k] = true;
3725            for (int is = -15; is <= 15; ++is) {
3726                float id = (2*kMaxQ-1+is*0.2f)/max;
3727                float this_scale = 1/id;
3728                for (int k = 0; k < 8; ++k) {
3729                    for (int i = 0; i < 4; ++i) {
3730                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
3731                        Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
3732                    }
3733                    uint16_t u = 0;
3734                    for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
3735                    int grid_index = kmap_q3xs[u];
3736                    is_on_grid_aux[k] = true;
3737                    if (grid_index < 0) {
3738                        is_on_grid_aux[k] = false;
3739                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
3740                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
3741                    }
3742                }
3743                float sumqx = 0, sumq2 = 0;
3744                for (int i = 0; i < 32; ++i) {
3745                    float w = weight[i];
3746                    float q = 2*Laux[i] + 1;
3747                    sumqx += w*xval[i]*q;
3748                    sumq2 += w*q*q;
3749                }
3750                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
3751                    scale = sumqx/sumq2; best = scale*sumqx;
3752                    for (int i = 0; i < 32; ++i) L[i] = Laux[i];
3753                    for (int k = 0; k <  8; ++k) is_on_grid[k] = is_on_grid_aux[k];
3754                }
3755            }
3756            int n_not_ongrid = 0;
3757            for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
3758            if (n_not_ongrid > 0 && scale > 0) {
3759                float id = 1/scale;
3760                for (int k = 0; k < 8; ++k) {
3761                    if (is_on_grid[k]) continue;
3762                    uint16_t u = 0;
3763                    for (int i = 0; i < 4; ++i) {
3764                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
3765                        l = MAX(0, MIN(kMaxQ-1, l));
3766                        u |= (l << 3*i);
3767                    }
3768                    int grid_index = kmap_q3xs[u];
3769                    if (grid_index < 0) {
3770                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
3771                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
3772                    }
3773                    const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
3774                    for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
3775                }
3776                float sumqx = 0, sumq2 = 0;
3777                for (int i = 0; i < 32; ++i) {
3778                    float w = weight[i];
3779                    float q = 2*L[i] + 1;
3780                    sumqx += w*xval[i]*q;
3781                    sumq2 += w*q*q;
3782                }
3783                if (sumq2 > 0) scale = sumqx/sumq2;
3784            }
3785            if (scale < 0) {
3786                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
3787                // and correspondingly flip quant signs.
3788                scale = -scale;
3789                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
3790            }
3791            for (int k = 0; k < 8; ++k) {
3792                uint16_t u = 0;
3793                for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
3794                int grid_index = kmap_q3xs[u];
3795                if (grid_index < 0) {
3796                    printf("Oops: found point %u not on grid:", u);
3797                    for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
3798                    printf("\n");
3799                    GGML_ABORT("fatal error");
3800                }
3801                if (grid_size == 256) {
3802                    q3[8*ib+k] = grid_index;
3803                } else {
3804                    q3[8*ib+k] = grid_index & 255;
3805                    qh[ib] |= ((grid_index >> 8) << k);
3806                }
3807
3808            }
3809            scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
3810            GGML_ASSERT(scale >= 0);
3811            scales[ib] = scale;
3812            max_scale = MAX(max_scale, scale);
3813        }
3814
3815        if (!max_scale) {
3816            memset(qs, 0, quant_size);
3817            dh += block_size/sizeof(ggml_fp16_t);
3818            qs += block_size;
3819            continue;
3820        }
3821
3822        float d = max_scale/31;
3823        dh[0] = GGML_FP32_TO_FP16(d * 1.0125f);  // small improvement via this fudge factor
3824        float id = 1/d;
3825        for (int ib = 0; ib < QK_K/32; ++ib) {
3826            int l = nearest_int(0.5f*(id*scales[ib]-1));
3827            l = MAX(0, MIN(15, l));
3828            scales_and_signs[ib] |= ((uint32_t)l << 28);
3829        }
3830        memcpy(qs, q3, quant_size);
3831
3832        dh += block_size/sizeof(ggml_fp16_t);
3833        qs += block_size;
3834
3835    }
3836}
3837
3838size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3839    GGML_ASSERT(n_per_row%QK_K == 0);
3840    int64_t nblock = n_per_row/QK_K;
3841    char * qrow = (char *)dst;
3842    for (int64_t row = 0; row < nrow; ++row) {
3843        quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
3844        src += n_per_row;
3845        qrow += nblock*sizeof(block_iq3_xxs);
3846    }
3847    return nrow * nblock * sizeof(block_iq3_xxs);
3848}
3849
3850void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k) {
3851    assert(k % QK_K == 0);
3852    quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
3853}
3854
3855static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int n,
3856        const float * GGML_RESTRICT quant_weights,
3857        float   * scales,
3858        float   * weight,
3859        float   * xval,
3860        int8_t  * L,
3861        int8_t  * Laux,
3862        float   * waux,
3863        bool    * is_on_grid,
3864        bool    * is_on_grid_aux,
3865        uint8_t * block_signs) {
3866
3867    const int gindex = iq3_data_index(512);
3868
3869    const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;
3870    const int      * kmap_q3xs       = iq3_data[gindex].map;
3871    const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
3872
3873    //GGML_ASSERT(quant_weights   && "missing quantization weights");
3874    GGML_ASSERT(kgrid_q3xs      && "forgot to call ggml_quantize_init()?");
3875    GGML_ASSERT(kmap_q3xs       && "forgot to call ggml_quantize_init()?");
3876    GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
3877    GGML_ASSERT(n%QK_K == 0);
3878
3879    const int kMaxQ = 8;
3880
3881    const int64_t nbl = n/QK_K;
3882
3883    block_iq3_s * y = vy;
3884
3885    const int bs4 = block_size/4;
3886    const int bs8 = block_size/8;
3887
3888    for (int ibl = 0; ibl < nbl; ++ibl) {
3889
3890        memset(&y[ibl], 0, sizeof(block_iq3_s));
3891        y[ibl].d = GGML_FP32_TO_FP16(0.f);
3892
3893        uint8_t * qs = y[ibl].qs;
3894        uint8_t * qh = y[ibl].qh;
3895        uint8_t * signs = y[ibl].signs;
3896
3897        float max_scale = 0;
3898
3899        const float * xbl = x + QK_K*ibl;
3900        float sumx2 = 0;
3901        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
3902        float sigma2 = 2*sumx2/QK_K;
3903
3904        for (int ib = 0; ib < QK_K/block_size; ++ib) {
3905            const float * xb = xbl + block_size*ib;
3906            if (quant_weights) {
3907                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
3908                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
3909            } else {
3910                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
3911            }
3912            for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]);
3913            for (int k = 0; k < bs8; ++k) {
3914                uint8_t s = 0;
3915                for (int i = 0; i < 8; ++i) {
3916                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
3917                    else {
3918                        xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
3919                    }
3920                }
3921                block_signs[k] = s;
3922            }
3923            float max = xval[0];
3924            for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
3925            if (!max) {
3926                scales[ib] = 0;
3927                continue;
3928            }
3929            float best = 0;
3930            float scale = max/(2*kMaxQ-1);
3931            for (int k = 0; k < bs4; ++k) is_on_grid[k] = false;
3932            for (int is = -9; is <= 9; ++is) {
3933                float id = (2*kMaxQ-1+is*0.2f)/max;
3934                float this_scale = 1/id;
3935                for (int k = 0; k < bs4; ++k) {
3936                    for (int i = 0; i < 4; ++i) {
3937                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
3938                        Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
3939                    }
3940                    uint16_t u = 0;
3941                    for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
3942                    int grid_index = kmap_q3xs[u];
3943                    is_on_grid_aux[k] = true;
3944                    if (grid_index < 0) {
3945                        is_on_grid_aux[k] = false;
3946                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
3947                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
3948                    }
3949                }
3950                float sumqx = 0, sumq2 = 0;
3951                for (int i = 0; i < block_size; ++i) {
3952                    float w = weight[i];
3953                    float q = 2*Laux[i] + 1;
3954                    sumqx += w*xval[i]*q;
3955                    sumq2 += w*q*q;
3956                }
3957                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
3958                    scale = sumqx/sumq2; best = scale*sumqx;
3959                    for (int i = 0; i < block_size; ++i) L[i] = Laux[i];
3960                    for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k];
3961                }
3962            }
3963            int n_not_ongrid = 0;
3964            for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
3965            if (n_not_ongrid > 0 && scale > 0) {
3966                float id = 1/scale;
3967                for (int k = 0; k < bs4; ++k) {
3968                    //if (is_on_grid[k]) continue;
3969                    uint16_t u = 0;
3970                    for (int i = 0; i < 4; ++i) {
3971                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
3972                        l = MAX(0, MIN(kMaxQ-1, l));
3973                        u |= (l << 3*i);
3974                    }
3975                    int grid_index = kmap_q3xs[u];
3976                    if (grid_index < 0) {
3977                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
3978                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
3979                    }
3980                    const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
3981                    for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
3982                }
3983                float sumqx = 0, sumq2 = 0;
3984                for (int i = 0; i < block_size; ++i) {
3985                    float w = weight[i];
3986                    float q = 2*L[i] + 1;
3987                    sumqx += w*xval[i]*q;
3988                    sumq2 += w*q*q;
3989                }
3990                if (sumq2 > 0) scale = sumqx/sumq2;
3991            }
3992            if (scale < 0) {
3993                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
3994                // and correspondingly flip quant signs.
3995                scale = -scale;
3996                for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k];
3997            }
3998            for (int k = 0; k < bs4; ++k) {
3999                uint16_t u = 0;
4000                for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
4001                int grid_index = kmap_q3xs[u];
4002                if (grid_index < 0) {
4003                    printf("Oops: found point %u not on grid:", u);
4004                    for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
4005                    printf("\n");
4006                    GGML_ABORT("fatal error");
4007                }
4008                qs[k] = grid_index & 255;
4009                qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));
4010            }
4011            qs += bs4;
4012            for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k];
4013            signs += bs8;
4014            GGML_ASSERT(scale >= 0);
4015            scales[ib] = scale;
4016            max_scale = MAX(max_scale, scale);
4017        }
4018
4019        if (!max_scale) {
4020            continue;
4021        }
4022
4023        float d = max_scale/31;
4024        y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f);
4025        float id = 1/d;
4026        for (int ib = 0; ib < QK_K/block_size; ib += 2) {
4027            int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
4028            l1 = MAX(0, MIN(15, l1));
4029            int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));
4030            l2 = MAX(0, MIN(15, l2));
4031            y[ibl].scales[ib/2] = l1 | (l2 << 4);
4032        }
4033
4034    }
4035}
4036
4037#define IQ3S_BLOCK_SIZE 32
4038size_t quantize_iq3_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
4039    GGML_ASSERT(n_per_row%QK_K == 0);
4040    int64_t nblock = n_per_row/QK_K;
4041    float scales[QK_K/IQ3S_BLOCK_SIZE];
4042    float weight[IQ3S_BLOCK_SIZE];
4043    float xval[IQ3S_BLOCK_SIZE];
4044    int8_t L[IQ3S_BLOCK_SIZE];
4045    int8_t Laux[IQ3S_BLOCK_SIZE];
4046    float  waux[IQ3S_BLOCK_SIZE];
4047    bool   is_on_grid[IQ3S_BLOCK_SIZE/4];
4048    bool   is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
4049    uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
4050    char * qrow = (char *)dst;
4051    for (int64_t row = 0; row < nrow; ++row) {
4052        quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
4053                scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
4054        src += n_per_row;
4055        qrow += nblock*sizeof(block_iq3_s);
4056    }
4057    return nrow * nblock * sizeof(block_iq3_s);
4058}
4059
4060void quantize_row_iq3_s_ref(const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k) {
4061    assert(k % QK_K == 0);
4062    quantize_iq3_s(x, y, 1, k, NULL);
4063}
4064
4065
4066// =================================== 1.5 bpw ===================================================
4067
4068static int iq1_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,
4069        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float * scale, int8_t * GGML_RESTRICT L, int ngrid) {
4070    int num_neighbors = neighbours[0];
4071    GGML_ASSERT(num_neighbors > 0);
4072    float best_score = -FLT_MAX;
4073    int grid_index = -1;
4074    for (int j = 1; j <= num_neighbors; ++j) {
4075        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
4076        float sumqx = 0, sumq2 = 0;
4077        for (int i = 0; i < 8; ++i) {
4078            float q = (pg[i] - 3)/2;
4079            float w = weight[i];
4080            sumqx += w*q*xval[i];
4081            sumq2 += w*q*q;
4082        }
4083        if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
4084            *scale = sumqx/sumq2; best_score = *scale * sumqx;
4085            grid_index = neighbours[j];
4086        }
4087    }
4088    if (grid_index < 0) {
4089        for (int i = 0; i < ngrid; ++i) {
4090            const int8_t * grid_i = (const int8_t *)(grid + i);
4091            float sumqx = 0, sumq2 = 0;
4092            for (int j = 0; j < 8; ++j) {
4093                float w = weight[j];
4094                float q = (grid_i[j] - 3)/2;
4095                sumqx += w*q*xval[j];
4096                sumq2 += w*q*q;
4097            }
4098            if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
4099                *scale = sumqx/sumq2; best_score = *scale*sumqx;
4100                grid_index = i;
4101            }
4102        }
4103    }
4104    if (grid_index < 0) {
4105        printf("Oops, did not find grid point\n");
4106        printf("Have %d neighbours\n", num_neighbors);
4107        for (int j = 1; j <= num_neighbors; ++j) {
4108            const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
4109            float sumqx = 0, sumq2 = 0;
4110            for (int i = 0; i < 8; ++i) {
4111                float q = (pg[i] - 3)/2;
4112                float w = weight[i];
4113                sumqx += w*q*xval[i];
4114                sumq2 += w*q*q;
4115            }
4116            printf("    neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
4117        }
4118    }
4119    GGML_ASSERT(grid_index >= 0);
4120    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
4121    *scale *= 1.05f;  // This is a fudge factor. Don't ask me why it improves the result.
4122    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
4123    const int8_t * pg = (const int8_t *)(grid + grid_index);
4124    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
4125    return grid_index;
4126}
4127
4128static int iq1_find_best_neighbour2(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,
4129        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, const float * GGML_RESTRICT xg, int8_t * GGML_RESTRICT L, int ngrid) {
4130    int num_neighbors = neighbours[0];
4131    GGML_ASSERT(num_neighbors > 0);
4132    float best_score = FLT_MAX;
4133    int grid_index = -1;
4134    for (int j = 1; j <= num_neighbors; ++j) {
4135        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
4136        float d2 = 0;
4137        for (int i = 0; i < 8; ++i) {
4138            float q = xg[(pg[i] - 1)/2];
4139            float w = weight[i];
4140            float diff = scale*q - xval[i];
4141            d2 += w*diff*diff;
4142        }
4143        if (d2 < best_score) {
4144            best_score = d2;
4145            grid_index = neighbours[j];
4146        }
4147    }
4148    if (grid_index < 0) {
4149        for (int i = 0; i < ngrid; ++i) {
4150            const int8_t * grid_i = (const int8_t *)(grid + i);
4151            float d2 = 0;
4152            for (int j = 0; j < 8; ++j) {
4153                float w = weight[j];
4154                float q = xg[(grid_i[j] - 1)/2];
4155                float diff = scale*q - xval[i];
4156                d2 += w*diff*diff;
4157            }
4158            if (d2 < best_score) {
4159                best_score = d2;
4160                grid_index = i;
4161            }
4162        }
4163    }
4164    if (grid_index < 0) {
4165        printf("Oops, did not find grid point\n");
4166        printf("Have %d neighbours\n", num_neighbors);
4167        for (int j = 1; j <= num_neighbors; ++j) {
4168            const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
4169            float sumqx = 0, sumq2 = 0;
4170            for (int i = 0; i < 8; ++i) {
4171                float q = xg[(pg[i] - 1)/2];
4172                float w = weight[i];
4173                sumqx += w*q*xval[i];
4174                sumq2 += w*q*q;
4175            }
4176            printf("    neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
4177        }
4178    }
4179    GGML_ASSERT(grid_index >= 0);
4180    const int8_t * pg = (const int8_t *)(grid + grid_index);
4181    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
4182    return grid_index;
4183}
4184
4185static int iq1_sort_helper(const void * left, const void * right) {
4186    const float * l = left;
4187    const float * r = right;
4188    return *l < *r ? -1 : *l > *r ? 1 : 0;
4189}
4190
4191#define IQ1S_BLOCK_SIZE 32
4192#define IQ1M_BLOCK_SIZE 16
4193static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights,
4194        float    * scales,
4195        float    * weight,
4196        float    * sumx,
4197        float    * sumw,
4198        float    * pairs,
4199        int8_t   * L,
4200        uint16_t * index,
4201        int8_t   * shifts) {
4202
4203    const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
4204
4205    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
4206    const int      * kmap_q2xs       = iq2_data[gindex].map;
4207    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
4208
4209    GGML_ASSERT(quant_weights   && "missing quantization weights");
4210    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
4211    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
4212    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
4213    GGML_ASSERT(n%QK_K == 0);
4214
4215    block_iq1_s * y = vy;
4216
4217    const int64_t nbl = n/QK_K;
4218
4219    const int block_size = IQ1S_BLOCK_SIZE;
4220
4221    const float x_p[3] = {-1 + IQ1S_DELTA,  IQ1S_DELTA, 1 + IQ1S_DELTA};
4222    const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA};
4223
4224
4225    int * idx = (int *)(pairs + 1);
4226
4227    for (int ibl = 0; ibl < nbl; ++ibl) {
4228
4229        y[ibl].d = GGML_FP32_TO_FP16(0.f);
4230        memset(y[ibl].qs, 0, QK_K/8);
4231        memset(y[ibl].qh, 0, QK_K/16);
4232
4233        float max_scale = 0;
4234
4235        const float * xbl = x + QK_K*ibl;
4236        float sumx2 = 0;
4237        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4238        float sigma2 = 2*sumx2/QK_K;
4239
4240        for (int ib = 0; ib < QK_K/block_size; ++ib) {
4241            const float * xb = xbl + block_size*ib;
4242            const float * qw = quant_weights + QK_K*ibl + block_size*ib;
4243            for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
4244            float max = fabsf(xb[0]);
4245            for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
4246            if (max < GROUP_MAX_EPS_IQ1_S) {
4247                scales[ib] = 0;
4248                memset(L, 1, block_size);
4249                continue;
4250            }
4251            // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
4252            // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
4253            // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
4254            // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
4255            // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
4256            // for each possible and score for each split.
4257            for (int j = 0; j < block_size; ++j) {
4258                pairs[2*j] = xb[j];
4259                idx[2*j] = j;
4260            }
4261            qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
4262            {
4263                sumx[0] = sumw[0] = 0;
4264                for (int j = 0; j < block_size; ++j) {
4265                    int i = idx[2*j];
4266                    sumx[j+1] = sumx[j] + weight[i]*xb[i];
4267                    sumw[j+1] = sumw[j] + weight[i];
4268                }
4269            }
4270            float best_score = -FLT_MAX, scale = max;
4271            int besti1 = -1, besti2 = -1, best_shift = 0;
4272            for (int i1 = 0; i1 <= block_size; ++i1) {
4273                for (int i2 = i1; i2 <= block_size; ++i2) {
4274                    float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2];
4275                    float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2];
4276                    if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
4277                        scale = sumqx/sumq2; best_score = scale*sumqx;
4278                        besti1 = i1; besti2 = i2; best_shift = 1;
4279                    }
4280                    sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2];
4281                    sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2];
4282                    if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
4283                        scale = sumqx/sumq2; best_score = scale*sumqx;
4284                        besti1 = i1; besti2 = i2; best_shift = -1;
4285                    }
4286                }
4287            }
4288            GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
4289            for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
4290            for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
4291            for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
4292            if (scale < 0) {
4293                for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
4294                scale = -scale; best_shift = -best_shift;
4295            }
4296            bool all_on_grid = true;
4297            const float * xx = best_shift == 1 ? x_p : x_m;
4298            for (int k = 0; k < block_size/8; ++k) {
4299                uint16_t u = 0;
4300                for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
4301                int grid_index = kmap_q2xs[u];
4302                if (grid_index < 0) {
4303                    all_on_grid = false;
4304                    const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
4305                    grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
4306                    GGML_ASSERT(grid_index >= 0);
4307                }
4308                index[k] = grid_index;
4309            }
4310            if (!all_on_grid) {
4311                float sumqx = 0, sumq2 = 0;
4312                for (int k = 0; k < block_size/8; ++k) {
4313                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
4314                    for (int j = 0; j < 8; ++j) {
4315                        float w = weight[8*k + j];
4316                        float q = xx[(pg[j] - 1)/2];
4317                        sumqx += w*q*xb[8*k+j];
4318                        sumq2 += w*q*q;
4319                    }
4320                }
4321                if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
4322            }
4323            uint16_t h = 0;
4324            for (int k = 0; k < block_size/8; ++k) {
4325                y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255;
4326                h |= (index[k] >> 8) << 3*k;
4327            }
4328            y[ibl].qh[ib] = h;
4329            GGML_ASSERT(scale >= 0);
4330            scales[ib] = scale;
4331            shifts[ib] = best_shift;
4332            max_scale = MAX(max_scale, scale);
4333        }
4334
4335        if (!max_scale) {
4336            continue;
4337        }
4338
4339        float d = max_scale/15;
4340        y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed.
4341        float id = 1/d;
4342        for (int ib = 0; ib < QK_K/block_size; ++ib) {
4343            int l = nearest_int(0.5f*(id*scales[ib]-1));
4344            l = MAX(0, MIN(7, l));
4345            if (shifts[ib] == -1) l |= 8;
4346            y[ibl].qh[ib] |= (l << 12);
4347        }
4348    }
4349}
4350
4351size_t quantize_iq1_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
4352    GGML_ASSERT(n_per_row%QK_K == 0);
4353    float  scales[QK_K/IQ1S_BLOCK_SIZE];
4354    float  weight[IQ1S_BLOCK_SIZE];
4355    int8_t L[IQ1S_BLOCK_SIZE];
4356    float  sumx[IQ1S_BLOCK_SIZE+1];
4357    float  sumw[IQ1S_BLOCK_SIZE+1];
4358    float  pairs[2*IQ1S_BLOCK_SIZE];
4359    uint16_t index[IQ1S_BLOCK_SIZE/8];
4360    int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
4361    int64_t nblock = n_per_row/QK_K;
4362    char * qrow = (char *)dst;
4363    for (int64_t row = 0; row < nrow; ++row) {
4364        quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts);
4365        src += n_per_row;
4366        qrow += nblock*sizeof(block_iq1_s);
4367    }
4368    return nrow * nblock * sizeof(block_iq1_s);
4369}
4370
4371static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights,
4372        float    * scales,
4373        float    * weight,
4374        float    * pairs,
4375        int8_t   * L,
4376        uint16_t * index,
4377        int8_t   * shifts) {
4378
4379    const int gindex = iq2_data_index(GGML_TYPE_IQ1_M);
4380
4381    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
4382    const int      * kmap_q2xs       = iq2_data[gindex].map;
4383    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
4384
4385    //GGML_ASSERT(quant_weights   && "missing quantization weights");
4386    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
4387    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
4388    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
4389    GGML_ASSERT(n%QK_K == 0);
4390
4391    block_iq1_m * y = vy;
4392
4393    const int64_t nbl = n/QK_K;
4394
4395    const int block_size = IQ1M_BLOCK_SIZE;
4396
4397    const float x_p[3] = {-1 + IQ1M_DELTA,  IQ1M_DELTA, 1 + IQ1M_DELTA};
4398    const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
4399    const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
4400
4401    int * idx = (int *)(pairs + 1);
4402
4403    float sumqx[4], sumq2[4];
4404
4405    iq1m_scale_t s;
4406    const float * xx;
4407
4408    for (int ibl = 0; ibl < nbl; ++ibl) {
4409        memset(y[ibl].qs, 0, QK_K/8);
4410        memset(y[ibl].qh, 0, QK_K/16);
4411        memset(y[ibl].scales, 0, QK_K/32);
4412
4413        float max_scale = 0;
4414
4415        const float * xbl = x + QK_K*ibl;
4416        float sumx2 = 0;
4417        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4418        float sigma2 = 2*sumx2/QK_K;
4419
4420        for (int ib = 0; ib < QK_K/block_size; ++ib) {
4421            const float * xb = xbl + block_size*ib;
4422            if (quant_weights) {
4423                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
4424                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
4425            } else {
4426                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
4427            }
4428            float max = fabsf(xb[0]);
4429            for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
4430            if (max < GROUP_MAX_EPS_IQ1_M) {
4431                scales[ib] = 0;
4432                memset(L, 1, block_size);
4433                continue;
4434            }
4435            // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
4436            // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
4437            // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
4438            // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
4439            // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
4440            // for each possible and score for each split.
4441            for (int j = 0; j < block_size; ++j) {
4442                pairs[2*j] = xb[j];
4443                idx[2*j] = j;
4444            }
4445            qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
4446            float best_score = -FLT_MAX, scale = max;
4447            int besti1 = -1, besti2 = -1, best_k = -1;
4448            // 0: +, +
4449            // 1: +, -
4450            // 2: -, +
4451            // 3: -, -
4452            for (int i1 = 0; i1 <= block_size; ++i1) {
4453                for (int i2 = i1; i2 <= block_size; ++i2) {
4454                    memset(sumqx, 0, 4*sizeof(float));
4455                    memset(sumq2, 0, 4*sizeof(float));
4456                    for (int j = 0; j < i1; ++j) {
4457                        int i = idx[2*j];
4458                        if (i < block_size/2) {
4459                            sumqx[0] += weight[i]*x_p[0]*xb[i];
4460                            sumqx[1] += weight[i]*x_p[0]*xb[i];
4461                            sumqx[2] += weight[i]*x_m[0]*xb[i];
4462                            sumqx[3] += weight[i]*x_m[0]*xb[i];
4463                            sumq2[0] += weight[i]*x_p[0]*x_p[0];
4464                            sumq2[1] += weight[i]*x_p[0]*x_p[0];
4465                            sumq2[2] += weight[i]*x_m[0]*x_m[0];
4466                            sumq2[3] += weight[i]*x_m[0]*x_m[0];
4467                        } else {
4468                            sumqx[0] += weight[i]*x_p[0]*xb[i];
4469                            sumqx[2] += weight[i]*x_p[0]*xb[i];
4470                            sumqx[1] += weight[i]*x_m[0]*xb[i];
4471                            sumqx[3] += weight[i]*x_m[0]*xb[i];
4472                            sumq2[0] += weight[i]*x_p[0]*x_p[0];
4473                            sumq2[2] += weight[i]*x_p[0]*x_p[0];
4474                            sumq2[1] += weight[i]*x_m[0]*x_m[0];
4475                            sumq2[3] += weight[i]*x_m[0]*x_m[0];
4476                        }
4477                    }
4478                    for (int j = i1; j < i2; ++j) {
4479                        int i = idx[2*j];
4480                        if (i < block_size/2) {
4481                            sumqx[0] += weight[i]*x_p[1]*xb[i];
4482                            sumqx[1] += weight[i]*x_p[1]*xb[i];
4483                            sumqx[2] += weight[i]*x_m[1]*xb[i];
4484                            sumqx[3] += weight[i]*x_m[1]*xb[i];
4485                            sumq2[0] += weight[i]*x_p[1]*x_p[1];
4486                            sumq2[1] += weight[i]*x_p[1]*x_p[1];
4487                            sumq2[2] += weight[i]*x_m[1]*x_m[1];
4488                            sumq2[3] += weight[i]*x_m[1]*x_m[1];
4489                        } else {
4490                            sumqx[0] += weight[i]*x_p[1]*xb[i];
4491                            sumqx[2] += weight[i]*x_p[1]*xb[i];
4492                            sumqx[1] += weight[i]*x_m[1]*xb[i];
4493                            sumqx[3] += weight[i]*x_m[1]*xb[i];
4494                            sumq2[0] += weight[i]*x_p[1]*x_p[1];
4495                            sumq2[2] += weight[i]*x_p[1]*x_p[1];
4496                            sumq2[1] += weight[i]*x_m[1]*x_m[1];
4497                            sumq2[3] += weight[i]*x_m[1]*x_m[1];
4498                        }
4499                    }
4500                    for (int j = i2; j < block_size; ++j) {
4501                        int i = idx[2*j];
4502                        if (i < block_size/2) {
4503                            sumqx[0] += weight[i]*x_p[2]*xb[i];
4504                            sumqx[1] += weight[i]*x_p[2]*xb[i];
4505                            sumqx[2] += weight[i]*x_m[2]*xb[i];
4506                            sumqx[3] += weight[i]*x_m[2]*xb[i];
4507                            sumq2[0] += weight[i]*x_p[2]*x_p[2];
4508                            sumq2[1] += weight[i]*x_p[2]*x_p[2];
4509                            sumq2[2] += weight[i]*x_m[2]*x_m[2];
4510                            sumq2[3] += weight[i]*x_m[2]*x_m[2];
4511                        } else {
4512                            sumqx[0] += weight[i]*x_p[2]*xb[i];
4513                            sumqx[2] += weight[i]*x_p[2]*xb[i];
4514                            sumqx[1] += weight[i]*x_m[2]*xb[i];
4515                            sumqx[3] += weight[i]*x_m[2]*xb[i];
4516                            sumq2[0] += weight[i]*x_p[2]*x_p[2];
4517                            sumq2[2] += weight[i]*x_p[2]*x_p[2];
4518                            sumq2[1] += weight[i]*x_m[2]*x_m[2];
4519                            sumq2[3] += weight[i]*x_m[2]*x_m[2];
4520                        }
4521                    }
4522                    for (int k = 0; k < 4; ++k) {
4523                        if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
4524                            scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
4525                            besti1 = i1; besti2 = i2; best_k = k;
4526                        }
4527                    }
4528                }
4529            }
4530            GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0);
4531            for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
4532            for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
4533            for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
4534            if (scale < 0) {
4535                for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
4536                scale = -scale;
4537                best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0;
4538            }
4539            bool all_on_grid = true;
4540            for (int k = 0; k < block_size/8; ++k) {
4541                if (k == 0) xx = best_k < 2 ? x_p : x_m;
4542                else xx = best_k%2 == 0 ? x_p : x_m;
4543                uint16_t u = 0;
4544                for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
4545                int grid_index = kmap_q2xs[u];
4546                if (grid_index < 0) {
4547                    all_on_grid = false;
4548                    const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
4549                    grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
4550                    GGML_ASSERT(grid_index >= 0);
4551                }
4552                index[k] = grid_index;
4553            }
4554            if (!all_on_grid) {
4555                float sumqx_f = 0, sumq2_f = 0;
4556                for (int k = 0; k < block_size/8; ++k) {
4557                    if (k == 0) xx = best_k < 2 ? x_p : x_m;
4558                    else xx = best_k%2 == 0 ? x_p : x_m;
4559                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
4560                    for (int j = 0; j < 8; ++j) {
4561                        float w = weight[8*k + j];
4562                        float q = xx[(pg[j] - 1)/2];
4563                        sumqx_f += w*q*xb[8*k+j];
4564                        sumq2_f += w*q*q;
4565                    }
4566                }
4567                if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
4568            }
4569            y[ibl].qs[2*ib + 0] = index[0] & 255;
4570            y[ibl].qs[2*ib + 1] = index[1] & 255;
4571            y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4);
4572            GGML_ASSERT(scale >= 0);
4573            scales[ib] = scale;
4574            shifts[ib] = best_k;
4575            max_scale = MAX(max_scale, scale);
4576        }
4577
4578        if (!max_scale) {
4579            continue;
4580        }
4581
4582        uint16_t * sc = (uint16_t *)y[ibl].scales;
4583        float d = max_scale/15;
4584        float id = 1/d;
4585        float sumqx_f = 0, sumq2_f = 0;
4586        for (int ib = 0; ib < QK_K/block_size; ++ib) {
4587            int l = nearest_int(0.5f*(id*scales[ib+0]-1));
4588            l = MAX(0, MIN(7, l));
4589            sc[ib/4] |= (l << 3*(ib%4));
4590            y[ibl].qh[ib] |= masks[shifts[ib]];
4591            const float * xb = xbl + block_size*ib;
4592            if (quant_weights) {
4593                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
4594                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
4595            } else {
4596                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
4597            }
4598            for (int k = 0; k < block_size/8; ++k) {
4599                if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m;
4600                else xx = shifts[ib]%2 == 0 ? x_p : x_m;
4601                const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700));
4602                for (int j = 0; j < 8; ++j) {
4603                    float w = weight[8*k + j];
4604                    float q = xx[(pg[j] - 1)/2]*(2*l+1);
4605                    sumqx_f += w*q*xb[8*k+j];
4606                    sumq2_f += w*q*q;
4607                }
4608            }
4609        }
4610        if (sumq2_f > 0) d = sumqx_f/sumq2_f;
4611        s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
4612        sc[0] |= ((s.u16 & 0x000f) << 12);
4613        sc[1] |= ((s.u16 & 0x00f0) <<  8);
4614        sc[2] |= ((s.u16 & 0x0f00) <<  4);
4615        sc[3] |= ((s.u16 & 0xf000) <<  0);
4616    }
4617}
4618
4619size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
4620    GGML_ASSERT(n_per_row%QK_K == 0);
4621    float  scales[QK_K/IQ1M_BLOCK_SIZE];
4622    float  weight[IQ1M_BLOCK_SIZE];
4623    int8_t L[IQ1M_BLOCK_SIZE];
4624    float  pairs[2*IQ1M_BLOCK_SIZE];
4625    uint16_t index[IQ1M_BLOCK_SIZE/8];
4626    int8_t shifts[QK_K/IQ1M_BLOCK_SIZE];
4627    int64_t nblock = n_per_row/QK_K;
4628    char * qrow = (char *)dst;
4629    for (int64_t row = 0; row < nrow; ++row) {
4630        quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts);
4631        src += n_per_row;
4632        qrow += nblock*sizeof(block_iq1_m);
4633    }
4634    return nrow * nblock * sizeof(block_iq1_m);
4635}
4636
4637// ============================ 4-bit non-linear quants
4638
4639static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
4640        ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
4641        float * scales, float * weight, uint8_t * L,
4642        const int8_t * values,
4643        const float * quant_weights,
4644        const int ntry) {
4645
4646    float sigma2 = 0;
4647    for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
4648    sigma2 *= 2.f/super_block_size;
4649
4650    memset(q4, 0, super_block_size/2);
4651    dh[0] = GGML_FP32_TO_FP16(0.f);
4652
4653    float max_scale = 0, amax_scale = 0;
4654    for (int ib = 0; ib < super_block_size/block_size; ++ib) {
4655        const float * xb = x + ib*block_size;
4656        uint8_t * Lb = L + ib*block_size;
4657        if (quant_weights) {
4658            const float * qw = quant_weights + ib*block_size;
4659            for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
4660        } else {
4661            for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
4662        }
4663        float amax = 0, max = 0;
4664        for (int j = 0; j < block_size; ++j) {
4665            float ax = fabsf(xb[j]);
4666            if (ax > amax) {
4667                amax = ax; max = xb[j];
4668            }
4669        }
4670        if (amax < GROUP_MAX_EPS) {
4671            scales[ib] = 0;
4672            continue;
4673        }
4674        float d = ntry > 0 ? -max/values[0] : max/values[0];
4675        float id = 1/d;
4676        float sumqx = 0, sumq2 = 0;
4677        for (int j = 0; j < block_size; ++j) {
4678            float al = id*xb[j];
4679            int l = best_index_int8(16, values, al);
4680            Lb[j] = l;
4681            float q = values[l];
4682            float w = weight[j];
4683            sumqx += w*q*xb[j];
4684            sumq2 += w*q*q;
4685        }
4686        d = sumqx/sumq2;
4687        float best = d*sumqx;
4688        for (int itry = -ntry; itry <= ntry; ++itry) {
4689            id = (itry + values[0])/max;
4690            sumqx = sumq2 = 0;
4691            for (int j = 0; j < block_size; ++j) {
4692                float al = id*xb[j];
4693                int l = best_index_int8(16, values, al);
4694                float q = values[l];
4695                float w = weight[j];
4696                sumqx += w*q*xb[j];
4697                sumq2 += w*q*q;
4698            }
4699            if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
4700                d = sumqx/sumq2; best = d * sumqx;
4701            }
4702        }
4703        scales[ib] = d;
4704        float abs_d = fabsf(d);
4705        if (abs_d > amax_scale) {
4706            amax_scale = abs_d; max_scale = d;
4707        }
4708    }
4709
4710    if (super_block_size/block_size > 1) {
4711        int nb = super_block_size/block_size;
4712        memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t));
4713        float d = -max_scale/32;
4714        dh[0] = GGML_FP32_TO_FP16(d);
4715        float id = d ? 1/d : 0.f;
4716        for (int ib = 0; ib < super_block_size/block_size; ++ib) {
4717            int l = nearest_int(id*scales[ib]);
4718            l = MAX(-32, MIN(31, l));
4719            float dl = d * l;
4720            float idl = dl ? 1/dl : 0.f;
4721            uint8_t * Lb = L + ib*block_size;
4722            const float * xb = x + ib*block_size;
4723            for (int j = 0; j < block_size; ++j) {
4724                Lb[j] = best_index_int8(16, values, idl*xb[j]);
4725            }
4726            l += 32;
4727            uint8_t l_l = l & 0xf;
4728            uint8_t l_h = l >>  4;
4729            if (ib%2 == 0) scales_l[ib/2] = l_l;
4730            else scales_l[ib/2] |= (l_l << 4);
4731            scales_h[ib/8] |= (l_h << 2*(ib%8));
4732        }
4733    } else {
4734        dh[0] = GGML_FP32_TO_FP16(scales[0]);
4735        if (ntry > 0) {
4736            float id = scales[0] ? 1/scales[0] : 0;
4737            for (int j = 0; j < super_block_size; ++j) {
4738                L[j] = best_index_int8(16, values, id*x[j]);
4739            }
4740        }
4741    }
4742
4743    for (int i = 0; i < super_block_size/32; ++i) {
4744        for (int j = 0; j < 16; ++j) {
4745            q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
4746        }
4747    }
4748}
4749
4750size_t quantize_iq4_nl(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
4751    GGML_ASSERT(n_per_row%QK4_NL == 0);
4752    int64_t nblock = n_per_row/QK4_NL;
4753    char * qrow = (char *)dst;
4754    uint8_t L[QK4_NL];
4755    float weight[QK4_NL];
4756    uint16_t unused_h;
4757    uint8_t * unused_l = NULL;
4758    float scale;
4759    for (int64_t row = 0; row < nrow; ++row) {
4760        block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
4761        for (int ibl = 0; ibl < nblock; ++ibl) {
4762            const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
4763            quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
4764                    &scale, weight, L, kvalues_iq4nl, qw, 7);
4765        }
4766        src += n_per_row;
4767        qrow += nblock*sizeof(block_iq4_nl);
4768    }
4769    return nrow * nblock * sizeof(block_iq4_nl);
4770}
4771
4772//void quantize_row_iq4_nl_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
4773void quantize_row_iq4_nl_ref(const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k) {
4774    GGML_ASSERT(k%QK4_NL == 0);
4775    int64_t nblock = k/QK4_NL;
4776    uint8_t L[QK4_NL];
4777    float weight[QK4_NL];
4778    uint16_t unused_h;
4779    uint8_t * unused_l = NULL;
4780    float scale;
4781    block_iq4_nl * iq4 = y;
4782    for (int ibl = 0; ibl < nblock; ++ibl) {
4783        quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
4784                &scale, weight, L, kvalues_iq4nl, NULL, -1);
4785    }
4786}
4787
4788size_t quantize_iq4_xs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
4789    GGML_ASSERT(n_per_row%QK_K == 0);
4790    int64_t nblock = n_per_row/QK_K;
4791    char * qrow = (char *)dst;
4792    uint8_t L[QK_K];
4793    float weight[32];
4794    float scales[QK_K/32];
4795    for (int64_t row = 0; row < nrow; ++row) {
4796        block_iq4_xs * iq4 = (block_iq4_xs *)qrow;
4797        for (int ibl = 0; ibl < nblock; ++ibl) {
4798            const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
4799            quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
4800                    scales, weight, L, kvalues_iq4nl, qw, 7);
4801        }
4802        src += n_per_row;
4803        qrow += nblock*sizeof(block_iq4_xs);
4804    }
4805    return nrow * nblock * sizeof(block_iq4_xs);
4806}
4807
4808void quantize_row_iq4_xs_ref(const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k) {
4809    assert(k % QK_K == 0);
4810    quantize_iq4_xs(x, y, 1, k, NULL);
4811}
4812
4813// =============================== 2.5625 bpw
4814
4815static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {
4816
4817    const int gindex = iq2_data_index(GGML_TYPE_IQ2_S);
4818
4819    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
4820    const int      * kmap_q2xs       = iq2_data[gindex].map;
4821    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
4822
4823    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
4824    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
4825    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
4826    GGML_ASSERT(n%QK_K == 0);
4827
4828    const int kMaxQ = 3;
4829
4830    const int64_t nbl = n/QK_K;
4831
4832    block_iq2_s * y = vy;
4833
4834    float scales[QK_K/16];
4835    float weight[16];
4836    float xval[16];
4837    int8_t L[16];
4838    int8_t Laux[16];
4839    float  waux[16];
4840    bool   is_on_grid[2];
4841    bool   is_on_grid_aux[2];
4842    uint8_t block_signs[2];
4843
4844    for (int ibl = 0; ibl < nbl; ++ibl) {
4845
4846        memset(&y[ibl], 0, sizeof(block_iq2_s));
4847        y[ibl].d = GGML_FP32_TO_FP16(0.f);
4848
4849        float max_scale = 0;
4850
4851        const float * xbl = x + QK_K*ibl;
4852        float sumx2 = 0;
4853        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4854        float sigma2 = 2*sumx2/QK_K;
4855
4856        for (int ib = 0; ib < QK_K/16; ++ib) {
4857            const float * xb = xbl + 16*ib;
4858            if (quant_weights) {
4859                const float * qw = quant_weights + QK_K*ibl + 16*ib;
4860                for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
4861            } else {
4862                for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i];
4863            }
4864            for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
4865            for (int k = 0; k < 2; ++k) {
4866                uint8_t s = 0;
4867                for (int i = 0; i < 8; ++i) {
4868                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
4869                    else {
4870                        xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
4871                    }
4872                }
4873                block_signs[k] = s;
4874            }
4875            float max = xval[0];
4876            for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
4877            if (max < GROUP_MAX_EPS_IQ2_S) {
4878                scales[ib] = 0;
4879                continue;
4880            }
4881            float best = 0;
4882            float scale = max/(2*kMaxQ-1);
4883            is_on_grid[0] = is_on_grid[1] = true;
4884            for (int is = -9; is <= 9; ++is) {
4885                float id = (2*kMaxQ-1+is*0.1f)/max;
4886                float this_scale = 1/id;
4887                for (int k = 0; k < 2; ++k) {
4888                    for (int i = 0; i < 8; ++i) {
4889                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
4890                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
4891                    }
4892                    uint16_t u = 0;
4893                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
4894                    int grid_index = kmap_q2xs[u];
4895                    is_on_grid_aux[k] = true;
4896                    if (grid_index < 0) {
4897                        is_on_grid_aux[k] = false;
4898                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
4899                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
4900                    }
4901                }
4902                float sumqx = 0, sumq2 = 0;
4903                for (int i = 0; i < 16; ++i) {
4904                    float w = weight[i];
4905                    float q = 2*Laux[i] + 1;
4906                    sumqx += w*xval[i]*q;
4907                    sumq2 += w*q*q;
4908                }
4909                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
4910                    scale = sumqx/sumq2; best = scale*sumqx;
4911                    for (int i = 0; i < 16; ++i) L[i] = Laux[i];
4912                    for (int k = 0; k <  2; ++k) is_on_grid[k] = is_on_grid_aux[k];
4913                }
4914            }
4915            int n_not_ongrid = 0;
4916            for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
4917            if (n_not_ongrid > 0 && scale > 0) {
4918                float id = 1/scale;
4919                for (int k = 0; k < 2; ++k) {
4920                    if (is_on_grid[k]) continue;
4921                    uint16_t u = 0;
4922                    for (int i = 0; i < 8; ++i) {
4923                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
4924                        l = MAX(0, MIN(kMaxQ-1, l));
4925                        u |= (l << 2*i);
4926                        L[8*k + i] = l;
4927                    }
4928                    int grid_index = kmap_q2xs[u];
4929                    if (grid_index < 0) {
4930                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
4931                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
4932                    }
4933                }
4934                float sumqx = 0, sumq2 = 0;
4935                for (int i = 0; i < 16; ++i) {
4936                    float w = weight[i];
4937                    float q = 2*L[i] + 1;
4938                    sumqx += w*xval[i]*q;
4939                    sumq2 += w*q*q;
4940                }
4941                if (sumq2 > 0) scale = sumqx/sumq2;
4942            }
4943            if (scale < 0) {
4944                scale = -scale;
4945                for (int k = 0; k < 2; ++k) block_signs[k] = ~block_signs[k];
4946            }
4947            for (int k = 0; k < 2; ++k) {
4948                uint16_t u = 0;
4949                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
4950                int grid_index = kmap_q2xs[u];
4951                if (grid_index < 0) {
4952                    printf("Oops: found point %u not on grid:", u);
4953                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
4954                    printf("\n");
4955                    GGML_ABORT("fatal error");
4956                }
4957                const int i8 = 2*ib + k;
4958                y[ibl].qs[i8] = grid_index & 255;
4959                y[ibl].qh[i8/4] |= ((grid_index >> 8) << 2*(i8%4));
4960                y[ibl].qs[QK_K/8 + i8] = block_signs[k];
4961            }
4962            GGML_ASSERT(scale >= 0);
4963            scales[ib] = scale;
4964            max_scale = MAX(max_scale, scale);
4965        }
4966
4967        if (!max_scale) {
4968            continue;
4969        }
4970
4971        float d = max_scale/31;
4972        y[ibl].d = GGML_FP32_TO_FP16(d * 0.9875f);
4973        float id = 1/d;
4974        for (int ib = 0; ib < QK_K/16; ++ib) {
4975            int l = nearest_int(0.5f*(id*scales[ib]-1));
4976            l = MAX(0, MIN(15, l));
4977            if (ib%2 == 0) y[ibl].scales[ib/2] = l;
4978            else y[ibl].scales[ib/2] |= (l << 4);
4979        }
4980    }
4981}
4982
4983size_t quantize_iq2_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
4984    GGML_ASSERT(n_per_row%QK_K == 0);
4985    int64_t nblock = n_per_row/QK_K;
4986    char * qrow = (char *)dst;
4987    for (int64_t row = 0; row < nrow; ++row) {
4988        quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights);
4989        src += n_per_row;
4990        qrow += nblock*sizeof(block_iq2_s);
4991    }
4992    return nrow * nblock * sizeof(block_iq2_s);
4993}
4994
4995void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k) {
4996    assert(k % QK_K == 0);
4997    quantize_iq2_s(x, y, 1, k, NULL);
4998}
4999
5000// =============================== data validation
5001
5002static bool validate_float(float f, size_t i) {
5003    if (isinf(f)) {
5004        fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
5005        return false;
5006    }
5007
5008    if (isnan(f)) {
5009        fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
5010        return false;
5011    }
5012
5013    return true;
5014}
5015
5016static bool isinf_fp16(ggml_fp16_t f) {
5017    return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
5018}
5019
5020static bool isnan_fp16(ggml_fp16_t f) {
5021    return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
5022}
5023
5024static bool validate_fp16(ggml_fp16_t f, size_t i) {
5025    if (isinf_fp16(f)) {
5026        fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
5027        return false;
5028    }
5029
5030    if (isnan_fp16(f)) {
5031        fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
5032        return false;
5033    }
5034
5035    return true;
5036}
5037
5038static bool validate_e_e8m0(uint8_t e, size_t i) {
5039    if (e == 0xff) {
5040        fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
5041        return false;
5042    }
5043
5044    return true;
5045}
5046
5047#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
5048    const type * q = (const type *) (data); \
5049    for (size_t i = 0; i < (nb); ++i) { \
5050        if (!validate_fp16(q[i].d, i)) { \
5051            return false; \
5052        } \
5053    }
5054
5055#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
5056    const type * q = (const type *) (data); \
5057    for (size_t i = 0; i < (nb); ++i) { \
5058        if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
5059            return false; \
5060        } \
5061    }
5062
5063#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
5064    const type * q = (const type *) (data); \
5065    for (size_t i = 0; i < (nb); ++i) { \
5066        if (!validate_e_e8m0(q[i].e, i)) { \
5067            return false; \
5068        } \
5069    }
5070
5071#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
5072    const type * q = (const type *) (data); \
5073    for (size_t i = 0; i < (nb); ++i) { \
5074        for (size_t j = 0; j < (nr); ++j) { \
5075            if (!validate_fp16(q[i].d[j], i)) { \
5076                return false; \
5077            } \
5078        } \
5079    }
5080
5081bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
5082    if (type < 0 || type >= GGML_TYPE_COUNT) {
5083        fprintf(stderr, "%s: invalid type %d\n", __func__, type);
5084        return false;
5085    }
5086
5087    if (nbytes % ggml_type_size(type) != 0) {
5088        fprintf(stderr, "%s: invalid size %zu for type %s (type size = %zu)\n", __func__, nbytes, ggml_type_name(type), ggml_type_size(type));
5089        return false;
5090    }
5091
5092    const size_t nb = nbytes/ggml_type_size(type);
5093
5094    switch (type) {
5095        case GGML_TYPE_BF16:
5096            {
5097                int nans = 0;
5098                int infs = 0;
5099                const unsigned short * f = (const unsigned short *) data;
5100                for (size_t i = 0; i < nb; ++i) {
5101                    nans += (f[i] & 0x7fff) > 0x7f80;
5102                    infs += (f[i] & 0x7fff) == 0x7f80;
5103                }
5104                if (nans) {
5105                    fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
5106                    return false;
5107                }
5108                if (infs) {
5109                    fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
5110                    return false;
5111                }
5112            } break;
5113        case GGML_TYPE_F16:
5114            {
5115                const ggml_fp16_t * f = (const ggml_fp16_t *) data;
5116                size_t i = 0;
5117#if defined(__AVX2__)
5118                for (; i + 15 < nb; i += 16) {
5119                    __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
5120                    __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
5121                    __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
5122                    int mask = _mm256_movemask_epi8(cmp);
5123                    if (mask) {
5124                        for (size_t j = 0; j < 16; ++j) {
5125                            if (!validate_fp16(f[i + j], i + j)) {
5126                                return false;
5127                            }
5128                        }
5129                        GGML_UNREACHABLE();
5130                    }
5131                }
5132#elif defined(__ARM_NEON)
5133                for (; i + 7 < nb; i += 8) {
5134                    uint16x8_t v = vld1q_u16(f + i);
5135                    uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
5136                    uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
5137                    uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
5138                    if (mask) {
5139                        for (size_t j = 0; j < 8; ++j) {
5140                            if (!validate_fp16(f[i + j], i + j)) {
5141                                return false;
5142                            }
5143                        }
5144                        GGML_UNREACHABLE();
5145                    }
5146                }
5147#endif
5148                for (; i < nb; ++i) {
5149                    if (!validate_fp16(f[i], i)) {
5150                        return false;
5151                    }
5152                }
5153            } break;
5154        case GGML_TYPE_F32:
5155            {
5156                const float * f = (const float *) data;
5157                size_t i = 0;
5158#if defined(__AVX2__)
5159                for (; i + 7 < nb; i += 8) {
5160                    __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
5161                    __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
5162                    __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
5163                    int mask = _mm256_movemask_epi8(cmp);
5164                    if (mask) {
5165                        for (size_t j = 0; j < 8; ++j) {
5166                            if (!validate_float(f[i + j], i + j)) {
5167                                return false;
5168                            }
5169                        }
5170                        GGML_UNREACHABLE();
5171                    }
5172                }
5173#elif defined(__ARM_NEON)
5174                for (; i + 3 < nb; i += 4) {
5175                    uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
5176                    uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
5177                    uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
5178                    uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);
5179                    if (mask) {
5180                        for (size_t j = 0; j < 4; ++j) {
5181                            if (!validate_float(f[i + j], i + j)) {
5182                                return false;
5183                            }
5184                        }
5185                        GGML_UNREACHABLE();
5186                    }
5187                }
5188#endif
5189                for (; i < nb; ++i) {
5190                    if (!validate_float(f[i], i)) {
5191                        return false;
5192                    }
5193                }
5194            } break;
5195        case GGML_TYPE_F64:
5196            {
5197                const double * f = (const double *) data;
5198                for (size_t i = 0; i < nb; ++i) {
5199                    if (!validate_float(f[i], i)) {
5200                        return false;
5201                    }
5202                }
5203            } break;
5204        case GGML_TYPE_Q4_0:
5205            {
5206                VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
5207            } break;
5208        case GGML_TYPE_Q4_1:
5209            {
5210                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
5211            } break;
5212        case GGML_TYPE_Q5_0:
5213            {
5214                VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
5215            } break;
5216        case GGML_TYPE_Q5_1:
5217            {
5218                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
5219            } break;
5220        case GGML_TYPE_Q8_0:
5221            {
5222                VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
5223            } break;
5224        case GGML_TYPE_MXFP4:
5225            {
5226                VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
5227            } break;
5228        case GGML_TYPE_Q2_K:
5229            {
5230                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
5231            } break;
5232        case GGML_TYPE_Q3_K:
5233            {
5234                VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
5235            } break;
5236        case GGML_TYPE_Q4_K:
5237            {
5238                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
5239            } break;
5240        case GGML_TYPE_Q5_K:
5241            {
5242                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
5243            } break;
5244        case GGML_TYPE_Q6_K:
5245            {
5246                VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
5247            } break;
5248        case GGML_TYPE_Q8_K:
5249            {
5250                const block_q8_K * q = (const block_q8_K *) data;
5251                for (size_t i = 0; i < nb; ++i) {
5252                    if (!validate_float(q[i].d, i)) {
5253                        return false;
5254                    }
5255                }
5256            } break;
5257        case GGML_TYPE_TQ1_0:
5258            {
5259                VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);
5260            } break;
5261        case GGML_TYPE_TQ2_0:
5262            {
5263                VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
5264            } break;
5265        case GGML_TYPE_IQ1_S:
5266            {
5267                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
5268            } break;
5269        case GGML_TYPE_IQ1_M:
5270            {
5271                const block_iq1_m * q = (const block_iq1_m *) data;
5272                for (size_t i = 0; i < nb; ++i) {
5273                    iq1m_scale_t scale;
5274                    const uint16_t * sc = (const uint16_t *)q[i].scales;
5275                    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5276                    if (!validate_fp16(scale.f16, i)) {
5277                        return false;
5278                    }
5279                }
5280            } break;
5281        case GGML_TYPE_IQ2_XXS:
5282            {
5283                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
5284            } break;
5285        case GGML_TYPE_IQ2_XS:
5286            {
5287                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
5288            } break;
5289        case GGML_TYPE_IQ2_S:
5290            {
5291                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
5292            } break;
5293        case GGML_TYPE_IQ3_XXS:
5294            {
5295                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
5296            } break;
5297
5298        case GGML_TYPE_IQ3_S:
5299            {
5300                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
5301            } break;
5302        case GGML_TYPE_IQ4_XS:
5303            {
5304                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
5305            } break;
5306        case GGML_TYPE_IQ4_NL:
5307            {
5308                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
5309            } break;
5310
5311        case GGML_TYPE_I8:
5312        case GGML_TYPE_I16:
5313        case GGML_TYPE_I32:
5314        case GGML_TYPE_I64:
5315            // nothing to validate
5316            break;
5317        default:
5318            {
5319                fprintf(stderr, "%s: invalid type %d\n", __func__, type);
5320                return false;
5321            }
5322    }
5323
5324    return true;
5325}