1#include "llama.h"
  2#include "common.h"
  3#include "console.h"
  4
  5#include <cstdio>
  6#include <string>
  7#include <map>
  8#include <vector>
  9#include <fstream>
 10#include <thread>
 11
 12//static const std::map<std::string, std::vector<llama_token>> & k_tests() {
 13//    static std::map<std::string, std::vector<llama_token>> _k_tests = {
 14//        { ""                      , {  }, },
 15//        { " "                     , {     220, }, },
 16//        { "  "                    , {     256, }, },
 17//        { "   "                   , {     262, }, },
 18//        { "\t"                    , {     197, }, },
 19//        { "\n"                    , {     198, }, },
 20//        { "\n\n"                  , {     271, }, },
 21//        { "\n\n\n"                , {    1432, }, },
 22//        { "\t\n"                  , {    1602, }, },
 23//        { "Hello world"           , {    9906,   1917, }, },
 24//        { " Hello world"          , {   22691,   1917, }, },
 25//        { "Hello World"           , {    9906,   4435, }, },
 26//        { " Hello World"          , {   22691,   4435, }, },
 27//        { " Hello World!"         , {   22691,   4435,      0, }, },
 28//        { "Hello, world!"         , {    9906,     11,   1917,      0, }, },
 29//        { " Hello, world!"        , {   22691,     11,   1917,      0, }, },
 30//        { " this is 🦙.cpp"        , {     420,    374,  11410,     99,    247,     13,  11055, }, },
 31//        { "w048 7tuijk dsdfhu"    , {      86,  23904,    220,     22,     83,   2005,  42908,  11729,   3013,  17156, }, },
 32//        { "нещо на Български"     , {   79862, 102118,  13373,  64571,  34694,   3114, 112203,  80112, }, },
 33//        { "កាន់តែពិសេសអាចខលចេញ"   , {   21549,    222,  98629,    241,  45358,    233,  21549,    237,  45358,    224,  21549,    244,  21549,    115,  21549,    253,  45358,    223,  21549,    253,  21549,     95,  98629,    227,  21549,    223,  21549,    249,  21549,    227,  45358,    223,  21549,    231, }, },
 34//        { "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", {    9468,    248,    222,    320,   8416,      8,  27623,    114, 102470,   9468,    234,    104,  31643,    320,  36773, 100166,  98634,      8,  26602,    227,    320,   3323,  43465,    430,    706,   1202,   1866,   4037,      8, }, },
 35//        { "Hello"                 , {    9906, }, },
 36//        { " Hello"                , {   22691, }, },
 37//        { "  Hello"               , {     220,  22691, }, },
 38//        { "   Hello"              , {     256,  22691, }, },
 39//        { "    Hello"             , {     262,  22691, }, },
 40//        { "    Hello\n    Hello"  , {     262,  22691,    198,    262,  22691, }, },
 41//        { " ("                    , {     320, }, },
 42//        { "\n ="                  , {     198,    284, }, },
 43//        { "' era"                 , {       6,  11639, }, },
 44//        { "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", {    9906,     11,    379,  65948,      0,   2650,    527,    499,  27623,    223,    949,  37046, 101067,  19000,  23182, 102301,   9263,  18136,     16,  36827,  21909, }, },
 45//        { "3"                     , {      18, }, },
 46//        { "33"                    , {    1644, }, },
 47//        { "333"                   , {    8765, }, },
 48//        { "3333"                  , {    8765,     18, }, },
 49//        { "33333"                 , {    8765,   1644, }, },
 50//        { "333333"                , {    8765,   8765, }, },
 51//        { "3333333"               , {    8765,   8765,     18, }, },
 52//        { "33333333"              , {    8765,   8765,   1644, }, },
 53//        { "333333333"             , {    8765,   8765,   8765, }, },
 54//    };
 55//
 56//    return _k_tests;
 57//}
 58
 59using llama_tests = std::map<std::string, std::vector<llama_token>>;
 60
 61static llama_tests read_tests(const std::string & fname_inp, const std::string & fname_out) {
 62    llama_tests tests;
 63
 64    std::ifstream ifs_inp(fname_inp);
 65    if (!ifs_inp) {
 66        fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_inp.c_str());
 67        return tests;
 68    }
 69
 70    std::string sraw((std::istreambuf_iterator<char>(ifs_inp)), std::istreambuf_iterator<char>());
 71
 72    std::ifstream ifs_out(fname_out);
 73    if (!ifs_out) {
 74        fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
 75        return tests;
 76    }
 77
 78    std::vector<std::string> sout;
 79    for (std::string line; std::getline(ifs_out, line);) {
 80        sout.push_back(line);
 81    }
 82
 83    const std::string sep = "\n__ggml_vocab_test__\n";
 84
 85    std::vector<std::string> sinp;
 86
 87    size_t pos = 0;
 88    while (pos < sraw.size()) {
 89        const size_t next = sraw.find(sep, pos);
 90        if (next == std::string::npos) {
 91            sinp.push_back(sraw.substr(pos));
 92            break;
 93        }
 94        sinp.push_back(sraw.substr(pos, next - pos));
 95        pos = next + sep.size();
 96    }
 97
 98    if (sinp.size() != sout.size()) {
 99        fprintf(stderr, "%s : error: input and output files have different number of tests\n", __func__);
100        return tests;
101    }
102
103    for (size_t i = 0; i < sinp.size(); ++i) {
104        const std::string & s = sinp[i];
105        const std::string & o = string_strip(sout[i]);
106
107        std::vector<llama_token> toks;
108
109        size_t pos = 0;
110        while (pos < o.size()) {
111            size_t next = o.find(' ', pos);
112            if (next == std::string::npos) {
113                next = o.size();
114            }
115            const std::string stok = o.substr(pos, next - pos);
116            toks.push_back(std::stoi(stok));
117            pos = next + 1;
118        }
119
120        tests[s] = toks;
121    }
122
123    return tests;
124}
125
126int main(int argc, char **argv) {
127    if (argc < 2) {
128        fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
129        return 1;
130    }
131
132    const std::string fname = argv[1];
133
134    const std::string fname_inp = fname + ".inp";
135    const std::string fname_out = fname + ".out";
136
137    std::string fname_text;
138    if (argc > 2) {
139        fname_text = argv[2];
140    }
141
142    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
143
144    llama_model * model;
145    llama_context * ctx;
146
147    llama_backend_init();
148
149    // load the vocab
150    {
151        auto mparams = llama_model_default_params();
152
153        mparams.vocab_only = true;
154
155        model = llama_model_load_from_file(fname.c_str(), mparams);
156
157        if (model == NULL) {
158            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
159            return 1;
160        }
161
162        auto cparams = llama_context_default_params();
163
164        ctx = llama_init_from_model(model, cparams);
165
166        if (ctx == NULL) {
167            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
168            llama_model_free(model);
169            return 1;
170        }
171    }
172
173#ifdef _WIN32
174    // We need this for unicode console support
175    console::init(false, false);
176    atexit([]() { console::cleanup(); });
177#endif
178
179    bool success = true;
180
181    const auto k_tests = [&]() -> llama_tests {
182        if (!fname_text.empty()) {
183            return {};
184        }
185
186        const auto res = read_tests(fname_inp, fname_out);
187
188        if (res.empty()) {
189            fprintf(stderr, "%s : error: no tests found\n", __func__);
190            exit(1);
191        }
192
193        return res;
194    }();
195
196    const bool add_special = false;
197
198    // multi-threaded tokenization
199    const int nthread = std::thread::hardware_concurrency();
200    std::vector<std::thread> threads(nthread);
201
202    for (int i = 0; i < nthread; i++) {
203        threads[i] = std::thread([&, i]() {
204            for (const auto & test_kv : k_tests) {
205                const std::vector<llama_token> res = common_tokenize(ctx, test_kv.first, add_special, false);
206
207                // here only print the result of the first thread
208                // because the other threads are running the same tests
209                if (i != 0) {
210                    continue;
211                }
212
213                printf("\n");
214                printf("src: '%s'\n", test_kv.first.c_str());
215                printf("res: '%s'\n", common_detokenize(ctx, res).c_str());
216                printf("tok: ");
217                for (const auto & tok : res) {
218                    printf("%d ", tok);
219                }
220                printf("\n");
221
222                bool correct = res.size() == test_kv.second.size();
223                for (int i = 0; i < (int) res.size() && correct; ++i) {
224                    if (test_kv.second[i] != res[i]) {
225                        correct = false;
226                    }
227                }
228
229                if (!correct) {
230                    fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
231                    fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
232                        common_detokenize(ctx, res).c_str(),
233                        common_detokenize(ctx, test_kv.second).c_str());
234                    fprintf(stderr, "%s : expected tokens: ", __func__);
235                    for (const auto & t : test_kv.second) {
236                        fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
237                    }
238                    fprintf(stderr, "\n");
239                    fprintf(stderr, "%s : got tokens:      ", __func__);
240                    for (const auto & t : res) {
241                        fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
242                    }
243                    fprintf(stderr, "\n");
244
245                    success = false;
246                }
247            }
248        });
249    }
250
251    for (int i = 0; i < nthread; i++) {
252        threads[i].join();
253    }
254
255    // single threaded tokenization
256    if (!fname_text.empty()) {
257        fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
258
259        std::string text;
260        {
261            std::ifstream ifs(fname_text);
262            if (!ifs) {
263                fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
264                return 1;
265            }
266            text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
267        }
268
269        fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
270
271        std::vector<llama_token> res;
272
273        {
274            const auto t_start = ggml_time_us();
275
276            res = common_tokenize(ctx, text, add_special, false);
277
278            const auto t_end = ggml_time_us();
279
280            fprintf(stderr, "%s : tokenized in %.3f ms (cpp)\n", __func__, (t_end - t_start) / 1000.0);
281        }
282
283        fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
284
285        {
286            const std::string fname_out = fname_text + ".tokcpp";
287
288            std::ofstream ofs(fname_out);
289            if (!ofs) {
290                fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
291                return 1;
292            }
293
294            for (const auto & tok : res) {
295                //ofs << tok << " '" << string_strip(llama_detokenize(ctx, std::vector<int>{tok})) << "'" << std::endl;
296                ofs << tok << "\n";
297            }
298        }
299
300        fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
301    }
302
303    llama_free(ctx);
304    llama_model_free(model);
305
306    llama_backend_free();
307
308    printf("\n");
309    printf("Tests %s\n", success ? "passed" : "failed");
310
311    return success ? 0 : 3;
312}