1#include "llama.h"
2#include "vectordb.h"
3#include "models.h"
4
5#define NONSTD_IMPLEMENTATION
6#include "nonstd.h"
7
8#include <stdio.h>
9#include <stdlib.h>
10#include <string.h>
11#include <getopt.h>
12
13#include "prompts/lotr.h"
14
15static void llama_log_callback(enum ggml_log_level level, const char *text, void *user_data) {
16 (void)level;
17 (void)user_data;
18 (void)text;
19}
20
21static void list_available_models() {
22 printf("Model list:\n");
23 ModelConfig model;
24 static_foreach(ModelConfig, model, models) {
25 printf(" - %s [ctx: %d, temp: %f]\n", model.name, model.n_ctx, model.temperature);
26 }
27}
28
29static void show_help(const char *prog) {
30 printf("Usage: %s [OPTIONS]\n", prog);
31 printf("Options:\n");
32 printf(" -m, --model <name> Specify model to use (default: first model)\n");
33 printf(" -e, --embed-model <name> Specify model to use for embeddings\n");
34 printf(" -p, --prompt <text> Specify prompt text (default: \"What is 2+2?\")\n");
35 printf(" -c, --context <file> Specify vector database file (.vdb)\n");
36 printf(" -l, --list Lists all available models\n");
37 printf(" -v, --verbose Enable verbose logging\n");
38 printf(" -h, --help Show this help message\n");
39}
40
41static int has_vdb_extension(const char *path) {
42 size_t len = strlen(path);
43 const char *ext = ".vdb";
44 size_t ext_len = strlen(ext);
45 if (len < ext_len) {
46 return 0;
47 }
48 return strcmp(path + (len - ext_len), ext) == 0;
49}
50
51static void append_prompt_context(stringb *sb, const char *context, const char *question) {
52 sb_append_cstr(sb, "Context:\n");
53 if (context && context[0] != '\0') {
54 sb_append_cstr(sb, context);
55 }
56 sb_append_cstr(sb, "\nQuestion:\n");
57 sb_append_cstr(sb, question ? question : "");
58}
59
60static char *build_prompt(const ModelConfig *cfg, const char *system, const char *context,
61 const char *question) {
62 stringb full = {0};
63 sb_init(&full, 0);
64
65 switch (cfg->prompt_style) {
66 case PROMPT_STYLE_T5:
67 sb_append_cstr(&full, "instruction: ");
68 sb_append_cstr(&full, system ? system : "");
69 sb_append_cstr(&full, "\nquestion: ");
70 sb_append_cstr(&full, question ? question : "");
71 sb_append_cstr(&full, "\ncontext:\n");
72 if (context && context[0] != '\0') {
73 sb_append_cstr(&full, context);
74 }
75 sb_append_cstr(&full, "\nanswer:");
76 break;
77 case PROMPT_STYLE_CHAT:
78 sb_append_cstr(&full, "System:\n");
79 sb_append_cstr(&full, system ? system : "");
80 sb_append_cstr(&full, "\nUser:\n");
81 append_prompt_context(&full, context, question);
82 sb_append_cstr(&full, "\nAssistant:");
83 break;
84 case PROMPT_STYLE_PLAIN:
85 default:
86 sb_append_cstr(&full, "System:\n");
87 sb_append_cstr(&full, system ? system : "");
88 sb_append_cstr(&full, "\n");
89 append_prompt_context(&full, context, question);
90 sb_append_cstr(&full, "\nAnswer:");
91 break;
92 }
93
94 return full.data;
95}
96
97static int execute_prompt_with_context(const ModelConfig *cfg, const char *prompt,
98 const char *context, int n_predict) {
99 if (cfg == NULL) {
100 log_message(stderr, LOG_ERROR, "Model config is missing");
101 return 1;
102 }
103
104 char *system_prefix = (char *)malloc(prompts_lotr_txt_len + 1);
105 if (system_prefix == NULL) {
106 log_message(stderr, LOG_ERROR, "Failed to allocate system prompt");
107 return 1;
108 }
109 memcpy(system_prefix, prompts_lotr_txt, prompts_lotr_txt_len);
110 system_prefix[prompts_lotr_txt_len] = '\0';
111
112 ggml_backend_load_all();
113
114 struct llama_model_params model_params = llama_model_default_params();
115 model_params.n_gpu_layers = cfg->n_gpu_layers;
116 model_params.use_mmap = cfg->use_mmap;
117
118 struct llama_model *model = llama_model_load_from_file(cfg->filepath, model_params);
119 if (model == NULL) {
120 log_message(stderr, LOG_ERROR, "Unable to load model from %s", cfg->filepath);
121 return 1;
122 }
123
124 const struct llama_vocab *vocab = llama_model_get_vocab(model);
125
126 const char *system_text = system_prefix;
127 if (strncmp(system_prefix, "System:", 7) == 0) {
128 system_text = system_prefix + 7;
129 while (*system_text == ' ' || *system_text == '\n' || *system_text == '\r') {
130 system_text++;
131 }
132 }
133
134 char *full_prompt = build_prompt(cfg, system_text, context, prompt);
135 if (full_prompt == NULL) {
136 log_message(stderr, LOG_ERROR, "Failed to build prompt");
137 free(system_prefix);
138 llama_model_free(model);
139 return 1;
140 }
141
142 int n_prompt = -llama_tokenize(vocab, full_prompt, strlen(full_prompt), NULL, 0, true, true);
143 llama_token *prompt_tokens = (llama_token *)malloc((size_t)n_prompt * sizeof(llama_token));
144 if (prompt_tokens == NULL) {
145 log_message(stderr, LOG_ERROR, "Failed to allocate prompt tokens");
146 free(full_prompt);
147 free(system_prefix);
148 llama_model_free(model);
149 return 1;
150 }
151 if (llama_tokenize(vocab, full_prompt, strlen(full_prompt), prompt_tokens, n_prompt, true, true) < 0) {
152 log_message(stderr, LOG_ERROR, "Failed to tokenize prompt");
153 free(full_prompt);
154 free(prompt_tokens);
155 free(system_prefix);
156 llama_model_free(model);
157 return 1;
158 }
159
160 struct llama_context_params ctx_params = llama_context_default_params();
161 ctx_params.n_ctx = cfg->n_ctx;
162 ctx_params.n_batch = cfg->n_batch;
163 ctx_params.embeddings = cfg->embeddings;
164
165 struct llama_context *ctx = llama_init_from_model(model, ctx_params);
166 if (ctx == NULL) {
167 log_message(stderr, LOG_ERROR, "Failed to create llama_context");
168 free(full_prompt);
169 free(prompt_tokens);
170 free(system_prefix);
171 llama_model_free(model);
172 return 1;
173 }
174
175 struct llama_sampler_chain_params sparams = llama_sampler_chain_default_params();
176 struct llama_sampler *smpl = llama_sampler_chain_init(sparams);
177 if (cfg->top_k > 0) {
178 llama_sampler_chain_add(smpl, llama_sampler_init_top_k(cfg->top_k));
179 }
180 if (cfg->top_p > 0.0f && cfg->top_p < 1.0f) {
181 llama_sampler_chain_add(smpl, llama_sampler_init_top_p(cfg->top_p, 1));
182 }
183 if (cfg->min_p > 0.0f) {
184 llama_sampler_chain_add(smpl, llama_sampler_init_min_p(cfg->min_p, 1));
185 }
186 llama_sampler_chain_add(smpl, llama_sampler_init_penalties(
187 cfg->repeat_last_n,
188 cfg->repeat_penalty,
189 cfg->freq_penalty,
190 cfg->presence_penalty));
191 llama_sampler_chain_add(smpl, llama_sampler_init_temp(cfg->temperature));
192 llama_sampler_chain_add(smpl, llama_sampler_init_dist(cfg->seed));
193
194 struct llama_batch batch = llama_batch_get_one(prompt_tokens, n_prompt);
195
196 if (llama_model_has_encoder(model)) {
197 if (llama_encode(ctx, batch)) {
198 log_message(stderr, LOG_ERROR, "Failed to encode prompt");
199 llama_sampler_free(smpl);
200 free(full_prompt);
201 free(prompt_tokens);
202 free(system_prefix);
203 llama_free(ctx);
204 llama_model_free(model);
205 return 1;
206 }
207
208 llama_token decoder_start = llama_model_decoder_start_token(model);
209 if (decoder_start == LLAMA_TOKEN_NULL) {
210 decoder_start = llama_vocab_bos(vocab);
211 }
212 batch = llama_batch_get_one(&decoder_start, 1);
213 }
214
215 printf(">> Prompt: %s\n", prompt);
216 printf(">> Response: ");
217 fflush(stdout);
218
219 int n_pos = 0;
220 llama_token new_token_id;
221 size_t out_cap = 256;
222 size_t out_len = 0;
223 char *out = (char *)malloc(out_cap);
224 if (out == NULL) {
225 log_message(stderr, LOG_ERROR, "Failed to allocate output buffer");
226 free(full_prompt);
227 free(prompt_tokens);
228 free(system_prefix);
229 llama_sampler_free(smpl);
230 llama_free(ctx);
231 llama_model_free(model);
232 return 1;
233 }
234 out[0] = '\0';
235
236 while (n_pos + batch.n_tokens < n_prompt + n_predict) {
237 if (llama_decode(ctx, batch)) {
238 log_message(stderr, LOG_ERROR, "Failed to decode");
239 break;
240 }
241
242 n_pos += batch.n_tokens;
243 new_token_id = llama_sampler_sample(smpl, ctx, -1);
244 if (llama_vocab_is_eog(vocab, new_token_id)) {
245 break;
246 }
247
248 char buf[128];
249 int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
250 if (n < 0) {
251 log_message(stderr, LOG_ERROR, "Failed to convert token to piece");
252 break;
253 }
254 if (out_len == 0 && n > 0 && buf[0] == '\n') {
255 batch = llama_batch_get_one(&new_token_id, 1);
256 continue;
257 }
258 if (out_len + (size_t)n + 1 > out_cap) {
259 while (out_len + (size_t)n + 1 > out_cap) {
260 out_cap *= 2;
261 }
262 char *next = (char *)realloc(out, out_cap);
263 if (next == NULL) {
264 log_message(stderr, LOG_ERROR, "Failed to grow output buffer");
265 break;
266 }
267 out = next;
268 }
269 memcpy(out + out_len, buf, (size_t)n);
270 out_len += (size_t)n;
271 out[out_len] = '\0';
272
273 batch = llama_batch_get_one(&new_token_id, 1);
274 }
275
276 printf("%s\n", out);
277
278 free(full_prompt);
279 free(prompt_tokens);
280 free(system_prefix);
281 free(out);
282
283 llama_sampler_free(smpl);
284 llama_free(ctx);
285 llama_model_free(model);
286
287 return 0;
288}
289
290int main(int argc, char **argv) {
291 set_log_level(LOG_DEBUG);
292
293 const char *model_name = NULL;
294 const char *prompt = NULL;
295 const char *context_file = NULL;
296 int verbose = 0;
297 const char *embed_model_name = NULL;
298
299 int n_predict = 0;
300
301 static struct option long_options[] = {
302 {"model", required_argument, 0, 'm'},
303 {"prompt", required_argument, 0, 'p'},
304 {"context", required_argument, 0, 'c'},
305 {"embed-model", required_argument, 0, 'e'},
306 {"list", no_argument, 0, 'l'},
307 {"verbose", no_argument, 0, 'v'},
308 {"help", no_argument, 0, 'h'},
309 {0, 0, 0, 0}
310 };
311
312 int opt;
313 int option_index = 0;
314 while ((opt = getopt_long(argc, argv, "m:p:c:e:lvh", long_options, &option_index)) != -1) {
315 switch (opt) {
316 case 'm':
317 model_name = optarg;
318 break;
319 case 'p':
320 prompt = optarg;
321 break;
322 case 'c':
323 context_file = optarg;
324 break;
325 case 'e':
326 embed_model_name = optarg;
327 break;
328 case 'v':
329 verbose = 1;
330 break;
331 case 'l':
332 list_available_models();
333 return 0;
334 case 'h':
335 show_help(argv[0]);
336 return 0;
337 default:
338 fprintf(stderr, "Usage: %s [-m model] [-p prompt] [-h]\n", argv[0]);
339 return 1;
340 }
341 }
342
343 if (verbose == 0) {
344 llama_log_set(llama_log_callback, NULL);
345 }
346
347 if (prompt == NULL) {
348 log_message(stderr, LOG_ERROR, "Prompt must be provided. Exiting...");
349 return 1;
350 }
351
352 if (model_name == NULL) {
353 log_message(stderr, LOG_ERROR, "Model must be provided. Exiting...");
354 return 1;
355 }
356
357 if (context_file == NULL) {
358 log_message(stderr, LOG_ERROR, "Context .vdb file must be provided. Exiting...");
359 return 1;
360 }
361
362 if (!has_vdb_extension(context_file)) {
363 log_message(stderr, LOG_ERROR, "Context file must be a .vdb vector database");
364 return 1;
365 }
366
367 llama_backend_init();
368
369 const ModelConfig *cfg = NULL;
370 if (model_name != NULL) {
371 cfg = get_model_by_name(model_name);
372 if (cfg == NULL) {
373 log_message(stderr, LOG_ERROR, "Unknown model '%s'", model_name);
374 llama_backend_free();
375 return 1;
376 }
377 } else {
378 cfg = &models[0];
379 }
380
381 const ModelConfig *embed_cfg = NULL;
382 if (embed_model_name != NULL) {
383 embed_cfg = get_model_by_name(embed_model_name);
384 if (embed_cfg == NULL) {
385 log_message(stderr, LOG_ERROR, "Unknown embedding model '%s'", embed_model_name);
386 llama_backend_free();
387 return 1;
388 }
389 } else if (cfg->embed_model_name != NULL) {
390 embed_cfg = get_model_by_name(cfg->embed_model_name);
391 }
392 if (embed_cfg == NULL) {
393 embed_cfg = cfg;
394 }
395
396 if (n_predict <= 0) {
397 n_predict = cfg->n_predict > 0 ? cfg->n_predict : 128;
398 }
399
400 struct llama_model_params embed_params = llama_model_default_params();
401 embed_params.n_gpu_layers = embed_cfg->n_gpu_layers;
402 embed_params.use_mmap = embed_cfg->use_mmap;
403 struct llama_model *model = llama_model_load_from_file(embed_cfg->filepath, embed_params);
404 if (model == NULL) {
405 log_message(stderr, LOG_ERROR, "Unable to load embedding model");
406 llama_backend_free();
407 return 1;
408 }
409
410 struct llama_context_params cparams = llama_context_default_params();
411 cparams.n_ctx = embed_cfg->n_ctx;
412 cparams.n_batch = embed_cfg->n_batch;
413 cparams.embeddings = true;
414
415 struct llama_context *embed_ctx = llama_init_from_model(model, cparams);
416 if (embed_ctx == NULL) {
417 log_message(stderr, LOG_ERROR, "Failed to create embedding context");
418 llama_model_free(model);
419 llama_backend_free();
420 return 1;
421 }
422
423 VectorDB db = {};
424 vdb_init(&db, embed_ctx);
425 VectorDBErrorCode vdb_rc = vdb_load(&db, context_file);
426 if (vdb_rc != VDB_SUCCESS) {
427 log_message(stderr, LOG_ERROR, "Failed to load vector database %s: %s", context_file, vdb_error(vdb_rc));
428 llama_free(embed_ctx);
429 llama_model_free(model);
430 llama_backend_free();
431 return 1;
432 }
433
434 float query[VDB_EMBED_SIZE];
435 int results[5];
436 for (int i = 0; i < 5; i++) {
437 results[i] = -1;
438 }
439
440 vdb_embed_query(&db, prompt, query);
441 vdb_search(&db, query, 5, results);
442
443 size_t context_cap = 1024;
444 size_t context_len = 0;
445 char *context = (char *)malloc(context_cap);
446 if (context == NULL) {
447 log_message(stderr, LOG_ERROR, "Failed to allocate context buffer");
448 llama_free(embed_ctx);
449 llama_model_free(model);
450 llama_backend_free();
451 return 1;
452 }
453 context[0] = '\0';
454
455 for (int i = 0; i < 5; i++) {
456 if (results[i] < 0) {
457 continue;
458 }
459 const char *text = db.docs[results[i]].text;
460 char header[32];
461 int header_len = snprintf(header, sizeof(header), "Snippet %d:\n", i + 1);
462 size_t text_len = strlen(text);
463 size_t need = context_len + (size_t)header_len + text_len + 2;
464 if (need > context_cap) {
465 while (need > context_cap) {
466 context_cap *= 2;
467 }
468 char *next = (char *)realloc(context, context_cap);
469 if (next == NULL) {
470 log_message(stderr, LOG_ERROR, "Failed to grow context buffer");
471 free(context);
472 llama_free(embed_ctx);
473 llama_model_free(model);
474 llama_backend_free();
475 return 1;
476 }
477 context = next;
478 }
479 if (header_len > 0) {
480 memcpy(context + context_len, header, (size_t)header_len);
481 context_len += (size_t)header_len;
482 }
483 memcpy(context + context_len, text, text_len);
484 context_len += text_len;
485 context[context_len++] = '\n';
486 context[context_len] = '\0';
487 }
488
489 llama_free(embed_ctx);
490 llama_model_free(model);
491
492 int rc = execute_prompt_with_context(cfg, prompt, context, n_predict);
493 free(context);
494 llama_backend_free();
495 return rc;
496}