1// A basic application simulating a server with multiple clients.
  2// The clients submit requests to the server and they are processed in parallel.
  3
  4#include "arg.h"
  5#include "common.h"
  6#include "sampling.h"
  7#include "log.h"
  8#include "llama.h"
  9
 10#include <cmath>
 11#include <cstdio>
 12#include <string>
 13#include <vector>
 14#include <ctime>
 15#include <algorithm>
 16
 17// trim whitespace from the beginning and end of a string
 18static std::string trim(const std::string & str) {
 19    size_t start = 0;
 20    size_t end = str.size();
 21
 22    while (start < end && isspace(str[start])) {
 23        start += 1;
 24    }
 25
 26    while (end > start && isspace(str[end - 1])) {
 27        end -= 1;
 28    }
 29
 30    return str.substr(start, end - start);
 31}
 32
 33static std::string k_system =
 34R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
 35The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
 36
 37User:
 38Recommend a nice restaurant in the area.
 39Assistant:
 40I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
 41User:
 42Who is Richard Feynman?
 43Assistant:
 44Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
 45)";
 46
 47static std::vector<std::string> k_questions = {
 48    "What is the tallest mountain in the world?",
 49    "Who was the first person to win two Nobel Prizes?",
 50    "Which country invented paper?",
 51    "What organ is primarily responsible for pumping blood throughout the body?",
 52    "Which planet is known for its prominent ring system?",
 53    "Who directed the movie 'Inception'?",
 54    "What is the freezing point of water in Fahrenheit?",
 55    "Which animal is known to have the longest lifespan?",
 56    "What language has the most native speakers worldwide?",
 57    "What is the capital city of Canada?",
 58    "Who is credited with inventing the World Wide Web?",
 59    "Which metal is liquid at room temperature?",
 60    "What is the term for an animal that eats both plants and meat?",
 61    "Who painted 'The Starry Night'?",
 62    "What gas do humans exhale that plants use for photosynthesis?",
 63    "What year did World War II end?",
 64    "Which continent has the most countries?",
 65    "Who wrote the novel 'Frankenstein'?",
 66    "What does DNA stand for?",
 67    "What is the main ingredient in traditional Japanese miso soup?"
 68};
 69
 70static std::vector<std::string> k_answers = {
 71    "The tallest mountain in the world is Mount Everest.",
 72    "Marie Curie was the first person to win two Nobel Prizes.",
 73    "Paper was invented in China.",
 74    "The heart is the organ responsible for pumping blood.",
 75    "Saturn is known for its prominent ring system.",
 76    "Christopher Nolan directed the movie 'Inception'.",
 77    "The freezing point of water in Fahrenheit is 32°F.",
 78    "The bowhead whale is known to have the longest lifespan among mammals.",
 79    "Mandarin Chinese has the most native speakers in the world.",
 80    "The capital city of Canada is Ottawa.",
 81    "Tim Berners-Lee is credited with inventing the World Wide Web.",
 82    "Mercury is the metal that is liquid at room temperature.",
 83    "An animal that eats both plants and meat is called an omnivore.",
 84    "'The Starry Night' was painted by Vincent van Gogh.",
 85    "Humans exhale carbon dioxide, which plants use in photosynthesis.",
 86    "World War II ended in 1945.",
 87    "Africa is the continent with the most countries.",
 88    "The novel 'Frankenstein' was written by Mary Shelley.",
 89    "DNA stands for Deoxyribonucleic Acid.",
 90    "The main ingredient in traditional Japanese miso soup is fermented soybean paste."
 91};
 92
 93static std::vector<std::string> k_prompts = {
 94    "What is the meaning of life?",
 95    "Tell me an interesting fact about llamas.",
 96    "What is the best way to cook a steak?",
 97    "Are you familiar with the Special Theory of Relativity and can you explain it to me?",
 98    "Recommend some interesting books to read.",
 99    "What is the best way to learn a new language?",
100    "How to get a job at Google?",
101    "If you could have any superpower, what would it be?",
102    "I want to learn how to play the piano. What would be the best way to do it?",
103};
104
105struct client {
106    ~client() {
107        if (smpl) {
108            common_sampler_free(smpl);
109        }
110    }
111
112    int32_t id = 0;
113
114    llama_seq_id seq_id = -1;
115
116    llama_token sampled;
117
118    int64_t t_start_prompt;
119    int64_t t_start_gen;
120
121    int32_t n_past    = 0;
122    int32_t n_prompt  = 0;
123    int32_t n_decoded = 0;
124    int32_t i_batch   = -1;
125
126    std::string input;
127    std::string prompt;
128    std::string response;
129
130    struct common_sampler * smpl = nullptr;
131};
132
133static void print_date_time() {
134    std::time_t current_time = std::time(nullptr);
135    std::tm* local_time = std::localtime(&current_time);
136    char buffer[80];
137    strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", local_time);
138
139    LOG_INF("\n");
140    LOG_INF("\033[35mrun parameters as of %s\033[0m\n", buffer);
141    LOG_INF("\n");
142}
143
144// Define a split string function to ...
145static std::vector<std::string> split_string(const std::string& input, char delimiter) {
146    std::vector<std::string> tokens;
147    std::istringstream stream(input);
148    std::string token;
149    while (std::getline(stream, token, delimiter)) {
150        tokens.push_back(token);
151    }
152    return tokens;
153}
154
155int main(int argc, char ** argv) {
156    srand(1234);
157
158    common_params params;
159
160    params.n_predict = 128;
161    params.n_junk = 1;
162
163    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
164        return 1;
165    }
166
167    common_init();
168
169    // number of simultaneous "clients" to simulate
170    const int32_t n_clients = params.n_parallel;
171
172    // dedicate one sequence to the system prompt
173    params.n_parallel += 1;
174
175    // requests to simulate
176    const int32_t n_seq = params.n_sequences;
177
178    // insert new requests as soon as the previous one is done
179    const bool cont_batching = params.cont_batching;
180
181    // is the system prompt shared in the cache
182    const bool is_sp_shared = params.is_pp_shared;
183
184    // extra text to insert in each client's prompt in order to make it larger
185    const int32_t n_junk = std::max(1, params.n_junk);
186
187    // signed seed, use negative values to indicate different seeds for the different clients
188    const int32_t & sseed = params.sampling.seed;
189
190    // init llama.cpp
191    llama_backend_init();
192    llama_numa_init(params.numa);
193
194    // load the target model
195    auto llama_init = common_init_from_params(params);
196
197    auto * model = llama_init->model();
198    auto * ctx   = llama_init->context();
199
200    auto * mem = llama_get_memory(ctx);
201
202    const llama_vocab * vocab = llama_model_get_vocab(model);
203
204    // load the prompts from an external file if there are any
205    if (params.prompt.empty()) {
206        LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
207    } else {
208        // Output each line of the input params.prompts vector and copy to k_prompts
209        int index = 0;
210        LOG_INF("\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str());
211
212        std::vector<std::string> prompts = split_string(params.prompt, '\n');
213        for (const auto& prompt : prompts) {
214            k_prompts.resize(index + 1);
215            k_prompts[index] = prompt;
216            index++;
217            LOG_INF("%3d prompt: %s\n", index, prompt.c_str());
218        }
219    }
220
221    LOG_INF("\n\n");
222
223    const int n_ctx = llama_n_ctx(ctx);
224
225    if (sseed >= 0) {
226        LOG_INF("%s: initializing all samplers with the same RNG seed: %d (use a negative seed to have different seeds)\n", __func__, sseed);
227    } else {
228        LOG_INF("%s: initializing samplers with different RNG seeds, starting from %d\n", __func__, sseed);
229    }
230
231    std::vector<client> clients(n_clients);
232    for (size_t i = 0; i < clients.size(); ++i) {
233        auto & client = clients[i];
234        client.id = i;
235        client.smpl = common_sampler_init(model, params.sampling);
236
237        if (sseed < 0) {
238            params.sampling.seed--;
239        }
240    }
241
242    std::vector<llama_token> tokens_system;
243
244    tokens_system = common_tokenize(ctx, k_system, true);
245    const int32_t n_tokens_system = tokens_system.size();
246
247    llama_seq_id g_seq_id = 0;
248
249    // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
250    // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
251    llama_batch batch = llama_batch_init(n_ctx, 0, 1);
252
253    int32_t n_total_prompt = 0;
254    int32_t n_total_gen    = 0;
255    int32_t n_cache_miss   = 0;
256
257    const auto t_main_start = ggml_time_us();
258
259    LOG_INF("%s: Simulating parallel requests from clients:\n", __func__);
260    LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
261    LOG_INF("\n");
262
263    if (is_sp_shared) {
264        LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
265
266        for (int32_t i = 0; i < n_tokens_system; ++i) {
267            common_batch_add(batch, tokens_system[i], i, { 0 }, false);
268        }
269
270        if (llama_decode(ctx, batch) != 0) {
271            LOG_ERR("%s: llama_decode() failed\n", __func__);
272            return 1;
273        }
274
275        // assign the system KV cache to all parallel sequences
276        for (int32_t i = 1; i <= n_clients; ++i) {
277            llama_memory_seq_cp(mem, 0, i, -1, -1);
278        }
279
280        LOG_INF("\n");
281    }
282
283    LOG_INF("Processing requests ...\n\n");
284
285    while (true) {
286        common_batch_clear(batch);
287
288        // decode any currently ongoing sequences
289        for (auto & client : clients) {
290            if (client.seq_id == -1) {
291                continue;
292            }
293
294            client.i_batch = batch.n_tokens;
295
296            common_batch_add(batch, client.sampled, client.n_past++, { client.id + 1 }, true);
297
298            client.n_decoded += 1;
299        }
300
301        if (batch.n_tokens == 0) {
302            // all sequences have ended - clear the entire KV cache
303            for (int i = 1; i <= n_clients; ++i) {
304                llama_memory_seq_rm(mem, i, -1, -1);
305                // but keep the system prompt
306                llama_memory_seq_cp(mem, 0, i, -1, -1);
307            }
308
309            LOG_INF("%s: clearing the KV cache\n", __func__);
310        }
311
312        // insert new sequences for decoding
313        if (cont_batching || batch.n_tokens == 0) {
314            for (auto & client : clients) {
315                if (client.seq_id == -1 && g_seq_id < n_seq) {
316                    client.seq_id = g_seq_id;
317
318                    client.t_start_prompt = ggml_time_us();
319                    client.t_start_gen    = 0;
320
321                    client.input    = k_prompts[rand() % k_prompts.size()];
322                    client.response = "";
323
324                    // construct the prompt:
325                    // [system prompt] + [junk] + [user prompt]
326                    client.n_past = 0;
327                    client.prompt = "";
328                    if (is_sp_shared) {
329                        client.n_past = n_tokens_system;
330                    } else {
331                        client.prompt += k_system;
332                    }
333
334                    const int n_junk_cur = rand() % n_junk;
335
336                    for (int i = 0; i < n_junk_cur; ++i) {
337                        const int r = rand() % k_questions.size();
338                        client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n";
339                    }
340                    client.prompt += "User:\n" + client.input + "\nAssistant:\n";
341
342                    common_sampler_reset(client.smpl);
343
344                    // do not prepend BOS because we have a system prompt!
345                    std::vector<llama_token> tokens_prompt;
346                    tokens_prompt = common_tokenize(ctx, client.prompt, false);
347
348                    for (size_t i = 0; i < tokens_prompt.size(); ++i) {
349                        common_batch_add(batch, tokens_prompt[i], client.n_past++, { client.id + 1 }, false);
350                    }
351
352                    // extract the logits only for the last token
353                    if (batch.n_tokens > 0) {
354                        batch.logits[batch.n_tokens - 1] = true;
355                    }
356
357                    client.n_prompt  = tokens_prompt.size();
358                    client.n_decoded = 0;
359                    client.i_batch   = batch.n_tokens - 1;
360
361                    LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
362
363                    g_seq_id += 1;
364
365                    // insert new requests one-by-one
366                    //if (cont_batching) {
367                    //    break;
368                    //}
369                }
370            }
371        }
372
373        if (batch.n_tokens == 0) {
374            break;
375        }
376
377        // process in chunks of params.n_batch
378        int32_t n_batch = params.n_batch;
379
380        int32_t i_next = 0;
381
382        for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
383            // experiment: process in powers of 2
384            //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
385            //    n_batch /= 2;
386            //    i -= n_batch;
387            //    continue;
388            //}
389
390            const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
391
392            llama_batch batch_view = {
393                n_tokens,
394                batch.token    + i,
395                nullptr,
396                batch.pos      + i,
397                batch.n_seq_id + i,
398                batch.seq_id   + i,
399                batch.logits   + i,
400            };
401
402            const int ret = llama_decode(ctx, batch_view);
403            if (ret != 0) {
404                if (n_batch == 1 || ret < 0) {
405                    // if you get here, it means the KV cache is full - try increasing it via the context size
406                    LOG_ERR("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
407                    return 1;
408                }
409
410                LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
411
412                n_cache_miss += 1;
413
414                // retry with half the batch size to try to find a free slot in the KV cache
415                n_batch /= 2;
416
417                continue;
418            }
419
420            LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
421
422            // move the head of the batch forward with the number of tokens we just processed
423            i_next = i + n_tokens;
424
425            // on successful decode, restore the original batch size
426            n_batch = params.n_batch;
427
428            for (auto & client : clients) {
429                if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
430                    continue;
431                }
432
433                //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
434                //        client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
435
436                const llama_token id = common_sampler_sample(client.smpl, ctx, client.i_batch - i);
437
438                common_sampler_accept(client.smpl, id, true);
439
440                if (client.n_decoded == 1) {
441                    // start measuring generation time after the first token to make sure all concurrent clients
442                    // have their prompt already processed
443                    client.t_start_gen = ggml_time_us();
444                }
445
446                const std::string token_str = common_token_to_piece(ctx, id);
447
448                client.response += token_str;
449                client.sampled = id;
450
451                //printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
452                //        client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
453
454                if (client.n_decoded > 2 &&
455                    (llama_vocab_is_eog(vocab, id) ||
456                     (params.n_predict > 0 && client.n_decoded >= params.n_predict) ||
457                     client.response.find("User:") != std::string::npos)) {
458                    // basic reverse prompt
459                    const size_t pos = client.response.find("User:");
460                    if (pos != std::string::npos) {
461                        client.response = client.response.substr(0, pos);
462                    }
463
464                    // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
465                    llama_memory_seq_rm(mem,    client.id + 1, -1, -1);
466                    llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
467
468                    const auto t_main_end = ggml_time_us();
469
470                    LOG_INF("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput:    %s\n\033[35mResponse: %s\033[0m\n\n",
471                            client.id, client.seq_id, n_seq, client.n_prompt, client.n_decoded,
472                            (t_main_end - client.t_start_prompt) / 1e6,
473                            (double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6,
474                            n_cache_miss,
475                            ::trim(client.input).c_str(),
476                            ::trim(client.response).c_str());
477
478                    n_total_prompt += client.n_prompt;
479                    n_total_gen    += client.n_decoded;
480
481                    client.seq_id = -1;
482                }
483
484                client.i_batch = -1;
485            }
486        }
487    }
488
489    const auto t_main_end = ggml_time_us();
490
491    print_date_time();
492
493    LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
494    if (params.prompt_file.empty()) {
495        params.prompt_file = "used built-in defaults";
496    }
497    LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str());
498    LOG_INF("Model and path used:  \033[32m%s\033[0m\n\n", params.model.path.c_str());
499
500    LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt              ) / (t_main_end - t_main_start) * 1e6);
501    LOG_INF("Total gen tokens:    %6d, speed: %5.2f t/s\n", n_total_gen,    (double) (n_total_gen                 ) / (t_main_end - t_main_start) * 1e6);
502    LOG_INF("Total speed (AVG):   %6s  speed: %5.2f t/s\n", "",             (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
503    LOG_INF("Cache misses:        %6d\n", n_cache_miss);
504
505    LOG_INF("\n");
506
507    // TODO: print sampling/grammar timings for all clients
508    llama_perf_context_print(ctx);
509
510    llama_batch_free(batch);
511
512    llama_backend_free();
513
514    LOG("\n\n");
515
516    return 0;
517}