1#include "arg.h"
 2#include "common.h"
 3#include "log.h"
 4#include "llama.h"
 5
 6#include <cmath>
 7#include <cstdio>
 8#include <cstring>
 9#include <ctime>
10#include <vector>
11
12#if defined(_MSC_VER)
13#pragma warning(disable: 4244 4267)  // possible loss of data
14#endif
15
16int main(int argc, char ** argv) {
17    common_params params;
18    params.escape = false;
19
20    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
21        return 1;
22    }
23
24    if (params.use_mmap) {
25        LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
26                __func__);
27        params.use_mmap = false;
28    }
29    if (params.cache_type_k != GGML_TYPE_F32) {
30        LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
31        params.cache_type_k = GGML_TYPE_F32;
32    }
33    if (params.cache_type_v != GGML_TYPE_F32) {
34        LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
35        params.cache_type_v = GGML_TYPE_F32;
36    }
37
38    common_init();
39    llama_backend_init();
40    llama_numa_init(params.numa);
41    // load the model and apply lora adapter, if any
42    auto llama_init = common_init_from_params(params);
43
44    auto * model = llama_init->model();
45    auto * ctx   = llama_init->context();
46
47    if (model == NULL) {
48        LOG_ERR("%s: unable to load model\n", __func__);
49        return 1;
50    }
51
52    // print system information
53    {
54        LOG_INF("\n");
55        LOG_INF("%s\n", common_params_get_system_info(params).c_str());
56    }
57
58    std::vector<llama_token> tokens  = common_tokenize(ctx, params.prompt, true);
59    ggml_opt_dataset_t       dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2);
60
61    struct lr_opt & lr = params.lr;
62    LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
63            ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.decay_epochs,
64            (unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
65
66    struct llama_opt_params lopt_params{
67        /*n_ctx_train     =*/0,
68        /*param_filter    =*/llama_opt_param_filter_all,
69        /*param_filter_ud =*/nullptr,
70        /*get_opt_pars    =*/common_opt_lr_pars,
71        /*get_opt_pars_ud =*/&params.lr,
72        /*optimizer_type  =*/params.optimizer,
73    };
74    llama_opt_init(ctx, model, lopt_params);
75
76    const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
77
78    ggml_opt_result_t result_train = ggml_opt_result_init();
79    ggml_opt_result_t result_eval  = ggml_opt_result_init();
80
81    for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
82        llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
83                        ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
84        fprintf(stderr, "\n");
85
86        ggml_opt_result_reset(result_train);
87        ggml_opt_result_reset(result_eval);
88    }
89    ggml_opt_result_free(result_train);
90    ggml_opt_result_free(result_eval);
91
92    llama_model_save_to_file(model, params.out_file.c_str());
93
94    llama_backend_free();
95
96    return 0;
97}