Update readme

Author Mitja Felicijan <mitja.felicijan@gmail.com> 2026-02-18 01:58:42 +0100
Committer Mitja Felicijan <mitja.felicijan@gmail.com> 2026-02-18 01:58:42 +0100
Commit d964db648a06a553a52c49aea82047e0d352dbd0 (patch)
-rw-r--r-- prompt.c 493
1 files changed, 0 insertions, 493 deletions
diff --git a/prompt.c b/prompt.c
1
#include "llama.h"
  
2
#include "vectordb.h"
  
3
#include "models.h"
  
4
  
  
5
#include <stdio.h>
  
6
#include <stdlib.h>
  
7
#include <string.h>
  
8
#include <getopt.h>
  
9
#include <ctype.h>
  
10
  
  
11
#define MAX_TOKENS 512
  
12
#define MAX_TOKEN_LEN 32
  
13
  
  
14
typedef struct {
  
15
  
  
16
} Engine;
  
17
  
  
18
static const char *refusal_text = "I don't have that information.";
  
19
  
  
20
static void llama_log_callback(enum ggml_log_level level, const char *text, void *user_data) {
  
21
    (void)level;
  
22
    (void)user_data;
  
23
    (void)text;
  
24
}
  
25
  
  
26
static int is_stopword(const char *token, size_t len) {
  
27
    static const char *stopwords[] = {
  
28
        "a", "an", "the", "is", "are", "was", "were", "of", "to", "in", "on",
  
29
        "for", "with", "and", "or", "not", "if", "then", "else", "from", "by",
  
30
        "as", "at", "it", "its", "this", "that", "these", "those", "who", "what",
  
31
        "when", "where", "why", "how", "which", "about", "into", "over", "under",
  
32
        "be", "been", "being", "do", "does", "did", "but", "so", "than"
  
33
    };
  
34
    for (size_t i = 0; i < sizeof(stopwords) / sizeof(stopwords[0]); i++) {
  
35
        if (strlen(stopwords[i]) == len && strncmp(stopwords[i], token, len) == 0) {
  
36
            return 1;
  
37
        }
  
38
    }
  
39
    return 0;
  
40
}
  
41
  
  
42
static int token_exists(char tokens[MAX_TOKENS][MAX_TOKEN_LEN], int count, const char *token) {
  
43
    for (int i = 0; i < count; i++) {
  
44
        if (strcmp(tokens[i], token) == 0) {
  
45
            return 1;
  
46
        }
  
47
    }
  
48
    return 0;
  
49
}
  
50
  
  
51
static int collect_tokens(const char *text, char tokens[MAX_TOKENS][MAX_TOKEN_LEN]) {
  
52
    int count = 0;
  
53
    char buf[MAX_TOKEN_LEN];
  
54
    int len = 0;
  
55
    for (const unsigned char *p = (const unsigned char *)text; ; p++) {
  
56
        if (isalnum(*p)) {
  
57
            if (len < MAX_TOKEN_LEN - 1) {
  
58
                buf[len++] = (char)tolower(*p);
  
59
            }
  
60
        } else {
  
61
            if (len > 0) {
  
62
                buf[len] = '\0';
  
63
                if (len >= 4 && !is_stopword(buf, (size_t)len)) {
  
64
                    if (!token_exists(tokens, count, buf) && count < MAX_TOKENS) {
  
65
                        memcpy(tokens[count], buf, (size_t)len + 1);
  
66
                        tokens[count][MAX_TOKEN_LEN - 1] = '\0';
  
67
                        count++;
  
68
                    }
  
69
                }
  
70
                len = 0;
  
71
            }
  
72
            if (*p == '\0') {
  
73
                break;
  
74
            }
  
75
        }
  
76
    }
  
77
    return count;
  
78
}
  
79
  
  
80
static int has_overlap(const char *a, const char *b) {
  
81
    if (a == NULL || b == NULL) {
  
82
        return 0;
  
83
    }
  
84
    char tokens[MAX_TOKENS][MAX_TOKEN_LEN];
  
85
    int token_count = collect_tokens(b, tokens);
  
86
    if (token_count == 0) {
  
87
        return 0;
  
88
    }
  
89
    char buf[MAX_TOKEN_LEN];
  
90
    int len = 0;
  
91
    for (const unsigned char *p = (const unsigned char *)a; ; p++) {
  
92
        if (isalnum(*p)) {
  
93
            if (len < MAX_TOKEN_LEN - 1) {
  
94
                buf[len++] = (char)tolower(*p);
  
95
            }
  
96
        } else {
  
97
            if (len > 0) {
  
98
                buf[len] = '\0';
  
99
                if (len >= 4 && !is_stopword(buf, (size_t)len)) {
  
100
                    if (token_exists(tokens, token_count, buf)) {
  
101
                        return 1;
  
102
                    }
  
103
                }
  
104
                len = 0;
  
105
            }
  
106
            if (*p == '\0') {
  
107
                break;
  
108
            }
  
109
        }
  
110
    }
  
111
    return 0;
  
112
}
  
113
  
  
114
static int execute_prompt(const char *model_name, const char *prompt, const char *context, int n_predict) {
  
115
    const model_config *cfg = NULL;
  
116
    if (model_name != NULL) {
  
117
        cfg = get_model_by_name(model_name);
  
118
        if (cfg == NULL) {
  
119
            fprintf(stderr, "Error: unknown model '%s'\n", model_name);
  
120
            return 1;
  
121
        }
  
122
    } else {
  
123
        cfg = &models[0];
  
124
    }
  
125
  
  
126
    if (!has_overlap(prompt, context)) {
  
127
        printf("------------ Prompt: %s\n", prompt);
  
128
        printf("------------ Response: %s\n", refusal_text);
  
129
        return 0;
  
130
    }
  
131
  
  
132
    ggml_backend_load_all();
  
133
  
  
134
    struct llama_model_params model_params = llama_model_default_params();
  
135
    model_params.n_gpu_layers = cfg->n_gpu_layers;
  
136
    model_params.use_mmap = cfg->use_mmap;
  
137
  
  
138
    struct llama_model *model = llama_model_load_from_file(cfg->filepath, model_params);
  
139
    if (model == NULL) {
  
140
        fprintf(stderr, "Error: unable to load model from %s\n", cfg->filepath);
  
141
        return 1;
  
142
    }
  
143
  
  
144
    const struct llama_vocab *vocab = llama_model_get_vocab(model);
  
145
  
  
146
    const char *system_prefix = "System: Answer using only the Context. If the answer is not explicitly stated in Context, respond exactly: I don't have that information.\n\n";
  
147
    const char *context_prefix = "Context:\n";
  
148
    const char *prompt_prefix = "\n\nQuestion:\n";
  
149
    const char *answer_prefix = "\n\nAnswer:\n";
  
150
    size_t context_len = context ? strlen(context) : 0;
  
151
    size_t prompt_len = strlen(prompt);
  
152
    size_t full_len = strlen(system_prefix) + strlen(context_prefix) + context_len + strlen(prompt_prefix) + prompt_len + strlen(answer_prefix) + 1;
  
153
    char *full_prompt = (char *)malloc(full_len);
  
154
    if (full_prompt == NULL) {
  
155
        fprintf(stderr, "Error: failed to allocate prompt buffer\n");
  
156
        llama_model_free(model);
  
157
        return 1;
  
158
    }
  
159
    snprintf(full_prompt, full_len, "%s%s%s%s%s", system_prefix, context_prefix, context ? context : "", prompt_prefix, prompt);
  
160
    strncat(full_prompt, answer_prefix, full_len - strlen(full_prompt) - 1);
  
161
  
  
162
    int n_prompt = -llama_tokenize(vocab, full_prompt, strlen(full_prompt), NULL, 0, true, true);
  
163
    llama_token *prompt_tokens = (llama_token *)malloc(n_prompt * sizeof(llama_token));
  
164
    if (llama_tokenize(vocab, full_prompt, strlen(full_prompt), prompt_tokens, n_prompt, true, true) < 0) {
  
165
        fprintf(stderr, "Error: failed to tokenize the prompt\n");
  
166
        free(full_prompt);
  
167
        free(prompt_tokens);
  
168
        llama_model_free(model);
  
169
        return 1;
  
170
    }
  
171
  
  
172
    struct llama_context_params ctx_params = llama_context_default_params();
  
173
    ctx_params.n_ctx = cfg->n_ctx;
  
174
    ctx_params.n_batch = cfg->n_batch;
  
175
    ctx_params.embeddings = cfg->embeddings;
  
176
  
  
177
    struct llama_context *ctx = llama_init_from_model(model, ctx_params);
  
178
    if (ctx == NULL) {
  
179
        fprintf(stderr, "Error: failed to create the llama_context\n");
  
180
        free(full_prompt);
  
181
        free(prompt_tokens);
  
182
        llama_model_free(model);
  
183
        return 1;
  
184
    }
  
185
  
  
186
    struct llama_sampler_chain_params sparams = llama_sampler_chain_default_params();
  
187
    struct llama_sampler *smpl = llama_sampler_chain_init(sparams);
  
188
    llama_sampler_chain_add(smpl, llama_sampler_init_temp(cfg->temperature));
  
189
    llama_sampler_chain_add(smpl, llama_sampler_init_min_p(cfg->min_p, 1));
  
190
    llama_sampler_chain_add(smpl, llama_sampler_init_dist(cfg->seed));
  
191
  
  
192
    struct llama_batch batch = llama_batch_get_one(prompt_tokens, n_prompt);
  
193
  
  
194
    if (llama_model_has_encoder(model)) {
  
195
        if (llama_encode(ctx, batch)) {
  
196
            fprintf(stderr, "Error: failed to encode prompt\n");
  
197
            llama_sampler_free(smpl);
  
198
            free(full_prompt);
  
199
            free(prompt_tokens);
  
200
            llama_free(ctx);
  
201
            llama_model_free(model);
  
202
            return 1;
  
203
        }
  
204
  
  
205
        llama_token decoder_start = llama_model_decoder_start_token(model);
  
206
        if (decoder_start == LLAMA_TOKEN_NULL) {
  
207
            decoder_start = llama_vocab_bos(vocab);
  
208
        }
  
209
        batch = llama_batch_get_one(&decoder_start, 1);
  
210
    }
  
211
  
  
212
    printf("------------ Prompt: %s\n", prompt);
  
213
    printf("------------ Response: ");
  
214
    fflush(stdout);
  
215
  
  
216
    int n_pos = 0;
  
217
    llama_token new_token_id;
  
218
    size_t out_cap = 256;
  
219
    size_t out_len = 0;
  
220
    char *out = (char *)malloc(out_cap);
  
221
    if (out == NULL) {
  
222
        fprintf(stderr, "Error: failed to allocate output buffer\n");
  
223
        free(full_prompt);
  
224
        free(prompt_tokens);
  
225
        llama_sampler_free(smpl);
  
226
        llama_free(ctx);
  
227
        llama_model_free(model);
  
228
        return 1;
  
229
    }
  
230
    out[0] = '\0';
  
231
  
  
232
    while (n_pos + batch.n_tokens < n_prompt + n_predict) {
  
233
        if (llama_decode(ctx, batch)) {
  
234
            fprintf(stderr, "Error: failed to decode\n");
  
235
            break;
  
236
        }
  
237
  
  
238
        n_pos += batch.n_tokens;
  
239
  
  
240
        new_token_id = llama_sampler_sample(smpl, ctx, -1);
  
241
  
  
242
        if (llama_vocab_is_eog(vocab, new_token_id)) {
  
243
            break;
  
244
        }
  
245
  
  
246
        char buf[128];
  
247
        int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
  
248
        if (n < 0) {
  
249
            fprintf(stderr, "Error: failed to convert token to piece\n");
  
250
            break;
  
251
        }
  
252
        int stop_at = n;
  
253
        for (int i = 0; i < n; i++) {
  
254
            if (buf[i] == '\n') {
  
255
                stop_at = i;
  
256
                break;
  
257
            }
  
258
        }
  
259
        if (out_len + (size_t)stop_at + 1 > out_cap) {
  
260
            while (out_len + (size_t)stop_at + 1 > out_cap) {
  
261
                out_cap *= 2;
  
262
            }
  
263
            char *next = (char *)realloc(out, out_cap);
  
264
            if (next == NULL) {
  
265
                fprintf(stderr, "Error: failed to grow output buffer\n");
  
266
                break;
  
267
            }
  
268
            out = next;
  
269
        }
  
270
        memcpy(out + out_len, buf, (size_t)stop_at);
  
271
        out_len += (size_t)stop_at;
  
272
        out[out_len] = '\0';
  
273
  
  
274
        if (stop_at != n) {
  
275
            break;
  
276
        }
  
277
  
  
278
        batch = llama_batch_get_one(&new_token_id, 1);
  
279
    }
  
280
  
  
281
    if (!has_overlap(out, context)) {
  
282
        strcpy(out, refusal_text);
  
283
        out_len = strlen(out);
  
284
    }
  
285
  
  
286
    printf("%s\n", out);
  
287
  
  
288
    free(full_prompt);
  
289
    free(prompt_tokens);
  
290
    free(out);
  
291
    llama_sampler_free(smpl);
  
292
    llama_free(ctx);
  
293
    llama_model_free(model);
  
294
  
  
295
    return 0;
  
296
}
  
297
  
  
298
static char *generate_context(const char *model_name, const char *context_file, const char *prompt) {
  
299
    FILE *context_fp = fopen(context_file, "r");
  
300
    if (context_fp == NULL) {
  
301
        fprintf(stderr, "Error: unable to open context file %s\n", context_file);
  
302
        return NULL;
  
303
    }
  
304
  
  
305
    llama_backend_init();
  
306
  
  
307
    const model_config *cfg = NULL;
  
308
    if (model_name != NULL) {
  
309
        cfg = get_model_by_name(model_name);
  
310
        if (cfg == NULL) {
  
311
            fprintf(stderr, "Error: unknown model '%s'\n", model_name);
  
312
            fclose(context_fp);
  
313
            llama_backend_free();
  
314
            return NULL;
  
315
        }
  
316
    } else {
  
317
        cfg = &models[0];
  
318
    }
  
319
  
  
320
    /* struct llama_model *model = llama_load_model_from_file(cfg->filepath, llama_model_default_params()); */
  
321
    struct llama_model *model = llama_model_load_from_file(cfg->filepath, llama_model_default_params());
  
322
    if (model == NULL) {
  
323
        fprintf(stderr, "Error: unable to load embedding model\n");
  
324
        fclose(context_fp);
  
325
        llama_backend_free();
  
326
        return NULL;
  
327
    }
  
328
  
  
329
    struct llama_context_params cparams = llama_context_default_params();
  
330
    cparams.embeddings = true;
  
331
  
  
332
    /* struct llama_context *embed_ctx = llama_new_context_with_model(model, cparams); */
  
333
    struct llama_context *embed_ctx = llama_init_from_model(model, cparams);
  
334
    if (embed_ctx == NULL) {
  
335
        fprintf(stderr, "Error: failed to create embedding context\n");
  
336
        llama_model_free(model);
  
337
        fclose(context_fp);
  
338
        llama_backend_free();
  
339
        return NULL;
  
340
    }
  
341
  
  
342
    VectorDB db;
  
343
    vdb_init(&db, embed_ctx);
  
344
  
  
345
    char line[1024];
  
346
    while (fgets(line, sizeof(line), context_fp) != NULL) {
  
347
        size_t len = strlen(line);
  
348
        while (len > 0 && (line[len - 1] == '\n' || line[len - 1] == '\r')) {
  
349
            line[len - 1] = '\0';
  
350
            len--;
  
351
        }
  
352
        if (len == 0) {
  
353
            continue;
  
354
        }
  
355
        vdb_add_document(&db, line);
  
356
    }
  
357
  
  
358
    float query[VDB_EMBED_SIZE];
  
359
    int results[3];
  
360
  
  
361
    vdb_embed_query(&db, prompt, query);
  
362
    vdb_search(&db, query, 3, results);
  
363
  
  
364
    size_t context_cap = 1024;
  
365
    size_t context_len = 0;
  
366
    char *context = (char *)malloc(context_cap);
  
367
    if (context == NULL) {
  
368
        fprintf(stderr, "Error: failed to allocate context buffer\n");
  
369
        fclose(context_fp);
  
370
        llama_free(embed_ctx);
  
371
        llama_model_free(model);
  
372
        llama_backend_free();
  
373
        return NULL;
  
374
    }
  
375
    context[0] = '\0';
  
376
  
  
377
    for (int i = 0; i < 3; i++) {
  
378
        if (results[i] < 0) {
  
379
            continue;
  
380
        }
  
381
        const char *text = db.docs[results[i]].text;
  
382
        size_t text_len = strlen(text);
  
383
        size_t need = context_len + text_len + 2;
  
384
        if (need > context_cap) {
  
385
            while (need > context_cap) {
  
386
                context_cap *= 2;
  
387
            }
  
388
            char *next = (char *)realloc(context, context_cap);
  
389
            if (next == NULL) {
  
390
                fprintf(stderr, "Error: failed to grow context buffer\n");
  
391
                free(context);
  
392
                fclose(context_fp);
  
393
                llama_free(embed_ctx);
  
394
                llama_model_free(model);
  
395
                llama_backend_free();
  
396
                return NULL;
  
397
            }
  
398
            context = next;
  
399
        }
  
400
        memcpy(context + context_len, text, text_len);
  
401
        context_len += text_len;
  
402
        context[context_len++] = '\n';
  
403
        context[context_len] = '\0';
  
404
    }
  
405
  
  
406
    fclose(context_fp);
  
407
    llama_free(embed_ctx);
  
408
    llama_model_free(model);
  
409
    llama_backend_free();
  
410
  
  
411
    return context;
  
412
}
  
413
  
  
414
static void show_help(const char *prog) {
  
415
    printf("Usage: %s [OPTIONS]\n", prog);
  
416
    printf("Options:\n");
  
417
    printf("  -m, --model <name>    Specify model to use (default: first model)\n");
  
418
    printf("  -p, --prompt <text>   Specify prompt text (default: \"What is 2+2?\")\n");
  
419
    printf("  -b, --build <file>    Specify context file\n");
  
420
    printf("  -c, --context <text>  Specify context file\n");
  
421
    printf("  -v, --verbose         Enable verbose logging\n");
  
422
    printf("  -h, --help            Show this help message\n");
  
423
}
  
424
  
  
425
int main(int argc, char **argv) {
  
426
	/* Engine engine = {}; */
  
427
  
  
428
  
  
429
    const char *model_name = NULL;
  
430
    const char *prompt = NULL;
  
431
    const char *context_file = NULL;
  
432
	int verbose = 0;
  
433
    
  
434
    int n_predict = 64;
  
435
  
  
436
    static struct option long_options[] = {
  
437
        {"model", required_argument, 0, 'm'},
  
438
        {"prompt", required_argument, 0, 'p'},
  
439
        {"context", required_argument, 0, 'c'},
  
440
        {"build", required_argument, 0, 'b'},
  
441
        {"verbose", no_argument, 0, 'v'},
  
442
        {"help", no_argument, 0, 'h'},
  
443
        {0, 0, 0, 0}
  
444
    };
  
445
  
  
446
    int opt;
  
447
    int option_index = 0;
  
448
    while ((opt = getopt_long(argc, argv, "m:p:c:vh", long_options, &option_index)) != -1) {
  
449
        switch (opt) {
  
450
            case 'm':
  
451
                model_name = optarg;
  
452
                break;
  
453
            case 'p':
  
454
                prompt = optarg;
  
455
                break;
  
456
            case 'c':
  
457
                context_file = optarg;
  
458
                break;
  
459
            case 'v':
  
460
                verbose = 1;
  
461
                break;
  
462
            case 'h':
  
463
                show_help(argv[0]);
  
464
                return 0;
  
465
            default:
  
466
                fprintf(stderr, "Usage: %s [-m model] [-p prompt] [-h]\n", argv[0]);
  
467
                return 1;
  
468
        }
  
469
    }
  
470
  
  
471
	if (verbose == 0) {
  
472
		llama_log_set(llama_log_callback, NULL);
  
473
	}
  
474
  
  
475
    if (prompt == NULL) {
  
476
		printf("Prompt must be provided. Exiting...");
  
477
		return 1;
  
478
    }
  
479
  
  
480
    if (context_file == NULL) {
  
481
		printf("Context file must be provided. Exiting...");
  
482
		return 1;
  
483
    }
  
484
  
  
485
    char *context = generate_context(model_name, context_file, prompt);
  
486
    if (context == NULL) {
  
487
        return 1;
  
488
    }
  
489
  
  
490
    int rc = execute_prompt(model_name, prompt, context, n_predict);
  
491
    free(context);
  
492
    return rc;
  
493
}