1#include "common.h"
  2#include "arg.h"
  3#include "console.h"
  4// #include "log.h"
  5
  6#include "server-context.h"
  7#include "server-task.h"
  8
  9#include <atomic>
 10#include <fstream>
 11#include <thread>
 12#include <signal.h>
 13
 14#if defined(_WIN32)
 15#define WIN32_LEAN_AND_MEAN
 16#ifndef NOMINMAX
 17#   define NOMINMAX
 18#endif
 19#include <windows.h>
 20#endif
 21
 22const char * LLAMA_ASCII_LOGO = R"(
 23▄▄ ▄▄
 24██ ██
 25██ ██  ▀▀█▄ ███▄███▄  ▀▀█▄    ▄████ ████▄ ████▄
 26██ ██ ▄█▀██ ██ ██ ██ ▄█▀██    ██    ██ ██ ██ ██
 27██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
 28                                    ██    ██
 29                                    ▀▀    ▀▀
 30)";
 31
 32static std::atomic<bool> g_is_interrupted = false;
 33static bool should_stop() {
 34    return g_is_interrupted.load();
 35}
 36
 37#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
 38static void signal_handler(int) {
 39    if (g_is_interrupted.load()) {
 40        // second Ctrl+C - exit immediately
 41        // make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
 42        fprintf(stdout, "\033[0m\n");
 43        fflush(stdout);
 44        std::exit(130);
 45    }
 46    g_is_interrupted.store(true);
 47}
 48#endif
 49
 50struct cli_context {
 51    server_context ctx_server;
 52    json messages = json::array();
 53    std::vector<raw_buffer> input_files;
 54    task_params defaults;
 55
 56    // thread for showing "loading" animation
 57    std::atomic<bool> loading_show;
 58
 59    cli_context(const common_params & params) {
 60        defaults.sampling    = params.sampling;
 61        defaults.speculative = params.speculative;
 62        defaults.n_keep      = params.n_keep;
 63        defaults.n_predict   = params.n_predict;
 64        defaults.antiprompt  = params.antiprompt;
 65
 66        defaults.stream = true; // make sure we always use streaming mode
 67        defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
 68        // defaults.return_progress = true; // TODO: show progress
 69    }
 70
 71    std::string generate_completion(result_timings & out_timings) {
 72        server_response_reader rd = ctx_server.get_response_reader();
 73        auto chat_params = format_chat();
 74        {
 75            // TODO: reduce some copies here in the future
 76            server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
 77            task.id         = rd.get_new_id();
 78            task.index      = 0;
 79            task.params     = defaults;           // copy
 80            task.cli_prompt = chat_params.prompt; // copy
 81            task.cli_files  = input_files;        // copy
 82            task.cli        = true;
 83
 84            // chat template settings
 85            task.params.chat_parser_params = common_chat_parser_params(chat_params);
 86            task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
 87            if (!chat_params.parser.empty()) {
 88                task.params.chat_parser_params.parser.load(chat_params.parser);
 89            }
 90
 91            rd.post_task({std::move(task)});
 92        }
 93
 94        // wait for first result
 95        console::spinner::start();
 96        server_task_result_ptr result = rd.next(should_stop);
 97
 98        console::spinner::stop();
 99        std::string curr_content;
100        bool is_thinking = false;
101
102        while (result) {
103            if (should_stop()) {
104                break;
105            }
106            if (result->is_error()) {
107                json err_data = result->to_json();
108                if (err_data.contains("message")) {
109                    console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
110                } else {
111                    console::error("Error: %s\n", err_data.dump().c_str());
112                }
113                return curr_content;
114            }
115            auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
116            if (res_partial) {
117                out_timings = std::move(res_partial->timings);
118                for (const auto & diff : res_partial->oaicompat_msg_diffs) {
119                    if (!diff.content_delta.empty()) {
120                        if (is_thinking) {
121                            console::log("\n[End thinking]\n\n");
122                            console::set_display(DISPLAY_TYPE_RESET);
123                            is_thinking = false;
124                        }
125                        curr_content += diff.content_delta;
126                        console::log("%s", diff.content_delta.c_str());
127                        console::flush();
128                    }
129                    if (!diff.reasoning_content_delta.empty()) {
130                        console::set_display(DISPLAY_TYPE_REASONING);
131                        if (!is_thinking) {
132                            console::log("[Start thinking]\n");
133                        }
134                        is_thinking = true;
135                        console::log("%s", diff.reasoning_content_delta.c_str());
136                        console::flush();
137                    }
138                }
139            }
140            auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
141            if (res_final) {
142                out_timings = std::move(res_final->timings);
143                break;
144            }
145            result = rd.next(should_stop);
146        }
147        g_is_interrupted.store(false);
148        // server_response_reader automatically cancels pending tasks upon destruction
149        return curr_content;
150    }
151
152    // TODO: support remote files in the future (http, https, etc)
153    std::string load_input_file(const std::string & fname, bool is_media) {
154        std::ifstream file(fname, std::ios::binary);
155        if (!file) {
156            return "";
157        }
158        if (is_media) {
159            raw_buffer buf;
160            buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
161            input_files.push_back(std::move(buf));
162            return mtmd_default_marker();
163        } else {
164            std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
165            return content;
166        }
167    }
168
169    common_chat_params format_chat() {
170        auto meta = ctx_server.get_meta();
171        auto & chat_params = meta.chat_params;
172
173        common_chat_templates_inputs inputs;
174        inputs.messages              = common_chat_msgs_parse_oaicompat(messages);
175        inputs.tools                 = {}; // TODO
176        inputs.tool_choice           = COMMON_CHAT_TOOL_CHOICE_NONE;
177        inputs.json_schema           = ""; // TODO
178        inputs.grammar               = ""; // TODO
179        inputs.use_jinja             = chat_params.use_jinja;
180        inputs.parallel_tool_calls   = false;
181        inputs.add_generation_prompt = true;
182        inputs.enable_thinking       = chat_params.enable_thinking;
183
184        // Apply chat template to the list of messages
185        return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
186    }
187};
188
189int main(int argc, char ** argv) {
190    common_params params;
191
192    params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
193
194    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
195        return 1;
196    }
197
198    // TODO: maybe support it later?
199    if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
200        console::error("--no-conversation is not supported by llama-cli\n");
201        console::error("please use llama-completion instead\n");
202    }
203
204    common_init();
205
206    // struct that contains llama context and inference
207    cli_context ctx_cli(params);
208
209    llama_backend_init();
210    llama_numa_init(params.numa);
211
212    // TODO: avoid using atexit() here by making `console` a singleton
213    console::init(params.simple_io, params.use_color);
214    atexit([]() { console::cleanup(); });
215
216    console::set_display(DISPLAY_TYPE_RESET);
217
218#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
219    struct sigaction sigint_action;
220    sigint_action.sa_handler = signal_handler;
221    sigemptyset (&sigint_action.sa_mask);
222    sigint_action.sa_flags = 0;
223    sigaction(SIGINT, &sigint_action, NULL);
224    sigaction(SIGTERM, &sigint_action, NULL);
225#elif defined (_WIN32)
226    auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
227        return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
228    };
229    SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
230#endif
231
232    console::log("\nLoading model... "); // followed by loading animation
233    console::spinner::start();
234    if (!ctx_cli.ctx_server.load_model(params)) {
235        console::spinner::stop();
236        console::error("\nFailed to load the model\n");
237        return 1;
238    }
239
240    console::spinner::stop();
241    console::log("\n");
242
243    std::thread inference_thread([&ctx_cli]() {
244        ctx_cli.ctx_server.start_loop();
245    });
246
247    auto inf = ctx_cli.ctx_server.get_meta();
248    std::string modalities = "text";
249    if (inf.has_inp_image) {
250        modalities += ", vision";
251    }
252    if (inf.has_inp_audio) {
253        modalities += ", audio";
254    }
255
256    if (!params.system_prompt.empty()) {
257        ctx_cli.messages.push_back({
258            {"role",    "system"},
259            {"content", params.system_prompt}
260        });
261    }
262
263    console::log("\n");
264    console::log("%s\n", LLAMA_ASCII_LOGO);
265    console::log("build      : %s\n", inf.build_info.c_str());
266    console::log("model      : %s\n", inf.model_name.c_str());
267    console::log("modalities : %s\n", modalities.c_str());
268    if (!params.system_prompt.empty()) {
269        console::log("using custom system prompt\n");
270    }
271    console::log("\n");
272    console::log("available commands:\n");
273    console::log("  /exit or Ctrl+C     stop or exit\n");
274    console::log("  /regen              regenerate the last response\n");
275    console::log("  /clear              clear the chat history\n");
276    console::log("  /read               add a text file\n");
277    if (inf.has_inp_image) {
278        console::log("  /image <file>       add an image file\n");
279    }
280    if (inf.has_inp_audio) {
281        console::log("  /audio <file>       add an audio file\n");
282    }
283    console::log("\n");
284
285    // interactive loop
286    std::string cur_msg;
287    while (true) {
288        std::string buffer;
289        console::set_display(DISPLAY_TYPE_USER_INPUT);
290        if (params.prompt.empty()) {
291            console::log("\n> ");
292            std::string line;
293            bool another_line = true;
294            do {
295                another_line = console::readline(line, params.multiline_input);
296                buffer += line;
297            } while (another_line);
298        } else {
299            // process input prompt from args
300            for (auto & fname : params.image) {
301                std::string marker = ctx_cli.load_input_file(fname, true);
302                if (marker.empty()) {
303                    console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
304                    break;
305                }
306                console::log("Loaded media from '%s'\n", fname.c_str());
307                cur_msg += marker;
308            }
309            buffer = params.prompt;
310            if (buffer.size() > 500) {
311                console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
312            } else {
313                console::log("\n> %s\n", buffer.c_str());
314            }
315            params.prompt.clear(); // only use it once
316        }
317        console::set_display(DISPLAY_TYPE_RESET);
318        console::log("\n");
319
320        if (should_stop()) {
321            g_is_interrupted.store(false);
322            break;
323        }
324
325        // remove trailing newline
326        if (!buffer.empty() &&buffer.back() == '\n') {
327            buffer.pop_back();
328        }
329
330        // skip empty messages
331        if (buffer.empty()) {
332            continue;
333        }
334
335        bool add_user_msg = true;
336
337        // process commands
338        if (string_starts_with(buffer, "/exit")) {
339            break;
340        } else if (string_starts_with(buffer, "/regen")) {
341            if (ctx_cli.messages.size() >= 2) {
342                size_t last_idx = ctx_cli.messages.size() - 1;
343                ctx_cli.messages.erase(last_idx);
344                add_user_msg = false;
345            } else {
346                console::error("No message to regenerate.\n");
347                continue;
348            }
349        } else if (string_starts_with(buffer, "/clear")) {
350            ctx_cli.messages.clear();
351            ctx_cli.input_files.clear();
352            console::log("Chat history cleared.\n");
353            continue;
354        } else if (
355                (string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
356                (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
357            // just in case (bad copy-paste for example), we strip all trailing/leading spaces
358            std::string fname = string_strip(buffer.substr(7));
359            std::string marker = ctx_cli.load_input_file(fname, true);
360            if (marker.empty()) {
361                console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
362                continue;
363            }
364            cur_msg += marker;
365            console::log("Loaded media from '%s'\n", fname.c_str());
366            continue;
367        } else if (string_starts_with(buffer, "/read ")) {
368            std::string fname = string_strip(buffer.substr(6));
369            std::string marker = ctx_cli.load_input_file(fname, false);
370            if (marker.empty()) {
371                console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
372                continue;
373            }
374            cur_msg += marker;
375            console::log("Loaded text from '%s'\n", fname.c_str());
376            continue;
377        } else {
378            // not a command
379            cur_msg += buffer;
380        }
381
382        // generate response
383        if (add_user_msg) {
384            ctx_cli.messages.push_back({
385                {"role",    "user"},
386                {"content", cur_msg}
387            });
388            cur_msg.clear();
389        }
390        result_timings timings;
391        std::string assistant_content = ctx_cli.generate_completion(timings);
392        ctx_cli.messages.push_back({
393            {"role",    "assistant"},
394            {"content", assistant_content}
395        });
396        console::log("\n");
397
398        if (params.show_timings) {
399            console::set_display(DISPLAY_TYPE_INFO);
400            console::log("\n");
401            console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
402            console::set_display(DISPLAY_TYPE_RESET);
403        }
404
405        if (params.single_turn) {
406            break;
407        }
408    }
409
410    console::set_display(DISPLAY_TYPE_RESET);
411
412    console::log("\nExiting...\n");
413    ctx_cli.ctx_server.terminate();
414    inference_thread.join();
415
416    // bump the log level to display timings
417    common_log_set_verbosity_thold(LOG_LEVEL_INFO);
418    llama_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
419
420    return 0;
421}