1#include "arg.h"
  2#include "common.h"
  3#include "log.h"
  4#include "llama.h"
  5
  6#include <algorithm>
  7#include <cstdio>
  8#include <string>
  9#include <vector>
 10
 11static void print_usage(int, char ** argv) {
 12    LOG("\nexample usage:\n");
 13    LOG("\n    %s -m model.gguf -c 2048 -b 2048 -ub 512 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32 [-pps]\n", argv[0]);
 14    LOG("\n");
 15}
 16
 17int main(int argc, char ** argv) {
 18    common_params params;
 19
 20    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) {
 21        return 1;
 22    }
 23
 24    common_init();
 25
 26    int is_pp_shared   = params.is_pp_shared;
 27    int is_tg_separate = params.is_tg_separate;
 28
 29    std::vector<int> n_pp = params.n_pp;
 30    std::vector<int> n_tg = params.n_tg;
 31    std::vector<int> n_pl = params.n_pl;
 32
 33    // init LLM
 34
 35    llama_backend_init();
 36    llama_numa_init(params.numa);
 37
 38    // initialize the model
 39
 40    llama_model_params model_params = common_model_params_to_llama(params);
 41
 42    llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
 43
 44    if (model == NULL) {
 45        fprintf(stderr , "%s: error: unable to load model\n" , __func__);
 46        return 1;
 47    }
 48
 49    llama_context_params ctx_params = common_context_params_to_llama(params);
 50
 51    // ensure enough sequences are available
 52    ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end());
 53
 54    llama_context * ctx = llama_init_from_model(model, ctx_params);
 55
 56    if (ctx == NULL) {
 57        fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
 58        llama_model_free(model);
 59        return 1;
 60    }
 61
 62    const llama_vocab * vocab   = llama_model_get_vocab(model);
 63    const int32_t       n_vocab = llama_vocab_n_tokens(vocab);
 64
 65    const auto get_token_rand = [n_vocab]() -> llama_token {
 66        return std::rand() % n_vocab;
 67    };
 68
 69    auto * mem = llama_get_memory(ctx);
 70
 71    const int32_t n_kv_max = llama_n_ctx(ctx);
 72
 73    llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
 74
 75    // decode in batches of ctx_params.n_batch tokens
 76    auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) {
 77        for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
 78            const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
 79
 80            llama_batch batch_view = {
 81                n_tokens,
 82                batch.token    + i,
 83                nullptr,
 84                batch.pos      + i,
 85                batch.n_seq_id + i,
 86                batch.seq_id   + i,
 87                batch.logits   + i,
 88            };
 89
 90            const int ret = llama_decode(ctx, batch_view);
 91            if (ret != 0) {
 92                LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
 93                return false;
 94            }
 95
 96            if (synchronize) {
 97                llama_synchronize(ctx);
 98            }
 99        }
100
101        return true;
102    };
103
104    // warm up
105    {
106        for (int i = 0; i < 16; ++i) {
107            common_batch_add(batch, get_token_rand(), i, { 0 }, false);
108        }
109
110        if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
111            LOG_ERR("%s: llama_decode() failed\n", __func__);
112            llama_free(ctx);
113            llama_model_free(model);
114            return 1;
115        }
116    }
117
118    if (!params.batched_bench_output_jsonl) {
119        LOG("\n");
120        LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, is_tg_separate = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), is_pp_shared, is_tg_separate, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
121        LOG("\n");
122        LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
123        LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
124    }
125
126    for (        int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) {
127        for (    int i_tg = 0; i_tg < (int) n_tg.size(); ++i_tg) {
128            for (int i_pl = 0; i_pl < (int) n_pl.size(); ++i_pl) {
129                const int pp = n_pp[i_pp];
130                const int tg = n_tg[i_tg];
131                const int pl = n_pl[i_pl];
132
133                const int n_ctx_req = is_pp_shared ? (params.kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
134
135                if (n_ctx_req > n_kv_max) {
136                    continue;
137                }
138
139                common_batch_clear(batch);
140
141                for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
142                    for (int i = 0; i < pp; ++i) {
143                        common_batch_add(batch, get_token_rand(), i, { j }, i == pp - 1);
144                    }
145                }
146
147                llama_memory_clear(mem, false);
148
149                const auto t_pp_start = ggml_time_us();
150
151                if (!decode_helper(ctx, batch, ctx_params.n_batch, false)) {
152                    LOG_ERR("%s: llama_decode() failed\n", __func__);
153                    llama_free(ctx);
154                    llama_model_free(model);
155                    return 1;
156                }
157
158                llama_synchronize(ctx);
159
160                const auto t_pp_end = ggml_time_us();
161
162                if (is_pp_shared) {
163                    for (int32_t i = 1; i < pl; ++i) {
164                        llama_memory_seq_cp(mem, 0, i, -1, -1);
165                    }
166
167                    if (!params.kv_unified) {
168                        // run one dummy token to apply the memory copy
169                        common_batch_clear(batch);
170                        common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
171                        if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
172                            LOG_ERR("%s: llama_decode() failed\n", __func__);
173                            llama_free(ctx);
174                            llama_model_free(model);
175                            return 1;
176                        }
177                        llama_memory_seq_rm(mem, 0, pp, -1);
178                    }
179                }
180
181                const auto t_tg_start = ggml_time_us();
182
183                if (is_tg_separate) {
184                    // decode pattern:
185                    // 0 0 0 ... 1 1 1 ... 2 2 2 ... 3 3 3 ...
186                    for (int j = 0; j < pl; ++j) {
187                        for (int i = 0; i < tg; ++i) {
188                            common_batch_clear(batch);
189
190                            common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
191
192                            if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
193                                LOG_ERR("%s: llama_decode() failed\n", __func__);
194                                llama_free(ctx);
195                                llama_model_free(model);
196                                return 1;
197                            }
198                        }
199                    }
200                } else {
201                    // decode pattern:
202                    // 0123 0123 0123 ...
203                    for (int i = 0; i < tg; ++i) {
204                        common_batch_clear(batch);
205
206                        for (int j = 0; j < pl; ++j) {
207                            common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
208                        }
209
210                        if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
211                            LOG_ERR("%s: llama_decode() failed\n", __func__);
212                            llama_free(ctx);
213                            llama_model_free(model);
214                            return 1;
215                        }
216                    }
217                }
218
219                const auto t_tg_end = ggml_time_us();
220
221                const int32_t n_kv = n_ctx_req;
222
223                const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f;
224                const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f;
225                const float t    = t_pp + t_tg;
226
227                const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp;
228                const float speed_tg = pl*tg / t_tg;
229                const float speed    = ((is_pp_shared ? pp : pl*pp) + pl*tg) / t;
230
231                if(params.batched_bench_output_jsonl) {
232                    LOG(
233                        "{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"is_pp_shared\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, "
234                        "\"pp\": %d, \"tg\": %d, \"pl\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f, \"t\": %f, \"speed\": %f}\n",
235                        n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch,
236                        pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed
237                    );
238                } else {
239                    LOG("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed);
240                }
241            }
242        }
243    }
244
245    LOG("\n");
246    llama_perf_context_print(ctx);
247
248    llama_batch_free(batch);
249
250    llama_free(ctx);
251    llama_model_free(model);
252
253    llama_backend_free();
254
255    return 0;
256}