1#include "common.h"
  2//#include "log.h" // TODO: start using log.h
  3#include "llama.h"
  4
  5#include <cstdio>
  6#include <cstring>
  7#include <fstream>
  8#include <string>
  9#include <vector>
 10#include <iostream> // TODO: remove me
 11
 12#if defined(_WIN32)
 13#define WIN32_LEAN_AND_MEAN
 14#include <windows.h>
 15#include <shellapi.h>   // For CommandLineToArgvW
 16#endif
 17
 18static void print_usage_information(const char * argv0) {
 19    printf("usage: %s [options]\n\n", argv0);
 20    printf("The tokenize program tokenizes a prompt using a given model,\n");
 21    printf("and prints the resulting tokens to standard output.\n\n");
 22    printf("It needs a model file, a prompt, and optionally other flags\n");
 23    printf("to control the behavior of the tokenizer.\n\n");
 24    printf("    The possible options are:\n");
 25    printf("\n");
 26    printf("    -h, --help                           print this help and exit\n");
 27    printf("    -m MODEL_PATH, --model MODEL_PATH    path to model.\n");
 28    printf("    --ids                                if given, only print numerical token IDs, and not token strings.\n");
 29    printf("                                         The output format looks like [1, 2, 3], i.e. parseable by Python.\n");
 30    printf("    -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n");
 31    printf("    -p PROMPT, --prompt PROMPT           read prompt from the argument.\n");
 32    printf("    --stdin                              read prompt from standard input.\n");
 33    printf("    --no-bos                             do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
 34    printf("    --no-escape                          do not escape input (such as \\n, \\t, etc.).\n");
 35    printf("    --no-parse-special                   do not parse control tokens.\n");
 36    printf("    --log-disable                        disable logs. Makes stderr quiet when loading the model.\n");
 37    printf("    --show-count                         print the total number of tokens.\n");
 38}
 39
 40static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) {
 41    (void) level;
 42    (void) text;
 43    (void) user_data;
 44}
 45
 46static std::string read_prompt_from_file(const char * filepath, bool & success) {
 47    success = false;
 48
 49    std::ifstream in(filepath, std::ios::binary);
 50    if (!in) {
 51        fprintf(stderr, "%s: could not open file '%s' for reading: %s\n", __func__, filepath, strerror(errno));
 52        return std::string();
 53    }
 54    // do not assume the file is seekable (e.g. /dev/stdin)
 55    std::stringstream buffer;
 56    buffer << in.rdbuf();
 57    if (in.fail()) {
 58        fprintf(stderr, "%s: could not read the entire file '%s': %s\n", __func__, filepath, strerror(errno));
 59        return std::string();
 60    }
 61
 62    success = true;
 63    return buffer.str();
 64}
 65
 66//
 67// Function: ingest_args(...) -> vector<string>
 68//
 69//  Takes argc and argv arguments, and converts them to a vector of UTF-8 encoded
 70//  strings, as an STL vector<string>.
 71//
 72//  In particular, it handles character encoding shenanigans on Windows.
 73//
 74// Note: raw_argc and raw_argv are not actually read at all on Windows.
 75//       On Windows we call GetCommandLineW to get the arguments in wchar_t
 76//       format, ignoring the regular argc/argv arguments to main().
 77//
 78// TODO: potential opportunity to roll common stuff into common/console.cpp
 79//       in relation to Windows wchar_t shenanigans.
 80static std::vector<std::string> ingest_args(int raw_argc, char ** raw_argv) {
 81    std::vector<std::string> argv;
 82
 83    // Handle Windows, if given non-ASCII arguments.
 84    // We convert wchar_t arguments into UTF-8 char* on this platform.
 85    // Lets you invoke 'tokenize' on Windows cmd.exe with non-ASCII characters
 86    // without throwing tantrums.
 87#if defined(_WIN32)
 88    int argc;
 89    const LPWSTR cmdline_wargv = GetCommandLineW();
 90    LPWSTR * wargv = CommandLineToArgvW(cmdline_wargv, &argc);
 91
 92    // silence unused arg warnings
 93    (void) raw_argc;
 94    (void) raw_argv;
 95
 96    for (int i = 0; i < argc; ++i) {
 97        int length_needed = WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), 0, 0, NULL, NULL);
 98        char * output_buf = (char *) calloc(length_needed+1, sizeof(char));
 99        GGML_ASSERT(output_buf);
100
101        WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), output_buf, length_needed, NULL, NULL);
102        output_buf[length_needed] = '\0';
103
104        argv.push_back(output_buf);
105        free(output_buf);
106    }
107
108    LocalFree((HLOCAL) wargv);
109#else
110    int argc = raw_argc;
111    for (int i = 0; i < argc; ++i) {
112        argv.push_back(raw_argv[i]);
113    }
114#endif
115
116    GGML_ASSERT((unsigned int) argc == argv.size());
117
118    return argv;
119}
120
121//
122// Function: write_utf8_cstr_to_stdout(const char *) -> <writes to stdout>
123//
124// writes a string to standard output; taking into account that on Windows
125// to display correctly you have to use special handling. Works even if the
126// user has not set a unicode code page on a Windows cmd.exe.
127//
128// In case of invalid UTF-8, invalid_utf8 is set to true on Windows, and something
129// a human-readable is written instead.
130//
131// On non-Windows systems, simply printfs() the string.
132static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) {
133        invalid_utf8 = false;
134
135#if defined(_WIN32)
136        // Are we in a console?
137        HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
138        DWORD dwMode = 0;
139
140        // According to Microsoft docs:
141        // "WriteConsole fails if it is used with a standard handle that is redirected to a file."
142        // Also according to the docs, you can use GetConsoleMode to check for that.
143        if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
144            printf("%s", str);
145            return;
146        }
147
148        // MultiByteToWideChar reports an error if str is empty, don't report
149        // them as invalid_utf8.
150        if (*str == 0) {
151            return;
152        }
153        int length_needed = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, strlen(str), NULL, 0);
154        if (length_needed == 0) {
155            DWORD err = GetLastError();
156            if (err == ERROR_NO_UNICODE_TRANSLATION) {
157                invalid_utf8 = true;
158                int len = strlen(str);
159                printf("<");
160                for (int i = 0; i < len; ++i) {
161                    if (i > 0) {
162                        printf(" ");
163                    }
164                    printf("%02x", (uint8_t) str[i]);
165                }
166                printf(">");
167                return;
168            }
169            GGML_ABORT("MultiByteToWideChar() failed in an unexpected way.");
170        }
171
172        LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr));
173        GGML_ASSERT(wstr);
174
175        MultiByteToWideChar(CP_UTF8, 0, str, strlen(str), wstr, length_needed);
176        WriteConsoleW(hConsole, wstr, length_needed, NULL, NULL);
177
178        free(wstr);
179#else
180        // TODO: reporting invalid_utf8 would be useful on non-Windows too.
181        // printf will silently just write bad unicode.
182        printf("%s", str);
183#endif
184}
185
186int main(int raw_argc, char ** raw_argv) {
187    const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv);
188    const int argc = argv.size();
189
190    if (argc <= 1) {
191        print_usage_information(argv[0].c_str());
192        return 1;
193    }
194
195    //////
196    // Read out all the command line arguments.
197    //////
198
199    // variables where to put any arguments we see.
200    bool printing_ids = false;
201    bool no_bos = false;
202    bool no_escape = false;
203    bool no_parse_special = false;
204    bool disable_logging = false;
205    bool show_token_count = false;
206    const char * model_path = NULL;
207    const char * prompt_path = NULL;
208    const char * prompt_arg = NULL;
209
210    // track which arguments were explicitly given
211    // used for sanity checking down the line
212    bool model_path_set = false;
213    bool prompt_path_set = false;
214    bool prompt_set = false;
215    bool stdin_set = false;
216
217    int iarg = 1;
218    for (; iarg < argc; ++iarg) {
219        std::string arg{argv[iarg]};
220        if (arg == "-h" || arg == "--help") {
221            print_usage_information(argv[0].c_str());
222            return 0;
223        }
224        else if (arg == "--ids") {
225            printing_ids = true;
226        }
227        else if (arg == "-m" || arg == "--model") {
228            if (model_path_set) {
229                fprintf(stderr, "Error: -m or --model specified multiple times.\n");
230                return 1;
231            }
232            model_path = argv[++iarg].c_str();
233            model_path_set = true;
234        }
235        else if (arg == "--no-bos") {
236            no_bos = true;
237        }
238        else if (arg == "--no-escape") {
239            no_escape = true;
240        }
241        else if (arg == "--no-parse-special") {
242            no_parse_special = true;
243        }
244        else if (arg == "-p" || arg == "--prompt") {
245            if (prompt_set) {
246                fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
247                return 1;
248            }
249            prompt_arg = argv[++iarg].c_str();
250            prompt_set = true;
251        }
252        else if (arg == "-f" || arg == "--file") {
253            if (prompt_path_set) {
254                fprintf(stderr, "Error: -f or --file specified multiple times.\n");
255                return 1;
256            }
257            prompt_path = argv[++iarg].c_str();
258            prompt_path_set = true;
259        }
260        else if (arg == "--stdin") {
261            stdin_set = true;
262        }
263        else if (arg == "--log-disable") {
264            disable_logging = true;
265        }
266        else if (arg == "--show-count") {
267            show_token_count = true;
268        }
269        else {
270            fprintf(stderr, "Error: unknown option '%s'\n", argv[iarg].c_str());
271            return 1;
272        }
273    }
274
275    //////
276    // Sanity check the command line arguments.
277    //////
278
279    // Check that we have the required stuff set.
280    if (model_path_set && model_path == NULL) {
281        fprintf(stderr, "Error: --model requires an argument.\n");
282        return 1;
283    }
284    if (!model_path_set) {
285        fprintf(stderr, "Error: must specify --model.\n");
286        return 1;
287    }
288    if (prompt_path_set && prompt_path == NULL) {
289        fprintf(stderr, "Error: --file requires an argument.\n");
290        return 1;
291    }
292    if (prompt_set && prompt_arg == NULL) {
293        fprintf(stderr, "Error: --prompt requires an argument.\n");
294        return 1;
295    }
296    const int prompts_set = !!(prompt_path_set) + !!(prompt_set) + !!(stdin_set);
297    if (prompts_set > 1) {
298        fprintf(stderr, "Error: --stdin, --file and --prompt are mutually exclusive.\n");
299        return 1;
300    }
301    // Must have some prompt.
302    if (prompts_set == 0) {
303        fprintf(stderr, "Error: must specify one of: --stdin, --file or --prompt.\n");
304        return 1;
305    }
306
307    GGML_ASSERT(model_path);
308    GGML_ASSERT(prompt_path || prompt_arg || stdin_set);
309
310    //////
311    // Figure out where will the prompt come from.
312    //////
313
314    std::string prompt;
315    if (prompt_path_set) {
316        bool success = false;
317        prompt = read_prompt_from_file(prompt_path, success);
318        if (!success) {
319            return 1;
320        }
321    } else if (prompt_set) {
322        prompt = prompt_arg;
323    } else {
324        GGML_ASSERT(stdin_set);
325        // we read stdin *after* loading model (early exit if model cannot
326        // be loaded, which can be a nicer user experience)
327    }
328
329    //////
330    // Start actually doing the tokenizing stuff.
331    //////
332
333    if (disable_logging) {
334        llama_log_set(llama_log_callback_null, NULL);
335    }
336
337    llama_backend_init();
338
339    llama_model_params model_params = llama_model_default_params();
340    model_params.vocab_only = true;
341    llama_model * model = llama_model_load_from_file(model_path, model_params);
342    if (!model) {
343        fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path);
344        return 1;
345    }
346
347    const llama_vocab * vocab = llama_model_get_vocab(model);
348
349    llama_context_params ctx_params = llama_context_default_params();
350    llama_context * ctx = llama_init_from_model(model, ctx_params);
351    if (!ctx) {
352        fprintf(stderr, "Error: could not create context.\n");
353        return 1;
354    }
355
356    // read entire prompt from stdin?
357    if (stdin_set) {
358        GGML_ASSERT(!prompt_path_set && !prompt_set);
359
360        std::stringstream stdin_buffer;
361        stdin_buffer << std::cin.rdbuf();
362        if (std::cin.fail()) {
363            fprintf(stderr, "Error: could not read the entire standard input.\n");
364            return 1;
365        }
366
367        prompt = stdin_buffer.str();
368    }
369
370    const bool model_wants_add_bos = llama_vocab_get_add_bos(vocab);
371    const bool add_bos = model_wants_add_bos && !no_bos;
372    const bool parse_special = !no_parse_special;
373    const bool escape = !no_escape;
374
375    if (escape) {
376        string_process_escapes(prompt);
377    }
378
379    std::vector<llama_token> tokens;
380    tokens = common_tokenize(vocab, prompt, add_bos, parse_special);
381
382    if (printing_ids) {
383        printf("[");
384    }
385
386    for (int i = 0; i < (int) tokens.size(); i++) {
387        if (printing_ids) {
388            if (i > 0) {
389                printf(", ");
390            }
391            printf("%d", tokens[i]);
392        } else {
393            bool invalid_utf8 = false;
394            printf("%6d -> '", tokens[i]);
395            write_utf8_cstr_to_stdout(common_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8);
396            if (invalid_utf8) {
397                printf("' (utf-8 decode failure)\n");
398            } else {
399                printf("'\n");
400            }
401        }
402    }
403
404    if (printing_ids) {
405        printf("]\n");
406    }
407
408    if (show_token_count) {
409        printf("Total number of tokens: %zu\n", tokens.size());
410    }
411    // silence valgrind
412    llama_free(ctx);
413    llama_model_free(model);
414
415    return 0;
416}