1#include "llama-model-saver.h"
  2
  3#include "gguf.h"
  4
  5#include "llama.h"
  6#include "llama-hparams.h"
  7#include "llama-model.h"
  8#include "llama-vocab.h"
  9
 10#include <string>
 11
 12llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) {
 13    gguf_ctx = gguf_init_empty();
 14}
 15
 16llama_model_saver::~llama_model_saver() {
 17    gguf_free(gguf_ctx);
 18}
 19
 20void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) {
 21    gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value);
 22}
 23
 24void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) {
 25    gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value);
 26}
 27
 28void llama_model_saver::add_kv(const enum llm_kv key, const float value) {
 29    gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value);
 30}
 31
 32void llama_model_saver::add_kv(const enum llm_kv key, const bool value) {
 33    gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value);
 34}
 35
 36void llama_model_saver::add_kv(const enum llm_kv key, const char * value) {
 37    gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value);
 38}
 39
 40[[noreturn]]
 41void llama_model_saver::add_kv(const enum llm_kv key, const char value) {
 42    GGML_UNUSED(key);
 43    GGML_UNUSED(value);
 44    GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile
 45}
 46
 47template <typename Container>
 48void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) {
 49    const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size();
 50    GGML_ASSERT(n_values <= value.size());
 51
 52    if (n_values == 0) {
 53        return;
 54    }
 55
 56    if (per_layer) {
 57        bool all_values_the_same = true;
 58        for (size_t i = 1; i < n_values; ++i) {
 59            if (value[i] != value[0]) {
 60                all_values_the_same = false;
 61                break;
 62            }
 63        }
 64        if (all_values_the_same) {
 65            add_kv(key, value[0]);
 66            return;
 67        }
 68    }
 69
 70    if (std::is_same<typename Container::value_type, uint8_t>::value) {
 71        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values);
 72    } else if (std::is_same<typename Container::value_type, int8_t>::value) {
 73        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values);
 74    } else if (std::is_same<typename Container::value_type, uint32_t>::value) {
 75        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values);
 76    } else if (std::is_same<typename Container::value_type, int32_t>::value) {
 77        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values);
 78    } else if (std::is_same<typename Container::value_type, float>::value) {
 79        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values);
 80    } else if (std::is_same<Container, std::string>::value) {
 81        gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast<const char *>(value.data()));
 82    } else {
 83        GGML_ABORT("fatal error");
 84    }
 85}
 86
 87void llama_model_saver::add_kv(const enum llm_kv key, const std::vector<std::string> & value) {
 88    std::vector<const char *> tmp(value.size());
 89    for (size_t i = 0; i < value.size(); ++i) {
 90        tmp[i] = value[i].c_str();
 91    }
 92    gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size());
 93}
 94
 95void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) {
 96    if (!tensor) {
 97        return;
 98    }
 99    if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) {
100        GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME
101        return;
102    }
103    gguf_add_tensor(gguf_ctx, tensor);
104}
105
106void llama_model_saver::add_kv_from_model() {
107    const llama_hparams & hparams = model.hparams;
108    const llama_vocab   & vocab   = model.vocab;
109
110    const int32_t n_vocab = vocab.n_tokens();
111    std::vector<std::string> tokens(n_vocab);
112    std::vector<float>       scores(n_vocab);
113    std::vector<int32_t>     token_types(n_vocab);
114
115    for (int32_t id = 0; id < n_vocab; ++id) {
116        const llama_vocab::token_data & token_data = vocab.get_token_data(id);
117
118        tokens[id] = token_data.text;
119        scores[id] = token_data.score;
120
121        switch(token_data.attr) {
122            case LLAMA_TOKEN_ATTR_UNKNOWN:      token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN;      break;
123            case LLAMA_TOKEN_ATTR_UNUSED:       token_types[id] = LLAMA_TOKEN_TYPE_UNUSED;       break;
124            case LLAMA_TOKEN_ATTR_NORMAL:       token_types[id] = LLAMA_TOKEN_TYPE_NORMAL;       break;
125            case LLAMA_TOKEN_ATTR_CONTROL:      token_types[id] = LLAMA_TOKEN_TYPE_CONTROL;      break;
126            case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break;
127            case LLAMA_TOKEN_ATTR_BYTE:         token_types[id] = LLAMA_TOKEN_TYPE_BYTE;         break;
128            case LLAMA_TOKEN_ATTR_UNDEFINED:
129            default:                            token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED;    break;
130        }
131    }
132
133    // add_kv(LLM_KV_GENERAL_TYPE,                      ???);
134    add_kv(LLM_KV_GENERAL_ARCHITECTURE,              model.arch_name());
135    // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION,      ???);
136    // add_kv(LLM_KV_GENERAL_ALIGNMENT,                 ???);
137    add_kv(LLM_KV_GENERAL_NAME,                      model.name);
138    // add_kv(LLM_KV_GENERAL_AUTHOR,                    ???);
139    // add_kv(LLM_KV_GENERAL_VERSION,                   ???);
140    // add_kv(LLM_KV_GENERAL_URL,                       ???);
141    // add_kv(LLM_KV_GENERAL_DESCRIPTION,               ???);
142    // add_kv(LLM_KV_GENERAL_LICENSE,                   ???);
143    // add_kv(LLM_KV_GENERAL_SOURCE_URL,                ???);
144    // add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO,            ???);
145
146    add_kv(LLM_KV_VOCAB_SIZE,                        vocab.n_tokens());
147    add_kv(LLM_KV_CONTEXT_LENGTH,                    hparams.n_ctx_train);
148    add_kv(LLM_KV_EMBEDDING_LENGTH,                  hparams.n_embd);
149    if (hparams.n_embd_out_impl > 0) {
150        add_kv(LLM_KV_EMBEDDING_LENGTH_OUT,          hparams.n_embd_out_impl);
151    }
152    add_kv(LLM_KV_BLOCK_COUNT,                       hparams.n_layer);
153    add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT,         hparams.n_layer_dense_lead);
154    add_kv(LLM_KV_FEED_FORWARD_LENGTH,               hparams.n_ff_arr, true);
155    add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
156    add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
157    add_kv(LLM_KV_USE_PARALLEL_RESIDUAL,             hparams.use_par_res);
158    // add_kv(LLM_KV_TENSOR_DATA_LAYOUT,                ???);
159    add_kv(LLM_KV_EXPERT_COUNT,                      hparams.n_expert);
160    add_kv(LLM_KV_EXPERT_USED_COUNT,                 hparams.n_expert_used);
161    add_kv(LLM_KV_EXPERT_SHARED_COUNT,               hparams.n_expert_shared);
162    add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE,              hparams.expert_weights_scale);
163    add_kv(LLM_KV_POOLING_TYPE,                      uint32_t(hparams.pooling_type));
164    add_kv(LLM_KV_LOGIT_SCALE,                       hparams.f_logit_scale);
165    add_kv(LLM_KV_DECODER_START_TOKEN_ID,            hparams.dec_start_token_id);
166    add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING,            hparams.f_attn_logit_softcapping);
167    add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING,           hparams.f_final_logit_softcapping);
168    add_kv(LLM_KV_SWIN_NORM,                         hparams.swin_norm);
169    add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS,            hparams.rescale_every_n_layers);
170    add_kv(LLM_KV_TIME_MIX_EXTRA_DIM,                hparams.time_mix_extra_dim);
171    add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM,              hparams.time_decay_extra_dim);
172    add_kv(LLM_KV_RESIDUAL_SCALE,                    hparams.f_residual_scale);
173    add_kv(LLM_KV_EMBEDDING_SCALE,                   hparams.f_embedding_scale);
174
175    add_kv(LLM_KV_ATTENTION_HEAD_COUNT,              hparams.n_head_arr, true);
176    add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV,           hparams.n_head_kv_arr, true);
177    add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS,          hparams.f_max_alibi_bias);
178    add_kv(LLM_KV_ATTENTION_CLAMP_KQV,               hparams.f_clamp_kqv);
179    add_kv(LLM_KV_ATTENTION_KEY_LENGTH,              hparams.n_embd_head_k);
180    add_kv(LLM_KV_ATTENTION_VALUE_LENGTH,            hparams.n_embd_head_v);
181    add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS,           hparams.f_norm_eps);
182    add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
183    add_kv(LLM_KV_ATTENTION_CAUSAL,                  hparams.causal_attn);
184    add_kv(LLM_KV_ATTENTION_Q_LORA_RANK,             hparams.n_lora_q);
185    add_kv(LLM_KV_ATTENTION_KV_LORA_RANK,            hparams.n_lora_kv);
186    add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,  hparams.n_rel_attn_bkts);
187    add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW,          hparams.n_swa);
188    add_kv(LLM_KV_ATTENTION_SCALE,                   hparams.f_attention_scale);
189
190    const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
191
192    add_kv(LLM_KV_ROPE_DIMENSION_COUNT,              hparams.n_rot);
193    add_kv(LLM_KV_ROPE_FREQ_BASE,                    hparams.rope_freq_base_train);
194    // add_kv(LLM_KV_ROPE_SCALE_LINEAR,                 rope_scaling_factor); // old name
195    add_kv(LLM_KV_ROPE_SCALING_TYPE,                 llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
196    add_kv(LLM_KV_ROPE_SCALING_FACTOR,               rope_scaling_factor);
197    add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR,          hparams.rope_attn_factor);
198    add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,         hparams.n_ctx_orig_yarn);
199    add_kv(LLM_KV_ROPE_SCALING_FINETUNED,            hparams.rope_finetuned);
200    add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL,         hparams.rope_yarn_log_mul);
201
202    // TODO: implement split file support
203    // add_kv(LLM_KV_SPLIT_NO,                          ???);
204    // add_kv(LLM_KV_SPLIT_COUNT,                       ???);
205    // add_kv(LLM_KV_SPLIT_TENSORS_COUNT,               ???);
206
207    add_kv(LLM_KV_SSM_INNER_SIZE,                    hparams.ssm_d_inner);
208    add_kv(LLM_KV_SSM_CONV_KERNEL,                   hparams.ssm_d_conv);
209    add_kv(LLM_KV_SSM_STATE_SIZE,                    hparams.ssm_d_state);
210    add_kv(LLM_KV_SSM_TIME_STEP_RANK,                hparams.ssm_dt_rank);
211    add_kv(LLM_KV_SSM_DT_B_C_RMS,                    hparams.ssm_dt_b_c_rms);
212
213    add_kv(LLM_KV_WKV_HEAD_SIZE,                     hparams.wkv_head_size);
214
215    add_kv(LLM_KV_TOKENIZER_MODEL,                   vocab.get_tokenizer_model());
216    add_kv(LLM_KV_TOKENIZER_PRE,                     vocab.get_tokenizer_pre());
217    add_kv(LLM_KV_TOKENIZER_LIST,                    tokens);
218    add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE,              token_types);
219    add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,        vocab.n_token_types());
220    add_kv(LLM_KV_TOKENIZER_SCORES,                  scores);
221    add_kv(LLM_KV_TOKENIZER_MERGES,                  vocab.get_bpe_merges());
222    // FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though
223    add_kv(LLM_KV_TOKENIZER_BOS_ID,                  uint32_t(vocab.token_bos()));
224    add_kv(LLM_KV_TOKENIZER_EOS_ID,                  uint32_t(vocab.token_eos()));
225    add_kv(LLM_KV_TOKENIZER_EOT_ID,                  uint32_t(vocab.token_eot()));
226    add_kv(LLM_KV_TOKENIZER_EOM_ID,                  uint32_t(vocab.token_eom()));
227    add_kv(LLM_KV_TOKENIZER_UNK_ID,                  uint32_t(vocab.token_unk()));
228    add_kv(LLM_KV_TOKENIZER_SEP_ID,                  uint32_t(vocab.token_sep()));
229    add_kv(LLM_KV_TOKENIZER_PAD_ID,                  uint32_t(vocab.token_pad()));
230    // add_kv(LLM_KV_TOKENIZER_CLS_ID,                  uint32_t(vocab.token_bos())); // deprecated
231    // add_kv(LLM_KV_TOKENIZER_MASK_ID,                 ???);
232    add_kv(LLM_KV_TOKENIZER_ADD_BOS,                 vocab.get_add_bos());
233    add_kv(LLM_KV_TOKENIZER_ADD_EOS,                 vocab.get_add_eos());
234    add_kv(LLM_KV_TOKENIZER_ADD_SEP,                 vocab.get_add_sep());
235    add_kv(LLM_KV_TOKENIZER_ADD_PREFIX,              vocab.get_add_space_prefix());
236    add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,         vocab.get_remove_extra_whitespaces());
237    add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,    vocab.get_precompiled_charsmap());
238    // add_kv(LLM_KV_TOKENIZER_HF_JSON,                 ???);
239    // add_kv(LLM_KV_TOKENIZER_RWKV,                    ???);
240    add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID,              uint32_t(vocab.token_fim_pre()));
241    add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID,              uint32_t(vocab.token_fim_suf()));
242    add_kv(LLM_KV_TOKENIZER_FIM_MID_ID,              uint32_t(vocab.token_fim_mid()));
243    add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID,              uint32_t(vocab.token_fim_pad()));
244    add_kv(LLM_KV_TOKENIZER_FIM_REP_ID,              uint32_t(vocab.token_fim_rep()));
245    add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID,              uint32_t(vocab.token_fim_sep()));
246
247    // TODO: implement LoRA support
248    // add_kv(LLM_KV_ADAPTER_TYPE,                      ???);
249    // add_kv(LLM_KV_ADAPTER_LORA_ALPHA,                ???);
250
251    // deprecated
252    // add_kv(LLM_KV_TOKENIZER_PREFIX_ID,               ???);
253    // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID,               ???);
254    // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID,               ???);
255}
256
257void llama_model_saver::add_tensors_from_model() {
258    if (std::string(model.output->name) != std::string(model.tok_embd->name)) {
259        add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output
260    }
261    add_tensor(model.type_embd);
262    add_tensor(model.pos_embd);
263    add_tensor(model.tok_norm);
264    add_tensor(model.tok_norm_b);
265    add_tensor(model.output_norm);
266    add_tensor(model.output_norm_b);
267    add_tensor(model.output);
268    add_tensor(model.output_b);
269    add_tensor(model.output_norm_enc);
270    add_tensor(model.cls);
271    add_tensor(model.cls_b);
272    add_tensor(model.cls_out);
273    add_tensor(model.cls_out_b);
274
275    for (const struct llama_layer & layer : model.layers) {
276        for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
277            add_tensor(reinterpret_cast<const struct ggml_tensor * const *>(&layer)[i]);
278        }
279    }
280}
281
282void llama_model_saver::save(const std::string & path_model) {
283    gguf_write_to_file(gguf_ctx, path_model.c_str(), false);
284}
285