|
diff --git a/game.c b/game.c
|
|
|
1 |
#include <getopt.h> |
| 1 |
#include <stdio.h> |
2 |
#include <stdio.h> |
|
|
3 |
#include <stdlib.h> |
|
|
4 |
#include <string.h> |
|
|
5 |
#include <strings.h> |
| 2 |
|
6 |
|
| 3 |
#define TB_IMPL |
7 |
#define TB_IMPL |
| 4 |
#include "termbox2.h" |
8 |
#include "termbox2.h" |
| ... |
| 6 |
#define NONSTD_IMPLEMENTATION |
10 |
#define NONSTD_IMPLEMENTATION |
| 7 |
#include "nonstd.h" |
11 |
#include "nonstd.h" |
| 8 |
|
12 |
|
|
|
13 |
#include "llama.h" |
|
|
14 |
#include "models.h" |
|
|
15 |
#include "vectordb.h" |
| 9 |
#include "maps.h" |
16 |
#include "maps.h" |
| 10 |
|
17 |
|
| 11 |
#define MIN_W 40 |
18 |
#define MIN_W 40 |
| ... |
| 61 |
char input[128]; |
68 |
char input[128]; |
| 62 |
int input_len; |
69 |
int input_len; |
| 63 |
int npc_index; |
70 |
int npc_index; |
|
|
71 |
const char *npc_name; |
| 64 |
DialogEntry entries[DIALOG_HISTORY_MAX]; |
72 |
DialogEntry entries[DIALOG_HISTORY_MAX]; |
| 65 |
int entry_count; |
73 |
int entry_count; |
| 66 |
} Dialog; |
74 |
} Dialog; |
| 67 |
|
75 |
|
|
|
76 |
typedef struct { |
|
|
77 |
const ModelConfig *model_cfg; |
|
|
78 |
struct llama_model *model; |
|
|
79 |
struct llama_model *embed_model; |
|
|
80 |
struct llama_context *embed_ctx; |
|
|
81 |
VectorDB *npc_dbs; |
|
|
82 |
int *npc_db_loaded; |
|
|
83 |
int verbose; |
|
|
84 |
} GameRuntime; |
|
|
85 |
|
|
|
86 |
static void llama_log_callback(enum ggml_log_level level, const char *text, void *user_data) { |
|
|
87 |
(void)level; |
|
|
88 |
(void)user_data; |
|
|
89 |
(void)text; |
|
|
90 |
} |
|
|
91 |
|
| 68 |
static int clamp(int value, int min, int max); |
92 |
static int clamp(int value, int min, int max); |
| 69 |
|
93 |
|
|
|
94 |
static void show_help(const char *prog) { |
|
|
95 |
printf("Usage: %s [OPTIONS]\n", prog); |
|
|
96 |
printf("Options:\n"); |
|
|
97 |
printf(" -m, --model <name> Specify model to use (default: first model)\n"); |
|
|
98 |
printf(" -e, --embed-model <name> Specify model to use for embeddings\n"); |
|
|
99 |
printf(" -v, --verbose Enable verbose logging\n"); |
|
|
100 |
printf(" -h, --help Show this help message\n"); |
|
|
101 |
} |
|
|
102 |
|
| 70 |
static void draw_border(int x, int y, int w, int h, uintattr_t fg) { |
103 |
static void draw_border(int x, int y, int w, int h, uintattr_t fg) { |
| 71 |
int ix; |
104 |
int ix; |
| 72 |
int iy; |
105 |
int iy; |
| ... |
| 87 |
} |
120 |
} |
| 88 |
|
121 |
|
| 89 |
static void draw_border_bg(int x, int y, int w, int h, uintattr_t fg, |
122 |
static void draw_border_bg(int x, int y, int w, int h, uintattr_t fg, |
| 90 |
uintattr_t bg) { |
123 |
uintattr_t bg) { |
| 91 |
int ix; |
124 |
int ix; |
| 92 |
int iy; |
125 |
int iy; |
| 93 |
|
126 |
|
| ... |
| 107 |
} |
140 |
} |
| 108 |
|
141 |
|
| 109 |
static void get_layout(int w, int h, int *map_x, int *map_y, int *map_w, |
142 |
static void get_layout(int w, int h, int *map_x, int *map_y, int *map_w, |
| 110 |
int *map_h, int *side_x, int *side_y, int *side_w, int *side_h, |
143 |
int *map_h, int *side_x, int *side_y, int *side_w, int *side_h, |
| 111 |
int *msg1_y, int *msg2_y) { |
144 |
int *msg1_y, int *msg2_y) { |
| 112 |
*map_x = 0; |
145 |
*map_x = 0; |
| 113 |
*map_y = 0; |
146 |
*map_y = 0; |
| 114 |
*map_w = w - SIDEBAR_W; |
147 |
*map_w = w - SIDEBAR_W; |
| ... |
| 226 |
} |
259 |
} |
| 227 |
|
260 |
|
| 228 |
static void update_camera(const Map *map, int view_w, int view_h, |
261 |
static void update_camera(const Map *map, int view_w, int view_h, |
| 229 |
const Player *player, int *cam_x, int *cam_y) { |
262 |
const Player *player, int *cam_x, int *cam_y) { |
| 230 |
int max_cam_x; |
263 |
int max_cam_x; |
| 231 |
int max_cam_y; |
264 |
int max_cam_y; |
| 232 |
int margin_x; |
265 |
int margin_x; |
| ... |
| 267 |
} |
300 |
} |
| 268 |
|
301 |
|
| 269 |
static void draw_map(const Map *map, int map_x, int map_y, int view_w, |
302 |
static void draw_map(const Map *map, int map_x, int map_y, int view_w, |
| 270 |
int view_h, const Player *player, int cam_x, int cam_y) { |
303 |
int view_h, const Player *player, int cam_x, int cam_y) { |
| 271 |
int ix; |
304 |
int ix; |
| 272 |
int iy; |
305 |
int iy; |
| 273 |
|
306 |
|
| ... |
| 296 |
} |
329 |
} |
| 297 |
|
330 |
|
| 298 |
if (player->x >= cam_x && player->x < cam_x + view_w && player->y >= cam_y |
331 |
if (player->x >= cam_x && player->x < cam_x + view_w && player->y >= cam_y |
| 299 |
&& player->y < cam_y + view_h) { |
332 |
&& player->y < cam_y + view_h) { |
| 300 |
int sx = map_x + (player->x - cam_x); |
333 |
int sx = map_x + (player->x - cam_x); |
| 301 |
int sy = map_y + (player->y - cam_y); |
334 |
int sy = map_y + (player->y - cam_y); |
| 302 |
tb_set_cell(sx, sy, '@', COLOR_GREEN_256 | TB_BOLD, TB_DEFAULT); |
335 |
tb_set_cell(sx, sy, '@', COLOR_GREEN_256 | TB_BOLD, TB_DEFAULT); |
| ... |
| 324 |
filled = (inner_w * value) / max; |
357 |
filled = (inner_w * value) / max; |
| 325 |
tb_set_cell(x, y, '[', COLOR_WHITE_256, TB_DEFAULT); |
358 |
tb_set_cell(x, y, '[', COLOR_WHITE_256, TB_DEFAULT); |
| 326 |
for (ix = 0; ix < inner_w; ix++) { |
359 |
for (ix = 0; ix < inner_w; ix++) { |
| 327 |
uintattr_t fg = ix < filled ? COLOR_GREEN_256 : COLOR_WHITE_256; |
360 |
uintattr_t fg = ix < filled ? COLOR_GREEN_256 : COLOR_WHITE_256; |
| 328 |
uint32_t ch = ix < filled ? '=' : ' '; |
361 |
uint32_t ch = ix < filled ? '=' : ' '; |
| 329 |
tb_set_cell(x + 1 + ix, y, ch, fg, TB_DEFAULT); |
362 |
tb_set_cell(x + 1 + ix, y, ch, fg, TB_DEFAULT); |
| 330 |
} |
363 |
} |
| ... |
| 389 |
status_msg = message ? message : ""; |
422 |
status_msg = message ? message : ""; |
| 390 |
} |
423 |
} |
| 391 |
|
424 |
|
| 392 |
static void copy_truncated(char *dst, size_t dst_size, const char *src, int max_chars) { |
425 |
static int draw_wrapped(int x, int y, int max_lines, int box_w, uintattr_t fg, |
| 393 |
int i = 0; |
426 |
uintattr_t bg, const char *prefix, const char *text) { |
| 394 |
if (dst_size == 0) { |
427 |
if (max_lines <= 0 || box_w <= 0 || text == NULL) { |
| 395 |
return; |
428 |
return 0; |
|
|
429 |
} |
|
|
430 |
int lines = 0; |
|
|
431 |
int prefix_len = prefix ? (int)strlen(prefix) : 0; |
|
|
432 |
if (prefix_len < 0) { |
|
|
433 |
prefix_len = 0; |
|
|
434 |
} |
|
|
435 |
int avail = box_w - 4 - prefix_len; |
|
|
436 |
if (avail < 1) { |
|
|
437 |
return 0; |
|
|
438 |
} |
|
|
439 |
char pad[64]; |
|
|
440 |
int pad_len = prefix_len < (int)sizeof(pad) - 1 ? prefix_len : (int)sizeof(pad) - 1; |
|
|
441 |
for (int i = 0; i < pad_len; i++) { |
|
|
442 |
pad[i] = ' '; |
|
|
443 |
} |
|
|
444 |
pad[pad_len] = '\0'; |
|
|
445 |
const char *p = text; |
|
|
446 |
while (*p != '\0' && lines < max_lines) { |
|
|
447 |
while (*p == ' ') { |
|
|
448 |
p++; |
|
|
449 |
} |
|
|
450 |
int line_len = 0; |
|
|
451 |
int last_space = -1; |
|
|
452 |
for (int i = 0; i < avail && p[i] != '\0'; i++) { |
|
|
453 |
if (p[i] == '\n') { |
|
|
454 |
line_len = i; |
|
|
455 |
break; |
|
|
456 |
} |
|
|
457 |
if (p[i] == ' ') { |
|
|
458 |
last_space = i; |
|
|
459 |
} |
|
|
460 |
line_len = i + 1; |
|
|
461 |
} |
|
|
462 |
if (line_len == 0) { |
|
|
463 |
break; |
|
|
464 |
} |
|
|
465 |
int cut = line_len; |
|
|
466 |
if (cut == avail && p[cut] != '\0' && last_space > 0) { |
|
|
467 |
cut = last_space; |
|
|
468 |
} |
|
|
469 |
char buf[512]; |
|
|
470 |
int copy_len = cut < (int)sizeof(buf) - 1 ? cut : (int)sizeof(buf) - 1; |
|
|
471 |
memcpy(buf, p, (size_t)copy_len); |
|
|
472 |
buf[copy_len] = '\0'; |
|
|
473 |
while (copy_len > 0 && buf[copy_len - 1] == ' ') { |
|
|
474 |
buf[copy_len - 1] = '\0'; |
|
|
475 |
copy_len--; |
|
|
476 |
} |
|
|
477 |
const char *line_prefix = (lines == 0) ? (prefix ? prefix : "") : pad; |
|
|
478 |
tb_printf(x, y + lines, fg, bg, "%s%s", line_prefix, buf); |
|
|
479 |
lines++; |
|
|
480 |
p += cut; |
|
|
481 |
if (*p == '\n') { |
|
|
482 |
p++; |
|
|
483 |
} |
|
|
484 |
} |
|
|
485 |
return lines; |
|
|
486 |
} |
|
|
487 |
|
|
|
488 |
static int count_wrapped_lines(int box_w, const char *prefix, const char *text) { |
|
|
489 |
if (box_w <= 0 || text == NULL) { |
|
|
490 |
return 0; |
| 396 |
} |
491 |
} |
| 397 |
if (max_chars < 0) { |
492 |
int prefix_len = prefix ? (int)strlen(prefix) : 0; |
| 398 |
max_chars = 0; |
493 |
if (prefix_len < 0) { |
|
|
494 |
prefix_len = 0; |
| 399 |
} |
495 |
} |
| 400 |
while (i < max_chars && src[i] != '\0' && i < (int)dst_size - 1) { |
496 |
int avail = box_w - 4 - prefix_len; |
| 401 |
dst[i] = src[i]; |
497 |
if (avail < 1) { |
| 402 |
i++; |
498 |
return 0; |
| 403 |
} |
499 |
} |
| 404 |
dst[i] = '\0'; |
500 |
int lines = 0; |
|
|
501 |
const char *p = text; |
|
|
502 |
while (*p != '\0') { |
|
|
503 |
while (*p == ' ') { |
|
|
504 |
p++; |
|
|
505 |
} |
|
|
506 |
int line_len = 0; |
|
|
507 |
int last_space = -1; |
|
|
508 |
for (int i = 0; i < avail && p[i] != '\0'; i++) { |
|
|
509 |
if (p[i] == '\n') { |
|
|
510 |
line_len = i; |
|
|
511 |
break; |
|
|
512 |
} |
|
|
513 |
if (p[i] == ' ') { |
|
|
514 |
last_space = i; |
|
|
515 |
} |
|
|
516 |
line_len = i + 1; |
|
|
517 |
} |
|
|
518 |
if (line_len == 0) { |
|
|
519 |
break; |
|
|
520 |
} |
|
|
521 |
int cut = line_len; |
|
|
522 |
if (cut == avail && p[cut] != '\0' && last_space > 0) { |
|
|
523 |
cut = last_space; |
|
|
524 |
} |
|
|
525 |
lines++; |
|
|
526 |
p += cut; |
|
|
527 |
if (*p == '\n') { |
|
|
528 |
p++; |
|
|
529 |
} |
|
|
530 |
} |
|
|
531 |
return lines; |
| 405 |
} |
532 |
} |
| 406 |
|
533 |
|
| 407 |
static void dialog_open(Dialog *dialog, int npc_index) { |
534 |
static void dialog_open(Dialog *dialog, int npc_index, const char *npc_name) { |
| 408 |
dialog->open = 1; |
535 |
dialog->open = 1; |
| 409 |
dialog->input_len = 0; |
536 |
dialog->input_len = 0; |
| 410 |
dialog->input[0] = '\0'; |
537 |
dialog->input[0] = '\0'; |
| 411 |
dialog->npc_index = npc_index; |
538 |
dialog->npc_index = npc_index; |
|
|
539 |
dialog->npc_name = npc_name; |
| 412 |
} |
540 |
} |
| 413 |
|
541 |
|
| 414 |
static void dialog_close(Dialog *dialog) { |
542 |
static void dialog_close(Dialog *dialog) { |
| 415 |
dialog->open = 0; |
543 |
dialog->open = 0; |
| 416 |
dialog->npc_index = -1; |
544 |
dialog->npc_index = -1; |
|
|
545 |
dialog->npc_name = NULL; |
| 417 |
} |
546 |
} |
| 418 |
|
547 |
|
| 419 |
static void dialog_append(Dialog *dialog, uint32_t ch) { |
548 |
static void dialog_append(Dialog *dialog, uint32_t ch) { |
| ... |
| 435 |
dialog->input[dialog->input_len] = '\0'; |
564 |
dialog->input[dialog->input_len] = '\0'; |
| 436 |
} |
565 |
} |
| 437 |
|
566 |
|
| 438 |
static void dialog_submit(Dialog *dialog, const GameMap *game_map) { |
567 |
static void trim_leading(char **text) { |
|
|
568 |
while (**text == ' ' || **text == '\t' || **text == '\n' || **text == '\r') { |
|
|
569 |
(*text)++; |
|
|
570 |
} |
|
|
571 |
} |
|
|
572 |
|
|
|
573 |
static void trim_leading_punct(char **text) { |
|
|
574 |
while (**text == '"' || **text == '\'' || **text == '`') { |
|
|
575 |
(*text)++; |
|
|
576 |
trim_leading(text); |
|
|
577 |
} |
|
|
578 |
} |
|
|
579 |
|
|
|
580 |
static void trim_trailing(char *text) { |
|
|
581 |
size_t len = strlen(text); |
|
|
582 |
while (len > 0) { |
|
|
583 |
char ch = text[len - 1]; |
|
|
584 |
if (ch != ' ' && ch != '\t' && ch != '\n' && ch != '\r') { |
|
|
585 |
break; |
|
|
586 |
} |
|
|
587 |
text[len - 1] = '\0'; |
|
|
588 |
len--; |
|
|
589 |
} |
|
|
590 |
} |
|
|
591 |
|
|
|
592 |
static void strip_any_prefix(char **text, const char *prefix) { |
|
|
593 |
if (strncasecmp(*text, prefix, strlen(prefix)) == 0) { |
|
|
594 |
*text += strlen(prefix); |
|
|
595 |
trim_leading(text); |
|
|
596 |
} |
|
|
597 |
} |
|
|
598 |
|
|
|
599 |
|
|
|
600 |
static char *sanitize_reply(char *reply, const char *name) { |
|
|
601 |
if (reply == NULL) { |
|
|
602 |
return NULL; |
|
|
603 |
} |
|
|
604 |
char *start = reply; |
|
|
605 |
trim_leading(&start); |
|
|
606 |
trim_leading_punct(&start); |
|
|
607 |
strip_any_prefix(&start, "Answer:"); |
|
|
608 |
strip_any_prefix(&start, "NPC:"); |
|
|
609 |
strip_any_prefix(&start, "Context:"); |
|
|
610 |
strip_any_prefix(&start, "System:"); |
|
|
611 |
if (strncmp(start, "<context>", 9) == 0) { |
|
|
612 |
start += 9; |
|
|
613 |
trim_leading(&start); |
|
|
614 |
} |
|
|
615 |
char *reminder = strstr(start, "<system-reminder>"); |
|
|
616 |
if (reminder) { |
|
|
617 |
*reminder = '\0'; |
|
|
618 |
} |
|
|
619 |
char *system_tag = strstr(start, "<system"); |
|
|
620 |
if (system_tag) { |
|
|
621 |
*system_tag = '\0'; |
|
|
622 |
} |
|
|
623 |
char *tag = strstr(start, "<|"); |
|
|
624 |
if (tag) { |
|
|
625 |
*tag = '\0'; |
|
|
626 |
} |
|
|
627 |
char *eos = strstr(start, "</s>"); |
|
|
628 |
if (eos) { |
|
|
629 |
*eos = '\0'; |
|
|
630 |
} |
|
|
631 |
char *hash = strstr(start, "###"); |
|
|
632 |
if (hash) { |
|
|
633 |
*hash = '\0'; |
|
|
634 |
} |
|
|
635 |
if (name && name[0] != '\0') { |
|
|
636 |
size_t name_len = strlen(name); |
|
|
637 |
for (;;) { |
|
|
638 |
if (strncasecmp(start, name, name_len) != 0) { |
|
|
639 |
break; |
|
|
640 |
} |
|
|
641 |
start += name_len; |
|
|
642 |
while (*start == ':' || *start == '-' || *start == ',') { |
|
|
643 |
start++; |
|
|
644 |
} |
|
|
645 |
trim_leading(&start); |
|
|
646 |
trim_leading_punct(&start); |
|
|
647 |
} |
|
|
648 |
} |
|
|
649 |
if (start != reply) { |
|
|
650 |
memmove(reply, start, strlen(start) + 1); |
|
|
651 |
} |
|
|
652 |
trim_trailing(reply); |
|
|
653 |
return reply; |
|
|
654 |
} |
|
|
655 |
|
|
|
656 |
static int find_substr_offset(const char *buf, int n, const char *needle) { |
|
|
657 |
int needle_len = (int)strlen(needle); |
|
|
658 |
if (needle_len <= 0 || n <= 0 || needle_len > n) { |
|
|
659 |
return -1; |
|
|
660 |
} |
|
|
661 |
for (int i = 0; i + needle_len <= n; i++) { |
|
|
662 |
int match = 1; |
|
|
663 |
for (int j = 0; j < needle_len; j++) { |
|
|
664 |
if (buf[i + j] != needle[j]) { |
|
|
665 |
match = 0; |
|
|
666 |
break; |
|
|
667 |
} |
|
|
668 |
} |
|
|
669 |
if (match) { |
|
|
670 |
return i; |
|
|
671 |
} |
|
|
672 |
} |
|
|
673 |
return -1; |
|
|
674 |
} |
|
|
675 |
|
|
|
676 |
static int find_stop_offset(const char *buf, int n) { |
|
|
677 |
int stop_at = n; |
|
|
678 |
for (int i = 0; i < n; i++) { |
|
|
679 |
if (buf[i] == '\n') { |
|
|
680 |
stop_at = i; |
|
|
681 |
break; |
|
|
682 |
} |
|
|
683 |
} |
|
|
684 |
int off = find_substr_offset(buf, n, "</s>"); |
|
|
685 |
if (off >= 0 && off < stop_at) { |
|
|
686 |
stop_at = off; |
|
|
687 |
} |
|
|
688 |
off = find_substr_offset(buf, n, "<system-reminder>"); |
|
|
689 |
if (off >= 0 && off < stop_at) { |
|
|
690 |
stop_at = off; |
|
|
691 |
} |
|
|
692 |
off = find_substr_offset(buf, n, "<system"); |
|
|
693 |
if (off >= 0 && off < stop_at) { |
|
|
694 |
stop_at = off; |
|
|
695 |
} |
|
|
696 |
off = find_substr_offset(buf, n, "<|"); |
|
|
697 |
if (off >= 0 && off < stop_at) { |
|
|
698 |
stop_at = off; |
|
|
699 |
} |
|
|
700 |
off = find_substr_offset(buf, n, "###"); |
|
|
701 |
if (off >= 0 && off < stop_at) { |
|
|
702 |
stop_at = off; |
|
|
703 |
} |
|
|
704 |
off = find_substr_offset(buf, n, "System:"); |
|
|
705 |
if (off >= 0 && off < stop_at) { |
|
|
706 |
stop_at = off; |
|
|
707 |
} |
|
|
708 |
off = find_substr_offset(buf, n, "User:"); |
|
|
709 |
if (off >= 0 && off < stop_at) { |
|
|
710 |
stop_at = off; |
|
|
711 |
} |
|
|
712 |
off = find_substr_offset(buf, n, "Assistant:"); |
|
|
713 |
if (off >= 0 && off < stop_at) { |
|
|
714 |
stop_at = off; |
|
|
715 |
} |
|
|
716 |
return stop_at; |
|
|
717 |
} |
|
|
718 |
|
|
|
719 |
static void append_prompt_context(stringb *sb, const char *npc_name, const char *context, |
|
|
720 |
const char *question) { |
|
|
721 |
sb_append_cstr(sb, "Context:\n"); |
|
|
722 |
if (npc_name && npc_name[0] != '\0') { |
|
|
723 |
sb_append_cstr(sb, "NPC Name: "); |
|
|
724 |
sb_append_cstr(sb, npc_name); |
|
|
725 |
sb_append_cstr(sb, "\n"); |
|
|
726 |
} |
|
|
727 |
if (context && context[0] != '\0') { |
|
|
728 |
sb_append_cstr(sb, context); |
|
|
729 |
} |
|
|
730 |
sb_append_cstr(sb, "\nQuestion:\n"); |
|
|
731 |
sb_append_cstr(sb, question ? question : ""); |
|
|
732 |
} |
|
|
733 |
|
|
|
734 |
static char *build_prompt(const ModelConfig *cfg, const char *system, const char *npc_name, |
|
|
735 |
const char *context, const char *question) { |
|
|
736 |
stringb full = {0}; |
|
|
737 |
sb_init(&full, 0); |
|
|
738 |
|
|
|
739 |
switch (cfg->prompt_style) { |
|
|
740 |
case PROMPT_STYLE_T5: |
|
|
741 |
sb_append_cstr(&full, "instruction: "); |
|
|
742 |
sb_append_cstr(&full, system ? system : ""); |
|
|
743 |
sb_append_cstr(&full, "\nquestion: "); |
|
|
744 |
sb_append_cstr(&full, question ? question : ""); |
|
|
745 |
sb_append_cstr(&full, "\ncontext:\n"); |
|
|
746 |
if (npc_name && npc_name[0] != '\0') { |
|
|
747 |
sb_append_cstr(&full, "NPC Name: "); |
|
|
748 |
sb_append_cstr(&full, npc_name); |
|
|
749 |
sb_append_cstr(&full, "\n"); |
|
|
750 |
} |
|
|
751 |
if (context && context[0] != '\0') { |
|
|
752 |
sb_append_cstr(&full, context); |
|
|
753 |
} |
|
|
754 |
sb_append_cstr(&full, "\nanswer:"); |
|
|
755 |
break; |
|
|
756 |
case PROMPT_STYLE_CHAT: |
|
|
757 |
sb_append_cstr(&full, "System:\n"); |
|
|
758 |
sb_append_cstr(&full, system ? system : ""); |
|
|
759 |
sb_append_cstr(&full, "\nUser:\n"); |
|
|
760 |
append_prompt_context(&full, npc_name, context, question); |
|
|
761 |
sb_append_cstr(&full, "\nAssistant:"); |
|
|
762 |
break; |
|
|
763 |
case PROMPT_STYLE_PLAIN: |
|
|
764 |
default: |
|
|
765 |
sb_append_cstr(&full, "System:\n"); |
|
|
766 |
sb_append_cstr(&full, system ? system : ""); |
|
|
767 |
sb_append_cstr(&full, "\n"); |
|
|
768 |
append_prompt_context(&full, npc_name, context, question); |
|
|
769 |
sb_append_cstr(&full, "\nAnswer:"); |
|
|
770 |
break; |
|
|
771 |
} |
|
|
772 |
|
|
|
773 |
return full.data; |
|
|
774 |
} |
|
|
775 |
|
|
|
776 |
static char *generate_npc_reply(const GameRuntime *runtime, const GameMap *game_map, |
|
|
777 |
int npc_index, const char *prompt) { |
|
|
778 |
if (runtime == NULL || prompt == NULL) { |
|
|
779 |
return NULL; |
|
|
780 |
} |
|
|
781 |
const char *fallback = "Demo reply: The old ruins are north of here."; |
|
|
782 |
const char *npc_name = NULL; |
|
|
783 |
if (game_map && npc_index >= 0 && npc_index < 10) { |
|
|
784 |
const char *npc_reply = game_map->npcs[npc_index].reply; |
|
|
785 |
npc_name = game_map->npcs[npc_index].name; |
|
|
786 |
if (npc_reply && npc_reply[0] != '\0') { |
|
|
787 |
fallback = npc_reply; |
|
|
788 |
} |
|
|
789 |
} |
|
|
790 |
|
|
|
791 |
if (runtime->model == NULL || runtime->model_cfg == NULL || runtime->embed_ctx == NULL |
|
|
792 |
|| runtime->npc_dbs == NULL || runtime->npc_db_loaded == NULL) { |
|
|
793 |
return strdup(fallback); |
|
|
794 |
} |
|
|
795 |
if (npc_index < 0 || npc_index >= 10 || runtime->npc_db_loaded[npc_index] == 0) { |
|
|
796 |
return strdup(fallback); |
|
|
797 |
} |
|
|
798 |
|
|
|
799 |
VectorDB *db = &runtime->npc_dbs[npc_index]; |
|
|
800 |
float query[VDB_EMBED_SIZE]; |
|
|
801 |
int results[5]; |
|
|
802 |
for (int i = 0; i < 5; i++) { |
|
|
803 |
results[i] = -1; |
|
|
804 |
} |
|
|
805 |
vdb_embed_query(db, prompt, query); |
|
|
806 |
vdb_search(db, query, 5, results); |
|
|
807 |
|
|
|
808 |
size_t context_cap = 1024; |
|
|
809 |
size_t context_len = 0; |
|
|
810 |
char *context = (char *)malloc(context_cap); |
|
|
811 |
if (context == NULL) { |
|
|
812 |
return strdup(fallback); |
|
|
813 |
} |
|
|
814 |
context[0] = '\0'; |
|
|
815 |
if (runtime->verbose) { |
|
|
816 |
fprintf(stderr, "[npc] question: %s\n", prompt); |
|
|
817 |
} |
|
|
818 |
for (int i = 0; i < 5; i++) { |
|
|
819 |
if (results[i] < 0) { |
|
|
820 |
continue; |
|
|
821 |
} |
|
|
822 |
const char *text = db->docs[results[i]].text; |
|
|
823 |
if (runtime->verbose) { |
|
|
824 |
fprintf(stderr, "[npc] context[%d]: %s\n", i, text); |
|
|
825 |
} |
|
|
826 |
char header[32]; |
|
|
827 |
int header_len = snprintf(header, sizeof(header), "Snippet %d:\n", i + 1); |
|
|
828 |
size_t text_len = strlen(text); |
|
|
829 |
size_t need = context_len + (size_t)header_len + text_len + 2; |
|
|
830 |
if (need > context_cap) { |
|
|
831 |
while (need > context_cap) { |
|
|
832 |
context_cap *= 2; |
|
|
833 |
} |
|
|
834 |
char *next = (char *)realloc(context, context_cap); |
|
|
835 |
if (next == NULL) { |
|
|
836 |
free(context); |
|
|
837 |
return strdup(fallback); |
|
|
838 |
} |
|
|
839 |
context = next; |
|
|
840 |
} |
|
|
841 |
if (header_len > 0) { |
|
|
842 |
memcpy(context + context_len, header, (size_t)header_len); |
|
|
843 |
context_len += (size_t)header_len; |
|
|
844 |
} |
|
|
845 |
memcpy(context + context_len, text, text_len); |
|
|
846 |
context_len += text_len; |
|
|
847 |
context[context_len++] = '\n'; |
|
|
848 |
context[context_len] = '\0'; |
|
|
849 |
} |
|
|
850 |
|
|
|
851 |
const char *system_prompt = "You are a helpful NPC. Speak in first person. " |
|
|
852 |
"Use only the provided context. If the context does not contain the answer, say \"I don't know.\" " |
|
|
853 |
"If asked your name, answer with the NPC Name from the context. " |
|
|
854 |
"Do not mention context, system messages, or prompts. Reply with one short sentence."; |
|
|
855 |
|
|
|
856 |
char *full_prompt = build_prompt(runtime->model_cfg, system_prompt, npc_name, context, prompt); |
|
|
857 |
if (full_prompt == NULL) { |
|
|
858 |
free(context); |
|
|
859 |
return strdup(fallback); |
|
|
860 |
} |
|
|
861 |
free(context); |
|
|
862 |
|
|
|
863 |
if (runtime->verbose) { |
|
|
864 |
printf(">> %s\n", full_prompt); |
|
|
865 |
} |
|
|
866 |
|
|
|
867 |
const struct llama_vocab *vocab = llama_model_get_vocab(runtime->model); |
|
|
868 |
int n_prompt = -llama_tokenize(vocab, full_prompt, strlen(full_prompt), NULL, 0, true, true); |
|
|
869 |
llama_token *prompt_tokens = (llama_token *)malloc((size_t)n_prompt * sizeof(llama_token)); |
|
|
870 |
if (prompt_tokens == NULL) { |
|
|
871 |
free(full_prompt); |
|
|
872 |
return strdup(fallback); |
|
|
873 |
} |
|
|
874 |
if (llama_tokenize(vocab, full_prompt, strlen(full_prompt), prompt_tokens, n_prompt, true, true) < 0) { |
|
|
875 |
free(full_prompt); |
|
|
876 |
free(prompt_tokens); |
|
|
877 |
return strdup(fallback); |
|
|
878 |
} |
|
|
879 |
|
|
|
880 |
struct llama_context_params ctx_params = llama_context_default_params(); |
|
|
881 |
ctx_params.n_ctx = runtime->model_cfg->n_ctx; |
|
|
882 |
ctx_params.n_batch = runtime->model_cfg->n_batch; |
|
|
883 |
ctx_params.embeddings = false; |
|
|
884 |
|
|
|
885 |
struct llama_context *ctx = llama_init_from_model(runtime->model, ctx_params); |
|
|
886 |
if (ctx == NULL) { |
|
|
887 |
free(full_prompt); |
|
|
888 |
free(prompt_tokens); |
|
|
889 |
return strdup(fallback); |
|
|
890 |
} |
|
|
891 |
|
|
|
892 |
struct llama_sampler_chain_params sparams = llama_sampler_chain_default_params(); |
|
|
893 |
struct llama_sampler *smpl = llama_sampler_chain_init(sparams); |
|
|
894 |
if (runtime->model_cfg->top_k > 0) { |
|
|
895 |
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(runtime->model_cfg->top_k)); |
|
|
896 |
} |
|
|
897 |
if (runtime->model_cfg->top_p > 0.0f && runtime->model_cfg->top_p < 1.0f) { |
|
|
898 |
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(runtime->model_cfg->top_p, 1)); |
|
|
899 |
} |
|
|
900 |
if (runtime->model_cfg->min_p > 0.0f) { |
|
|
901 |
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(runtime->model_cfg->min_p, 1)); |
|
|
902 |
} |
|
|
903 |
llama_sampler_chain_add(smpl, llama_sampler_init_penalties( |
|
|
904 |
runtime->model_cfg->repeat_last_n, |
|
|
905 |
runtime->model_cfg->repeat_penalty, |
|
|
906 |
runtime->model_cfg->freq_penalty, |
|
|
907 |
runtime->model_cfg->presence_penalty)); |
|
|
908 |
llama_sampler_chain_add(smpl, llama_sampler_init_temp(runtime->model_cfg->temperature)); |
|
|
909 |
llama_sampler_chain_add(smpl, llama_sampler_init_dist(runtime->model_cfg->seed)); |
|
|
910 |
|
|
|
911 |
struct llama_batch batch = llama_batch_get_one(prompt_tokens, n_prompt); |
|
|
912 |
|
|
|
913 |
if (llama_model_has_encoder(runtime->model)) { |
|
|
914 |
if (llama_encode(ctx, batch)) { |
|
|
915 |
llama_sampler_free(smpl); |
|
|
916 |
free(full_prompt); |
|
|
917 |
free(prompt_tokens); |
|
|
918 |
llama_free(ctx); |
|
|
919 |
return strdup(fallback); |
|
|
920 |
} |
|
|
921 |
llama_token decoder_start = llama_model_decoder_start_token(runtime->model); |
|
|
922 |
if (decoder_start == LLAMA_TOKEN_NULL) { |
|
|
923 |
decoder_start = llama_vocab_bos(vocab); |
|
|
924 |
} |
|
|
925 |
batch = llama_batch_get_one(&decoder_start, 1); |
|
|
926 |
} |
|
|
927 |
|
|
|
928 |
int n_pos = 0; |
|
|
929 |
llama_token new_token_id; |
|
|
930 |
size_t out_cap = 256; |
|
|
931 |
size_t out_len = 0; |
|
|
932 |
char *out = (char *)malloc(out_cap); |
|
|
933 |
if (out == NULL) { |
|
|
934 |
llama_sampler_free(smpl); |
|
|
935 |
free(full_prompt); |
|
|
936 |
free(prompt_tokens); |
|
|
937 |
llama_free(ctx); |
|
|
938 |
return strdup(fallback); |
|
|
939 |
} |
|
|
940 |
out[0] = '\0'; |
|
|
941 |
int n_predict = runtime->model_cfg->n_predict > 0 ? runtime->model_cfg->n_predict : 64; |
|
|
942 |
if (n_predict > 64) { |
|
|
943 |
n_predict = 64; |
|
|
944 |
} |
|
|
945 |
while (n_pos + batch.n_tokens < n_prompt + n_predict) { |
|
|
946 |
if (llama_decode(ctx, batch)) { |
|
|
947 |
break; |
|
|
948 |
} |
|
|
949 |
n_pos += batch.n_tokens; |
|
|
950 |
new_token_id = llama_sampler_sample(smpl, ctx, -1); |
|
|
951 |
if (llama_vocab_is_eog(vocab, new_token_id)) { |
|
|
952 |
break; |
|
|
953 |
} |
|
|
954 |
char buf[128]; |
|
|
955 |
int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true); |
|
|
956 |
if (n < 0) { |
|
|
957 |
break; |
|
|
958 |
} |
|
|
959 |
int stop_at = find_stop_offset(buf, n); |
|
|
960 |
if (out_len == 0 && stop_at == 0 && n > 0 && buf[0] == '\n') { |
|
|
961 |
batch = llama_batch_get_one(&new_token_id, 1); |
|
|
962 |
continue; |
|
|
963 |
} |
|
|
964 |
if (out_len + (size_t)stop_at + 1 > out_cap) { |
|
|
965 |
while (out_len + (size_t)stop_at + 1 > out_cap) { |
|
|
966 |
out_cap *= 2; |
|
|
967 |
} |
|
|
968 |
char *next = (char *)realloc(out, out_cap); |
|
|
969 |
if (next == NULL) { |
|
|
970 |
break; |
|
|
971 |
} |
|
|
972 |
out = next; |
|
|
973 |
} |
|
|
974 |
memcpy(out + out_len, buf, (size_t)stop_at); |
|
|
975 |
out_len += (size_t)stop_at; |
|
|
976 |
out[out_len] = '\0'; |
|
|
977 |
if (stop_at != n) { |
|
|
978 |
break; |
|
|
979 |
} |
|
|
980 |
batch = llama_batch_get_one(&new_token_id, 1); |
|
|
981 |
} |
|
|
982 |
|
|
|
983 |
llama_sampler_free(smpl); |
|
|
984 |
free(full_prompt); |
|
|
985 |
free(prompt_tokens); |
|
|
986 |
llama_free(ctx); |
|
|
987 |
|
|
|
988 |
if (out_len == 0) { |
|
|
989 |
free(out); |
|
|
990 |
return strdup(fallback); |
|
|
991 |
} |
|
|
992 |
return out; |
|
|
993 |
} |
|
|
994 |
|
|
|
995 |
static void dialog_submit(Dialog *dialog, const GameMap *game_map, const GameRuntime *runtime) { |
| 439 |
if (dialog->input_len == 0) { |
996 |
if (dialog->input_len == 0) { |
| 440 |
return; |
997 |
return; |
| 441 |
} |
998 |
} |
| 442 |
{ |
999 |
{ |
| 443 |
const char *demo = "Demo reply: The old ruins are north of here."; |
1000 |
const char *npc_name = NULL; |
| 444 |
const char *reply = demo; |
1001 |
char *reply = generate_npc_reply(runtime, game_map, dialog->npc_index, dialog->input); |
|
|
1002 |
const char *fallback = ""; |
| 445 |
if (game_map && dialog->npc_index >= 0 && dialog->npc_index < 10) { |
1003 |
if (game_map && dialog->npc_index >= 0 && dialog->npc_index < 10) { |
| 446 |
const char *npc_reply = game_map->npcs[dialog->npc_index].reply; |
1004 |
npc_name = game_map->npcs[dialog->npc_index].name; |
| 447 |
if (npc_reply && npc_reply[0] != '\0') { |
1005 |
fallback = game_map->npcs[dialog->npc_index].reply; |
| 448 |
reply = npc_reply; |
1006 |
if (fallback == NULL) { |
|
|
1007 |
fallback = ""; |
| 449 |
} |
1008 |
} |
| 450 |
} |
1009 |
} |
|
|
1010 |
reply = sanitize_reply(reply, npc_name); |
|
|
1011 |
if (reply == NULL || reply[0] == '\0') { |
|
|
1012 |
free(reply); |
|
|
1013 |
reply = NULL; |
|
|
1014 |
} |
|
|
1015 |
const char *reply_text = reply != NULL ? reply : fallback; |
| 451 |
if (dialog->entry_count >= DIALOG_HISTORY_MAX) { |
1016 |
if (dialog->entry_count >= DIALOG_HISTORY_MAX) { |
| 452 |
for (int i = 1; i < DIALOG_HISTORY_MAX; i++) { |
1017 |
for (int i = 1; i < DIALOG_HISTORY_MAX; i++) { |
| 453 |
dialog->entries[i - 1] = dialog->entries[i]; |
1018 |
dialog->entries[i - 1] = dialog->entries[i]; |
| ... |
| 457 |
snprintf(dialog->entries[dialog->entry_count].prompt, |
1022 |
snprintf(dialog->entries[dialog->entry_count].prompt, |
| 458 |
sizeof(dialog->entries[dialog->entry_count].prompt), "%s", dialog->input); |
1023 |
sizeof(dialog->entries[dialog->entry_count].prompt), "%s", dialog->input); |
| 459 |
snprintf(dialog->entries[dialog->entry_count].response, |
1024 |
snprintf(dialog->entries[dialog->entry_count].response, |
| 460 |
sizeof(dialog->entries[dialog->entry_count].response), "%s", reply); |
1025 |
sizeof(dialog->entries[dialog->entry_count].response), "%s", reply_text); |
| 461 |
dialog->entry_count++; |
1026 |
dialog->entry_count++; |
|
|
1027 |
free(reply); |
| 462 |
} |
1028 |
} |
| 463 |
dialog->input_len = 0; |
1029 |
dialog->input_len = 0; |
| 464 |
dialog->input[0] = '\0'; |
1030 |
dialog->input[0] = '\0'; |
| ... |
| 479 |
} |
1045 |
} |
| 480 |
|
1046 |
|
| 481 |
static void render(const Map *map, const Player *player, int *cam_x, |
1047 |
static void render(const Map *map, const Player *player, int *cam_x, |
| 482 |
int *cam_y, int *out_view_w, int *out_view_h, const Dialog *dialog) { |
1048 |
int *cam_y, int *out_view_w, int *out_view_h, const Dialog *dialog) { |
| 483 |
int w; |
1049 |
int w; |
| 484 |
int h; |
1050 |
int h; |
| 485 |
int map_x; |
1051 |
int map_x; |
| ... |
| 594 |
if (max_text < 0) { |
1160 |
if (max_text < 0) { |
| 595 |
max_text = 0; |
1161 |
max_text = 0; |
| 596 |
} |
1162 |
} |
| 597 |
int max_entries = max_lines / 2; |
1163 |
int start = dialog->entry_count; |
| 598 |
int start = dialog->entry_count - max_entries; |
|
|
| 599 |
if (start < 0) { |
1164 |
if (start < 0) { |
| 600 |
start = 0; |
1165 |
start = 0; |
| 601 |
} |
1166 |
} |
|
|
1167 |
int used_lines = 0; |
|
|
1168 |
for (int i = dialog->entry_count - 1; i >= 0; i--) { |
|
|
1169 |
const char *prompt_text = dialog->entries[i].prompt; |
|
|
1170 |
const char *response_text = dialog->entries[i].response; |
|
|
1171 |
const char *name = dialog->npc_name && dialog->npc_name[0] != '\0' ? dialog->npc_name : "NPC"; |
|
|
1172 |
char prefix_you[16]; |
|
|
1173 |
char prefix_npc[64]; |
|
|
1174 |
snprintf(prefix_you, sizeof(prefix_you), "You: "); |
|
|
1175 |
snprintf(prefix_npc, sizeof(prefix_npc), "%s: ", name); |
|
|
1176 |
int need = count_wrapped_lines(box_w, prefix_you, prompt_text) |
|
|
1177 |
+ count_wrapped_lines(box_w, prefix_npc, response_text); |
|
|
1178 |
if (used_lines + need > max_lines && used_lines > 0) { |
|
|
1179 |
break; |
|
|
1180 |
} |
|
|
1181 |
used_lines += need; |
|
|
1182 |
start = i; |
|
|
1183 |
if (used_lines >= max_lines) { |
|
|
1184 |
break; |
|
|
1185 |
} |
|
|
1186 |
} |
| 602 |
for (int i = start; i < dialog->entry_count && line + 1 <= max_lines; i++) { |
1187 |
for (int i = start; i < dialog->entry_count && line + 1 <= max_lines; i++) { |
| 603 |
char prompt_buf[128]; |
1188 |
const char *prompt_text = dialog->entries[i].prompt; |
| 604 |
char response_buf[256]; |
1189 |
const char *response_text = dialog->entries[i].response; |
| 605 |
copy_truncated(prompt_buf, sizeof(prompt_buf), dialog->entries[i].prompt, max_text); |
1190 |
const char *name = dialog->npc_name && dialog->npc_name[0] != '\0' ? dialog->npc_name : "NPC"; |
| 606 |
copy_truncated(response_buf, sizeof(response_buf), dialog->entries[i].response, max_text); |
1191 |
char prefix_you[16]; |
| 607 |
if (line < max_lines) { |
1192 |
char prefix_npc[64]; |
| 608 |
tb_printf(box_x + 2, log_y + line, COLOR_WHITE_256, 19, "You: %s", prompt_buf); |
1193 |
snprintf(prefix_you, sizeof(prefix_you), "You: "); |
| 609 |
line++; |
1194 |
snprintf(prefix_npc, sizeof(prefix_npc), "%s: ", name); |
|
|
1195 |
int used = draw_wrapped(box_x + 2, log_y + line, max_lines - line, box_w, |
|
|
1196 |
COLOR_WHITE_256, 19, prefix_you, prompt_text); |
|
|
1197 |
line += used; |
|
|
1198 |
if (line >= max_lines) { |
|
|
1199 |
break; |
| 610 |
} |
1200 |
} |
| 611 |
if (line < max_lines) { |
1201 |
used = draw_wrapped(box_x + 2, log_y + line, max_lines - line, box_w, |
| 612 |
tb_printf(box_x + 2, log_y + line, COLOR_GREEN_256, 19, "NPC: %s", response_buf); |
1202 |
COLOR_GREEN_256, 19, prefix_npc, response_text); |
| 613 |
line++; |
1203 |
line += used; |
|
|
1204 |
if (line >= max_lines) { |
|
|
1205 |
break; |
| 614 |
} |
1206 |
} |
| 615 |
} |
1207 |
} |
| 616 |
|
1208 |
|
| ... |
| 641 |
return value; |
1233 |
return value; |
| 642 |
} |
1234 |
} |
| 643 |
|
1235 |
|
| 644 |
int main(void) { |
1236 |
int main(int argc, char **argv) { |
|
|
1237 |
const char *model_name = NULL; |
|
|
1238 |
const char *embed_model_name = NULL; |
|
|
1239 |
const ModelConfig *model_cfg = NULL; |
|
|
1240 |
struct llama_model *embed_model = NULL; |
|
|
1241 |
struct llama_model *gen_model = NULL; |
|
|
1242 |
struct llama_context *embed_ctx = NULL; |
|
|
1243 |
int tb_ready = 0; |
|
|
1244 |
int llama_ready = 0; |
|
|
1245 |
int exit_code = 0; |
|
|
1246 |
int verbose = 0; |
|
|
1247 |
|
|
|
1248 |
static struct option long_options[] = { |
|
|
1249 |
{"model", required_argument, 0, 'm'}, |
|
|
1250 |
{"embed-model", required_argument, 0, 'e'}, |
|
|
1251 |
{"verbose", no_argument, 0, 'v'}, |
|
|
1252 |
{"help", no_argument, 0, 'h'}, |
|
|
1253 |
{0, 0, 0, 0} |
|
|
1254 |
}; |
|
|
1255 |
|
|
|
1256 |
int opt; |
|
|
1257 |
int option_index = 0; |
|
|
1258 |
while ((opt = getopt_long(argc, argv, "m:e:vh", long_options, &option_index)) != -1) { |
|
|
1259 |
switch (opt) { |
|
|
1260 |
case 'm': |
|
|
1261 |
model_name = optarg; |
|
|
1262 |
break; |
|
|
1263 |
case 'e': |
|
|
1264 |
embed_model_name = optarg; |
|
|
1265 |
break; |
|
|
1266 |
case 'v': |
|
|
1267 |
verbose = 1; |
|
|
1268 |
break; |
|
|
1269 |
case 'h': |
|
|
1270 |
show_help(argv[0]); |
|
|
1271 |
return 0; |
|
|
1272 |
default: |
|
|
1273 |
fprintf(stderr, "Usage: %s [-m model] [-v] [-h]\n", argv[0]); |
|
|
1274 |
return 1; |
|
|
1275 |
} |
|
|
1276 |
} |
|
|
1277 |
|
|
|
1278 |
if (model_name != NULL) { |
|
|
1279 |
model_cfg = get_model_by_name(model_name); |
|
|
1280 |
if (model_cfg == NULL) { |
|
|
1281 |
fprintf(stderr, "Unknown model '%s'\n", model_name); |
|
|
1282 |
return 1; |
|
|
1283 |
} |
|
|
1284 |
} else { |
|
|
1285 |
model_cfg = &models[0]; |
|
|
1286 |
} |
|
|
1287 |
|
| 645 |
Player player = {0}; |
1288 |
Player player = {0}; |
| 646 |
array(GameMap) maps; |
1289 |
array(GameMap) maps; |
| 647 |
GameMap map1 = {0}; |
1290 |
GameMap map1 = {0}; |
| 648 |
GameMap *current_map = NULL; |
1291 |
GameMap *current_map = NULL; |
|
|
1292 |
VectorDB *npc_dbs = NULL; |
|
|
1293 |
int *npc_db_loaded = NULL; |
| 649 |
int running = 1; |
1294 |
int running = 1; |
| 650 |
int view_w = 0; |
1295 |
int view_w = 0; |
| 651 |
int view_h = 0; |
1296 |
int view_h = 0; |
| 652 |
int cam_x = 0; |
1297 |
int cam_x = 0; |
| 653 |
int cam_y = 0; |
1298 |
int cam_y = 0; |
| 654 |
Dialog dialog = {0}; |
1299 |
Dialog dialog = {0}; |
|
|
1300 |
GameRuntime runtime = {0}; |
| 655 |
|
1301 |
|
| 656 |
player_init(&player); |
1302 |
player_init(&player); |
| 657 |
array_init(maps); |
1303 |
array_init(maps); |
| ... |
| 660 |
current_map = &maps.data[0]; |
1306 |
current_map = &maps.data[0]; |
| 661 |
map_init(¤t_map->map, current_map->data, current_map->len); |
1307 |
map_init(¤t_map->map, current_map->data, current_map->len); |
| 662 |
|
1308 |
|
|
|
1309 |
if (verbose == 0) { |
|
|
1310 |
llama_log_set(llama_log_callback, NULL); |
|
|
1311 |
} |
|
|
1312 |
|
|
|
1313 |
npc_dbs = (VectorDB *)calloc(10, sizeof(VectorDB)); |
|
|
1314 |
npc_db_loaded = (int *)calloc(10, sizeof(int)); |
|
|
1315 |
if (npc_dbs == NULL || npc_db_loaded == NULL) { |
|
|
1316 |
fprintf(stderr, "Failed to allocate NPC vector databases\n"); |
|
|
1317 |
exit_code = 1; |
|
|
1318 |
goto cleanup; |
|
|
1319 |
} |
|
|
1320 |
|
|
|
1321 |
llama_backend_init(); |
|
|
1322 |
ggml_backend_load_all(); |
|
|
1323 |
llama_ready = 1; |
|
|
1324 |
const ModelConfig *embed_cfg = NULL; |
|
|
1325 |
if (embed_model_name != NULL) { |
|
|
1326 |
embed_cfg = get_model_by_name(embed_model_name); |
|
|
1327 |
if (embed_cfg == NULL) { |
|
|
1328 |
fprintf(stderr, "Unknown embedding model '%s'\n", embed_model_name); |
|
|
1329 |
exit_code = 1; |
|
|
1330 |
goto cleanup; |
|
|
1331 |
} |
|
|
1332 |
} else if (model_cfg->embed_model_name != NULL) { |
|
|
1333 |
embed_cfg = get_model_by_name(model_cfg->embed_model_name); |
|
|
1334 |
} |
|
|
1335 |
if (embed_cfg == NULL) { |
|
|
1336 |
embed_cfg = model_cfg; |
|
|
1337 |
} |
|
|
1338 |
|
|
|
1339 |
struct llama_model_params gen_params = llama_model_default_params(); |
|
|
1340 |
gen_params.n_gpu_layers = model_cfg->n_gpu_layers; |
|
|
1341 |
gen_params.use_mmap = model_cfg->use_mmap; |
|
|
1342 |
gen_model = llama_model_load_from_file(model_cfg->filepath, gen_params); |
|
|
1343 |
if (gen_model == NULL) { |
|
|
1344 |
fprintf(stderr, "Unable to load generation model\n"); |
|
|
1345 |
exit_code = 1; |
|
|
1346 |
goto cleanup; |
|
|
1347 |
} |
|
|
1348 |
|
|
|
1349 |
struct llama_model_params embed_params = llama_model_default_params(); |
|
|
1350 |
embed_params.n_gpu_layers = embed_cfg->n_gpu_layers; |
|
|
1351 |
embed_params.use_mmap = embed_cfg->use_mmap; |
|
|
1352 |
embed_model = llama_model_load_from_file(embed_cfg->filepath, embed_params); |
|
|
1353 |
if (embed_model == NULL) { |
|
|
1354 |
fprintf(stderr, "Unable to load embedding model\n"); |
|
|
1355 |
exit_code = 1; |
|
|
1356 |
goto cleanup; |
|
|
1357 |
} |
|
|
1358 |
|
|
|
1359 |
struct llama_context_params cparams = llama_context_default_params(); |
|
|
1360 |
cparams.n_ctx = embed_cfg->n_ctx; |
|
|
1361 |
cparams.n_batch = embed_cfg->n_batch; |
|
|
1362 |
cparams.embeddings = true; |
|
|
1363 |
embed_ctx = llama_init_from_model(embed_model, cparams); |
|
|
1364 |
if (embed_ctx == NULL) { |
|
|
1365 |
fprintf(stderr, "Failed to create embedding context\n"); |
|
|
1366 |
exit_code = 1; |
|
|
1367 |
goto cleanup; |
|
|
1368 |
} |
|
|
1369 |
|
|
|
1370 |
for (int i = 0; i < 10; i++) { |
|
|
1371 |
const char *vdb_path = current_map->npcs[i].vdb_path; |
|
|
1372 |
if (vdb_path == NULL || vdb_path[0] == '\0') { |
|
|
1373 |
continue; |
|
|
1374 |
} |
|
|
1375 |
vdb_init(&npc_dbs[i], embed_ctx); |
|
|
1376 |
VectorDBErrorCode vdb_rc = vdb_load(&npc_dbs[i], vdb_path); |
|
|
1377 |
if (vdb_rc != VDB_SUCCESS) { |
|
|
1378 |
fprintf(stderr, "Failed to load vector database %s: %s\n", vdb_path, vdb_error(vdb_rc)); |
|
|
1379 |
vdb_free(&npc_dbs[i]); |
|
|
1380 |
continue; |
|
|
1381 |
} |
|
|
1382 |
npc_db_loaded[i] = 1; |
|
|
1383 |
} |
|
|
1384 |
|
|
|
1385 |
runtime.model_cfg = model_cfg; |
|
|
1386 |
runtime.model = gen_model; |
|
|
1387 |
runtime.embed_model = embed_model; |
|
|
1388 |
runtime.embed_ctx = embed_ctx; |
|
|
1389 |
runtime.npc_dbs = npc_dbs; |
|
|
1390 |
runtime.npc_db_loaded = npc_db_loaded; |
|
|
1391 |
runtime.verbose = verbose; |
|
|
1392 |
|
| 663 |
if (tb_init() != TB_OK) { |
1393 |
if (tb_init() != TB_OK) { |
| 664 |
fprintf(stderr, "Failed to init termbox.\n"); |
1394 |
fprintf(stderr, "Failed to init termbox.\n"); |
| 665 |
return 1; |
1395 |
exit_code = 1; |
|
|
1396 |
goto cleanup; |
| 666 |
} |
1397 |
} |
|
|
1398 |
tb_ready = 1; |
| 667 |
|
1399 |
|
| 668 |
tb_set_input_mode(TB_INPUT_ESC); |
1400 |
tb_set_input_mode(TB_INPUT_ESC); |
| 669 |
tb_set_output_mode(TB_OUTPUT_256); |
1401 |
tb_set_output_mode(TB_OUTPUT_256); |
| ... |
| 678 |
if (ev.key == TB_KEY_ESC) { |
1410 |
if (ev.key == TB_KEY_ESC) { |
| 679 |
dialog_close(&dialog); |
1411 |
dialog_close(&dialog); |
| 680 |
} else if (ev.key == TB_KEY_ENTER) { |
1412 |
} else if (ev.key == TB_KEY_ENTER) { |
| 681 |
dialog_submit(&dialog, current_map); |
1413 |
dialog_submit(&dialog, current_map, &runtime); |
| 682 |
} else if (ev.key == TB_KEY_BACKSPACE || ev.key == TB_KEY_BACKSPACE2) { |
1414 |
} else if (ev.key == TB_KEY_BACKSPACE || ev.key == TB_KEY_BACKSPACE2) { |
| 683 |
dialog_backspace(&dialog); |
1415 |
dialog_backspace(&dialog); |
| 684 |
} else if (ev.ch) { |
1416 |
} else if (ev.ch) { |
| ... |
| 692 |
u32 target = map_get(¤t_map->map, player.x, next_y); |
1424 |
u32 target = map_get(¤t_map->map, player.x, next_y); |
| 693 |
int npc_index = npc_index_from_tile(target); |
1425 |
int npc_index = npc_index_from_tile(target); |
| 694 |
if (target == 'N' || npc_index >= 0) { |
1426 |
if (target == 'N' || npc_index >= 0) { |
| 695 |
dialog_open(&dialog, npc_index); |
1427 |
const char *npc_name = current_map && npc_index >= 0 && npc_index < 10 |
|
|
1428 |
? current_map->npcs[npc_index].name |
|
|
1429 |
: NULL; |
|
|
1430 |
dialog_open(&dialog, npc_index, npc_name); |
| 696 |
update_npc_status(current_map, npc_index); |
1431 |
update_npc_status(current_map, npc_index); |
| 697 |
} else if (map_is_walkable(¤t_map->map, player.x, next_y)) { |
1432 |
} else if (map_is_walkable(¤t_map->map, player.x, next_y)) { |
| 698 |
player.y = next_y; |
1433 |
player.y = next_y; |
| ... |
| 702 |
u32 target = map_get(¤t_map->map, player.x, next_y); |
1437 |
u32 target = map_get(¤t_map->map, player.x, next_y); |
| 703 |
int npc_index = npc_index_from_tile(target); |
1438 |
int npc_index = npc_index_from_tile(target); |
| 704 |
if (target == 'N' || npc_index >= 0) { |
1439 |
if (target == 'N' || npc_index >= 0) { |
| 705 |
dialog_open(&dialog, npc_index); |
1440 |
const char *npc_name = current_map && npc_index >= 0 && npc_index < 10 |
|
|
1441 |
? current_map->npcs[npc_index].name |
|
|
1442 |
: NULL; |
|
|
1443 |
dialog_open(&dialog, npc_index, npc_name); |
| 706 |
update_npc_status(current_map, npc_index); |
1444 |
update_npc_status(current_map, npc_index); |
| 707 |
} else if (map_is_walkable(¤t_map->map, player.x, next_y)) { |
1445 |
} else if (map_is_walkable(¤t_map->map, player.x, next_y)) { |
| 708 |
player.y = next_y; |
1446 |
player.y = next_y; |
| ... |
| 712 |
u32 target = map_get(¤t_map->map, next_x, player.y); |
1450 |
u32 target = map_get(¤t_map->map, next_x, player.y); |
| 713 |
int npc_index = npc_index_from_tile(target); |
1451 |
int npc_index = npc_index_from_tile(target); |
| 714 |
if (target == 'N' || npc_index >= 0) { |
1452 |
if (target == 'N' || npc_index >= 0) { |
| 715 |
dialog_open(&dialog, npc_index); |
1453 |
const char *npc_name = current_map && npc_index >= 0 && npc_index < 10 |
|
|
1454 |
? current_map->npcs[npc_index].name |
|
|
1455 |
: NULL; |
|
|
1456 |
dialog_open(&dialog, npc_index, npc_name); |
| 716 |
update_npc_status(current_map, npc_index); |
1457 |
update_npc_status(current_map, npc_index); |
| 717 |
} else if (map_is_walkable(¤t_map->map, next_x, player.y)) { |
1458 |
} else if (map_is_walkable(¤t_map->map, next_x, player.y)) { |
| 718 |
player.x = next_x; |
1459 |
player.x = next_x; |
| ... |
| 722 |
u32 target = map_get(¤t_map->map, next_x, player.y); |
1463 |
u32 target = map_get(¤t_map->map, next_x, player.y); |
| 723 |
int npc_index = npc_index_from_tile(target); |
1464 |
int npc_index = npc_index_from_tile(target); |
| 724 |
if (target == 'N' || npc_index >= 0) { |
1465 |
if (target == 'N' || npc_index >= 0) { |
| 725 |
dialog_open(&dialog, npc_index); |
1466 |
const char *npc_name = current_map && npc_index >= 0 && npc_index < 10 |
|
|
1467 |
? current_map->npcs[npc_index].name |
|
|
1468 |
: NULL; |
|
|
1469 |
dialog_open(&dialog, npc_index, npc_name); |
| 726 |
update_npc_status(current_map, npc_index); |
1470 |
update_npc_status(current_map, npc_index); |
| 727 |
} else if (map_is_walkable(¤t_map->map, next_x, player.y)) { |
1471 |
} else if (map_is_walkable(¤t_map->map, next_x, player.y)) { |
| 728 |
player.x = next_x; |
1472 |
player.x = next_x; |
| ... |
| 742 |
} |
1486 |
} |
| 743 |
} |
1487 |
} |
| 744 |
|
1488 |
|
|
|
1489 |
cleanup: |
| 745 |
player_free(&player); |
1490 |
player_free(&player); |
| 746 |
for (size_t i = 0; i < maps.length; i++) { |
1491 |
for (size_t i = 0; i < maps.length; i++) { |
| 747 |
map_free(&maps.data[i].map); |
1492 |
map_free(&maps.data[i].map); |
| 748 |
} |
1493 |
} |
| 749 |
array_free(maps); |
1494 |
array_free(maps); |
| 750 |
tb_shutdown(); |
1495 |
if (tb_ready) { |
| 751 |
return 0; |
1496 |
tb_shutdown(); |
|
|
1497 |
} |
|
|
1498 |
for (int i = 0; i < 10; i++) { |
|
|
1499 |
if (npc_db_loaded && npc_db_loaded[i]) { |
|
|
1500 |
vdb_free(&npc_dbs[i]); |
|
|
1501 |
} |
|
|
1502 |
} |
|
|
1503 |
free(npc_db_loaded); |
|
|
1504 |
free(npc_dbs); |
|
|
1505 |
if (embed_ctx != NULL) { |
|
|
1506 |
llama_free(embed_ctx); |
|
|
1507 |
} |
|
|
1508 |
if (embed_model != NULL) { |
|
|
1509 |
llama_model_free(embed_model); |
|
|
1510 |
} |
|
|
1511 |
if (gen_model != NULL) { |
|
|
1512 |
llama_model_free(gen_model); |
|
|
1513 |
} |
|
|
1514 |
if (llama_ready) { |
|
|
1515 |
llama_backend_free(); |
|
|
1516 |
} |
|
|
1517 |
return exit_code; |
| 752 |
} |
1518 |
} |
|
diff --git a/npc.c b/npc.c
|
| 1 |
#include "llama.h" |
1 |
#include "llama.h" |
| 2 |
#include "vectordb.h" |
2 |
#include "vectordb.h" |
| 3 |
#include "models.h" |
3 |
#include "models.h" |
| 4 |
#include "models.h" |
|
|
| 5 |
|
4 |
|
| 6 |
#define NONSTD_IMPLEMENTATION |
5 |
#define NONSTD_IMPLEMENTATION |
| 7 |
#include "nonstd.h" |
6 |
#include "nonstd.h" |
| ... |
| 31 |
printf("Usage: %s [OPTIONS]\n", prog); |
30 |
printf("Usage: %s [OPTIONS]\n", prog); |
| 32 |
printf("Options:\n"); |
31 |
printf("Options:\n"); |
| 33 |
printf(" -m, --model <name> Specify model to use (default: first model)\n"); |
32 |
printf(" -m, --model <name> Specify model to use (default: first model)\n"); |
|
|
33 |
printf(" -e, --embed-model <name> Specify model to use for embeddings\n"); |
| 34 |
printf(" -p, --prompt <text> Specify prompt text (default: \"What is 2+2?\")\n"); |
34 |
printf(" -p, --prompt <text> Specify prompt text (default: \"What is 2+2?\")\n"); |
| 35 |
printf(" -c, --context <file> Specify vector database file (.vdb)\n"); |
35 |
printf(" -c, --context <file> Specify vector database file (.vdb)\n"); |
| 36 |
printf(" -l, --list Lists all available models\n"); |
36 |
printf(" -l, --list Lists all available models\n"); |
| ... |
| 48 |
return strcmp(path + (len - ext_len), ext) == 0; |
48 |
return strcmp(path + (len - ext_len), ext) == 0; |
| 49 |
} |
49 |
} |
| 50 |
|
50 |
|
| 51 |
static int execute_prompt_with_context(const ModelConfig *cfg, const char *prompt, const char *context, int n_predict) { |
51 |
static void append_prompt_context(stringb *sb, const char *context, const char *question) { |
|
|
52 |
sb_append_cstr(sb, "Context:\n"); |
|
|
53 |
if (context && context[0] != '\0') { |
|
|
54 |
sb_append_cstr(sb, context); |
|
|
55 |
} |
|
|
56 |
sb_append_cstr(sb, "\nQuestion:\n"); |
|
|
57 |
sb_append_cstr(sb, question ? question : ""); |
|
|
58 |
} |
|
|
59 |
|
|
|
60 |
static char *build_prompt(const ModelConfig *cfg, const char *system, const char *context, |
|
|
61 |
const char *question) { |
|
|
62 |
stringb full = {0}; |
|
|
63 |
sb_init(&full, 0); |
|
|
64 |
|
|
|
65 |
switch (cfg->prompt_style) { |
|
|
66 |
case PROMPT_STYLE_T5: |
|
|
67 |
sb_append_cstr(&full, "instruction: "); |
|
|
68 |
sb_append_cstr(&full, system ? system : ""); |
|
|
69 |
sb_append_cstr(&full, "\nquestion: "); |
|
|
70 |
sb_append_cstr(&full, question ? question : ""); |
|
|
71 |
sb_append_cstr(&full, "\ncontext:\n"); |
|
|
72 |
if (context && context[0] != '\0') { |
|
|
73 |
sb_append_cstr(&full, context); |
|
|
74 |
} |
|
|
75 |
sb_append_cstr(&full, "\nanswer:"); |
|
|
76 |
break; |
|
|
77 |
case PROMPT_STYLE_CHAT: |
|
|
78 |
sb_append_cstr(&full, "System:\n"); |
|
|
79 |
sb_append_cstr(&full, system ? system : ""); |
|
|
80 |
sb_append_cstr(&full, "\nUser:\n"); |
|
|
81 |
append_prompt_context(&full, context, question); |
|
|
82 |
sb_append_cstr(&full, "\nAssistant:"); |
|
|
83 |
break; |
|
|
84 |
case PROMPT_STYLE_PLAIN: |
|
|
85 |
default: |
|
|
86 |
sb_append_cstr(&full, "System:\n"); |
|
|
87 |
sb_append_cstr(&full, system ? system : ""); |
|
|
88 |
sb_append_cstr(&full, "\n"); |
|
|
89 |
append_prompt_context(&full, context, question); |
|
|
90 |
sb_append_cstr(&full, "\nAnswer:"); |
|
|
91 |
break; |
|
|
92 |
} |
|
|
93 |
|
|
|
94 |
return full.data; |
|
|
95 |
} |
|
|
96 |
|
|
|
97 |
static int execute_prompt_with_context(const ModelConfig *cfg, const char *prompt, |
|
|
98 |
const char *context, int n_predict) { |
| 52 |
if (cfg == NULL) { |
99 |
if (cfg == NULL) { |
| 53 |
log_message(stderr, LOG_ERROR, "Model config is missing"); |
100 |
log_message(stderr, LOG_ERROR, "Model config is missing"); |
| 54 |
return 1; |
101 |
return 1; |
| ... |
| 76 |
|
123 |
|
| 77 |
const struct llama_vocab *vocab = llama_model_get_vocab(model); |
124 |
const struct llama_vocab *vocab = llama_model_get_vocab(model); |
| 78 |
|
125 |
|
| 79 |
const char *context_prefix = "Context:\n"; |
126 |
const char *system_text = system_prefix; |
| 80 |
const char *prompt_prefix = "\n\nQuestion:\n"; |
127 |
if (strncmp(system_prefix, "System:", 7) == 0) { |
| 81 |
const char *answer_prefix = "\n\nAnswer:\n"; |
128 |
system_text = system_prefix + 7; |
| 82 |
size_t context_len = context ? strlen(context) : 0; |
129 |
while (*system_text == ' ' || *system_text == '\n' || *system_text == '\r') { |
| 83 |
size_t prompt_len = strlen(prompt); |
130 |
system_text++; |
| 84 |
size_t full_len = strlen(system_prefix) + strlen(context_prefix) + context_len + strlen(prompt_prefix) + prompt_len + strlen(answer_prefix) + 1; |
131 |
} |
| 85 |
char *full_prompt = (char *)malloc(full_len); |
132 |
} |
|
|
133 |
|
|
|
134 |
char *full_prompt = build_prompt(cfg, system_text, context, prompt); |
| 86 |
if (full_prompt == NULL) { |
135 |
if (full_prompt == NULL) { |
| 87 |
log_message(stderr, LOG_ERROR, "Failed to allocate prompt buffer"); |
136 |
log_message(stderr, LOG_ERROR, "Failed to build prompt"); |
| 88 |
free(system_prefix); |
137 |
free(system_prefix); |
| 89 |
llama_model_free(model); |
138 |
llama_model_free(model); |
| 90 |
return 1; |
139 |
return 1; |
| 91 |
} |
140 |
} |
| 92 |
snprintf(full_prompt, full_len, "%s%s%s%s%s", system_prefix, context_prefix, context ? context : "", prompt_prefix, prompt); |
|
|
| 93 |
strncat(full_prompt, answer_prefix, full_len - strlen(full_prompt) - 1); |
|
|
| 94 |
|
141 |
|
| 95 |
int n_prompt = -llama_tokenize(vocab, full_prompt, strlen(full_prompt), NULL, 0, true, true); |
142 |
int n_prompt = -llama_tokenize(vocab, full_prompt, strlen(full_prompt), NULL, 0, true, true); |
| 96 |
llama_token *prompt_tokens = (llama_token *)malloc((size_t)n_prompt * sizeof(llama_token)); |
143 |
llama_token *prompt_tokens = (llama_token *)malloc((size_t)n_prompt * sizeof(llama_token)); |
| ... |
| 127 |
|
174 |
|
| 128 |
struct llama_sampler_chain_params sparams = llama_sampler_chain_default_params(); |
175 |
struct llama_sampler_chain_params sparams = llama_sampler_chain_default_params(); |
| 129 |
struct llama_sampler *smpl = llama_sampler_chain_init(sparams); |
176 |
struct llama_sampler *smpl = llama_sampler_chain_init(sparams); |
|
|
177 |
if (cfg->top_k > 0) { |
|
|
178 |
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(cfg->top_k)); |
|
|
179 |
} |
|
|
180 |
if (cfg->top_p > 0.0f && cfg->top_p < 1.0f) { |
|
|
181 |
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(cfg->top_p, 1)); |
|
|
182 |
} |
|
|
183 |
if (cfg->min_p > 0.0f) { |
|
|
184 |
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(cfg->min_p, 1)); |
|
|
185 |
} |
|
|
186 |
llama_sampler_chain_add(smpl, llama_sampler_init_penalties( |
|
|
187 |
cfg->repeat_last_n, |
|
|
188 |
cfg->repeat_penalty, |
|
|
189 |
cfg->freq_penalty, |
|
|
190 |
cfg->presence_penalty)); |
| 130 |
llama_sampler_chain_add(smpl, llama_sampler_init_temp(cfg->temperature)); |
191 |
llama_sampler_chain_add(smpl, llama_sampler_init_temp(cfg->temperature)); |
| 131 |
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(cfg->min_p, 1)); |
|
|
| 132 |
llama_sampler_chain_add(smpl, llama_sampler_init_dist(cfg->seed)); |
192 |
llama_sampler_chain_add(smpl, llama_sampler_init_dist(cfg->seed)); |
| 133 |
|
193 |
|
| 134 |
struct llama_batch batch = llama_batch_get_one(prompt_tokens, n_prompt); |
194 |
struct llama_batch batch = llama_batch_get_one(prompt_tokens, n_prompt); |
| ... |
| 191 |
log_message(stderr, LOG_ERROR, "Failed to convert token to piece"); |
251 |
log_message(stderr, LOG_ERROR, "Failed to convert token to piece"); |
| 192 |
break; |
252 |
break; |
| 193 |
} |
253 |
} |
| 194 |
int stop_at = n; |
254 |
if (out_len == 0 && n > 0 && buf[0] == '\n') { |
| 195 |
for (int i = 0; i < n; i++) { |
255 |
batch = llama_batch_get_one(&new_token_id, 1); |
| 196 |
if (buf[i] == '\n') { |
256 |
continue; |
| 197 |
stop_at = i; |
|
|
| 198 |
break; |
|
|
| 199 |
} |
|
|
| 200 |
} |
257 |
} |
| 201 |
if (out_len + (size_t)stop_at + 1 > out_cap) { |
258 |
if (out_len + (size_t)n + 1 > out_cap) { |
| 202 |
while (out_len + (size_t)stop_at + 1 > out_cap) { |
259 |
while (out_len + (size_t)n + 1 > out_cap) { |
| 203 |
out_cap *= 2; |
260 |
out_cap *= 2; |
| 204 |
} |
261 |
} |
| 205 |
char *next = (char *)realloc(out, out_cap); |
262 |
char *next = (char *)realloc(out, out_cap); |
| ... |
| 209 |
} |
266 |
} |
| 210 |
out = next; |
267 |
out = next; |
| 211 |
} |
268 |
} |
| 212 |
memcpy(out + out_len, buf, (size_t)stop_at); |
269 |
memcpy(out + out_len, buf, (size_t)n); |
| 213 |
out_len += (size_t)stop_at; |
270 |
out_len += (size_t)n; |
| 214 |
out[out_len] = '\0'; |
271 |
out[out_len] = '\0'; |
| 215 |
|
|
|
| 216 |
if (stop_at != n) { |
|
|
| 217 |
break; |
|
|
| 218 |
} |
|
|
| 219 |
|
272 |
|
| 220 |
batch = llama_batch_get_one(&new_token_id, 1); |
273 |
batch = llama_batch_get_one(&new_token_id, 1); |
| 221 |
} |
274 |
} |
| ... |
| 241 |
const char *prompt = NULL; |
294 |
const char *prompt = NULL; |
| 242 |
const char *context_file = NULL; |
295 |
const char *context_file = NULL; |
| 243 |
int verbose = 0; |
296 |
int verbose = 0; |
|
|
297 |
const char *embed_model_name = NULL; |
| 244 |
|
298 |
|
| 245 |
int n_predict = 64; |
299 |
int n_predict = 0; |
| 246 |
|
300 |
|
| 247 |
static struct option long_options[] = { |
301 |
static struct option long_options[] = { |
| 248 |
{"model", required_argument, 0, 'm'}, |
302 |
{"model", required_argument, 0, 'm'}, |
| 249 |
{"prompt", required_argument, 0, 'p'}, |
303 |
{"prompt", required_argument, 0, 'p'}, |
| 250 |
{"context", required_argument, 0, 'c'}, |
304 |
{"context", required_argument, 0, 'c'}, |
|
|
305 |
{"embed-model", required_argument, 0, 'e'}, |
| 251 |
{"list", no_argument, 0, 'l'}, |
306 |
{"list", no_argument, 0, 'l'}, |
| 252 |
{"verbose", no_argument, 0, 'v'}, |
307 |
{"verbose", no_argument, 0, 'v'}, |
| 253 |
{"help", no_argument, 0, 'h'}, |
308 |
{"help", no_argument, 0, 'h'}, |
| ... |
| 256 |
|
311 |
|
| 257 |
int opt; |
312 |
int opt; |
| 258 |
int option_index = 0; |
313 |
int option_index = 0; |
| 259 |
while ((opt = getopt_long(argc, argv, "m:p:c:lvh", long_options, &option_index)) != -1) { |
314 |
while ((opt = getopt_long(argc, argv, "m:p:c:e:lvh", long_options, &option_index)) != -1) { |
| 260 |
switch (opt) { |
315 |
switch (opt) { |
| 261 |
case 'm': |
316 |
case 'm': |
| 262 |
model_name = optarg; |
317 |
model_name = optarg; |
| ... |
| 266 |
break; |
321 |
break; |
| 267 |
case 'c': |
322 |
case 'c': |
| 268 |
context_file = optarg; |
323 |
context_file = optarg; |
|
|
324 |
break; |
|
|
325 |
case 'e': |
|
|
326 |
embed_model_name = optarg; |
| 269 |
break; |
327 |
break; |
| 270 |
case 'v': |
328 |
case 'v': |
| 271 |
verbose = 1; |
329 |
verbose = 1; |
| ... |
| 320 |
cfg = &models[0]; |
378 |
cfg = &models[0]; |
| 321 |
} |
379 |
} |
| 322 |
|
380 |
|
| 323 |
struct llama_model *model = llama_model_load_from_file(cfg->filepath, llama_model_default_params()); |
381 |
const ModelConfig *embed_cfg = NULL; |
|
|
382 |
if (embed_model_name != NULL) { |
|
|
383 |
embed_cfg = get_model_by_name(embed_model_name); |
|
|
384 |
if (embed_cfg == NULL) { |
|
|
385 |
log_message(stderr, LOG_ERROR, "Unknown embedding model '%s'", embed_model_name); |
|
|
386 |
llama_backend_free(); |
|
|
387 |
return 1; |
|
|
388 |
} |
|
|
389 |
} else if (cfg->embed_model_name != NULL) { |
|
|
390 |
embed_cfg = get_model_by_name(cfg->embed_model_name); |
|
|
391 |
} |
|
|
392 |
if (embed_cfg == NULL) { |
|
|
393 |
embed_cfg = cfg; |
|
|
394 |
} |
|
|
395 |
|
|
|
396 |
if (n_predict <= 0) { |
|
|
397 |
n_predict = cfg->n_predict > 0 ? cfg->n_predict : 128; |
|
|
398 |
} |
|
|
399 |
|
|
|
400 |
struct llama_model_params embed_params = llama_model_default_params(); |
|
|
401 |
embed_params.n_gpu_layers = embed_cfg->n_gpu_layers; |
|
|
402 |
embed_params.use_mmap = embed_cfg->use_mmap; |
|
|
403 |
struct llama_model *model = llama_model_load_from_file(embed_cfg->filepath, embed_params); |
| 324 |
if (model == NULL) { |
404 |
if (model == NULL) { |
| 325 |
log_message(stderr, LOG_ERROR, "Unable to load embedding model"); |
405 |
log_message(stderr, LOG_ERROR, "Unable to load embedding model"); |
| 326 |
llama_backend_free(); |
406 |
llama_backend_free(); |
| ... |
| 328 |
} |
408 |
} |
| 329 |
|
409 |
|
| 330 |
struct llama_context_params cparams = llama_context_default_params(); |
410 |
struct llama_context_params cparams = llama_context_default_params(); |
|
|
411 |
cparams.n_ctx = embed_cfg->n_ctx; |
|
|
412 |
cparams.n_batch = embed_cfg->n_batch; |
| 331 |
cparams.embeddings = true; |
413 |
cparams.embeddings = true; |
| 332 |
|
414 |
|
| 333 |
struct llama_context *embed_ctx = llama_init_from_model(model, cparams); |
415 |
struct llama_context *embed_ctx = llama_init_from_model(model, cparams); |
| ... |
| 350 |
} |
432 |
} |
| 351 |
|
433 |
|
| 352 |
float query[VDB_EMBED_SIZE]; |
434 |
float query[VDB_EMBED_SIZE]; |
| 353 |
int results[3]; |
435 |
int results[5]; |
|
|
436 |
for (int i = 0; i < 5; i++) { |
|
|
437 |
results[i] = -1; |
|
|
438 |
} |
| 354 |
|
439 |
|
| 355 |
vdb_embed_query(&db, prompt, query); |
440 |
vdb_embed_query(&db, prompt, query); |
| 356 |
vdb_search(&db, query, 3, results); |
441 |
vdb_search(&db, query, 5, results); |
| 357 |
|
442 |
|
| 358 |
size_t context_cap = 1024; |
443 |
size_t context_cap = 1024; |
| 359 |
size_t context_len = 0; |
444 |
size_t context_len = 0; |
| ... |
| 367 |
} |
452 |
} |
| 368 |
context[0] = '\0'; |
453 |
context[0] = '\0'; |
| 369 |
|
454 |
|
| 370 |
for (int i = 0; i < 3; i++) { |
455 |
for (int i = 0; i < 5; i++) { |
| 371 |
if (results[i] < 0) { |
456 |
if (results[i] < 0) { |
| 372 |
continue; |
457 |
continue; |
| 373 |
} |
458 |
} |
| 374 |
const char *text = db.docs[results[i]].text; |
459 |
const char *text = db.docs[results[i]].text; |
|
|
460 |
char header[32]; |
|
|
461 |
int header_len = snprintf(header, sizeof(header), "Snippet %d:\n", i + 1); |
| 375 |
size_t text_len = strlen(text); |
462 |
size_t text_len = strlen(text); |
| 376 |
size_t need = context_len + text_len + 2; |
463 |
size_t need = context_len + (size_t)header_len + text_len + 2; |
| 377 |
if (need > context_cap) { |
464 |
if (need > context_cap) { |
| 378 |
while (need > context_cap) { |
465 |
while (need > context_cap) { |
| 379 |
context_cap *= 2; |
466 |
context_cap *= 2; |
| ... |
| 388 |
return 1; |
475 |
return 1; |
| 389 |
} |
476 |
} |
| 390 |
context = next; |
477 |
context = next; |
|
|
478 |
} |
|
|
479 |
if (header_len > 0) { |
|
|
480 |
memcpy(context + context_len, header, (size_t)header_len); |
|
|
481 |
context_len += (size_t)header_len; |
| 391 |
} |
482 |
} |
| 392 |
memcpy(context + context_len, text, text_len); |
483 |
memcpy(context + context_len, text, text_len); |
| 393 |
context_len += text_len; |
484 |
context_len += text_len; |
| ... |