1#define GGML_COMMON_IMPL_CPP
   2#define GGML_COMMON_DECL_CPP
   3#include "ggml-common.h"
   4#include "ggml-backend-impl.h"
   5
   6#include "ggml-impl.h"
   7#include "ggml-cpu.h"
   8#include "ggml-cpu-impl.h"
   9#include "simd-mappings.h"
  10#include "traits.h"
  11
  12#include "arch-fallback.h"
  13
  14#include <cmath>
  15#include <cstring>
  16#include <cassert>
  17#include <cstdio>  // for GGML_ASSERT
  18
  19#include "repack.h"
  20
  21#if defined(__GNUC__)
  22#pragma GCC diagnostic ignored "-Woverlength-strings"
  23#endif
  24
  25#define UNUSED GGML_UNUSED
  26
  27static inline int nearest_int(float fval) {
  28    assert(fabsf(fval) <= 4194303.f);
  29    float val = fval + 12582912.f;
  30    int i; memcpy(&i, &val, sizeof(int));
  31    return (i & 0x007fffff) - 0x00400000;
  32}
  33
  34// Functions to create the interleaved data layout formats
  35
  36// interleave 4 block_q4_0s in blocks of blck_size_interleave
  37// returns an interleaved block_q4_0x4
  38// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
  39// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
  40//
  41// - in                  : an array of block_q4_0 pointers
  42// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
  43//                         blck_size_interleave bytes
  44// - xor_mask            : the mask to convert the nibbles in block_q4_0 quants bytes
  45//                         from bias offset form to pure sign form (this saves subtract
  46//                         operations durin unpacking)
  47//
  48
  49extern "C" {
  50
  51void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  52    assert(QK8_0 == 32);
  53    assert(k % QK8_0 == 0);
  54    const int nb = k / QK8_0;
  55
  56    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
  57
  58    // scalar
  59    const int blck_size_interleave = 4;
  60    float srcv[4][QK8_0];
  61    float id[4];
  62
  63    for (int i = 0; i < nb; i++) {
  64        for (int row_iter = 0; row_iter < 4; row_iter++) {
  65            float amax = 0.0f; // absolute max
  66
  67            for (int j = 0; j < QK8_0; j++) {
  68                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
  69                amax = MAX(amax, fabsf(srcv[row_iter][j]));
  70            }
  71
  72            const float d = amax / ((1 << 7) - 1);
  73            id[row_iter] = d ? 1.0f / d : 0.0f;
  74
  75            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
  76        }
  77
  78        for (int j = 0; j < QK8_0 * 4; j++) {
  79            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
  80            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
  81            src_offset += (j % blck_size_interleave);
  82
  83            float x0 = srcv[src_id][src_offset] * id[src_id];
  84            y[i].qs[j] = roundf(x0);
  85        }
  86    }
  87}
  88
  89void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  90    assert(QK8_0 == 32);
  91    assert(k % QK8_0 == 0);
  92    const int nb = k / QK8_0;
  93
  94    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
  95
  96    // scalar
  97    const int blck_size_interleave = 8;
  98    float srcv[4][QK8_0];
  99    float id[4];
 100
 101    for (int i = 0; i < nb; i++) {
 102        for (int row_iter = 0; row_iter < 4; row_iter++) {
 103            float amax = 0.0f; // absolute max
 104
 105            for (int j = 0; j < QK8_0; j++) {
 106                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
 107                amax = MAX(amax, fabsf(srcv[row_iter][j]));
 108            }
 109
 110            const float d = amax / ((1 << 7) - 1);
 111            id[row_iter] = d ? 1.0f / d : 0.0f;
 112
 113            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
 114        }
 115
 116        for (int j = 0; j < QK8_0 * 4; j++) {
 117            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
 118            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
 119            src_offset += (j % blck_size_interleave);
 120
 121            float x0 = srcv[src_id][src_offset] * id[src_id];
 122            y[i].qs[j] = roundf(x0);
 123        }
 124    }
 125}
 126
 127
 128void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
 129    assert(QK_K == 256);
 130    assert(k % QK_K == 0);
 131    const int nb = k / QK_K;
 132
 133    block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
 134
 135    // scalar
 136    const int blck_size_interleave = 4;
 137    float srcv[4][QK_K];
 138    float iscale[4];
 139
 140    for (int i = 0; i < nb; i++) {
 141        for (int row_iter = 0; row_iter < 4; row_iter++) {
 142            float amax = 0.0f; // absolute max
 143            float max = 0;
 144
 145            for (int j = 0; j < QK_K; j++) {
 146                srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
 147                // Update the maximum value of the corresponding super block
 148                if(amax < fabsf(srcv[row_iter][j])) {
 149                    amax = fabsf(srcv[row_iter][j]);
 150                    max = srcv[row_iter][j];
 151                }
 152            }
 153
 154            iscale[row_iter] = amax ? -127.f/max : 0;
 155
 156            y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
 157        }
 158
 159        for (int j = 0; j < QK_K / 4; j++) {
 160            y[i].bsums[j] = 0;
 161        }
 162
 163        // Quants values are interleaved in sequence of four bytes from corresponding super blocks
 164        // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
 165        // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
 166        for (int j = 0; j < QK_K * 4; j++) {
 167            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
 168            int src_id     = (j % (4 * blck_size_interleave)) / blck_size_interleave;
 169            src_offset += (j % blck_size_interleave);
 170            int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
 171
 172            float x0 = srcv[src_id][src_offset] * iscale[src_id];
 173            y[i].qs[j] = nearest_int(x0);
 174            y[i].bsums[index] += y[i].qs[j];
 175        }
 176    }
 177}
 178
 179void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
 180    assert(QK_K == 256);
 181    assert(k % QK_K == 0);
 182    const int nb = k / QK_K;
 183
 184    block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
 185
 186    // scalar
 187    const int blck_size_interleave = 8;
 188    float srcv[4][QK_K];
 189    float iscale[4];
 190
 191    for (int i = 0; i < nb; i++) {
 192        for (int row_iter = 0; row_iter < 4; row_iter++) {
 193            float amax = 0.0f; // absolute max
 194            float max = 0;
 195
 196            for (int j = 0; j < QK_K; j++) {
 197                srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
 198                // Update the maximum value of the corresponding super block
 199                if(amax < fabsf(srcv[row_iter][j])) {
 200                    amax = fabsf(srcv[row_iter][j]);
 201                    max = srcv[row_iter][j];
 202                }
 203            }
 204
 205            iscale[row_iter] = amax ? -127.f/max : 0;
 206
 207            y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
 208        }
 209
 210        for (int j = 0; j < QK_K / 4; j++) {
 211            y[i].bsums[j] = 0;
 212        }
 213
 214        // Quants values are interleaved in sequence of eight bytes from corresponding super blocks
 215        // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
 216        // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
 217        for (int j = 0; j < QK_K * 4; j++) {
 218            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
 219            int src_id     = (j % (4 * blck_size_interleave)) / blck_size_interleave;
 220            src_offset += (j % blck_size_interleave);
 221            int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
 222
 223            float x0 = srcv[src_id][src_offset] * iscale[src_id];
 224            y[i].qs[j] = nearest_int(x0);
 225            y[i].bsums[index] += y[i].qs[j];
 226        }
 227    }
 228}
 229
 230} // extern "C"
 231
 232template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
 233void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
 234
 235template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
 236    assert(nrow == 4);
 237    UNUSED(nrow);
 238    ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
 239}
 240
 241template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
 242    assert(nrow == 4);
 243    UNUSED(nrow);
 244    ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
 245}
 246
 247template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
 248    assert(nrow == 4);
 249    UNUSED(nrow);
 250    ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
 251}
 252
 253template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
 254    assert(nrow == 4);
 255    UNUSED(nrow);
 256    ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
 257}
 258
 259template <int M, int N>
 260static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int                        n,
 261                                                 float * GGML_RESTRICT      s,
 262                                                 size_t                     bs,
 263                                                 const void * GGML_RESTRICT vx,
 264                                                 const void * GGML_RESTRICT vy,
 265                                                 int                        nr,
 266                                                 int                        nc) {
 267    constexpr int blocklen          = M;
 268    constexpr int ncols_interleaved = N;
 269    const int     qk                = QK_K;
 270    const int     nb                = n / qk;
 271    const int     blocks_per_half   = 64 / blocklen;
 272
 273    assert(n % qk == 0);
 274    assert(nc % ncols_interleaved == 0);
 275
 276    UNUSED(bs);
 277    UNUSED(nr);
 278
 279    float sumf[8];
 280
 281    const block_q8_K * a_ptr = (const block_q8_K *) vy;
 282    for (int x = 0; x < nc / ncols_interleaved; x++) {
 283        const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
 284
 285        for (int j = 0; j < ncols_interleaved; j++) {
 286            sumf[j] = 0.0f;
 287        }
 288
 289        for (int l = 0; l < nb; l++) {
 290            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 291                const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
 292                const int base_h = base_l + 64;
 293
 294                const int scale_idx_l = base_l / 16;
 295                const int scale_idx_h = base_h / 16;
 296
 297                const int qh_shift_l = ((base_l % 128) / 32) * 2;
 298                const int qh_shift_h = ((base_h % 128) / 32) * 2;
 299
 300                const int qh_half_l = (base_l / 128) * 32;
 301                const int qh_half_h = (base_h / 128) * 32;
 302
 303                for (int j = 0; j < ncols_interleaved; j++) {
 304                    const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
 305                    const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
 306
 307                    int sumi_l = 0;
 308                    int sumi_h = 0;
 309
 310                    for (int i = 0; i < blocklen; i++) {
 311                        const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
 312                        const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
 313                        const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
 314
 315                        const int qh_idx_l    = qh_half_l + ((base_l + i) % 32);
 316                        const int qh_chunk_l  = qh_idx_l / blocklen;
 317                        const int qh_pos_l    = qh_idx_l % blocklen;
 318                        const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
 319                        const int hi_2_l      = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
 320
 321                        const int qh_idx_h    = qh_half_h + ((base_h + i) % 32);
 322                        const int qh_chunk_h  = qh_idx_h / blocklen;
 323                        const int qh_pos_h    = qh_idx_h % blocklen;
 324                        const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
 325                        const int hi_2_h      = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
 326
 327                        const int q_l = ((hi_2_l << 4) | l_4) - 32;
 328                        const int q_h = ((hi_2_h << 4) | hi_4) - 32;
 329
 330                        const int8_t a_l = a_ptr[l].qs[base_l + i];
 331                        const int8_t a_h = a_ptr[l].qs[base_h + i];
 332
 333                        sumi_l += q_l * a_l;
 334                        sumi_h += q_h * a_h;
 335                    }
 336
 337                    sumf[j] +=
 338                        (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
 339                }
 340            }
 341        }
 342
 343        for (int j = 0; j < ncols_interleaved; j++) {
 344            s[x * ncols_interleaved + j] = sumf[j];
 345        }
 346    }
 347}
 348
 349template <int M, int N>
 350static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int                        n,
 351                                                 float * GGML_RESTRICT      s,
 352                                                 size_t                     bs,
 353                                                 const void * GGML_RESTRICT vx,
 354                                                 const void * GGML_RESTRICT vy,
 355                                                 int                        nr,
 356                                                 int                        nc) {
 357    constexpr int blocklen          = M;
 358    constexpr int ncols_interleaved = N;
 359    const int     qk                = QK_K;
 360    const int     nb                = n / qk;
 361    const int     blocks_per_half   = 64 / blocklen;
 362    const int     q8_half_stride    = 512;
 363    const int     q8_low_high_step  = 256;
 364
 365    assert(n % qk == 0);
 366    assert(nr % 4 == 0);
 367    assert(nc % ncols_interleaved == 0);
 368
 369    UNUSED(bs);
 370
 371    float sumf[4][8];
 372
 373    for (int y = 0; y < nr / 4; y++) {
 374        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
 375        for (int x = 0; x < nc / ncols_interleaved; x++) {
 376            const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
 377
 378            for (int m = 0; m < 4; m++) {
 379                for (int j = 0; j < ncols_interleaved; j++) {
 380                    sumf[m][j] = 0.0f;
 381                }
 382            }
 383
 384            for (int l = 0; l < nb; l++) {
 385                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 386                    const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
 387                    const int base_h = base_l + 64;
 388
 389                    const int scale_idx_l = base_l / 16;
 390                    const int scale_idx_h = base_h / 16;
 391
 392                    const int qh_shift_l = ((base_l % 128) / 32) * 2;
 393                    const int qh_shift_h = ((base_h % 128) / 32) * 2;
 394
 395                    const int qh_half_l = (base_l / 128) * 32;
 396                    const int qh_half_h = (base_h / 128) * 32;
 397
 398                    const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
 399
 400                    for (int m = 0; m < 4; m++) {
 401                        for (int j = 0; j < ncols_interleaved; j++) {
 402                            const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
 403                            const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
 404
 405                            int sumi_l = 0;
 406                            int sumi_h = 0;
 407
 408                            for (int i = 0; i < blocklen; i++) {
 409                                const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
 410                                const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
 411                                const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
 412
 413                                const int qh_idx_l   = qh_half_l + ((base_l + i) % 32);
 414                                const int qh_chunk_l = qh_idx_l / blocklen;
 415                                const int qh_pos_l   = qh_idx_l % blocklen;
 416                                const int qh_offset_l =
 417                                    qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
 418                                const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
 419
 420                                const int qh_idx_h   = qh_half_h + ((base_h + i) % 32);
 421                                const int qh_chunk_h = qh_idx_h / blocklen;
 422                                const int qh_pos_h   = qh_idx_h % blocklen;
 423                                const int qh_offset_h =
 424                                    qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
 425                                const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
 426
 427                                const int q_l = ((hi_2_l << 4) | l_4) - 32;
 428                                const int q_h = ((hi_2_h << 4) | hi_4) - 32;
 429
 430                                const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
 431                                const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
 432
 433                                sumi_l += q_l * q8_l;
 434                                sumi_h += q_h * q8_h;
 435                            }
 436
 437                            sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
 438                                          a_ptr[l].d[m];
 439                        }
 440                    }
 441                }
 442            }
 443
 444            for (int m = 0; m < 4; m++) {
 445                for (int j = 0; j < ncols_interleaved; j++) {
 446                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
 447                }
 448            }
 449        }
 450    }
 451}
 452
 453extern "C" {
 454
 455void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 456    const int qk = QK8_0;
 457    const int nb = n / qk;
 458    const int ncols_interleaved = 4;
 459    const int blocklen = 4;
 460
 461    assert(nr == 1);
 462    assert(n % qk == 0);
 463    assert(nc % ncols_interleaved == 0);
 464
 465    UNUSED(s);
 466    UNUSED(bs);
 467    UNUSED(vx);
 468    UNUSED(vy);
 469    UNUSED(nr);
 470    UNUSED(nc);
 471    UNUSED(nb);
 472    UNUSED(ncols_interleaved);
 473    UNUSED(blocklen);
 474
 475    float sumf[4];
 476    int sumi;
 477
 478    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
 479    for (int x = 0; x < nc / ncols_interleaved; x++) {
 480        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
 481
 482        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
 483        for (int l = 0; l < nb; l++) {
 484            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 485                for (int j = 0; j < ncols_interleaved; j++) {
 486                    sumi = 0;
 487                    for (int i = 0; i < blocklen; ++i) {
 488                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
 489                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
 490                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
 491                    }
 492                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
 493                }
 494            }
 495        }
 496        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
 497    }
 498}
 499
 500void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 501    const int qk = QK8_0;
 502    const int nb = n / qk;
 503    const int ncols_interleaved = 4;
 504    const int blocklen = 8;
 505
 506    assert (n % qk == 0);
 507    assert (nc % ncols_interleaved == 0);
 508
 509    UNUSED(s);
 510    UNUSED(bs);
 511    UNUSED(vx);
 512    UNUSED(vy);
 513    UNUSED(nr);
 514    UNUSED(nc);
 515    UNUSED(nb);
 516    UNUSED(ncols_interleaved);
 517    UNUSED(blocklen);
 518
 519    float sumf[4];
 520    int sumi;
 521
 522    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
 523    for (int x = 0; x < nc / ncols_interleaved; x++) {
 524        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
 525
 526        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
 527        for (int l = 0; l < nb; l++) {
 528            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 529                for (int j = 0; j < ncols_interleaved; j++) {
 530                    sumi = 0;
 531                    for (int i = 0; i < blocklen; ++i) {
 532                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
 533                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
 534                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
 535                    }
 536                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
 537                }
 538            }
 539        }
 540        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
 541    }
 542}
 543
 544void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 545    const int qk = QK8_0;
 546    const int nb = n / qk;
 547    const int ncols_interleaved = 8;
 548    const int blocklen = 8;
 549
 550    assert (n % qk == 0);
 551    assert (nc % ncols_interleaved == 0);
 552
 553    UNUSED(s);
 554    UNUSED(bs);
 555    UNUSED(vx);
 556    UNUSED(vy);
 557    UNUSED(nr);
 558    UNUSED(nc);
 559    UNUSED(nb);
 560    UNUSED(ncols_interleaved);
 561    UNUSED(blocklen);
 562
 563    float sumf[8];
 564    int sumi;
 565
 566    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
 567    for (int x = 0; x < nc / ncols_interleaved; x++) {
 568        const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
 569
 570        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
 571        for (int l = 0; l < nb; l++) {
 572            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 573                for (int j = 0; j < ncols_interleaved; j++) {
 574                    sumi = 0;
 575                    for (int i = 0; i < blocklen; ++i) {
 576                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
 577                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
 578                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
 579                    }
 580                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
 581                }
 582            }
 583        }
 584        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
 585    }
 586}
 587
 588void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 589    const int qk = QK_K;
 590    const int nb = n / qk;
 591    const int ncols_interleaved = 8;
 592    const int blocklen = 4;
 593    static const uint32_t kmask1 = 0x3f3f3f3f;
 594    static const uint32_t kmask2 = 0x0f0f0f0f;
 595    static const uint32_t kmask3 = 0x03030303;
 596
 597    assert (n % qk == 0);
 598    assert (nc % ncols_interleaved == 0);
 599
 600    UNUSED(bs);
 601    UNUSED(nr);
 602
 603    float sumf[8];
 604    float sum_minf[8];
 605    uint32_t utmp[32];
 606    int sumi1;
 607    int sumi2;
 608    int sumi;
 609
 610    const block_q8_K * a_ptr = (const block_q8_K *) vy;
 611    for (int x = 0; x < nc / ncols_interleaved; x++) {
 612        const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
 613
 614        for (int j = 0; j < ncols_interleaved; j++) {
 615            sumf[j] = 0.0;
 616            sum_minf[j] = 0.0;
 617        }
 618        for (int l = 0; l < nb; l++) {
 619            for (int sb = 0; sb < 8; sb++) {
 620                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
 621                utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
 622                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
 623                utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
 624                utmp[sb * 4 + 2] = uaux_0;
 625                utmp[sb * 4 + 0] &= kmask1;
 626            }
 627            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 628                uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
 629                uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
 630                for (int j = 0; j < ncols_interleaved; j++) {
 631                    sumi1 = 0;
 632                    sumi2 = 0;
 633                    sumi = 0;
 634                    for (int i = 0; i < blocklen; ++i) {
 635                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
 636                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
 637                        sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
 638                        sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
 639                        sumi1 = sumi1 * scales_0[j];
 640                        sumi2 = sumi2 * scales_1[j];
 641                        sumi += sumi1 + sumi2;
 642                    }
 643                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
 644                }
 645            }
 646            for (int sb = 0; sb < 8; sb++) {
 647                uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
 648                for (int j = 0; j < ncols_interleaved; j++) {
 649                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
 650                }
 651            }
 652        }
 653        for (int j = 0; j < ncols_interleaved; j++) {
 654            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
 655        }
 656    }
 657}
 658
 659void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 660    const int qk = QK_K;
 661    const int nb = n / qk;
 662    const int ncols_interleaved = 8;
 663    const int blocklen = 8;
 664    static const uint32_t kmask1 = 0x3f3f3f3f;
 665    static const uint32_t kmask2 = 0x0f0f0f0f;
 666    static const uint32_t kmask3 = 0x03030303;
 667
 668    assert (n % qk == 0);
 669    assert (nc % ncols_interleaved == 0);
 670
 671    UNUSED(bs);
 672    UNUSED(nr);
 673
 674    float sumf[8];
 675    float sum_minf[8];
 676    uint32_t utmp[32];
 677    int sumi1;
 678    int sumi2;
 679    int sumi;
 680
 681    const block_q8_K * a_ptr = (const block_q8_K *) vy;
 682    for (int x = 0; x < nc / ncols_interleaved; x++) {
 683        const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
 684
 685        for (int j = 0; j < ncols_interleaved; j++) {
 686            sumf[j] = 0.0;
 687            sum_minf[j] = 0.0;
 688        }
 689        for (int l = 0; l < nb; l++) {
 690            for (int sb = 0; sb < 8; sb++) {
 691                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
 692                utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
 693                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
 694                utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
 695                utmp[sb * 4 + 2] = uaux_0;
 696                utmp[sb * 4 + 0] &= kmask1;
 697            }
 698            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 699                uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
 700                uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
 701                for (int j = 0; j < ncols_interleaved; j++) {
 702                    sumi1 = 0;
 703                    sumi2 = 0;
 704                    sumi = 0;
 705                    for (int i = 0; i < blocklen; ++i) {
 706                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
 707                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
 708                        sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
 709                        sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
 710                        sumi1 = sumi1 * scales_0[j];
 711                        sumi2 = sumi2 * scales_1[j];
 712                        sumi += sumi1 + sumi2;
 713                    }
 714                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
 715                }
 716            }
 717            for (int sb = 0; sb < 8; sb++) {
 718                uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
 719                for (int j = 0; j < ncols_interleaved; j++) {
 720                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
 721                }
 722            }
 723        }
 724        for (int j = 0; j < ncols_interleaved; j++) {
 725            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
 726        }
 727    }
 728}
 729
 730void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 731    const int qk = QK_K;
 732    const int nb = n / qk;
 733    const int ncols_interleaved = 8;
 734    const int blocklen = 8;
 735
 736    assert (n % qk == 0);
 737    assert (nc % ncols_interleaved == 0);
 738
 739    UNUSED(s);
 740    UNUSED(bs);
 741    UNUSED(vx);
 742    UNUSED(vy);
 743    UNUSED(nr);
 744    UNUSED(nc);
 745    UNUSED(nb);
 746    UNUSED(ncols_interleaved);
 747    UNUSED(blocklen);
 748
 749    float sumf[8];
 750    float sum_minf[8];
 751    int sumi1,sumi2,sumi3,sumi4;
 752    int sumi;
 753
 754    const block_q8_K * a_ptr = (const block_q8_K *)vy;
 755    for(int x = 0; x < nc / ncols_interleaved; x++) {
 756        const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
 757        for (int j = 0; j < ncols_interleaved; j++) {
 758            sumf[j] = 0.0;
 759            sum_minf[j] = 0.0;
 760        }
 761        for (int l = 0; l < nb; l++) {
 762            for (int k = 0; k < (qk / (4 * blocklen)); k++) {
 763                const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
 764                const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
 765                const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
 766                const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
 767                for (int j = 0; j < ncols_interleaved; j++) {
 768                    sumi1 = 0;
 769                    sumi2 = 0;
 770                    sumi3 = 0;
 771                    sumi4 = 0;
 772                    sumi = 0;
 773                    int offset = ((k / 2) % 2) + j * 2;
 774                    for (int i = 0; i < blocklen; ++i){
 775                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
 776                        const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
 777                        const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
 778                        const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
 779                        sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
 780                        sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
 781                        sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
 782                        sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
 783
 784                        sumi1 = sumi1 * (scales_0[offset] & 0xF);
 785                        sumi2 = sumi2 * (scales_1[offset] & 0xF);
 786                        sumi3 = sumi3 * (scales_2[offset] & 0xF);
 787                        sumi4 = sumi4 * (scales_3[offset] & 0xF);
 788                        sumi += sumi1 + sumi2 + sumi3 + sumi4;
 789                    }
 790                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
 791                }
 792            }
 793            for(int sb = 0; sb < 8; sb++) {
 794                const uint8_t *mins = b_ptr[l].scales + sb * 16;
 795                for(int j = 0; j < ncols_interleaved; j++){
 796                    sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
 797                }
 798            }
 799        }
 800        for (int j = 0; j < ncols_interleaved; j++) {
 801            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
 802        }
 803    }
 804}
 805
 806void ggml_gemv_q5_K_8x8_q8_K_generic(int                        n,
 807                                     float * GGML_RESTRICT      s,
 808                                     size_t                     bs,
 809                                     const void * GGML_RESTRICT vx,
 810                                     const void * GGML_RESTRICT vy,
 811                                     int                        nr,
 812                                     int                        nc) {
 813    const int             qk                = QK_K;
 814    const int             nb                = n / qk;
 815    const int             ncols_interleaved = 8;
 816    const int             blocklen          = 8;
 817    static const uint32_t kmask1            = 0x3f3f3f3f;
 818    static const uint32_t kmask2            = 0x0f0f0f0f;
 819    static const uint32_t kmask3            = 0x03030303;
 820
 821    assert(n % qk == 0);
 822    assert(nc % ncols_interleaved == 0);
 823
 824    UNUSED(bs);
 825    UNUSED(nr);
 826
 827    float    sumf[8];
 828    float    sum_minf[8];
 829    uint32_t utmp[32];
 830    int      sumi1;
 831    int      sumi2;
 832    int      sumi;
 833
 834    const block_q8_K * a_ptr = (const block_q8_K *) vy;
 835    for (int x = 0; x < nc / ncols_interleaved; x++) {
 836        const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
 837
 838        for (int j = 0; j < ncols_interleaved; j++) {
 839            sumf[j]     = 0.0;
 840            sum_minf[j] = 0.0;
 841        }
 842        for (int l = 0; l < nb; l++) {
 843            for (int sb = 0; sb < 8; sb++) {
 844                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
 845                utmp[sb * 4 + 3]      = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
 846                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
 847                utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
 848                utmp[sb * 4 + 2]      = uaux_0;
 849                utmp[sb * 4 + 0] &= kmask1;
 850            }
 851            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 852                uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
 853                uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
 854
 855                const int qh_shift = (k / 4) * 2;
 856                for (int j = 0; j < ncols_interleaved; j++) {
 857                    sumi1 = 0;
 858                    sumi2 = 0;
 859                    sumi  = 0;
 860                    for (int i = 0; i < blocklen; ++i) {
 861                        const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
 862
 863                        const int qh_idx      = (k * 8 + i) % 32;
 864                        const int qh_chunk    = qh_idx / 8;
 865                        const int qh_pos      = qh_idx % 8;
 866                        const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
 867
 868                        const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
 869                        const uint8_t h0     = (qh_val >> qh_shift) & 1;
 870                        const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;
 871
 872                        const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
 873                        const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
 874
 875                        const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
 876
 877                        sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
 878                        sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
 879                        sumi1 = sumi1 * scales_0[j];
 880                        sumi2 = sumi2 * scales_1[j];
 881                        sumi += sumi1 + sumi2;
 882                    }
 883                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
 884                }
 885            }
 886            for (int sb = 0; sb < 8; sb++) {
 887                uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
 888                for (int j = 0; j < ncols_interleaved; j++) {
 889                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
 890                                   GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
 891                }
 892            }
 893        }
 894        for (int j = 0; j < ncols_interleaved; j++) {
 895            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
 896        }
 897    }
 898}
 899
 900
 901void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 902    ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
 903}
 904
 905void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 906    ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
 907}
 908
 909void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 910    const int qk = QK8_0;
 911    const int nb = n / qk;
 912    const int ncols_interleaved = 4;
 913    const int blocklen = 4;
 914
 915    assert(nr == 1);
 916    assert(n % qk == 0);
 917    assert(nc % ncols_interleaved == 0);
 918
 919    UNUSED(bs);
 920    UNUSED(nr);
 921
 922    float sumf[4];
 923    int sumi;
 924
 925    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
 926    for (int x = 0; x < nc / ncols_interleaved; x++) {
 927        const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
 928
 929        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
 930        for (int l = 0; l < nb; l++) {
 931            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 932                for (int j = 0; j < ncols_interleaved; j++) {
 933                    sumi = 0;
 934                    for (int i = 0; i < blocklen; ++i) {
 935                        const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
 936                        const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
 937                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
 938                    }
 939                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
 940                }
 941            }
 942        }
 943        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
 944    }
 945}
 946
 947void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 948    const int qk = QK8_0;
 949    const int nb = n / qk;
 950    const int ncols_interleaved = 8;
 951    const int blocklen = 8;
 952
 953    assert(nr == 1);
 954    assert(n % qk == 0);
 955    assert(nc % ncols_interleaved == 0);
 956
 957    UNUSED(bs);
 958    UNUSED(nr);
 959
 960    float sumf[8];
 961    int sumi;
 962
 963    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
 964    for (int x = 0; x < nc / ncols_interleaved; x++) {
 965        const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
 966
 967        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
 968        for (int l = 0; l < nb; l++) {
 969            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
 970                for (int j = 0; j < ncols_interleaved; j++) {
 971                    sumi = 0;
 972                    for (int i = 0; i < blocklen; ++i) {
 973                        const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
 974                        const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
 975                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
 976                    }
 977                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
 978                }
 979            }
 980        }
 981        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
 982    }
 983}
 984
 985void ggml_gemv_q8_0_4x4_q8_0_generic(int                        n,
 986                                     float * GGML_RESTRICT      s,
 987                                     size_t                     bs,
 988                                     const void * GGML_RESTRICT vx,
 989                                     const void * GGML_RESTRICT vy,
 990                                     int                        nr,
 991                                     int                        nc) {
 992    const int qk                = QK8_0;
 993    const int nb                = n / qk;
 994    const int ncols_interleaved = 4;
 995    const int blocklen          = 4;
 996
 997    assert(nr == 1);
 998    assert(n % qk == 0);
 999    assert(nc % ncols_interleaved == 0);
1000
1001    UNUSED(bs);
1002    UNUSED(nr);
1003
1004    float sumf[4];
1005    int   sumi;
1006
1007    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1008    for (int x = 0; x < nc / ncols_interleaved; x++) {
1009        const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1010
1011        for (int j = 0; j < ncols_interleaved; j++) {
1012            sumf[j] = 0.0;
1013        }
1014        for (int l = 0; l < nb; l++) {
1015            for (int k = 0; k < (qk / blocklen); k++) {
1016                for (int j = 0; j < ncols_interleaved; j++) {
1017                    sumi = 0;
1018                    for (int i = 0; i < blocklen; ++i) {
1019                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1020                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];
1021                    }
1022                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
1023                }
1024            }
1025        }
1026        for (int j = 0; j < ncols_interleaved; j++) {
1027            s[x * ncols_interleaved + j] = sumf[j];
1028        }
1029    }
1030}
1031
1032void ggml_gemv_q8_0_4x8_q8_0_generic(int                        n,
1033                                     float * GGML_RESTRICT      s,
1034                                     size_t                     bs,
1035                                     const void * GGML_RESTRICT vx,
1036                                     const void * GGML_RESTRICT vy,
1037                                     int                        nr,
1038                                     int                        nc) {
1039    const int qk                = QK8_0;
1040    const int nb                = n / qk;
1041    const int ncols_interleaved = 4;
1042    const int blocklen          = 8;
1043
1044    assert(nr == 1);
1045    assert(n % qk == 0);
1046    assert(nc % ncols_interleaved == 0);
1047
1048    UNUSED(bs);
1049    UNUSED(nr);
1050
1051    float sumf[4];
1052    int   sumi;
1053
1054    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1055    for (int x = 0; x < nc / ncols_interleaved; x++) {
1056        const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1057
1058        for (int j = 0; j < ncols_interleaved; j++) {
1059            sumf[j] = 0.0;
1060        }
1061        for (int l = 0; l < nb; l++) {
1062            for (int k = 0; k < (qk / blocklen); k++) {
1063                for (int j = 0; j < ncols_interleaved; j++) {
1064                    sumi = 0;
1065                    for (int i = 0; i < blocklen; ++i) {
1066                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1067                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];
1068                    }
1069                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
1070                }
1071            }
1072        }
1073        for (int j = 0; j < ncols_interleaved; j++) {
1074            s[x * ncols_interleaved + j] = sumf[j];
1075        }
1076    }
1077}
1078
1079void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1080    const int qk = QK8_0;
1081    const int nb = n / qk;
1082    const int ncols_interleaved = 4;
1083    const int blocklen = 4;
1084
1085    assert (n % qk == 0);
1086    assert (nr % 4 == 0);
1087    assert (nc % ncols_interleaved == 0);
1088
1089    UNUSED(s);
1090    UNUSED(bs);
1091    UNUSED(vx);
1092    UNUSED(vy);
1093    UNUSED(nr);
1094    UNUSED(nc);
1095    UNUSED(nb);
1096    UNUSED(ncols_interleaved);
1097    UNUSED(blocklen);
1098
1099    {
1100        float sumf[4][4];
1101        int sumi;
1102
1103        for (int y = 0; y < nr / 4; y++) {
1104            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1105            for (int x = 0; x < nc / ncols_interleaved; x++) {
1106                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
1107                for (int m = 0; m < 4; m++) {
1108                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1109                }
1110                for (int l = 0; l < nb; l++) {
1111                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1112                        for (int m = 0; m < 4; m++) {
1113                            for (int j = 0; j < ncols_interleaved; j++) {
1114                                sumi = 0;
1115                                for (int i = 0; i < blocklen; ++i) {
1116                                    const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
1117                                    const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
1118                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1119                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
1120                                }
1121                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1122                            }
1123                        }
1124                    }
1125                }
1126                for (int m = 0; m < 4; m++) {
1127                    for (int j = 0; j < ncols_interleaved; j++)
1128                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1129                }
1130            }
1131        }
1132    }
1133}
1134
1135void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1136    const int qk = QK8_0;
1137    const int nb = n / qk;
1138    const int ncols_interleaved = 4;
1139    const int blocklen = 8;
1140
1141    assert (n % qk == 0);
1142    assert (nr % 4 == 0);
1143    assert (nc % ncols_interleaved == 0);
1144
1145    UNUSED(s);
1146    UNUSED(bs);
1147    UNUSED(vx);
1148    UNUSED(vy);
1149    UNUSED(nr);
1150    UNUSED(nc);
1151    UNUSED(nb);
1152    UNUSED(ncols_interleaved);
1153    UNUSED(blocklen);
1154
1155    float sumf[4][4];
1156    int sumi;
1157
1158    for (int y = 0; y < nr / 4; y++) {
1159        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1160        for (int x = 0; x < nc / ncols_interleaved; x++) {
1161            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
1162            for (int m = 0; m < 4; m++) {
1163                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1164            }
1165            for (int l = 0; l < nb; l++) {
1166                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1167                    for (int m = 0; m < 4; m++) {
1168                        for (int j = 0; j < ncols_interleaved; j++) {
1169                            sumi = 0;
1170                            for (int i = 0; i < blocklen; ++i) {
1171                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
1172                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
1173                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1174                                        (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
1175                            }
1176                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1177                        }
1178                    }
1179                }
1180            }
1181            for (int m = 0; m < 4; m++) {
1182                for (int j = 0; j < ncols_interleaved; j++)
1183                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1184            }
1185        }
1186    }
1187}
1188
1189void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1190    const int qk = QK8_0;
1191    const int nb = n / qk;
1192    const int ncols_interleaved = 8;
1193    const int blocklen = 8;
1194
1195    assert (n % qk == 0);
1196    assert (nr % 4 == 0);
1197    assert (nc % ncols_interleaved == 0);
1198
1199    UNUSED(s);
1200    UNUSED(bs);
1201    UNUSED(vx);
1202    UNUSED(vy);
1203    UNUSED(nr);
1204    UNUSED(nc);
1205    UNUSED(nb);
1206    UNUSED(ncols_interleaved);
1207    UNUSED(blocklen);
1208
1209    float sumf[4][8];
1210    int sumi;
1211
1212    for (int y = 0; y < nr / 4; y++) {
1213        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1214        for (int x = 0; x < nc / ncols_interleaved; x++) {
1215            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
1216            for (int m = 0; m < 4; m++) {
1217                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1218            }
1219            for (int l = 0; l < nb; l++) {
1220                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1221                    for (int m = 0; m < 4; m++) {
1222                        for (int j = 0; j < ncols_interleaved; j++) {
1223                            sumi = 0;
1224                            for (int i = 0; i < blocklen; ++i) {
1225                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
1226                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
1227                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1228                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
1229                            }
1230                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1231                        }
1232                    }
1233                }
1234            }
1235            for (int m = 0; m < 4; m++) {
1236                for (int j = 0; j < ncols_interleaved; j++)
1237                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1238            }
1239        }
1240    }
1241}
1242
1243void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1244    const int qk = QK_K;
1245    const int nb = n / qk;
1246    const int ncols_interleaved = 8;
1247    const int blocklen = 4;
1248    static const uint32_t kmask1 = 0x3f3f3f3f;
1249    static const uint32_t kmask2 = 0x0f0f0f0f;
1250    static const uint32_t kmask3 = 0x03030303;
1251
1252    assert (n % qk == 0);
1253    assert (nr % 4 == 0);
1254    assert (nc % ncols_interleaved == 0);
1255
1256    UNUSED(nb);
1257    UNUSED(ncols_interleaved);
1258    UNUSED(blocklen);
1259
1260    float sumf[4][8];
1261    float sum_minf[4][8];
1262    uint32_t utmp[32];
1263    int sumi1;
1264    int sumi2;
1265    int sumi;
1266
1267    for (int y = 0; y < nr / 4; y++) {
1268        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
1269        for (int x = 0; x < nc / ncols_interleaved; x++) {
1270            const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
1271            for (int m = 0; m < 4; m++) {
1272                for (int j = 0; j < ncols_interleaved; j++) {
1273                    sumf[m][j] = 0.0;
1274                    sum_minf[m][j] = 0.0;
1275                }
1276            }
1277            for (int l = 0; l < nb; l++) {
1278                for (int sb = 0; sb < 8; sb++) {
1279                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
1280                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
1281                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
1282                    utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
1283                    utmp[sb * 4 + 2] = uaux_0;
1284                    utmp[sb * 4 + 0] &= kmask1;
1285                }
1286                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1287                    uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
1288                    uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
1289                    for (int m = 0; m < 4; m++) {
1290                        for (int j = 0; j < ncols_interleaved; j++) {
1291                            sumi1 = 0;
1292                            sumi2 = 0;
1293                            sumi = 0;
1294                            for (int i = 0; i < blocklen; ++i) {
1295                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
1296                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
1297                                sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
1298                                sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
1299                                sumi1 = sumi1 * scales_0[j];
1300                                sumi2 = sumi2 * scales_1[j];
1301                                sumi += sumi1 + sumi2;
1302                            }
1303                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
1304                        }
1305                    }
1306                }
1307                for (int sb = 0; sb < 8; sb++) {
1308                    uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
1309                    for(int m = 0; m < 4; m++) {
1310                        const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
1311                        for(int j = 0; j < ncols_interleaved; j++) {
1312                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
1313                        }
1314                    }
1315                }
1316            }
1317            for (int m = 0; m < 4; m++) {
1318                for (int j = 0; j < ncols_interleaved; j++) {
1319                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
1320                }
1321            }
1322        }
1323    }
1324}
1325
1326void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1327    const int qk = QK_K;
1328    const int nb = n / qk;
1329    const int ncols_interleaved = 8;
1330    const int blocklen = 8;
1331    static const uint32_t kmask1 = 0x3f3f3f3f;
1332    static const uint32_t kmask2 = 0x0f0f0f0f;
1333    static const uint32_t kmask3 = 0x03030303;
1334
1335    assert (n % qk == 0);
1336    assert (nr % 4 == 0);
1337    assert (nc % ncols_interleaved == 0);
1338
1339    UNUSED(bs);
1340
1341    float sumf[4][8];
1342    float sum_minf[4][8];
1343    uint32_t utmp[32];
1344    int sumi1;
1345    int sumi2;
1346    int sumi;
1347
1348    for (int y = 0; y < nr / 4; y++) {
1349        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
1350        for (int x = 0; x < nc / ncols_interleaved; x++) {
1351            const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
1352            for (int m = 0; m < 4; m++) {
1353                for (int j = 0; j < ncols_interleaved; j++) {
1354                    sumf[m][j] = 0.0;
1355                    sum_minf[m][j] = 0.0;
1356                }
1357            }
1358            for (int l = 0; l < nb; l++) {
1359                for (int sb = 0; sb < 8; sb++) {
1360                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
1361                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
1362                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
1363                    utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
1364                    utmp[sb * 4 + 2] = uaux_0;
1365                    utmp[sb * 4 + 0] &= kmask1;
1366                }
1367                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1368                    uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
1369                    uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
1370                    for (int m = 0; m < 4; m++) {
1371                        for (int j = 0; j < ncols_interleaved; j++) {
1372                            sumi1 = 0;
1373                            sumi2 = 0;
1374                            sumi = 0;
1375                            for (int i = 0; i < blocklen; ++i) {
1376                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
1377                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
1378                                sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
1379                                sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
1380                                sumi1 = sumi1 * scales_0[j];
1381                                sumi2 = sumi2 * scales_1[j];
1382                                sumi += sumi1 + sumi2;
1383                            }
1384                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
1385                        }
1386                    }
1387                }
1388                for (int sb = 0; sb < 8; sb++) {
1389                    uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
1390                    for(int m = 0; m < 4; m++) {
1391                        const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
1392                        for(int j = 0; j < ncols_interleaved; j++) {
1393                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
1394                        }
1395                    }
1396                }
1397            }
1398            for (int m = 0; m < 4; m++) {
1399                for (int j = 0; j < ncols_interleaved; j++) {
1400                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
1401                }
1402            }
1403        }
1404    }
1405}
1406
1407void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1408    const int qk = QK_K;
1409    const int nb = n / qk;
1410    const int ncols_interleaved = 8;
1411    const int blocklen = 8;
1412
1413    assert (n % qk == 0);
1414    assert (nr % 4 == 0);
1415    assert (nc % ncols_interleaved == 0);
1416
1417    UNUSED(s);
1418    UNUSED(bs);
1419    UNUSED(vx);
1420    UNUSED(vy);
1421    UNUSED(nr);
1422    UNUSED(nc);
1423    UNUSED(nb);
1424    UNUSED(ncols_interleaved);
1425    UNUSED(blocklen);
1426
1427    float sumf[4][8];
1428    float sum_minf[4][8];
1429    int sumi1, sumi2, sumi3, sumi4;
1430    int sumi;
1431
1432    for (int y = 0; y < nr / 4; y++) {
1433        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
1434        for (int x = 0; x < nc / ncols_interleaved; x++) {
1435            const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
1436            for (int m = 0; m < 4; m++) {
1437                for (int j = 0; j < ncols_interleaved; j++) {
1438                    sumf[m][j] = 0.0;
1439                    sum_minf[m][j] = 0.0;
1440                }
1441            }
1442            for (int l = 0; l < nb; l++) {
1443                for (int k = 0; k < (qk / (4 * blocklen)); k++) {
1444
1445                    const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
1446                    const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
1447                    const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
1448                    const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
1449                    for (int m = 0; m < 4; m++) {
1450                        for (int j = 0; j < ncols_interleaved; j++) {
1451                            sumi1 = 0;
1452                            sumi2 = 0;
1453                            sumi3 = 0;
1454                            sumi4 = 0;
1455                            sumi = 0;
1456                            int offset = ((k / 2) % 2) + j * 2;
1457                            for (int i = 0; i < blocklen; ++i){
1458                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
1459                                const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
1460                                const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
1461                                const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
1462                                sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
1463                                sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
1464                                sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
1465                                sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
1466                                sumi1 = sumi1 * (scales_0[offset] & 0xF);
1467                                sumi2 = sumi2 * (scales_1[offset] & 0xF);
1468                                sumi3 = sumi3 * (scales_2[offset] & 0xF);
1469                                sumi4 = sumi4 * (scales_3[offset] & 0xF);
1470                                sumi += sumi1 + sumi2 + sumi3 + sumi4;
1471                            }
1472                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
1473                        }
1474                    }
1475                }
1476                for(int sb = 0; sb < 8; sb++) {
1477                    const uint8_t *mins = b_ptr[l].scales + sb * 16;
1478                    for(int m = 0; m < 4; m++) {
1479                        const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) *  6);
1480                        for(int j = 0; j < ncols_interleaved; j++) {
1481                            int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
1482                            sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
1483                        }
1484                    }
1485                }
1486            }
1487
1488            for (int m = 0; m < 4; m++) {
1489                for (int j = 0; j < ncols_interleaved; j++) {
1490                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
1491                }
1492            }
1493        }
1494    }
1495}
1496
1497void ggml_gemm_q5_K_8x8_q8_K_generic(int                        n,
1498                                     float * GGML_RESTRICT      s,
1499                                     size_t                     bs,
1500                                     const void * GGML_RESTRICT vx,
1501                                     const void * GGML_RESTRICT vy,
1502                                     int                        nr,
1503                                     int                        nc) {
1504    const int qk                = QK_K;
1505    const int nb                = n / qk;
1506    const int ncols_interleaved = 8;
1507    const int blocklen          = 8;
1508
1509    constexpr uint32_t kmask1 = 0x3f3f3f3f;
1510    constexpr uint32_t kmask2 = 0x0f0f0f0f;
1511    constexpr uint32_t kmask3 = 0x03030303;
1512
1513    assert(n % qk == 0);
1514    assert(nr % 4 == 0);
1515    assert(nc % ncols_interleaved == 0);
1516
1517    float    sumf[4][8];
1518    float    sum_minf[4][8];
1519    uint32_t utmp[32];
1520    int      sumi1;
1521    int      sumi2;
1522    int      sumi;
1523
1524    for (int y = 0; y < nr / 4; y++) {
1525        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
1526        for (int x = 0; x < nc / ncols_interleaved; x++) {
1527            const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
1528            for (int m = 0; m < 4; m++) {
1529                for (int j = 0; j < ncols_interleaved; j++) {
1530                    sumf[m][j]     = 0.0;
1531                    sum_minf[m][j] = 0.0;
1532                }
1533            }
1534            for (int l = 0; l < nb; l++) {
1535                for (int sb = 0; sb < 8; sb++) {
1536                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
1537                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
1538                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
1539                    utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
1540                    utmp[sb * 4 + 2]      = uaux_0;
1541                    utmp[sb * 4 + 0] &= kmask1;
1542                }
1543                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1544                    uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
1545                    uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
1546
1547                    const int qh_shift = (k / 4) * 2;
1548                    for (int m = 0; m < 4; m++) {
1549                        for (int j = 0; j < ncols_interleaved; j++) {
1550                            sumi1 = 0;
1551                            sumi2 = 0;
1552                            sumi  = 0;
1553                            for (int i = 0; i < blocklen; ++i) {
1554                                const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
1555
1556                                const int qh_idx      = (k * 8 + i) % 32;
1557                                const int qh_chunk    = qh_idx / 8;
1558                                const int qh_pos      = qh_idx % 8;
1559                                const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
1560
1561                                const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
1562                                const uint8_t h0     = (qh_val >> qh_shift) & 1;
1563                                const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;
1564
1565                                const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
1566                                const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
1567
1568                                const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
1569
1570                                sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
1571                                sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
1572                                sumi1 = sumi1 * scales_0[j];
1573                                sumi2 = sumi2 * scales_1[j];
1574                                sumi += sumi1 + sumi2;
1575                            }
1576                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
1577                        }
1578                    }
1579                }
1580                for (int sb = 0; sb < 8; sb++) {
1581                    uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
1582                    for (int m = 0; m < 4; m++) {
1583                        const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
1584                        for (int j = 0; j < ncols_interleaved; j++) {
1585                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
1586                                              GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
1587                        }
1588                    }
1589                }
1590            }
1591            for (int m = 0; m < 4; m++) {
1592                for (int j = 0; j < ncols_interleaved; j++) {
1593                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
1594                }
1595            }
1596        }
1597    }
1598}
1599
1600void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1601    ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
1602}
1603
1604void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1605   ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
1606}
1607
1608void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1609    const int qk = QK8_0;
1610    const int nb = n / qk;
1611    const int ncols_interleaved = 4;
1612    const int blocklen = 4;
1613
1614    assert (n % qk == 0);
1615    assert (nr % 4 == 0);
1616    assert (nc % ncols_interleaved == 0);
1617
1618    UNUSED(s);
1619    UNUSED(bs);
1620    UNUSED(vx);
1621    UNUSED(vy);
1622    UNUSED(nr);
1623    UNUSED(nc);
1624    UNUSED(nb);
1625    UNUSED(ncols_interleaved);
1626    UNUSED(blocklen);
1627
1628    {
1629        float sumf[4][4];
1630        int sumi;
1631
1632        for (int y = 0; y < nr / 4; y++) {
1633            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1634            for (int x = 0; x < nc / ncols_interleaved; x++) {
1635                const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
1636                for (int m = 0; m < 4; m++) {
1637                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1638                }
1639                for (int l = 0; l < nb; l++) {
1640                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1641                        for (int m = 0; m < 4; m++) {
1642                            for (int j = 0; j < ncols_interleaved; j++) {
1643                                sumi = 0;
1644                                for (int i = 0; i < blocklen; ++i) {
1645                                    const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
1646                                    const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
1647                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1648                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
1649                                }
1650                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1651                            }
1652                        }
1653                    }
1654                }
1655                for (int m = 0; m < 4; m++) {
1656                    for (int j = 0; j < ncols_interleaved; j++)
1657                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1658                }
1659            }
1660        }
1661    }
1662}
1663
1664void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1665    const int qk = QK8_0;
1666    const int nb = n / qk;
1667    const int ncols_interleaved = 8;
1668    const int blocklen = 8;
1669
1670    assert(n % qk == 0);
1671    assert(nr % 4 == 0);
1672    assert(nc % ncols_interleaved == 0);
1673
1674    float sumf[4][8];
1675    int sumi;
1676
1677    for (int y = 0; y < nr / 4; y++) {
1678        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1679        for (int x = 0; x < nc / ncols_interleaved; x++) {
1680            const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
1681            for (int m = 0; m < 4; m++) {
1682                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1683            }
1684            for (int l = 0; l < nb; l++) {
1685                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1686                    for (int m = 0; m < 4; m++) {
1687                        for (int j = 0; j < ncols_interleaved; j++) {
1688                            sumi = 0;
1689                            for (int i = 0; i < blocklen; ++i) {
1690                                const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
1691                                const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
1692                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1693                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
1694                            }
1695                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1696                        }
1697                    }
1698                }
1699            }
1700            for (int m = 0; m < 4; m++) {
1701                for (int j = 0; j < ncols_interleaved; j++)
1702                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1703            }
1704        }
1705    }
1706}
1707
1708void ggml_gemm_q8_0_4x4_q8_0_generic(int                        n,
1709                                     float * GGML_RESTRICT      s,
1710                                     size_t                     bs,
1711                                     const void * GGML_RESTRICT vx,
1712                                     const void * GGML_RESTRICT vy,
1713                                     int                        nr,
1714                                     int                        nc) {
1715    const int qk                = QK8_0;
1716    const int nb                = n / qk;
1717    const int ncols_interleaved = 4;
1718    const int blocklen          = 4;
1719
1720    assert(n % qk == 0);
1721    assert(nr % 4 == 0);
1722    assert(nc % ncols_interleaved == 0);
1723
1724    float sumf[4][4];
1725    int   sumi;
1726
1727    for (int y = 0; y < nr / 4; y++) {
1728        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1729        for (int x = 0; x < nc / ncols_interleaved; x++) {
1730            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1731            for (int m = 0; m < 4; m++) {
1732                for (int j = 0; j < ncols_interleaved; j++) {
1733                    sumf[m][j] = 0.0;
1734                }
1735            }
1736            for (int l = 0; l < nb; l++) {
1737                for (int k = 0; k < (qk / blocklen); k++) {
1738                    for (int m = 0; m < 4; m++) {
1739                        for (int j = 0; j < ncols_interleaved; j++) {
1740                            sumi = 0;
1741                            for (int i = 0; i < blocklen; ++i) {
1742                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1743                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
1744                            }
1745                            sumf[m][j] +=
1746                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1747                        }
1748                    }
1749                }
1750            }
1751            for (int m = 0; m < 4; m++) {
1752                for (int j = 0; j < ncols_interleaved; j++) {
1753                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1754                }
1755            }
1756        }
1757    }
1758}
1759
1760void ggml_gemm_q8_0_4x8_q8_0_generic(int                        n,
1761                                     float * GGML_RESTRICT      s,
1762                                     size_t                     bs,
1763                                     const void * GGML_RESTRICT vx,
1764                                     const void * GGML_RESTRICT vy,
1765                                     int                        nr,
1766                                     int                        nc) {
1767    const int qk                = QK8_0;
1768    const int nb                = n / qk;
1769    const int ncols_interleaved = 4;
1770    const int blocklen          = 8;
1771
1772    assert(n % qk == 0);
1773    assert(nr % 4 == 0);
1774    assert(nc % ncols_interleaved == 0);
1775
1776    float sumf[4][4];
1777    int   sumi;
1778
1779    for (int y = 0; y < nr / 4; y++) {
1780        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1781        for (int x = 0; x < nc / ncols_interleaved; x++) {
1782            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1783            for (int m = 0; m < 4; m++) {
1784                for (int j = 0; j < ncols_interleaved; j++) {
1785                    sumf[m][j] = 0.0;
1786                }
1787            }
1788            for (int l = 0; l < nb; l++) {
1789                for (int k = 0; k < (qk / blocklen); k++) {
1790                    for (int m = 0; m < 4; m++) {
1791                        for (int j = 0; j < ncols_interleaved; j++) {
1792                            sumi = 0;
1793                            for (int i = 0; i < blocklen; ++i) {
1794                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1795                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
1796                            }
1797                            sumf[m][j] +=
1798                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1799                        }
1800                    }
1801                }
1802            }
1803            for (int m = 0; m < 4; m++) {
1804                for (int j = 0; j < ncols_interleaved; j++) {
1805                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1806                }
1807            }
1808        }
1809    }
1810}
1811
1812} // extern "C"
1813
1814static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
1815    block_q8_0x4 out;
1816
1817    for (int i = 0; i < 4; i++) {
1818        out.d[i] = in[i].d;
1819    }
1820
1821    const int end = QK8_0 * 4 / blck_size_interleave;
1822    for (int i = 0; i < end; ++i) {
1823        int src_id     = i % 4;
1824        int src_offset = (i / 4) * blck_size_interleave;
1825        int dst_offset = i * blck_size_interleave;
1826        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
1827    }
1828    return out;
1829}
1830
1831static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
1832    block_q4_0x4 out;
1833
1834    for (int i = 0; i < 4; i++) {
1835        out.d[i] = in[i].d;
1836    }
1837
1838    const int end = QK4_0 * 2 / blck_size_interleave;
1839
1840    if (blck_size_interleave == 8) {
1841        const uint64_t xor_mask = 0x8888888888888888ULL;
1842        for (int i = 0; i < end; ++i) {
1843            int src_id = i % 4;
1844            int src_offset = (i / 4) * blck_size_interleave;
1845            int dst_offset = i * blck_size_interleave;
1846
1847            uint64_t elems;
1848            // Using memcpy to avoid unaligned memory accesses
1849            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
1850            elems ^= xor_mask;
1851            memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
1852        }
1853    } else if (blck_size_interleave == 4) {
1854        const uint32_t xor_mask = 0x88888888;
1855        for (int i = 0; i < end; ++i) {
1856            int src_id = i % 4;
1857            int src_offset = (i / 4) * blck_size_interleave;
1858            int dst_offset = i * blck_size_interleave;
1859
1860            uint32_t elems;
1861            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
1862            elems ^= xor_mask;
1863            memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
1864        }
1865    } else {
1866        GGML_ASSERT(false);
1867    }
1868
1869    return out;
1870}
1871
1872// interleave 8 block_q4_0s in blocks of blck_size_interleave
1873// returns an interleaved block_q4_0x8
1874// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
1875// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
1876static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
1877    block_q4_0x8 out;
1878
1879    for (int i = 0; i < 8; i++) {
1880        out.d[i] = in[i].d;
1881    }
1882
1883    const int end = QK4_0 * 4 / blck_size_interleave;
1884    const uint64_t xor_mask = 0x8888888888888888ULL;
1885
1886    for (int i = 0; i < end; ++i) {
1887        int src_id = i % 8;
1888        int src_offset = (i / 8) * blck_size_interleave;
1889        int dst_offset = i * blck_size_interleave;
1890
1891        uint64_t elems;
1892        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
1893        elems ^= xor_mask;
1894        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
1895    }
1896
1897    return out;
1898}
1899
1900static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
1901    block_q4_Kx8 out;
1902    //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
1903    for (int i = 0; i < 8; i++) {
1904        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
1905    }
1906
1907    for (int i = 0; i < 8; i++) {
1908        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
1909    }
1910
1911    const int end = QK_K * 4 / blck_size_interleave;
1912
1913    // Interleave Q4_K quants by taking 8 bytes at a time
1914    for (int i = 0; i < end; ++i) {
1915        int src_id = i % 8;
1916        int src_offset = (i / 8) * blck_size_interleave;
1917        int dst_offset = i * blck_size_interleave;
1918
1919        uint64_t elems;
1920        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
1921        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
1922    }
1923
1924    // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
1925    // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
1926    // The output Q4_Kx8 structure has 96 bytes
1927    // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
1928    // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
1929    uint8_t s[8], m[8];
1930
1931    for (int i = 0; i < 4; i++) {
1932        for (int j = 0; j < 8; j++) {
1933            s[j] = in[j].scales[i] & 63;
1934            m[j] = in[j].scales[i + 4] & 63;
1935        }
1936
1937        out.scales[i * 12]      = (s[0] & 63) + ((s[4] & 48) << 2);
1938        out.scales[i * 12 + 1]  = (s[1] & 63) + ((s[5] & 48) << 2);
1939        out.scales[i * 12 + 2]  = (s[2] & 63) + ((s[6] & 48) << 2);
1940        out.scales[i * 12 + 3]  = (s[3] & 63) + ((s[7] & 48) << 2);
1941        out.scales[i * 12 + 4]  = (m[0] & 63) + ((m[4] & 48) << 2);
1942        out.scales[i * 12 + 5]  = (m[1] & 63) + ((m[5] & 48) << 2);
1943        out.scales[i * 12 + 6]  = (m[2] & 63) + ((m[6] & 48) << 2);
1944        out.scales[i * 12 + 7]  = (m[3] & 63) + ((m[7] & 48) << 2);
1945        out.scales[i * 12 + 8]  = (s[4] & 15) + ((m[4] & 15) << 4);
1946        out.scales[i * 12 + 9]  = (s[5] & 15) + ((m[5] & 15) << 4);
1947        out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
1948        out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
1949
1950    }
1951
1952    for (int i = 0; i < 4; i++) {
1953        for (int j = 0; j < 8; j++) {
1954            s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
1955            m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
1956        }
1957
1958        out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
1959        out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
1960        out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
1961        out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
1962        out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
1963        out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
1964        out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
1965        out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
1966        out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
1967        out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
1968        out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
1969        out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
1970
1971    }
1972
1973    return out;
1974}
1975
1976static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) {
1977    block_q2_Kx8 out;
1978
1979    // Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure
1980    for (int i = 0; i < 8; i++) {
1981        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
1982    }
1983
1984    for (int i = 0; i < 8; i++) {
1985        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
1986    }
1987
1988    const int end = QK_K * 2 / blck_size_interleave;
1989
1990    // Interleave Q2_K quants by taking 8 bytes at a time
1991    for (int i = 0; i < end; ++i) {
1992        int src_id = i % 8;
1993        int src_offset = (i / 8) * blck_size_interleave;
1994        int dst_offset = i * blck_size_interleave;
1995
1996        uint64_t elems;
1997        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
1998        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
1999    }
2000
2001    // The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K
2002    // Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value)
2003    // The output Q2_Kx8 structure has 128 bytes for storing scales and mins
2004    // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
2005    // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
2006
2007    for (int i = 0; i < 128; i++) {
2008        // Index for selecting which q2k super block
2009        int src1 = (i % 16) / 2;
2010        // Index for selecting scale
2011        int src2 = ((i / 16) * 2) + (i % 2);
2012
2013        out.scales[i] = in[src1].scales[src2];
2014    }
2015    return out;
2016}
2017
2018static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
2019    block_q5_Kx8 out;
2020    //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
2021    for (int i = 0; i < 8; i++) {
2022        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
2023    }
2024
2025    for (int i = 0; i < 8; i++) {
2026        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
2027    }
2028
2029    const int end = QK_K * 4 / blck_size_interleave;
2030
2031    // Interleave Q5_K quants by taking 8 bytes at a time
2032    for (int i = 0; i < end; ++i) {
2033        int src_id     = i % 8;
2034        int src_offset = (i / 8) * blck_size_interleave;
2035        int dst_offset = i * blck_size_interleave;
2036
2037        uint64_t elems;
2038        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
2039        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
2040    }
2041
2042    // Repeat for low bits 8 bytes at a time as well, since
2043    // the high bits are interleaved in Q5_K and the index is
2044    // qh_idx = (qs_idx % 32);
2045    // qh_val = qh[qh_idx] >> (qs_idx / 32);
2046    for (int i = 0; i < end / 4; ++i) {
2047        int src_id     = i % 8;
2048        int src_offset = (i / 8) * blck_size_interleave;
2049        int dst_offset = i * blck_size_interleave;
2050
2051        uint64_t elems;
2052        memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
2053        memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
2054    }
2055
2056    // The below logic is copied over from Q4_K
2057    // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
2058    // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
2059    // The output Q5_Kx8 structure has 96 bytes
2060    // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
2061    // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
2062    uint8_t s[8], m[8];
2063
2064    for (int i = 0; i < 4; i++) {
2065        for (int j = 0; j < 8; j++) {
2066            s[j] = in[j].scales[i] & 63;
2067            m[j] = in[j].scales[i + 4] & 63;
2068        }
2069
2070        out.scales[i * 12]      = (s[0] & 63) + ((s[4] & 48) << 2);
2071        out.scales[i * 12 + 1]  = (s[1] & 63) + ((s[5] & 48) << 2);
2072        out.scales[i * 12 + 2]  = (s[2] & 63) + ((s[6] & 48) << 2);
2073        out.scales[i * 12 + 3]  = (s[3] & 63) + ((s[7] & 48) << 2);
2074        out.scales[i * 12 + 4]  = (m[0] & 63) + ((m[4] & 48) << 2);
2075        out.scales[i * 12 + 5]  = (m[1] & 63) + ((m[5] & 48) << 2);
2076        out.scales[i * 12 + 6]  = (m[2] & 63) + ((m[6] & 48) << 2);
2077        out.scales[i * 12 + 7]  = (m[3] & 63) + ((m[7] & 48) << 2);
2078        out.scales[i * 12 + 8]  = (s[4] & 15) + ((m[4] & 15) << 4);
2079        out.scales[i * 12 + 9]  = (s[5] & 15) + ((m[5] & 15) << 4);
2080        out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
2081        out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
2082    }
2083
2084    for (int i = 0; i < 4; i++) {
2085        for (int j = 0; j < 8; j++) {
2086            s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
2087            m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
2088        }
2089
2090        out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
2091        out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
2092        out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
2093        out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
2094        out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
2095        out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
2096        out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
2097        out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
2098        out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
2099        out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
2100        out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
2101        out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
2102    }
2103
2104    return out;
2105}
2106
2107static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) {
2108    block_q6_Kx8  out;
2109    constexpr int n_blocks = 8;  // Kx8
2110    for (int i = 0; i < n_blocks; i++) {
2111        out.d[i] = in[i].d;
2112    }
2113
2114    const int end_ls = QK_K * 4 / blck_size_interleave;
2115    // Interleave Q6_K quants by taking blck_size_interleave bytes at a time
2116    for (int i = 0; i < end_ls; ++i) {
2117        int src_id     = i % n_blocks;
2118        int src_offset = (i / n_blocks) * blck_size_interleave;
2119        int dst_offset = i * blck_size_interleave;
2120
2121        uint64_t elem_ls;
2122        memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);
2123        memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);
2124    }
2125
2126    // Interleave high bits using same chunk size as low bits
2127    const int end_hs = end_ls / 2;
2128    for (int i = 0; i < end_hs; ++i) {
2129        int src_id     = i % n_blocks;
2130        int src_offset = (i / n_blocks) * blck_size_interleave;
2131        int dst_offset = i * blck_size_interleave;
2132
2133        uint64_t elem_hs;
2134        memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);
2135        memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);
2136    }
2137
2138    // The below logic is designed so as to unpack and rearrange scales in Q6_K
2139    // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants
2140    // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales
2141    // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7]  (bl = block)
2142    constexpr int n_scales = QK_K / 16;
2143
2144    for (int i = 0; i < n_blocks; i++) {
2145        for (int j = 0; j < n_scales; j++) {
2146            out.scales[j * n_blocks + i] = in[i].scales[j];
2147        }
2148    }
2149
2150    return out;
2151}
2152
2153static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
2154    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
2155    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
2156    constexpr int nrows_interleaved = 4;
2157
2158    block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
2159    const block_q4_0 * src = (const block_q4_0 *)data;
2160    block_q4_0 dst_tmp[4];
2161    int nrow = ggml_nrows(t);
2162    int nblocks = t->ne[0] / QK4_0;
2163
2164    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
2165
2166    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
2167        return -1;
2168    }
2169
2170    for (int b = 0; b < nrow; b += nrows_interleaved) {
2171        for (int64_t x = 0; x < nblocks; x++) {
2172            for (int i = 0; i < nrows_interleaved; i++) {
2173                dst_tmp[i] = src[x + i * nblocks];
2174            }
2175            *dst++ = make_block_q4_0x4(dst_tmp, interleave_block);
2176        }
2177        src += nrows_interleaved * nblocks;
2178    }
2179    return 0;
2180
2181    GGML_UNUSED(data_size);
2182}
2183
2184static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
2185    GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
2186    GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
2187    constexpr int nrows_interleaved = 8;
2188
2189    block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
2190    const block_q4_K * src = (const block_q4_K*) data;
2191    block_q4_K dst_tmp[8];
2192    int nrow = ggml_nrows(t);
2193    int nblocks = t->ne[0] / QK_K;
2194
2195    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
2196
2197    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
2198        return -1;
2199    }
2200
2201    for (int b = 0; b < nrow; b += nrows_interleaved) {
2202        for (int64_t x = 0; x < nblocks; x++) {
2203            for (int i  = 0; i < nrows_interleaved; i++ ) {
2204                dst_tmp[i] = src[x + i * nblocks];
2205            }
2206            *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);
2207        }
2208        src += nrows_interleaved * nblocks;
2209    }
2210    return 0;
2211
2212    GGML_UNUSED(data_size);
2213}
2214
2215static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
2216    GGML_ASSERT(t->type == GGML_TYPE_Q2_K);
2217    GGML_ASSERT(interleave_block == 8);
2218    constexpr int nrows_interleaved = 8;
2219
2220    block_q2_Kx8 * dst = (block_q2_Kx8*)t->data;
2221    const block_q2_K * src = (const block_q2_K*) data;
2222    block_q2_K dst_tmp[8];
2223    int nrow = ggml_nrows(t);
2224    int nblocks = t->ne[0] / QK_K;
2225
2226    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));
2227
2228    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
2229        return -1;
2230    }
2231
2232    for (int b = 0; b < nrow; b += nrows_interleaved) {
2233        for (int64_t x = 0; x < nblocks; x++) {
2234            for (int i = 0; i < nrows_interleaved; i++) {
2235                dst_tmp[i] = src[x + i * nblocks];
2236            }
2237            *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
2238        }
2239        src += nrows_interleaved * nblocks;
2240    }
2241    return 0;
2242
2243    GGML_UNUSED(data_size);
2244}
2245
2246static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor *       t,
2247                                    int                        interleave_block,
2248                                    const void * GGML_RESTRICT data,
2249                                    size_t                     data_size) {
2250    GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
2251    GGML_ASSERT(interleave_block == 8);
2252    constexpr int nrows_interleaved = 8;
2253
2254    block_q5_Kx8 *     dst = (block_q5_Kx8 *) t->data;
2255    const block_q5_K * src = (const block_q5_K *) data;
2256    block_q5_K         dst_tmp[8];
2257    int                nrow    = ggml_nrows(t);
2258    int                nblocks = t->ne[0] / QK_K;
2259
2260    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
2261
2262    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
2263        return -1;
2264    }
2265
2266    for (int b = 0; b < nrow; b += nrows_interleaved) {
2267        for (int64_t x = 0; x < nblocks; x++) {
2268            for (int i = 0; i < nrows_interleaved; i++) {
2269                dst_tmp[i] = src[x + i * nblocks];
2270            }
2271            *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
2272        }
2273        src += nrows_interleaved * nblocks;
2274    }
2275    return 0;
2276}
2277
2278static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
2279    GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
2280    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
2281    constexpr int nrows_interleaved = 8;
2282
2283    block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
2284    const block_q6_K * src = (const block_q6_K *) data;
2285    block_q6_K dst_tmp[8];
2286    int nrow = ggml_nrows(t);
2287    int nblocks = t->ne[0] / QK_K;
2288
2289    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K));
2290
2291    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
2292        return -1;
2293    }
2294
2295    for (int b = 0; b < nrow; b += nrows_interleaved) {
2296        for (int64_t x = 0; x < nblocks; x++) {
2297            for (int i = 0; i < nrows_interleaved; i++) {
2298                dst_tmp[i] = src[x + i * nblocks];
2299            }
2300            *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block);
2301        }
2302        src += nrows_interleaved * nblocks;
2303    }
2304    return 0;
2305}
2306
2307static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
2308    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
2309    GGML_ASSERT(interleave_block == 8);
2310    constexpr int nrows_interleaved = 8;
2311
2312    block_q4_0x8 * dst = (block_q4_0x8*)t->data;
2313    const block_q4_0 * src = (const block_q4_0*) data;
2314    block_q4_0 dst_tmp[8];
2315    int nrow = ggml_nrows(t);
2316    int nblocks = t->ne[0] / QK4_0;
2317
2318    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
2319
2320    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
2321        return -1;
2322    }
2323
2324    for (int b = 0; b < nrow; b += nrows_interleaved) {
2325        for (int64_t x = 0; x < nblocks; x++) {
2326            for (int i  = 0; i < nrows_interleaved; i++ ) {
2327                dst_tmp[i] = src[x + i * nblocks];
2328            }
2329            *dst++ = make_block_q4_0x8(dst_tmp, interleave_block);
2330        }
2331        src += nrows_interleaved * nblocks;
2332    }
2333    return 0;
2334
2335    GGML_UNUSED(data_size);
2336}
2337
2338static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor *       t,
2339                                    int                        interleave_block,
2340                                    const void * GGML_RESTRICT data,
2341                                    size_t                     data_size) {
2342    GGML_ASSERT(t->type == GGML_TYPE_Q8_0);
2343    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
2344    constexpr int nrows_interleaved = 4;
2345
2346    block_q8_0x4 *     dst = (block_q8_0x4 *) t->data;
2347    const block_q8_0 * src = (const block_q8_0 *) data;
2348    block_q8_0         dst_tmp[4];
2349    int                nrow    = ggml_nrows(t);
2350    int                nblocks = t->ne[0] / QK8_0;
2351
2352    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
2353
2354    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
2355        return -1;
2356    }
2357
2358    for (int b = 0; b < nrow; b += nrows_interleaved) {
2359        for (int64_t x = 0; x < nblocks; x++) {
2360            for (int i = 0; i < nrows_interleaved; i++) {
2361                dst_tmp[i] = src[x + i * nblocks];
2362            }
2363            *dst++ = make_block_q8_0x4(dst_tmp, interleave_block);
2364        }
2365        src += nrows_interleaved * nblocks;
2366    }
2367    return 0;
2368}
2369
2370static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
2371    block_iq4_nlx4 out;
2372
2373    for (int i = 0; i < 4; i++) {
2374        out.d[i] = in[i].d;
2375    }
2376
2377    const int end = QK4_NL * 2 / blck_size_interleave;
2378
2379    // TODO: this branch seems wrong
2380    //if (blck_size_interleave == 8) {
2381    //    for (int i = 0; i < end; ++i) {
2382    //        int src_id = i % 4;
2383    //        int src_offset = (i / 4) * blck_size_interleave;
2384    //        int dst_offset = i * blck_size_interleave;
2385
2386    //        // Using memcpy to avoid unaligned memory accesses
2387    //        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
2388    //    }
2389    //} else
2390    if (blck_size_interleave == 4) {
2391        for (int i = 0; i < end; ++i) {
2392            int src_id = i % 4;
2393            int src_offset = (i / 4) * blck_size_interleave;
2394            int dst_offset = i * blck_size_interleave;
2395
2396            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
2397        }
2398    } else {
2399        GGML_ASSERT(false);
2400    }
2401
2402    return out;
2403}
2404
2405static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
2406    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
2407    GGML_ASSERT(interleave_block == 4);
2408
2409    const block_iq4_nl   * src = (const block_iq4_nl   *)data;
2410          block_iq4_nlx4 * dst = (      block_iq4_nlx4 *)t->data;
2411
2412    block_iq4_nl dst_tmp[4];
2413
2414    int nrow = ggml_nrows(t);
2415    int nrows_interleaved = 4;
2416    int nblocks = t->ne[0] / QK4_NL;
2417
2418    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
2419
2420    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
2421        return -1;
2422    }
2423
2424    for (int b = 0; b < nrow; b += nrows_interleaved) {
2425        for (int64_t x = 0; x < nblocks; x++) {
2426            for (int i = 0; i < nrows_interleaved; i++) {
2427                dst_tmp[i] = src[x + i * nblocks];
2428            }
2429            *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
2430        }
2431        src += nrows_interleaved * nblocks;
2432    }
2433    return 0;
2434
2435    GGML_UNUSED(data_size);
2436}
2437
2438static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) {
2439    block_iq4_nlx8 out;
2440
2441    for (int i = 0; i < 8; i++) {
2442        out.d[i] = in[i].d;
2443    }
2444
2445    const int end = QK4_NL * 4 / blck_size_interleave;
2446
2447    if (blck_size_interleave == 8) {
2448        for (int i = 0; i < end; ++i) {
2449            int src_id = i % 8;
2450            int src_offset = (i / 8) * blck_size_interleave;
2451            int dst_offset = i * blck_size_interleave;
2452
2453            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
2454        }
2455    } else {
2456        GGML_ASSERT(false);
2457    }
2458
2459    return out;
2460}
2461
2462static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
2463    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
2464    GGML_ASSERT(interleave_block == 8);
2465
2466    const block_iq4_nl   * src = (const block_iq4_nl   *)data;
2467          block_iq4_nlx8 * dst = (      block_iq4_nlx8 *)t->data;
2468
2469    block_iq4_nl dst_tmp[8];
2470
2471    int nrow = ggml_nrows(t);
2472    int nrows_interleaved = 8;
2473    int nblocks = t->ne[0] / QK4_NL;
2474
2475    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
2476
2477    if (t->ne[1] % nrows_interleaved != 0) {
2478        return -1;
2479    }
2480
2481    for (int b = 0; b < nrow; b += nrows_interleaved) {
2482        for (int64_t x = 0; x < nblocks; x++) {
2483            for (int i = 0; i < nrows_interleaved; i++) {
2484                dst_tmp[i] = src[x + i * nblocks];
2485            }
2486            *dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block);
2487        }
2488        src += nrows_interleaved * nblocks;
2489    }
2490    return 0;
2491
2492    GGML_UNUSED(data_size);
2493}
2494
2495namespace ggml::cpu::repack {
2496// repack
2497template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
2498int repack(struct ggml_tensor *, const void *, size_t);
2499
2500// TODO: generalise.
2501template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
2502    return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
2503}
2504
2505template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
2506    return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
2507}
2508
2509template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
2510    return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
2511}
2512
2513template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
2514    return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
2515}
2516
2517template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
2518    return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
2519}
2520
2521template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
2522    return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
2523}
2524
2525template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
2526    return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
2527}
2528
2529template <> int repack<block_q6_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
2530    return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);
2531}
2532
2533template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
2534    return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
2535}
2536
2537template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
2538    return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
2539}
2540
2541// TODO: needs to be revisited
2542//template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
2543//    return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
2544//}
2545
2546template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
2547    return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
2548}
2549
2550template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
2551    return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
2552}
2553
2554template <> int repack<block_q8_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
2555    return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
2556}
2557
2558// gemv
2559template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
2560void gemv(int, float *, size_t, const void *, const void *, int, int);
2561
2562template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2563    ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
2564}
2565
2566template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2567    ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
2568}
2569
2570template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2571    ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
2572}
2573
2574template <>
2575void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int          n,
2576                                            float *      s,
2577                                            size_t       bs,
2578                                            const void * vx,
2579                                            const void * vy,
2580                                            int          nr,
2581                                            int          nc) {
2582    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
2583}
2584
2585template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2586    ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
2587}
2588
2589template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2590    ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
2591}
2592
2593template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2594    ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
2595}
2596
2597template <> void gemv<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2598    ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
2599}
2600
2601template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2602    ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
2603}
2604
2605template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2606    ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
2607}
2608
2609template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2610    ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
2611}
2612
2613template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2614    ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
2615}
2616
2617template <> void gemv<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2618    ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
2619}
2620
2621// gemm
2622template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
2623void gemm(int, float *, size_t, const void *, const void *, int, int);
2624
2625template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2626    ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
2627}
2628
2629template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2630    ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
2631}
2632
2633template <>
2634void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int          n,
2635                                            float *      s,
2636                                            size_t       bs,
2637                                            const void * vx,
2638                                            const void * vy,
2639                                            int          nr,
2640                                            int          nc) {
2641    ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
2642}
2643
2644template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2645    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
2646}
2647
2648template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2649    ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
2650}
2651
2652template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2653    ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
2654}
2655
2656template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2657    ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
2658}
2659
2660template <> void gemm<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2661    ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
2662}
2663
2664template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2665    ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
2666}
2667
2668template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2669    ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
2670}
2671
2672template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2673    ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
2674}
2675
2676template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2677    ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
2678}
2679
2680template <> void gemm<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2681    ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
2682}
2683
2684class tensor_traits_base : public ggml::cpu::tensor_traits {
2685  public:
2686    virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
2687};
2688
2689template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {
2690
2691    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
2692        // not realy a GGML_TYPE_Q8_0 but same size.
2693        switch (op->op) {
2694            case GGML_OP_MUL_MAT:
2695                {
2696                    size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
2697                    return true;
2698                }
2699            case GGML_OP_MUL_MAT_ID:
2700                {
2701                    size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
2702                    size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
2703
2704                    const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
2705                    const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
2706
2707                    const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
2708
2709                    size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
2710
2711                    return true;
2712                }
2713            default:
2714                // GGML_ABORT("fatal error");
2715                break;
2716        }
2717        return false;
2718    }
2719
2720    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
2721        switch (op->op) {
2722            case GGML_OP_MUL_MAT:
2723                forward_mul_mat(params, op);
2724                return true;
2725            case GGML_OP_MUL_MAT_ID:
2726                forward_mul_mat_id(params, op);
2727                return true;
2728            default:
2729                // GGML_ABORT("fatal error");
2730                break;
2731        }
2732        return false;
2733    }
2734
2735    void forward_mul_mat_one_chunk(ggml_compute_params * params,
2736                                   ggml_tensor *         op,
2737                                   int64_t               src0_start,
2738                                   int64_t               src0_end,
2739                                   int64_t               src1_start,
2740                                   int64_t               src1_end) {
2741        const ggml_tensor * src0 = op->src[0];
2742        const ggml_tensor * src1 = op->src[1];
2743        ggml_tensor *       dst  = op;
2744
2745        GGML_TENSOR_BINARY_OP_LOCALS
2746
2747        const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
2748
2749        GGML_ASSERT(ne03 == 1 && ne13 == 1);
2750        GGML_ASSERT(ne12 % ne02 == 0);
2751        const int64_t r2 = ne12 / ne02;
2752
2753        const int64_t i12 = src1_start / ne1;
2754        const int64_t i11 = src1_start - i12 * ne1;
2755
2756        // Determine batch index
2757        const int64_t i02 = i12 / r2;
2758
2759        const int64_t i1 = i11;
2760        const int64_t i2 = i12;
2761
2762        const char * src0_ptr = (const char *) src0->data + i02 * nb02;
2763        const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
2764        char *       dst_ptr  = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
2765
2766        const int64_t nrows = src1_end - src1_start;
2767        const int64_t ncols = src0_end - src0_start;
2768
2769        GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
2770
2771        // If there are more than three rows in src1, use gemm; otherwise, use gemv.
2772        if (nrows > 3) {
2773            gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
2774                                                             src0_ptr + src0_start * nb01, src1_ptr,
2775                                                             nrows - (nrows % 4), ncols);
2776        }
2777        for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
2778            gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
2779                                                             ne01, src0_ptr + src0_start * nb01,
2780                                                             src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
2781        }
2782    }
2783
2784    void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
2785        const ggml_tensor * src0 = op->src[0];
2786        const ggml_tensor * src1 = op->src[1];
2787        ggml_tensor *       dst  = op;
2788
2789        GGML_TENSOR_BINARY_OP_LOCALS
2790
2791        const int ith = params->ith;
2792        const int nth = params->nth;
2793
2794        GGML_ASSERT(ne0 == ne01);
2795        GGML_ASSERT(ne1 == ne11);
2796        GGML_ASSERT(ne2 == ne12);
2797        GGML_ASSERT(ne3 == ne13);
2798
2799        // dst cannot be transposed or permuted
2800        GGML_ASSERT(nb0 == sizeof(float));
2801        GGML_ASSERT(nb0 <= nb1);
2802        GGML_ASSERT(nb1 <= nb2);
2803        GGML_ASSERT(nb2 <= nb3);
2804
2805        // TODO: General batched mul mat for 4D tensors
2806        // Currently only supports 3D tensors
2807        GGML_ASSERT(ne03 == 1);
2808        GGML_ASSERT(ne13 == 1);
2809        GGML_ASSERT(ne3 == 1);
2810
2811        GGML_ASSERT(src1->type == GGML_TYPE_F32);
2812
2813        GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
2814        // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
2815
2816        char *       wdata = static_cast<char *>(params->wdata);
2817        const size_t nbw1  = ggml_row_size(PARAM_TYPE, ne10);
2818        const size_t nbw2  = nbw1 * ne11;
2819
2820        assert(params->wsize >= nbw2 * ne12);
2821
2822        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
2823
2824        // INFO: Quantization is done in planes to avoid extra complexity in chunking.
2825        // Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
2826        // the planes are broadcast.
2827        for (int64_t i12 = 0; i12 < ne12; i12++) {
2828            char * data_ptr  = (char *) src1->data + i12 * nb12;
2829            char * wdata_ptr = wdata + i12 * nbw2;
2830
2831            for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
2832                ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
2833                                                            (void *) (wdata_ptr + i11 * nbw1), 4, ne10);
2834            }
2835
2836            const int64_t i11_processed = ne11 - ne11 % 4;
2837            for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
2838                from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
2839            }
2840        }
2841
2842        // disable for NUMA
2843        const bool disable_chunking = ggml_is_numa();
2844
2845        // 4x chunks per thread
2846        const int64_t nr0 = ggml_nrows(op->src[0]);
2847
2848        int     nth_scaled  = nth * 4;
2849        int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
2850        int64_t nchunk0     = (nr0 + chunk_size0 - 1) / chunk_size0;
2851
2852        // src1 is chunked only by full planes.
2853        // When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
2854        // to route them thorugh GEMV.
2855        // nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
2856        // to avoid affecting their performance
2857        int64_t nchunk1 = ne12;
2858
2859        // Ensure minimum chunk size to avoid alignment issues with high thread counts
2860        // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
2861        const int64_t min_chunk_size = NB_COLS;
2862        if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
2863            nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
2864        }
2865
2866        int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
2867        // Only increase nchunk0 to nth if it won't make chunks too small
2868        if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
2869            nchunk0 = nth;
2870            dr0 = (nr0 + nchunk0 - 1) / nchunk0;
2871        }
2872
2873        // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
2874        // This prevents creating too many tiny chunks that could overlap after alignment
2875        const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
2876        nchunk0                  = MIN(nchunk0, max_nchunk);
2877
2878        if (ith == 0) {
2879            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
2880            ggml_threadpool_chunk_set(params->threadpool, nth);
2881        }
2882
2883        ggml_barrier(params->threadpool);
2884
2885        // The first chunk comes from our thread_id, the rest will get auto-assigned.
2886        int current_chunk = ith;
2887
2888        while (current_chunk < nchunk0 * nchunk1) {
2889            const int64_t ith0 = current_chunk % nchunk0;
2890            const int64_t ith1 = current_chunk / nchunk0;
2891
2892            int64_t src0_start = dr0 * ith0;
2893            int64_t src0_end   = MIN(src0_start + dr0, nr0);
2894
2895            // full-plane range for src1
2896            int64_t src1_start = ith1 * ne11;
2897            int64_t src1_end = (ith1 + 1) * ne11;
2898
2899            // Align boundaries to NB_COLS - round up to ensure all data is included
2900            // The chunk size limiting above ensures chunks are large enough to prevent overlaps
2901            src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
2902            src0_end   = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
2903            src0_end   = MIN(src0_end, ne01);
2904
2905            // Make sure current plane is the last one before exiting
2906            if (src0_start >= src0_end) {
2907                current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
2908                continue;
2909            }
2910
2911            forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
2912
2913            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
2914        }
2915    }
2916
2917    void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
2918        const ggml_tensor * src0 = op->src[0];
2919        const ggml_tensor * src1 = op->src[1];
2920        const ggml_tensor * ids  = op->src[2];
2921        ggml_tensor *       dst  = op;
2922
2923        GGML_TENSOR_BINARY_OP_LOCALS
2924
2925        const int ith = params->ith;
2926        const int nth = params->nth;
2927
2928        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
2929
2930        // we don't support permuted src0 or src1
2931        GGML_ASSERT(nb00 == ggml_type_size(src0->type));
2932        GGML_ASSERT(nb10 == ggml_type_size(src1->type));
2933
2934        // dst cannot be transposed or permuted
2935        GGML_ASSERT(nb0 == sizeof(float));
2936        GGML_ASSERT(nb0 <= nb1);
2937        GGML_ASSERT(nb1 <= nb2);
2938        GGML_ASSERT(nb2 <= nb3);
2939
2940        GGML_ASSERT(ne03 == 1);
2941        GGML_ASSERT(ne13 == 1);
2942        GGML_ASSERT(ne3  == 1);
2943
2944        GGML_ASSERT(src1->type == GGML_TYPE_F32);
2945
2946        // row groups
2947        const int n_ids = ids->ne[0]; // n_expert_used
2948        const int n_as  = ne02;       // n_expert
2949
2950        const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
2951        const size_t nbw2 = nbw1*ne11;
2952        const size_t nbw3 = nbw2*ne12;
2953
2954        struct mmid_row_mapping {
2955            int32_t i1;
2956            int32_t i2;
2957        };
2958
2959        GGML_ASSERT(params->wsize >=
2960                (GGML_PAD(nbw3, sizeof(int64_t)) +
2961                 n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
2962                );
2963
2964        auto * wdata          = (char *)params->wdata;
2965        auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
2966
2967        // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
2968        auto * matrix_row_counts = (int64_t *) (wdata_src1_end);                                        // [n_as]
2969        struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
2970
2971        // src1: float32 => param type
2972        for (int64_t i12 = 0; i12 < ne12; ++i12) {
2973            for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
2974                from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
2975                           (void *)               (wdata + i12 * nbw2 + i11 * nbw1),
2976                           ne10);
2977            }
2978        }
2979
2980#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
2981
2982        if (ith == 0) {
2983            // initialize matrix_row_counts
2984            memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
2985
2986            // group rows by src0 matrix
2987            for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
2988                for (int32_t id = 0; id < n_ids; ++id) {
2989                    const int32_t i02 =
2990                        *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
2991
2992                    GGML_ASSERT(i02 >= 0 && i02 < n_as);
2993
2994                    MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
2995                    matrix_row_counts[i02] += 1;
2996                }
2997            }
2998        }
2999
3000        ggml_barrier(params->threadpool);
3001
3002        // compute each matrix multiplication in sequence
3003        for (int cur_a = 0; cur_a < n_as; ++cur_a) {
3004            const int64_t cne1 = matrix_row_counts[cur_a];
3005
3006            if (cne1 == 0) {
3007                continue;
3008            }
3009
3010            const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
3011
3012            //const int64_t nr0 = ne01; // src0 rows
3013            const int64_t nr1 = cne1; // src1 rows
3014
3015            int64_t src0_cur_start = (ith * ne01) / nth;
3016            int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;
3017
3018            // Align boundaries to NB_COLS - round up to ensure all data is included
3019            src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
3020            src0_cur_end   = (src0_cur_end   % NB_COLS) ? src0_cur_end   + NB_COLS - (src0_cur_end   % NB_COLS) : src0_cur_end;
3021            if (src0_cur_end > ne01) {
3022                src0_cur_end = ne01;
3023            }
3024
3025            if (src0_cur_start >= src0_cur_end) {
3026                return;
3027            }
3028
3029            for (int ir1 = 0; ir1 < nr1; ir1++) {
3030                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
3031
3032                const int id = row_mapping.i1;  // selected expert index
3033
3034                const int64_t i11 = id % ne11;
3035                const int64_t i12 = row_mapping.i2;  // row index in src1
3036
3037                const int64_t i1 = id;               // selected expert index
3038                const int64_t i2 = i12;              // row
3039
3040                const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
3041
3042                gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(
3043                    ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
3044                    src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
3045            }
3046        }
3047#undef MMID_MATRIX_ROW
3048    }
3049
3050    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
3051        GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
3052                       (int) NB_COLS, (int) INTER_SIZE);
3053        return ggml::cpu::repack::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
3054    }
3055};
3056
3057}  // namespace ggml::cpu::repack
3058
3059static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
3060    // instance for Q4
3061    static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
3062    static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
3063    static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
3064
3065    // instance for Q4_K
3066    static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
3067    static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
3068
3069    // instance for Q5_K
3070    static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
3071
3072    // instance for Q6_K
3073    static const ggml::cpu::repack::tensor_traits<block_q6_K, 4, 8, GGML_TYPE_Q8_K> q6_K_8x4_q8_K;
3074    static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;
3075
3076    // instance for Q2
3077    static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
3078
3079    // instance for IQ4
3080    static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
3081    static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
3082
3083    // instance for Q8_0
3084    static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
3085    static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
3086
3087    if (cur->type == GGML_TYPE_Q4_0) {
3088        if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
3089            || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
3090            if (cur->ne[1] % 8 == 0) {
3091                return &q4_0_8x8_q8_0;
3092            }
3093        }
3094        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3095            if (cur->ne[1] % 4 == 0) {
3096                return &q4_0_4x8_q8_0;
3097            }
3098        }
3099        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
3100            if (cur->ne[1] % 4 == 0) {
3101                return &q4_0_4x4_q8_0;
3102            }
3103        }
3104    } else if (cur->type == GGML_TYPE_Q4_K) {
3105        if (ggml_cpu_has_avx2()) {
3106            if (cur->ne[1] % 8 == 0) {
3107                return &q4_K_8x8_q8_K;
3108            }
3109        }
3110        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3111            if (cur->ne[1] % 8 == 0) {
3112                return &q4_K_8x8_q8_K;
3113            }
3114        }
3115        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
3116            if (cur->ne[1] % 8 == 0) {
3117                return &q4_K_8x4_q8_K;
3118            }
3119        }
3120    } else if (cur->type == GGML_TYPE_Q2_K) {
3121        if (ggml_cpu_has_avx512()) {
3122            if (cur->ne[1] % 8 == 0) {
3123                return &q2_K_8x8_q8_K;
3124            }
3125        }
3126    } else if (cur->type == GGML_TYPE_Q5_K) {
3127        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3128            if (cur->ne[1] % 8 == 0) {
3129                return &q5_K_8x8_q8_K;
3130            }
3131        }
3132    } else if (cur->type == GGML_TYPE_Q6_K) {
3133        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3134            if (cur->ne[1] % 8 == 0) {
3135                return &q6_K_8x8_q8_K;
3136            }
3137        }
3138        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
3139            if (cur->ne[1] % 8 == 0) {
3140                return &q6_K_8x4_q8_K;
3141            }
3142        }
3143    } else if (cur->type == GGML_TYPE_IQ4_NL) {
3144        if (ggml_cpu_has_avx2()) {
3145            if (cur->ne[1] % 8 == 0) {
3146                return &iq4_nl_8x8_q8_0;
3147            }
3148        }
3149        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
3150            if (cur->ne[1] % 4 == 0) {
3151                return &iq4_nl_4x4_q8_0;
3152            }
3153        }
3154    } else if (cur->type == GGML_TYPE_Q8_0) {
3155        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3156            if (cur->ne[1] % 4 == 0) {
3157                return &q8_0_4x8_q8_0;
3158            }
3159        }
3160        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
3161            if (cur->ne[1] % 4 == 0) {
3162                return &q8_0_4x4_q8_0;
3163            }
3164        }
3165    }
3166
3167    return nullptr;
3168}
3169
3170static enum ggml_status ggml_backend_cpu_repack_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
3171    tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type(tensor));
3172
3173    GGML_UNUSED(buffer);
3174    return GGML_STATUS_SUCCESS;
3175}
3176
3177static void ggml_backend_cpu_repack_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
3178                                                       const void * data, size_t offset, size_t size) {
3179    GGML_ASSERT(offset == 0);
3180    GGML_ASSERT(size == ggml_nbytes(tensor));
3181
3182    auto tensor_traits = (ggml::cpu::repack::tensor_traits_base *) tensor->extra;
3183    auto OK            = tensor_traits->repack(tensor, data, size);
3184
3185    GGML_ASSERT(OK == 0);
3186    GGML_UNUSED(buffer);
3187}
3188
3189static const char * ggml_backend_cpu_repack_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
3190    return "CPU_REPACK";
3191
3192    GGML_UNUSED(buft);
3193}
3194
3195static ggml_backend_buffer_t ggml_backend_cpu_repack_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
3196    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
3197
3198    if (buffer == nullptr) {
3199        return nullptr;
3200    }
3201
3202    buffer->buft              = buft;
3203    buffer->iface.init_tensor = ggml_backend_cpu_repack_buffer_init_tensor;
3204    buffer->iface.set_tensor  = ggml_backend_cpu_repack_buffer_set_tensor;
3205    buffer->iface.get_tensor  = nullptr;
3206    buffer->iface.cpy_tensor  = nullptr;
3207    return buffer;
3208}
3209
3210static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
3211    return TENSOR_ALIGNMENT;
3212
3213    GGML_UNUSED(buft);
3214}
3215
3216namespace ggml::cpu::repack {
3217class extra_buffer_type : ggml::cpu::extra_buffer_type {
3218    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
3219        if (    op->op == GGML_OP_MUL_MAT &&
3220                op->src[0]->buffer &&
3221                (ggml_n_dims(op->src[0]) == 2) &&
3222                op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type() &&
3223                ggml_repack_get_optimal_repack_type(op->src[0])
3224                ) {
3225            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
3226                return false;
3227            }
3228            if (op->src[1]->type == GGML_TYPE_F32) {
3229                return true;
3230            }
3231            //if (op->src[1]->type == GGML_TYPE_Q8_0) {
3232            //    return true;
3233            //}
3234            // may be possible if Q8_0 packed...
3235        } else if (op->op == GGML_OP_MUL_MAT_ID
3236                && op->src[0]->buffer
3237                && (ggml_n_dims(op->src[0]) == 3)
3238                && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
3239                && ggml_repack_get_optimal_repack_type(op->src[0])
3240                ) {
3241            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
3242                return false;
3243            }
3244            if (op->src[1]->type == GGML_TYPE_F32) {
3245                return true;
3246            }
3247            //if (op->src[1]->type == GGML_TYPE_Q8_0) {
3248            //    return true;
3249            //}
3250        }
3251        return false;
3252    }
3253
3254    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
3255        if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
3256            if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
3257                return (ggml::cpu::tensor_traits *) op->src[0]->extra;
3258            }
3259        }
3260        return nullptr;
3261    }
3262};
3263}  // namespace ggml::cpu::repack
3264
3265ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) {
3266    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_repack = {
3267        /* .iface    = */ {
3268                           /* .get_name         = */ ggml_backend_cpu_repack_buffer_type_get_name,
3269                           /* .alloc_buffer     = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer,
3270                           /* .get_alignment    = */ ggml_backend_cpu_repack_buffer_type_get_alignment,
3271                           /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
3272                           /* .get_alloc_size   = */ nullptr,  // defaults to ggml_nbytes
3273                           /* .is_host          = */ nullptr,
3274                           },
3275        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
3276        /* .context = */ new ggml::cpu::repack::extra_buffer_type(),
3277    };
3278
3279    return &ggml_backend_cpu_buffer_type_repack;
3280}