diff options
Diffstat (limited to 'npc.c')
| -rw-r--r-- | npc.c | 159 |
1 files changed, 125 insertions, 34 deletions
@@ -1,7 +1,6 @@ #include "llama.h" #include "vectordb.h" #include "models.h" -#include "models.h" #define NONSTD_IMPLEMENTATION #include "nonstd.h" @@ -31,6 +30,7 @@ static void show_help(const char *prog) { printf("Usage: %s [OPTIONS]\n", prog); printf("Options:\n"); printf(" -m, --model <name> Specify model to use (default: first model)\n"); + printf(" -e, --embed-model <name> Specify model to use for embeddings\n"); printf(" -p, --prompt <text> Specify prompt text (default: \"What is 2+2?\")\n"); printf(" -c, --context <file> Specify vector database file (.vdb)\n"); printf(" -l, --list Lists all available models\n"); @@ -48,7 +48,54 @@ static int has_vdb_extension(const char *path) { return strcmp(path + (len - ext_len), ext) == 0; } -static int execute_prompt_with_context(const ModelConfig *cfg, const char *prompt, const char *context, int n_predict) { +static void append_prompt_context(stringb *sb, const char *context, const char *question) { + sb_append_cstr(sb, "Context:\n"); + if (context && context[0] != '\0') { + sb_append_cstr(sb, context); + } + sb_append_cstr(sb, "\nQuestion:\n"); + sb_append_cstr(sb, question ? question : ""); +} + +static char *build_prompt(const ModelConfig *cfg, const char *system, const char *context, + const char *question) { + stringb full = {0}; + sb_init(&full, 0); + + switch (cfg->prompt_style) { + case PROMPT_STYLE_T5: + sb_append_cstr(&full, "instruction: "); + sb_append_cstr(&full, system ? system : ""); + sb_append_cstr(&full, "\nquestion: "); + sb_append_cstr(&full, question ? question : ""); + sb_append_cstr(&full, "\ncontext:\n"); + if (context && context[0] != '\0') { + sb_append_cstr(&full, context); + } + sb_append_cstr(&full, "\nanswer:"); + break; + case PROMPT_STYLE_CHAT: + sb_append_cstr(&full, "System:\n"); + sb_append_cstr(&full, system ? system : ""); + sb_append_cstr(&full, "\nUser:\n"); + append_prompt_context(&full, context, question); + sb_append_cstr(&full, "\nAssistant:"); + break; + case PROMPT_STYLE_PLAIN: + default: + sb_append_cstr(&full, "System:\n"); + sb_append_cstr(&full, system ? system : ""); + sb_append_cstr(&full, "\n"); + append_prompt_context(&full, context, question); + sb_append_cstr(&full, "\nAnswer:"); + break; + } + + return full.data; +} + +static int execute_prompt_with_context(const ModelConfig *cfg, const char *prompt, + const char *context, int n_predict) { if (cfg == NULL) { log_message(stderr, LOG_ERROR, "Model config is missing"); return 1; @@ -76,21 +123,21 @@ static int execute_prompt_with_context(const ModelConfig *cfg, const char *promp const struct llama_vocab *vocab = llama_model_get_vocab(model); - const char *context_prefix = "Context:\n"; - const char *prompt_prefix = "\n\nQuestion:\n"; - const char *answer_prefix = "\n\nAnswer:\n"; - size_t context_len = context ? strlen(context) : 0; - size_t prompt_len = strlen(prompt); - size_t full_len = strlen(system_prefix) + strlen(context_prefix) + context_len + strlen(prompt_prefix) + prompt_len + strlen(answer_prefix) + 1; - char *full_prompt = (char *)malloc(full_len); + const char *system_text = system_prefix; + if (strncmp(system_prefix, "System:", 7) == 0) { + system_text = system_prefix + 7; + while (*system_text == ' ' || *system_text == '\n' || *system_text == '\r') { + system_text++; + } + } + + char *full_prompt = build_prompt(cfg, system_text, context, prompt); if (full_prompt == NULL) { - log_message(stderr, LOG_ERROR, "Failed to allocate prompt buffer"); + log_message(stderr, LOG_ERROR, "Failed to build prompt"); free(system_prefix); llama_model_free(model); return 1; } - snprintf(full_prompt, full_len, "%s%s%s%s%s", system_prefix, context_prefix, context ? context : "", prompt_prefix, prompt); - strncat(full_prompt, answer_prefix, full_len - strlen(full_prompt) - 1); int n_prompt = -llama_tokenize(vocab, full_prompt, strlen(full_prompt), NULL, 0, true, true); llama_token *prompt_tokens = (llama_token *)malloc((size_t)n_prompt * sizeof(llama_token)); @@ -127,8 +174,21 @@ static int execute_prompt_with_context(const ModelConfig *cfg, const char *promp struct llama_sampler_chain_params sparams = llama_sampler_chain_default_params(); struct llama_sampler *smpl = llama_sampler_chain_init(sparams); + if (cfg->top_k > 0) { + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(cfg->top_k)); + } + if (cfg->top_p > 0.0f && cfg->top_p < 1.0f) { + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(cfg->top_p, 1)); + } + if (cfg->min_p > 0.0f) { + llama_sampler_chain_add(smpl, llama_sampler_init_min_p(cfg->min_p, 1)); + } + llama_sampler_chain_add(smpl, llama_sampler_init_penalties( + cfg->repeat_last_n, + cfg->repeat_penalty, + cfg->freq_penalty, + cfg->presence_penalty)); llama_sampler_chain_add(smpl, llama_sampler_init_temp(cfg->temperature)); - llama_sampler_chain_add(smpl, llama_sampler_init_min_p(cfg->min_p, 1)); llama_sampler_chain_add(smpl, llama_sampler_init_dist(cfg->seed)); struct llama_batch batch = llama_batch_get_one(prompt_tokens, n_prompt); @@ -191,15 +251,12 @@ static int execute_prompt_with_context(const ModelConfig *cfg, const char *promp log_message(stderr, LOG_ERROR, "Failed to convert token to piece"); break; } - int stop_at = n; - for (int i = 0; i < n; i++) { - if (buf[i] == '\n') { - stop_at = i; - break; - } + if (out_len == 0 && n > 0 && buf[0] == '\n') { + batch = llama_batch_get_one(&new_token_id, 1); + continue; } - if (out_len + (size_t)stop_at + 1 > out_cap) { - while (out_len + (size_t)stop_at + 1 > out_cap) { + if (out_len + (size_t)n + 1 > out_cap) { + while (out_len + (size_t)n + 1 > out_cap) { out_cap *= 2; } char *next = (char *)realloc(out, out_cap); @@ -209,14 +266,10 @@ static int execute_prompt_with_context(const ModelConfig *cfg, const char *promp } out = next; } - memcpy(out + out_len, buf, (size_t)stop_at); - out_len += (size_t)stop_at; + memcpy(out + out_len, buf, (size_t)n); + out_len += (size_t)n; out[out_len] = '\0'; - if (stop_at != n) { - break; - } - batch = llama_batch_get_one(&new_token_id, 1); } @@ -241,13 +294,15 @@ int main(int argc, char **argv) { const char *prompt = NULL; const char *context_file = NULL; int verbose = 0; + const char *embed_model_name = NULL; - int n_predict = 64; + int n_predict = 0; static struct option long_options[] = { {"model", required_argument, 0, 'm'}, {"prompt", required_argument, 0, 'p'}, {"context", required_argument, 0, 'c'}, + {"embed-model", required_argument, 0, 'e'}, {"list", no_argument, 0, 'l'}, {"verbose", no_argument, 0, 'v'}, {"help", no_argument, 0, 'h'}, @@ -256,7 +311,7 @@ int main(int argc, char **argv) { int opt; int option_index = 0; - while ((opt = getopt_long(argc, argv, "m:p:c:lvh", long_options, &option_index)) != -1) { + while ((opt = getopt_long(argc, argv, "m:p:c:e:lvh", long_options, &option_index)) != -1) { switch (opt) { case 'm': model_name = optarg; @@ -267,6 +322,9 @@ int main(int argc, char **argv) { case 'c': context_file = optarg; break; + case 'e': + embed_model_name = optarg; + break; case 'v': verbose = 1; break; @@ -320,7 +378,29 @@ int main(int argc, char **argv) { cfg = &models[0]; } - struct llama_model *model = llama_model_load_from_file(cfg->filepath, llama_model_default_params()); + const ModelConfig *embed_cfg = NULL; + if (embed_model_name != NULL) { + embed_cfg = get_model_by_name(embed_model_name); + if (embed_cfg == NULL) { + log_message(stderr, LOG_ERROR, "Unknown embedding model '%s'", embed_model_name); + llama_backend_free(); + return 1; + } + } else if (cfg->embed_model_name != NULL) { + embed_cfg = get_model_by_name(cfg->embed_model_name); + } + if (embed_cfg == NULL) { + embed_cfg = cfg; + } + + if (n_predict <= 0) { + n_predict = cfg->n_predict > 0 ? cfg->n_predict : 128; + } + + struct llama_model_params embed_params = llama_model_default_params(); + embed_params.n_gpu_layers = embed_cfg->n_gpu_layers; + embed_params.use_mmap = embed_cfg->use_mmap; + struct llama_model *model = llama_model_load_from_file(embed_cfg->filepath, embed_params); if (model == NULL) { log_message(stderr, LOG_ERROR, "Unable to load embedding model"); llama_backend_free(); @@ -328,6 +408,8 @@ int main(int argc, char **argv) { } struct llama_context_params cparams = llama_context_default_params(); + cparams.n_ctx = embed_cfg->n_ctx; + cparams.n_batch = embed_cfg->n_batch; cparams.embeddings = true; struct llama_context *embed_ctx = llama_init_from_model(model, cparams); @@ -350,10 +432,13 @@ int main(int argc, char **argv) { } float query[VDB_EMBED_SIZE]; - int results[3]; + int results[5]; + for (int i = 0; i < 5; i++) { + results[i] = -1; + } vdb_embed_query(&db, prompt, query); - vdb_search(&db, query, 3, results); + vdb_search(&db, query, 5, results); size_t context_cap = 1024; size_t context_len = 0; @@ -367,13 +452,15 @@ int main(int argc, char **argv) { } context[0] = '\0'; - for (int i = 0; i < 3; i++) { + for (int i = 0; i < 5; i++) { if (results[i] < 0) { continue; } const char *text = db.docs[results[i]].text; + char header[32]; + int header_len = snprintf(header, sizeof(header), "Snippet %d:\n", i + 1); size_t text_len = strlen(text); - size_t need = context_len + text_len + 2; + size_t need = context_len + (size_t)header_len + text_len + 2; if (need > context_cap) { while (need > context_cap) { context_cap *= 2; @@ -389,6 +476,10 @@ int main(int argc, char **argv) { } context = next; } + if (header_len > 0) { + memcpy(context + context_len, header, (size_t)header_len); + context_len += (size_t)header_len; + } memcpy(context + context_len, text, text_len); context_len += text_len; context[context_len++] = '\n'; |
