summaryrefslogtreecommitdiff
path: root/context.c
blob: 66b8cc272d41f50130ad77b00fe99ae0f0472fe0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#include "llama.h"
#include "vectordb.h"
#include "models.h"

#define NONSTD_IMPLEMENTATION
#include "nonstd.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <getopt.h>

static void llama_log_callback(enum ggml_log_level level, const char *text, void *user_data) {
	(void)level;
	(void)user_data;
	(void)text;
}

static void list_available_models() {
	printf("Model list:\n");
	ModelConfig model;
	static_foreach(ModelConfig, model, models) {
		printf(" - %s [ctx: %d, temp: %f]\n", model.name, model.n_ctx, model.temperature);
	}
}

static void show_help(const char *prog) {
	printf("Usage: %s [OPTIONS]\n", prog);
	printf("Options:\n");
	printf("  -m, --model <name>    Specify model to use (default: first model)\n");
	printf("  -i, --in <file>       Specify input context file\n");
	printf("  -o, --out <file>      Specify output vector database file\n");
	printf("  -l, --list            Lists all available models\n");
	printf("  -v, --verbose         Enable verbose logging\n");
	printf("  -h, --help            Show this help message\n");
}

int main(int argc, char **argv) {
	set_log_level(LOG_DEBUG);

	const char *model_name = NULL;
	const char *in_file = NULL;
	const char *out_file = NULL;
	int list_models = 0;
	int verbose = 0;

	static struct option long_options[] = {
		{"model", required_argument, 0, 'm'},
		{"in", required_argument, 0, 'i'},
		{"out", required_argument, 0, 'o'},
		{"list", no_argument, 0, 'l'},
		{"verbose", no_argument, 0, 'v'},
		{"help", no_argument, 0, 'h'},
		{0, 0, 0, 0}
	};

	int opt;
	int option_index = 0;
	while ((opt = getopt_long(argc, argv, "m:i:o:lvh", long_options, &option_index)) != -1) {
		switch (opt) {
			case 'm':
				model_name = optarg;
				break;
			case 'i':
				in_file = optarg;
				break;
			case 'o':
				out_file = optarg;
				break;
			case 'l':
				list_models = 1;
				break;
			case 'v':
				verbose = 1;
				break;
			case 'h':
				show_help(argv[0]);
				return 0;
			default:
				fprintf(stderr, "Usage: %s [-m model] [-i file] [-o file] [-lvh]\n", argv[0]);
				return 1;
		}
	}

	if (verbose == 0) {
		llama_log_set(llama_log_callback, NULL);
	}

	if (list_models == 1) {
		list_available_models();
		return 0;
	}

	if (in_file == NULL) {
		log_message(stderr, LOG_ERROR, "Input context file must be provided. Exiting...");
		return 1;
	}

	if (out_file == NULL) {
		log_message(stderr, LOG_ERROR, "Output vector context file must be provided. Exiting...");
		return 1;
	}

	llama_backend_init();

	const ModelConfig *cfg = NULL;
	if (model_name != NULL) {
		cfg = get_model_by_name(model_name);
		if (cfg == NULL) {
			log_message(stderr, LOG_ERROR, "Unknown model '%s'", model_name);
			llama_backend_free();
			return 1;
		}
	} else {
		cfg = &models[0];
	}

	struct llama_model_params model_params = llama_model_default_params();
	model_params.n_gpu_layers = cfg->n_gpu_layers;
	model_params.use_mmap = cfg->use_mmap;
	struct llama_model *model = llama_model_load_from_file(cfg->filepath, model_params);
	if (model == NULL) {
		log_message(stderr, LOG_ERROR, "Unable to load embedding model");
		llama_backend_free();
		return 1;
	}

	struct llama_context_params cparams = llama_context_default_params();
	cparams.n_ctx = cfg->n_ctx;
	cparams.n_batch = cfg->n_batch;
	cparams.embeddings = true;

	struct llama_context *embed_ctx = llama_init_from_model(model, cparams);
	if (embed_ctx == NULL) {
		log_message(stderr, LOG_ERROR, "Failed to create embedding context");
		llama_model_free(model);
		llama_backend_free();
		return 1;
	}

	FILE *context_fp = fopen(in_file, "r");
	if (context_fp == NULL) {
		log_message(stderr, LOG_ERROR, "Unable to open context file %s", in_file);
		return 1;
	}

	VectorDB db;
	vdb_init(&db, embed_ctx);

	char line[1024];
	while (fgets(line, sizeof(line), context_fp) != NULL) {
		size_t len = strlen(line);
		while (len > 0 && (line[len - 1] == '\n' || line[len - 1] == '\r')) {
			line[len - 1] = '\0';
			len--;
		}
		if (len == 0) {
			continue;
		}
		vdb_add_document(&db, line);
	}

	VectorDBErrorCode vdb_rc = vdb_save(&db, out_file);
	if (vdb_rc != VDB_SUCCESS) {
		log_message(stderr, LOG_ERROR, "Something went wrong saving file %s: %s", out_file, vdb_error(vdb_rc));
		fclose(context_fp);
		return 1;
	}

	log_message(stdout, LOG_INFO, "Context vector database file %s successfully written", out_file);
	fclose(context_fp);
	return 0;
}