1#include "llama.h"
  2#include <cstdio>
  3#include <cstring>
  4#include <iostream>
  5#include <string>
  6#include <vector>
  7
  8static void print_usage(int, char ** argv) {
  9    printf("\nexample usage:\n");
 10    printf("\n    %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
 11    printf("\n");
 12}
 13
 14int main(int argc, char ** argv) {
 15    std::string model_path;
 16    int ngl = 99;
 17    int n_ctx = 2048;
 18
 19    // parse command line arguments
 20    for (int i = 1; i < argc; i++) {
 21        try {
 22            if (strcmp(argv[i], "-m") == 0) {
 23                if (i + 1 < argc) {
 24                    model_path = argv[++i];
 25                } else {
 26                    print_usage(argc, argv);
 27                    return 1;
 28                }
 29            } else if (strcmp(argv[i], "-c") == 0) {
 30                if (i + 1 < argc) {
 31                    n_ctx = std::stoi(argv[++i]);
 32                } else {
 33                    print_usage(argc, argv);
 34                    return 1;
 35                }
 36            } else if (strcmp(argv[i], "-ngl") == 0) {
 37                if (i + 1 < argc) {
 38                    ngl = std::stoi(argv[++i]);
 39                } else {
 40                    print_usage(argc, argv);
 41                    return 1;
 42                }
 43            } else {
 44                print_usage(argc, argv);
 45                return 1;
 46            }
 47        } catch (std::exception & e) {
 48            fprintf(stderr, "error: %s\n", e.what());
 49            print_usage(argc, argv);
 50            return 1;
 51        }
 52    }
 53    if (model_path.empty()) {
 54        print_usage(argc, argv);
 55        return 1;
 56    }
 57
 58    // only print errors
 59    llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) {
 60        if (level >= GGML_LOG_LEVEL_ERROR) {
 61            fprintf(stderr, "%s", text);
 62        }
 63    }, nullptr);
 64
 65    // load dynamic backends
 66    ggml_backend_load_all();
 67
 68    // initialize the model
 69    llama_model_params model_params = llama_model_default_params();
 70    model_params.n_gpu_layers = ngl;
 71
 72    llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
 73    if (!model) {
 74        fprintf(stderr , "%s: error: unable to load model\n" , __func__);
 75        return 1;
 76    }
 77
 78    const llama_vocab * vocab = llama_model_get_vocab(model);
 79
 80    // initialize the context
 81    llama_context_params ctx_params = llama_context_default_params();
 82    ctx_params.n_ctx = n_ctx;
 83    ctx_params.n_batch = n_ctx;
 84
 85    llama_context * ctx = llama_init_from_model(model, ctx_params);
 86    if (!ctx) {
 87        fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
 88        return 1;
 89    }
 90
 91    // initialize the sampler
 92    llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());
 93    llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1));
 94    llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
 95    llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
 96
 97    // helper function to evaluate a prompt and generate a response
 98    auto generate = [&](const std::string & prompt) {
 99        std::string response;
100
101        const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1;
102
103        // tokenize the prompt
104        const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
105        std::vector<llama_token> prompt_tokens(n_prompt_tokens);
106        if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
107            GGML_ABORT("failed to tokenize the prompt\n");
108        }
109
110        // prepare a batch for the prompt
111        llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
112        llama_token new_token_id;
113        while (true) {
114            // check if we have enough space in the context to evaluate this batch
115            int n_ctx = llama_n_ctx(ctx);
116            int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) + 1;
117            if (n_ctx_used + batch.n_tokens > n_ctx) {
118                printf("\033[0m\n");
119                fprintf(stderr, "context size exceeded\n");
120                exit(0);
121            }
122
123            int ret = llama_decode(ctx, batch);
124            if (ret != 0) {
125                GGML_ABORT("failed to decode, ret = %d\n", ret);
126            }
127
128            // sample the next token
129            new_token_id = llama_sampler_sample(smpl, ctx, -1);
130
131            // is it an end of generation?
132            if (llama_vocab_is_eog(vocab, new_token_id)) {
133                break;
134            }
135
136            // convert the token to a string, print it and add it to the response
137            char buf[256];
138            int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
139            if (n < 0) {
140                GGML_ABORT("failed to convert token to piece\n");
141            }
142            std::string piece(buf, n);
143            printf("%s", piece.c_str());
144            fflush(stdout);
145            response += piece;
146
147            // prepare the next batch with the sampled token
148            batch = llama_batch_get_one(&new_token_id, 1);
149        }
150
151        return response;
152    };
153
154    std::vector<llama_chat_message> messages;
155    std::vector<char> formatted(llama_n_ctx(ctx));
156    int prev_len = 0;
157    while (true) {
158        // get user input
159        printf("\033[32m> \033[0m");
160        std::string user;
161        std::getline(std::cin, user);
162
163        if (user.empty()) {
164            break;
165        }
166
167        const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
168
169        // add the user input to the message list and format it
170        messages.push_back({"user", strdup(user.c_str())});
171        int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
172        if (new_len > (int)formatted.size()) {
173            formatted.resize(new_len);
174            new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
175        }
176        if (new_len < 0) {
177            fprintf(stderr, "failed to apply the chat template\n");
178            return 1;
179        }
180
181        // remove previous messages to obtain the prompt to generate the response
182        std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
183
184        // generate a response
185        printf("\033[33m");
186        std::string response = generate(prompt);
187        printf("\n\033[0m");
188
189        // add the response to the messages
190        messages.push_back({"assistant", strdup(response.c_str())});
191        prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
192        if (prev_len < 0) {
193            fprintf(stderr, "failed to apply the chat template\n");
194            return 1;
195        }
196    }
197
198    // free resources
199    for (auto & msg : messages) {
200        free(const_cast<char *>(msg.content));
201    }
202    llama_sampler_free(smpl);
203    llama_free(ctx);
204    llama_model_free(model);
205
206    return 0;
207}