1#include "arg.h"
2#include "common.h"
3#include "log.h"
4#include "llama.h"
5#include "sampling.h"
6
7#include <algorithm>
8#include <cstdio>
9#include <string>
10#include <vector>
11
12static void print_usage(int, char ** argv) {
13 LOG("\nexample usage:\n");
14 LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32 -np 4\n", argv[0]);
15 LOG("\n");
16}
17
18int main(int argc, char ** argv) {
19 common_params params;
20
21 params.prompt = "Hello my name is";
22 params.n_predict = 32;
23
24 if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BATCHED, print_usage)) {
25 return 1;
26 }
27
28 common_init();
29
30 // number of parallel batches
31 int n_parallel = params.n_parallel;
32
33 // total length of the sequences including the prompt
34 int n_predict = params.n_predict;
35
36 // init LLM
37
38 llama_backend_init();
39 llama_numa_init(params.numa);
40
41 // initialize the model
42
43 llama_model_params model_params = common_model_params_to_llama(params);
44
45 llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
46
47 if (model == NULL) {
48 LOG_ERR("%s: error: unable to load model\n" , __func__);
49 return 1;
50 }
51
52 const llama_vocab * vocab = llama_model_get_vocab(model);
53
54 // tokenize the prompt
55
56 std::vector<llama_token> tokens_list;
57 tokens_list = common_tokenize(vocab, params.prompt, true);
58
59 const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel;
60
61 // initialize the context
62
63 llama_context_params ctx_params = common_context_params_to_llama(params);
64
65 ctx_params.n_ctx = n_kv_req;
66 ctx_params.n_batch = std::max(n_predict, n_parallel);
67
68 auto sparams = llama_sampler_chain_default_params();
69 sparams.no_perf = false;
70
71 std::vector<llama_sampler_seq_config> sampler_configs;
72
73 for (int32_t i = 0; i < n_parallel; ++i) {
74 llama_sampler * smpl = llama_sampler_chain_init(sparams);
75
76 llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
77 llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
78 llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
79 llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
80
81 sampler_configs.push_back({ i, smpl });
82 }
83
84 if (params.sampling.backend_sampling) {
85 ctx_params.samplers = sampler_configs.data();
86 ctx_params.n_samplers = sampler_configs.size();
87 }
88
89 llama_context * ctx = llama_init_from_model(model, ctx_params);
90
91 if (ctx == NULL) {
92 LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
93 return 1;
94 }
95
96 const int n_ctx = llama_n_ctx(ctx);
97
98 LOG_INF("\n%s: n_predict = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
99
100 // make sure the KV cache is big enough to hold all the prompt and generated tokens
101 if (n_kv_req > n_ctx) {
102 LOG_ERR("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
103 LOG_ERR("%s: either reduce n_parallel or increase n_ctx\n", __func__);
104 return 1;
105 }
106
107 // print the prompt token-by-token
108
109 LOG("\n");
110
111 for (auto id : tokens_list) {
112 LOG("%s", common_token_to_piece(ctx, id).c_str());
113 }
114
115 // create a llama_batch
116 // we use this object to submit token data for decoding
117 llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
118
119 std::vector<llama_seq_id> seq_ids(n_parallel, 0);
120 for (int32_t i = 0; i < n_parallel; ++i) {
121 seq_ids[i] = i;
122 }
123
124 // evaluate the initial prompt
125 for (size_t i = 0; i < tokens_list.size(); ++i) {
126 common_batch_add(batch, tokens_list[i], i, seq_ids, false);
127 }
128 GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
129
130 if (llama_model_has_encoder(model)) {
131 if (llama_encode(ctx, batch)) {
132 LOG_ERR("%s : failed to eval\n", __func__);
133 return 1;
134 }
135
136 llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
137 if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
138 decoder_start_token_id = llama_vocab_bos(vocab);
139 }
140
141 common_batch_clear(batch);
142 common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
143 }
144
145 // llama_decode will output logits only for the last token of the prompt
146 batch.logits[batch.n_tokens - 1] = true;
147
148 if (llama_decode(ctx, batch) != 0) {
149 LOG_ERR("%s: llama_decode() failed\n", __func__);
150 return 1;
151 }
152
153 //// assign the system KV cache to all parallel sequences
154 //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
155 //for (int32_t i = 1; i < n_parallel; ++i) {
156 // llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
157 //}
158
159 if (n_parallel > 1) {
160 LOG("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
161 }
162
163 // main loop
164
165 // we will store the parallel decoded sequences in this vector
166 std::vector<std::string> streams(n_parallel);
167
168 // remember the batch index of the last token for each parallel sequence
169 // we need this to determine which logits to sample from
170 std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
171
172 int n_cur = batch.n_tokens;
173 int n_decode = 0;
174
175 const auto t_main_start = ggml_time_us();
176
177 while (n_cur <= n_predict) {
178 // prepare the next batch
179 common_batch_clear(batch);
180
181 // sample the next token for each parallel sequence / stream
182 for (int32_t i = 0; i < n_parallel; ++i) {
183 if (i_batch[i] < 0) {
184 // the stream has already finished
185 continue;
186 }
187
188 const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]);
189
190 // is it an end of generation? -> mark the stream as finished
191 if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
192 i_batch[i] = -1;
193 LOG("\n");
194 if (n_parallel > 1) {
195 LOG_INF("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
196 }
197
198 continue;
199 }
200
201 // if there is only one stream, we print immediately to stdout
202 if (n_parallel == 1) {
203 LOG("%s", common_token_to_piece(ctx, new_token_id).c_str());
204 }
205
206 streams[i] += common_token_to_piece(ctx, new_token_id);
207
208 i_batch[i] = batch.n_tokens;
209
210 // push this new token for next evaluation
211 common_batch_add(batch, new_token_id, n_cur, { i }, true);
212
213 n_decode += 1;
214 }
215
216 // all streams are finished
217 if (batch.n_tokens == 0) {
218 break;
219 }
220
221 n_cur += 1;
222
223 // evaluate the current batch with the transformer model
224 if (llama_decode(ctx, batch)) {
225 LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
226 return 1;
227 }
228 }
229
230 if (n_parallel > 1) {
231 LOG("\n");
232
233 for (int32_t i = 0; i < n_parallel; ++i) {
234 LOG("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
235 }
236 }
237
238 const auto t_main_end = ggml_time_us();
239
240 LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
241 __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
242
243 LOG("\n");
244 llama_perf_sampler_print(sampler_configs[0].sampler);
245 llama_perf_context_print(ctx);
246
247 fprintf(stderr, "\n");
248
249 llama_batch_free(batch);
250
251 for (auto & sampler_config : sampler_configs) {
252 llama_sampler_free(sampler_config.sampler);
253 }
254
255 llama_free(ctx);
256 llama_model_free(model);
257
258 llama_backend_free();
259
260 return 0;
261}