1#include "llama-quant.h"
   2#include "llama-impl.h"
   3#include "llama-model.h"
   4#include "llama-model-loader.h"
   5
   6#include <algorithm>
   7#include <cmath>
   8#include <cstring>
   9#include <cinttypes>
  10#include <fstream>
  11#include <mutex>
  12#include <regex>
  13#include <thread>
  14#include <unordered_map>
  15
  16// Quantization types. Changes to this struct must be replicated in quantize.cpp
  17struct tensor_quantization {
  18    std::string name;
  19    ggml_type quant = GGML_TYPE_COUNT;
  20};
  21
  22static void zeros(std::ofstream & file, size_t n) {
  23    char zero = 0;
  24    for (size_t i = 0; i < n; ++i) {
  25        file.write(&zero, 1);
  26    }
  27}
  28
  29static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
  30    if (prune.empty()) {
  31        return orig_name;
  32    }
  33
  34    static const std::regex pattern(R"(blk\.(\d+)\.)");
  35    if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
  36        const int blk = std::stoi(match[1]);
  37        std::string new_name = orig_name;
  38
  39        if (mapped.count(blk)) {
  40            // Already mapped, do nothing
  41        } else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
  42            mapped[blk] = "";
  43        } else if (blk < prune.front()) {
  44            mapped[blk] = std::to_string(blk);
  45            next_id = blk + 1;
  46        } else {
  47            mapped[blk] = std::to_string(next_id);
  48            ++next_id;
  49        }
  50
  51        return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
  52    }
  53
  54    return orig_name;
  55}
  56
  57static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
  58    if (mapped.empty()) {
  59        return orig_name;
  60    }
  61
  62    static const std::regex pattern(R"(blk\.(\d+)\.)");
  63    if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
  64        const std::string blk(match[1]);
  65        std::string new_name = orig_name;
  66
  67        for (const auto & p : mapped) {
  68            if (p.second == blk) {
  69                LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
  70                return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
  71            }
  72        }
  73        GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
  74    }
  75
  76    return orig_name;
  77}
  78
  79struct quantize_state_impl {
  80    const llama_model                 & model;
  81    const llama_model_quantize_params * params;
  82
  83    int n_attention_wv = 0;
  84    int n_ffn_down     = 0;
  85    int n_ffn_gate     = 0;
  86    int n_ffn_up       = 0;
  87    int i_attention_wv = 0;
  88    int i_ffn_down     = 0;
  89    int i_ffn_gate     = 0;
  90    int i_ffn_up       = 0;
  91
  92    int n_k_quantized = 0;
  93    int n_fallback    = 0;
  94
  95    bool has_imatrix = false;
  96
  97    // used to figure out if a model shares tok_embd with the output weight
  98    bool has_output = false;
  99
 100    quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
 101        : model(model)
 102        , params(params)
 103        {}
 104};
 105
 106static void llama_tensor_dequantize_impl(
 107    ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
 108    const size_t nelements, const int nthread
 109) {
 110    if (output.size() < nelements) {
 111        output.resize(nelements);
 112    }
 113    float * f32_output = (float *) output.data();
 114
 115    const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
 116    if (ggml_is_quantized(tensor->type)) {
 117        if (qtype->to_float == NULL) {
 118            throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
 119        }
 120    } else if (tensor->type != GGML_TYPE_F16 &&
 121               tensor->type != GGML_TYPE_BF16) {
 122        throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
 123    }
 124
 125    if (nthread < 2) {
 126        if (tensor->type == GGML_TYPE_F16) {
 127            ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
 128        } else if (tensor->type == GGML_TYPE_BF16) {
 129            ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
 130        } else if (ggml_is_quantized(tensor->type)) {
 131            qtype->to_float(tensor->data, f32_output, nelements);
 132        } else {
 133            GGML_ABORT("fatal error"); // unreachable
 134        }
 135        return;
 136    }
 137
 138    size_t block_size;
 139    if (tensor->type == GGML_TYPE_F16 ||
 140        tensor->type == GGML_TYPE_BF16) {
 141        block_size = 1;
 142    } else {
 143        block_size = (size_t)ggml_blck_size(tensor->type);
 144    }
 145
 146    size_t block_size_bytes = ggml_type_size(tensor->type);
 147
 148    GGML_ASSERT(nelements % block_size == 0);
 149    size_t nblocks = nelements / block_size;
 150    size_t blocks_per_thread = nblocks / nthread;
 151    size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
 152
 153    size_t in_buff_offs = 0;
 154    size_t out_buff_offs = 0;
 155
 156    for (int tnum = 0; tnum < nthread; tnum++) {
 157        size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
 158        size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
 159        size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
 160
 161        auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
 162            if (typ == GGML_TYPE_F16) {
 163                ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
 164            } else if (typ == GGML_TYPE_BF16) {
 165                ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
 166            } else {
 167                qtype->to_float(inbuf, outbuf, nels);
 168            }
 169        };
 170        workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
 171        in_buff_offs += thr_block_bytes;
 172        out_buff_offs += thr_elems;
 173    }
 174    for (auto & w : workers) { w.join(); }
 175    workers.clear();
 176}
 177
 178static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
 179    const std::string name = ggml_get_name(tensor);
 180
 181    // TODO: avoid hardcoded tensor names - use the TN_* constants
 182    const llm_arch arch = qs.model.arch;
 183    const auto       tn = LLM_TN(arch);
 184
 185    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
 186        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
 187    };
 188    const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
 189    auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
 190        if (n_expert > 1) {
 191            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
 192            // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
 193            // for getting the current layer as I initially thought, and we need to resort to parsing the
 194            // tensor name.
 195            if (sscanf(name, "blk.%d.", &i_layer) != 1) {
 196                throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
 197            }
 198            if (i_layer < 0 || i_layer >= n_layer) {
 199                throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
 200            }
 201        }
 202        return std::make_pair(i_layer, n_layer);
 203    };
 204
 205    // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
 206    // with the quantization of the output tensor
 207    if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
 208        if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
 209            new_type = qs.params->output_tensor_type;
 210        } else {
 211            const int64_t nx = tensor->ne[0];
 212            const int64_t qk_k = ggml_blck_size(new_type);
 213
 214            if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
 215                new_type = GGML_TYPE_Q8_0;
 216            }
 217            else if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
 218                new_type = GGML_TYPE_Q8_0;
 219            }
 220            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
 221                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S  || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M   ||
 222                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
 223                new_type = GGML_TYPE_Q5_K;
 224            }
 225            else if (new_type != GGML_TYPE_Q8_0) {
 226                new_type = GGML_TYPE_Q6_K;
 227            }
 228        }
 229    } else if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
 230        // MoE   tensors -> MXFP4
 231        // other tensors -> Q8_0
 232        if (tensor->ne[2] > 1) {
 233            new_type = GGML_TYPE_MXFP4;
 234        } else {
 235            new_type = GGML_TYPE_Q8_0;
 236        }
 237    } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
 238        if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
 239            new_type = qs.params->token_embedding_type;
 240        } else {
 241            if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
 242                ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
 243                new_type = GGML_TYPE_Q2_K;
 244            }
 245            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
 246                new_type = GGML_TYPE_IQ3_S;
 247            }
 248            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
 249                new_type = GGML_TYPE_IQ3_S;
 250            }
 251            else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
 252                new_type = GGML_TYPE_Q4_K;
 253            }
 254        }
 255    } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
 256               ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
 257        if (name.find("attn_v.weight") != std::string::npos) {
 258            if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
 259            else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
 260            ++qs.i_attention_wv;
 261        }
 262        else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
 263            new_type = GGML_TYPE_Q4_K;
 264        }
 265        else if (name.find("ffn_down") != std::string::npos) {
 266            if (qs.i_ffn_down < qs.n_ffn_down/8) {
 267                new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
 268            }
 269            ++qs.i_ffn_down;
 270        }
 271        else if (name.find("attn_output.weight") != std::string::npos) {
 272            if (qs.model.hparams.n_expert == 8) {
 273                new_type = GGML_TYPE_Q5_K;
 274            } else {
 275                if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
 276                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
 277            }
 278        }
 279    } else if (name.find("attn_v.weight") != std::string::npos) {
 280        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
 281            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
 282        }
 283        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
 284            new_type = GGML_TYPE_Q4_K;
 285        }
 286        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
 287            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
 288        }
 289        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
 290            new_type = GGML_TYPE_Q4_K;
 291        }
 292        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
 293            new_type = GGML_TYPE_Q4_K;
 294        }
 295        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
 296            new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
 297        }
 298        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
 299        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
 300            new_type = GGML_TYPE_Q5_K;
 301        }
 302        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
 303                use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
 304        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
 305        if (qs.model.type == LLM_TYPE_70B) {
 306            // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
 307            // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
 308            // nearly negligible increase in model size by quantizing this tensor with more bits:
 309            if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
 310        }
 311        if (qs.model.hparams.n_expert == 8) {
 312            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
 313            // TODO: explore better strategies
 314            new_type = GGML_TYPE_Q8_0;
 315        }
 316        ++qs.i_attention_wv;
 317    } else if (name.find("attn_k.weight") != std::string::npos) {
 318        if (qs.model.hparams.n_expert == 8) {
 319            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
 320            // TODO: explore better strategies
 321            new_type = GGML_TYPE_Q8_0;
 322        }
 323        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
 324            new_type = GGML_TYPE_IQ3_XXS;
 325        }
 326        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
 327            new_type = GGML_TYPE_IQ2_S;
 328        }
 329    } else if (name.find("attn_q.weight") != std::string::npos) {
 330        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
 331            new_type = GGML_TYPE_IQ3_XXS;
 332        }
 333        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
 334            new_type = GGML_TYPE_IQ2_S;
 335        }
 336    } else if (name.find("ffn_down") != std::string::npos) {
 337        auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
 338        int i_layer = info.first, n_layer = info.second;
 339        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
 340        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
 341            if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
 342        }
 343        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
 344            new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
 345        }
 346        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
 347            new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
 348                     : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
 349                     : GGML_TYPE_Q3_K;
 350        }
 351        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
 352                    (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
 353            new_type = GGML_TYPE_Q4_K;
 354        }
 355        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
 356            new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
 357        }
 358        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
 359            if (arch == LLM_ARCH_FALCON) {
 360                new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
 361                           use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
 362            } else {
 363                if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
 364            }
 365        }
 366        else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
 367            new_type = GGML_TYPE_Q5_K;
 368        }
 369        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
 370        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
 371            new_type = GGML_TYPE_Q5_K;
 372        }
 373        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
 374                && qs.has_imatrix && i_layer < n_layer/8) {
 375            // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
 376            // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
 377            // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
 378            new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
 379        }
 380        ++qs.i_ffn_down;
 381    } else if (name.find("attn_output.weight") != std::string::npos) {
 382        if (arch != LLM_ARCH_FALCON) {
 383            if (qs.model.hparams.n_expert == 8) {
 384                if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
 385                    ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL  ||
 386                    ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S  ||
 387                    ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
 388                    new_type = GGML_TYPE_Q5_K;
 389                }
 390            } else {
 391                if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   ) new_type = GGML_TYPE_Q3_K;
 392                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
 393                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
 394                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
 395                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  ) new_type = GGML_TYPE_Q4_K;
 396            }
 397        } else {
 398            if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
 399        }
 400    }
 401    else if (name.find("attn_qkv.weight") != std::string::npos) {
 402        if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
 403            new_type = GGML_TYPE_Q4_K;
 404        }
 405        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
 406        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
 407    }
 408    else if (name.find("ffn_gate") != std::string::npos) {
 409        auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
 410        int i_layer = info.first, n_layer = info.second;
 411        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
 412            new_type = GGML_TYPE_IQ3_XXS;
 413        }
 414        ++qs.i_ffn_gate;
 415    }
 416    else if (name.find("ffn_up") != std::string::npos) {
 417        auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
 418        int i_layer = info.first, n_layer = info.second;
 419        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
 420            new_type = GGML_TYPE_IQ3_XXS;
 421        }
 422        ++qs.i_ffn_up;
 423    }
 424
 425    return new_type;
 426}
 427
 428static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
 429    if (nthread < 2) {
 430        // single-thread
 431        size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
 432        if (!ggml_validate_row_data(new_type, new_data, new_size)) {
 433            throw std::runtime_error("quantized data validation failed");
 434        }
 435        return new_size;
 436    }
 437
 438    std::mutex mutex;
 439    int64_t counter = 0;
 440    size_t new_size = 0;
 441    bool valid = true;
 442    auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
 443            nrows, n_per_row, imatrix]() {
 444        const int64_t nrows_per_chunk = chunk_size / n_per_row;
 445        size_t local_size = 0;
 446        while (true) {
 447            std::unique_lock<std::mutex> lock(mutex);
 448            int64_t first_row = counter; counter += nrows_per_chunk;
 449            if (first_row >= nrows) {
 450                if (local_size > 0) {
 451                    new_size += local_size;
 452                }
 453                break;
 454            }
 455            lock.unlock();
 456            const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
 457            size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
 458            local_size += this_size;
 459
 460            // validate the quantized data
 461            const size_t row_size  = ggml_row_size(new_type, n_per_row);
 462            void * this_data = (char *) new_data + first_row * row_size;
 463            if (!ggml_validate_row_data(new_type, this_data, this_size)) {
 464                std::unique_lock<std::mutex> lock(mutex);
 465                valid = false;
 466                break;
 467            }
 468        }
 469    };
 470    for (int it = 0; it < nthread - 1; ++it) {
 471        workers.emplace_back(compute);
 472    }
 473    compute();
 474    for (auto & w : workers) { w.join(); }
 475    workers.clear();
 476    if (!valid) {
 477        throw std::runtime_error("quantized data validation failed");
 478    }
 479    return new_size;
 480}
 481
 482static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
 483    ggml_type default_type;
 484    llama_ftype ftype = params->ftype;
 485
 486    switch (params->ftype) {
 487        case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
 488        case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
 489        case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
 490        case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
 491        case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
 492        case LLAMA_FTYPE_MOSTLY_F16:  default_type = GGML_TYPE_F16;  break;
 493        case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
 494        case LLAMA_FTYPE_ALL_F32:     default_type = GGML_TYPE_F32;  break;
 495
 496        case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break;
 497
 498        // K-quants
 499        case LLAMA_FTYPE_MOSTLY_Q2_K_S:
 500        case LLAMA_FTYPE_MOSTLY_Q2_K:    default_type = GGML_TYPE_Q2_K;    break;
 501        case LLAMA_FTYPE_MOSTLY_IQ3_XS:  default_type = GGML_TYPE_IQ3_S;   break;
 502        case LLAMA_FTYPE_MOSTLY_Q3_K_S:
 503        case LLAMA_FTYPE_MOSTLY_Q3_K_M:
 504        case LLAMA_FTYPE_MOSTLY_Q3_K_L:  default_type = GGML_TYPE_Q3_K;    break;
 505        case LLAMA_FTYPE_MOSTLY_Q4_K_S:
 506        case LLAMA_FTYPE_MOSTLY_Q4_K_M:  default_type = GGML_TYPE_Q4_K;    break;
 507        case LLAMA_FTYPE_MOSTLY_Q5_K_S:
 508        case LLAMA_FTYPE_MOSTLY_Q5_K_M:  default_type = GGML_TYPE_Q5_K;    break;
 509        case LLAMA_FTYPE_MOSTLY_Q6_K:    default_type = GGML_TYPE_Q6_K;    break;
 510        case LLAMA_FTYPE_MOSTLY_TQ1_0:   default_type = GGML_TYPE_TQ1_0;   break;
 511        case LLAMA_FTYPE_MOSTLY_TQ2_0:   default_type = GGML_TYPE_TQ2_0;   break;
 512        case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
 513        case LLAMA_FTYPE_MOSTLY_IQ2_XS:  default_type = GGML_TYPE_IQ2_XS;  break;
 514        case LLAMA_FTYPE_MOSTLY_IQ2_S:   default_type = GGML_TYPE_IQ2_XS;  break;
 515        case LLAMA_FTYPE_MOSTLY_IQ2_M:   default_type = GGML_TYPE_IQ2_S;   break;
 516        case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
 517        case LLAMA_FTYPE_MOSTLY_IQ1_S:   default_type = GGML_TYPE_IQ1_S;   break;
 518        case LLAMA_FTYPE_MOSTLY_IQ1_M:   default_type = GGML_TYPE_IQ1_M;   break;
 519        case LLAMA_FTYPE_MOSTLY_IQ4_NL:  default_type = GGML_TYPE_IQ4_NL;  break;
 520        case LLAMA_FTYPE_MOSTLY_IQ4_XS:  default_type = GGML_TYPE_IQ4_XS;  break;
 521        case LLAMA_FTYPE_MOSTLY_IQ3_S:   default_type = GGML_TYPE_IQ3_S;   break;
 522        case LLAMA_FTYPE_MOSTLY_IQ3_M:   default_type = GGML_TYPE_IQ3_S;   break;
 523
 524        default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
 525    }
 526
 527    int nthread = params->nthread;
 528
 529    if (nthread <= 0) {
 530        nthread = std::thread::hardware_concurrency();
 531    }
 532
 533    // mmap consistently increases speed on Linux, and also increases speed on Windows with
 534    // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
 535#if defined(__linux__) || defined(_WIN32)
 536    constexpr bool use_mmap = true;
 537#else
 538    constexpr bool use_mmap = false;
 539#endif
 540
 541    llama_model_kv_override * kv_overrides = nullptr;
 542    if (params->kv_overrides) {
 543        auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
 544        kv_overrides = v->data();
 545    }
 546
 547    std::vector<std::string> splits = {};
 548    llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
 549    ml.init_mappings(false); // no prefetching
 550
 551    llama_model model(llama_model_default_params());
 552
 553    model.load_arch   (ml);
 554    model.load_hparams(ml);
 555    model.load_stats  (ml);
 556
 557    quantize_state_impl qs(model, params);
 558
 559    if (params->only_copy) {
 560        ftype = ml.ftype;
 561    }
 562    const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
 563    if (params->imatrix) {
 564        imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
 565        if (imatrix_data) {
 566            LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
 567            qs.has_imatrix = true;
 568            // check imatrix for nans or infs
 569            for (const auto & kv : *imatrix_data) {
 570                for (float f : kv.second) {
 571                    if (!std::isfinite(f)) {
 572                        throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
 573                    }
 574                }
 575            }
 576        }
 577    }
 578
 579    const size_t align = GGUF_DEFAULT_ALIGNMENT;
 580    gguf_context_ptr ctx_out { gguf_init_empty() };
 581
 582    std::vector<int> prune_list = {};
 583    if (params->prune_layers) {
 584        prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
 585    }
 586
 587    // copy the KV pairs from the input file
 588    gguf_set_kv     (ctx_out.get(), ml.meta.get());
 589    gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
 590    gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
 591
 592    // Remove split metadata
 593    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
 594    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
 595    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
 596
 597    if (params->kv_overrides) {
 598        const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
 599        for (const auto & o : overrides) {
 600            if (o.key[0] == 0) break;
 601            if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
 602                gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
 603            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
 604                // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
 605                gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64));
 606            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
 607                gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
 608            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
 609                gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
 610            } else {
 611                LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
 612            }
 613        }
 614    }
 615
 616    std::map<int, std::string> mapped;
 617    int blk_id = 0;
 618
 619    // make a list of weights
 620    std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
 621    tensors.reserve(ml.weights_map.size());
 622    for (const auto & it : ml.weights_map) {
 623        const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
 624        if (remapped_name.empty()) {
 625            LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
 626            continue;
 627        }
 628
 629        if (remapped_name != it.first) {
 630            ggml_set_name(it.second.tensor, remapped_name.c_str());
 631            LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
 632        }
 633        tensors.push_back(&it.second);
 634    }
 635    if (!prune_list.empty()) {
 636        gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
 637    }
 638
 639    // keep_split requires that the weights are sorted by split index
 640    if (params->keep_split) {
 641        std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
 642            if (a->idx == b->idx) {
 643                return a->offs < b->offs;
 644            }
 645            return a->idx < b->idx;
 646        });
 647    }
 648
 649    for (const auto * it : tensors) {
 650        const struct ggml_tensor * tensor = it->tensor;
 651
 652        const std::string name = ggml_get_name(tensor);
 653
 654        // TODO: avoid hardcoded tensor names - use the TN_* constants
 655        if (name.find("attn_v.weight")   != std::string::npos ||
 656            name.find("attn_qkv.weight") != std::string::npos ||
 657            name.find("attn_kv_b.weight")!= std::string::npos) {
 658            ++qs.n_attention_wv;
 659        } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
 660            qs.has_output = true;
 661        }
 662    }
 663
 664    qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
 665
 666    size_t total_size_org = 0;
 667    size_t total_size_new = 0;
 668
 669    std::vector<std::thread> workers;
 670    workers.reserve(nthread);
 671
 672    int idx = 0;
 673
 674    std::vector<no_init<uint8_t>> read_data;
 675    std::vector<no_init<uint8_t>> work;
 676    std::vector<no_init<float>> f32_conv_buf;
 677
 678    uint16_t n_split = 1;
 679
 680    // Assume split index is continuous
 681    if (params->keep_split) {
 682        for (const auto * it : tensors) {
 683            n_split = std::max(uint16_t(it->idx + 1), n_split);
 684        }
 685    }
 686    std::vector<gguf_context_ptr> ctx_outs(n_split);
 687    ctx_outs[0] = std::move(ctx_out);
 688
 689    // populate the original tensors so we get an initial meta data
 690    for (const auto * it : tensors) {
 691        uint16_t i_split = params->keep_split ? it->idx : 0;
 692        ggml_tensor * tensor = it->tensor;
 693        if (!ctx_outs[i_split]) {
 694            ctx_outs[i_split].reset(gguf_init_empty());
 695        }
 696        gguf_add_tensor(ctx_outs[i_split].get(), tensor);
 697    }
 698
 699    // Set split info if needed
 700    if (n_split > 1) {
 701        for (size_t i = 0; i < ctx_outs.size(); ++i) {
 702            gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
 703            gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
 704            gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
 705        }
 706    }
 707
 708    int cur_split = -1;
 709    std::ofstream fout;
 710    auto close_ofstream = [&]() {
 711        // Write metadata and close file handler
 712        if (fout.is_open()) {
 713            fout.seekp(0);
 714            std::vector<uint8_t> data(gguf_get_meta_size(ctx_outs[cur_split].get()));
 715            gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
 716            fout.write((const char *) data.data(), data.size());
 717            fout.close();
 718        }
 719    };
 720    auto new_ofstream = [&](int index) {
 721        cur_split = index;
 722        GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
 723        std::string fname = fname_out;
 724        if (params->keep_split) {
 725            std::vector<char> split_path(llama_path_max(), 0);
 726            llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split);
 727            fname = std::string(split_path.data());
 728        }
 729
 730        fout = std::ofstream(fname, std::ios::binary);
 731        fout.exceptions(std::ofstream::failbit); // fail fast on write errors
 732        const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
 733        // placeholder for the meta data
 734        ::zeros(fout, meta_size);
 735    };
 736
 737    const auto tn = LLM_TN(model.arch);
 738    new_ofstream(0);
 739    for (const auto * it : tensors) {
 740        const auto & weight = *it;
 741        ggml_tensor * tensor = weight.tensor;
 742        if (weight.idx != cur_split && params->keep_split) {
 743            close_ofstream();
 744            new_ofstream(weight.idx);
 745        }
 746
 747        const std::string name = ggml_get_name(tensor);
 748
 749        if (!ml.use_mmap) {
 750            if (read_data.size() < ggml_nbytes(tensor)) {
 751                read_data.resize(ggml_nbytes(tensor));
 752            }
 753            tensor->data = read_data.data();
 754        }
 755        ml.load_data_for(tensor);
 756
 757        LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
 758               ++idx, ml.n_tensors,
 759               ggml_get_name(tensor),
 760               llama_format_tensor_shape(tensor).c_str(),
 761               ggml_type_name(tensor->type));
 762
 763        // This used to be a regex, but <regex> has an extreme cost to compile times.
 764        bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
 765
 766        // quantize only 2D and 3D tensors (experts)
 767        quantize &= (ggml_n_dims(tensor) >= 2);
 768
 769        // do not quantize norm tensors
 770        quantize &= name.find("_norm.weight") == std::string::npos;
 771
 772        quantize &= params->quantize_output_tensor || name != "output.weight";
 773        quantize &= !params->only_copy;
 774
 775        // do not quantize expert gating tensors
 776        // NOTE: can't use LLM_TN here because the layer number is not known
 777        quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
 778
 779        // these are very small (e.g. 4x4)
 780        quantize &= name.find("altup")  == std::string::npos;
 781        quantize &= name.find("laurel") == std::string::npos;
 782
 783        // these are not too big so keep them as it is
 784        quantize &= name.find("per_layer_model_proj") == std::string::npos;
 785
 786        // do not quantize positional embeddings and token types (BERT)
 787        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD,    "weight");
 788        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
 789
 790        // do not quantize Mamba /Kimi's small conv1d weights
 791        // NOTE: can't use LLM_TN here because the layer number is not known
 792        quantize &= name.find("ssm_conv1d") == std::string::npos;
 793        quantize &= name.find("shortconv.conv.weight") == std::string::npos;
 794
 795        // do not quantize RWKV's small yet 2D weights
 796        quantize &= name.find("time_mix_first.weight") == std::string::npos;
 797        quantize &= name.find("time_mix_w0.weight") == std::string::npos;
 798        quantize &= name.find("time_mix_w1.weight") == std::string::npos;
 799        quantize &= name.find("time_mix_w2.weight") == std::string::npos;
 800        quantize &= name.find("time_mix_v0.weight") == std::string::npos;
 801        quantize &= name.find("time_mix_v1.weight") == std::string::npos;
 802        quantize &= name.find("time_mix_v2.weight") == std::string::npos;
 803        quantize &= name.find("time_mix_a0.weight") == std::string::npos;
 804        quantize &= name.find("time_mix_a1.weight") == std::string::npos;
 805        quantize &= name.find("time_mix_a2.weight") == std::string::npos;
 806        quantize &= name.find("time_mix_g1.weight") == std::string::npos;
 807        quantize &= name.find("time_mix_g2.weight") == std::string::npos;
 808        quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
 809        quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
 810        quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
 811
 812        // do not quantize relative position bias (T5)
 813        quantize &= name.find("attn_rel_b.weight") == std::string::npos;
 814
 815        // do not quantize specific multimodal tensors
 816        quantize &= name.find(".position_embd.") == std::string::npos;
 817
 818        ggml_type new_type;
 819        void * new_data;
 820        size_t new_size;
 821
 822        if (quantize) {
 823            new_type = default_type;
 824
 825            // get more optimal quantization type based on the tensor shape, layer, etc.
 826            if (!params->pure && ggml_is_quantized(default_type)) {
 827                // if the user provided tensor types - use those
 828                bool manual = false;
 829                if (params->tensor_types) {
 830                    const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
 831                    const std::string tensor_name(tensor->name);
 832                    for (const auto & [tname, qtype] : tensor_types) {
 833                        if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
 834                            if  (qtype != new_type) {
 835                                LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype));
 836                                new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
 837                                manual = true;
 838                                break;
 839                            }
 840                        }
 841                    }
 842                }
 843
 844                // if not manual - use the standard logic for choosing the quantization type based on the selected mixture
 845                if (!manual) {
 846                    new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
 847                }
 848
 849                // incompatible tensor shapes are handled here - fallback to a compatible type
 850                {
 851                    bool convert_incompatible_tensor = false;
 852
 853                    const int64_t nx = tensor->ne[0];
 854                    const int64_t ny = tensor->ne[1];
 855                    const int64_t qk_k = ggml_blck_size(new_type);
 856
 857                    if (nx % qk_k != 0) {
 858                        LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
 859                        convert_incompatible_tensor = true;
 860                    } else {
 861                        ++qs.n_k_quantized;
 862                    }
 863
 864                    if (convert_incompatible_tensor) {
 865                        switch (new_type) {
 866                            case GGML_TYPE_TQ1_0:
 867                            case GGML_TYPE_TQ2_0:  new_type = GGML_TYPE_Q4_0; break;  // TODO: use a symmetric type instead
 868                            case GGML_TYPE_IQ2_XXS:
 869                            case GGML_TYPE_IQ2_XS:
 870                            case GGML_TYPE_IQ2_S:
 871                            case GGML_TYPE_IQ3_XXS:
 872                            case GGML_TYPE_IQ3_S:
 873                            case GGML_TYPE_IQ1_S:
 874                            case GGML_TYPE_IQ1_M:
 875                            case GGML_TYPE_Q2_K:
 876                            case GGML_TYPE_Q3_K:
 877                            case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
 878                            case GGML_TYPE_Q4_K:   new_type = GGML_TYPE_Q5_0;   break;
 879                            case GGML_TYPE_Q5_K:   new_type = GGML_TYPE_Q5_1;   break;
 880                            case GGML_TYPE_Q6_K:   new_type = GGML_TYPE_Q8_0;   break;
 881                            default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
 882                        }
 883                        if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
 884                            new_type = GGML_TYPE_F16;
 885                        }
 886                        LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
 887                        ++qs.n_fallback;
 888                    }
 889                }
 890            }
 891            if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
 892                new_type = params->token_embedding_type;
 893            }
 894            if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
 895                new_type = params->output_tensor_type;
 896            }
 897
 898            // If we've decided to quantize to the same type the tensor is already
 899            // in then there's nothing to do.
 900            quantize = tensor->type != new_type;
 901        }
 902
 903        if (!quantize) {
 904            new_type = tensor->type;
 905            new_data = tensor->data;
 906            new_size = ggml_nbytes(tensor);
 907            LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0);
 908        } else {
 909            const int64_t nelements = ggml_nelements(tensor);
 910
 911            const float * imatrix = nullptr;
 912            if (imatrix_data) {
 913                auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
 914                if (it == imatrix_data->end()) {
 915                    LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
 916                } else {
 917                    if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
 918                        imatrix = it->second.data();
 919                    } else {
 920                        LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
 921                                int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
 922
 923                        // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
 924                        // this is a significant error and it may be good idea to abort the process if this happens,
 925                        // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
 926                        // tok_embd should be ignored in this case, since it always causes this warning
 927                        if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
 928                            throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
 929                                    int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
 930                        }
 931                    }
 932                }
 933            }
 934            if ((new_type == GGML_TYPE_IQ2_XXS ||
 935                 new_type == GGML_TYPE_IQ2_XS  ||
 936                 new_type == GGML_TYPE_IQ2_S   ||
 937                 new_type == GGML_TYPE_IQ1_S   ||
 938                (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight"))  ||
 939                (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
 940                LLAMA_LOG_ERROR("\n\n============================================================\n");
 941                LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
 942                LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
 943                LLAMA_LOG_ERROR("============================================================\n\n");
 944                throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
 945            }
 946
 947            float * f32_data;
 948
 949            if (tensor->type == GGML_TYPE_F32) {
 950                f32_data = (float *) tensor->data;
 951            } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
 952                throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
 953            } else {
 954                llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
 955                f32_data = (float *) f32_conv_buf.data();
 956            }
 957
 958            LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
 959            fflush(stdout);
 960
 961            if (work.size() < (size_t)nelements * 4) {
 962                work.resize(nelements * 4); // upper bound on size
 963            }
 964            new_data = work.data();
 965
 966            const int64_t n_per_row = tensor->ne[0];
 967            const int64_t nrows = tensor->ne[1];
 968
 969            static const int64_t min_chunk_size = 32 * 512;
 970            const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
 971
 972            const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
 973            const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
 974            const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
 975
 976            // quantize each expert separately since they have different importance matrices
 977            new_size = 0;
 978            for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
 979                const float * f32_data_03 = f32_data + i03 * nelements_matrix;
 980                void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
 981                const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
 982
 983                new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
 984
 985                // TODO: temporary sanity check that the F16 -> MXFP4 is lossless
 986#if 0
 987                if (new_type == GGML_TYPE_MXFP4) {
 988                    auto * x = f32_data_03;
 989
 990                    //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row);
 991                    std::vector<float> deq(nrows*n_per_row);
 992                    const ggml_type_traits * qtype = ggml_get_type_traits(new_type);
 993                    qtype->to_float(new_data_03, deq.data(), deq.size());
 994
 995                    double err = 0.0f;
 996                    for (int i = 0; i < (int) deq.size(); ++i) {
 997                        err += fabsf(deq[i] - x[i]);
 998                        //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) {
 999                        if (deq[i] != x[i]) {
1000                            LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]);
1001                        }
1002                    }
1003                    //LLAMA_LOG_INFO("err = %f\n", err);
1004                    GGML_ASSERT(err == 0.00000);
1005                }
1006#endif
1007            }
1008            LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
1009        }
1010        total_size_org += ggml_nbytes(tensor);
1011        total_size_new += new_size;
1012
1013        // update the gguf meta data as we go
1014        gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
1015        GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
1016        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
1017
1018        // write tensor data + padding
1019        fout.write((const char *) new_data, new_size);
1020        zeros(fout, GGML_PAD(new_size, align) - new_size);
1021    }
1022    close_ofstream();
1023
1024    LLAMA_LOG_INFO("%s: model size  = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0);
1025    LLAMA_LOG_INFO("%s: quant size  = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0);
1026
1027    if (qs.n_fallback > 0) {
1028        LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
1029                __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
1030    }
1031}
1032
1033//
1034// interface implementation
1035//
1036
1037llama_model_quantize_params llama_model_quantize_default_params() {
1038    llama_model_quantize_params result = {
1039        /*.nthread                     =*/ 0,
1040        /*.ftype                       =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
1041        /*.output_tensor_type          =*/ GGML_TYPE_COUNT,
1042        /*.token_embedding_type        =*/ GGML_TYPE_COUNT,
1043        /*.allow_requantize            =*/ false,
1044        /*.quantize_output_tensor      =*/ true,
1045        /*.only_copy                   =*/ false,
1046        /*.pure                        =*/ false,
1047        /*.keep_split                  =*/ false,
1048        /*.imatrix                     =*/ nullptr,
1049        /*.kv_overrides                =*/ nullptr,
1050        /*.tensor_type                 =*/ nullptr,
1051        /*.prune_layers                =*/ nullptr
1052    };
1053
1054    return result;
1055}
1056
1057uint32_t llama_model_quantize(
1058        const char * fname_inp,
1059        const char * fname_out,
1060        const llama_model_quantize_params * params) {
1061    try {
1062        llama_model_quantize_impl(fname_inp, fname_out, params);
1063    } catch (const std::exception & err) {
1064        LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
1065        return 1;
1066    }
1067
1068    return 0;
1069}