1#include <stdio.h>
  2#include <string.h>
  3#include <math.h>
  4#include <stdint.h>
  5#include <errno.h>
  6
  7#include "llama.h"
  8#include "vectordb.h"
  9
 10// Returns cosine similarity in range [-1, 1] (approx).
 11// https://en.wikipedia.org/wiki/Cosine_similarity
 12static float cosine_similarity(float *a, float *b, int n) {
 13	float dot = 0, norm_a = 0, norm_b = 0;
 14	for (int i = 0; i < n; i++) {
 15		dot += a[i] * b[i];
 16		norm_a += a[i] * a[i];
 17		norm_b += b[i] * b[i];
 18	}
 19	return dot / (sqrtf(norm_a) * sqrtf(norm_b) + 1e-8f);
 20}
 21
 22static void embed_text(struct llama_context *ctx, const char *text, float *out) {
 23	llama_token tokens[VDB_TOKENS];
 24	const struct llama_model *model = llama_get_model(ctx);
 25	const struct llama_vocab *vocab = llama_model_get_vocab(model);
 26	int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, VDB_TOKENS, true, true);
 27	if (n_tokens < 0) {
 28		return;
 29	}
 30
 31	struct llama_batch batch = llama_batch_get_one(tokens, n_tokens);
 32	llama_decode(ctx, batch);
 33
 34	const float *emb = llama_get_embeddings(ctx);
 35	memcpy(out, emb, sizeof(float) * VDB_EMBED_SIZE);
 36
 37}
 38
 39void vdb_init(VectorDB *db, struct llama_context *embed_ctx) {
 40	memset(db, 0, sizeof(VectorDB));
 41	db->embed_ctx = embed_ctx;
 42}
 43
 44void vdb_free(VectorDB *db) {
 45	(void)db;
 46}
 47
 48void vdb_add_document(VectorDB *db, const char *text) {
 49	if (db->count >= VDB_MAX_DOCS) {
 50		printf("Vector database full\n");
 51		return;
 52	}
 53
 54	VectorDoc *doc = &db->docs[db->count++];
 55	strncpy(doc->text, text, VDB_MAX_TEXT - 1);
 56	doc->text[VDB_MAX_TEXT - 1] = 0;
 57
 58	printf("Embedding doc %d...\n", db->count);
 59	embed_text(db->embed_ctx, text, doc->embedding);
 60}
 61
 62void vdb_embed_query(VectorDB *db, const char *text, float *out_embedding) {
 63	embed_text(db->embed_ctx, text, out_embedding);
 64}
 65
 66void vdb_search(VectorDB *db, float *query, int top_k, int *results) {
 67	float best_scores[top_k];
 68	for (int i = 0; i < top_k; i++) {
 69		best_scores[i] = -1.0f;
 70		results[i] = -1;
 71	}
 72
 73	for (int i = 0; i < db->count; i++) {
 74		float score = cosine_similarity(query, db->docs[i].embedding, VDB_EMBED_SIZE);
 75
 76		for (int j = 0; j < top_k; j++) {
 77			if (score > best_scores[j]) {
 78				for (int k = top_k - 1; k > j; k--) {
 79					best_scores[k] = best_scores[k - 1];
 80					results[k] = results[k - 1];
 81				}
 82				best_scores[j] = score;
 83				results[j] = i;
 84				break;
 85			}
 86		}
 87	}
 88}
 89
 90VectorDBErrorCode vdb_save(const VectorDB *db, const char *path) {
 91	FILE *fp = fopen(path, "wb");
 92	if (!fp) {
 93		return VDB_OPEN_ERR;
 94	}
 95
 96	VdbFileHeader header = {
 97		.magic = VDB_MAGIC,
 98		.version = VDB_VERSION,
 99		.embed_size = VDB_EMBED_SIZE,
100		.max_text = VDB_MAX_TEXT,
101		.count = (uint32_t)db->count,
102	};
103
104	if (fwrite(&header, sizeof(header), 1, fp) != 1) {
105		fclose(fp);
106		return VDB_HEADER_WRITE_ERR;
107	}
108
109	if (db->count > 0) {
110		size_t wrote = fwrite(db->docs, sizeof(VectorDoc), (size_t)db->count, fp);
111		if (wrote != (size_t)db->count) {
112			fclose(fp);
113			return VDB_DOC_WRITE_ERR;
114		}
115	}
116
117	if (fclose(fp) != 0) {
118		return VDB_CLOSE_ERR;
119	}
120
121	return VDB_SUCCESS;
122}
123
124VectorDBErrorCode vdb_load(VectorDB *db, const char *path) {
125	struct llama_context *ctx = db->embed_ctx;
126	FILE *fp = fopen(path, "rb");
127	if (!fp) {
128		int open_err = errno;
129		fprintf(stderr, "vdb_load: open failed: %s\n", strerror(open_err));
130		return VDB_OPEN_ERR;
131	}
132
133	VdbFileHeader header = {0};
134	if (fread(&header, sizeof(header), 1, fp) != 1) {
135		int read_err = errno;
136		fprintf(stderr, "vdb_load: header read failed: %s\n", strerror(read_err));
137		fclose(fp);
138		return VDB_HEADER_READ_ERR;
139	}
140
141	if (header.magic != VDB_MAGIC || header.version != VDB_VERSION) {
142		fclose(fp);
143		return VDB_MAGIC_MISMATCH_ERR;
144	}
145
146	if (header.embed_size != VDB_EMBED_SIZE || header.max_text != VDB_MAX_TEXT) {
147		fclose(fp);
148		return VDB_EMBED_MISMATCH_ERR;
149	}
150
151	if (header.count > VDB_MAX_DOCS) {
152		fclose(fp);
153		return VDB_COUNT_TOO_LARGE_ERR;
154	}
155
156	memset(db, 0, sizeof(VectorDB));
157	db->embed_ctx = ctx;
158	db->count = (int)header.count;
159
160	if (db->count > 0) {
161		size_t read = fread(db->docs, sizeof(VectorDoc), (size_t)db->count, fp);
162		if (read != (size_t)db->count) {
163			int read_err = errno;
164			fprintf(stderr, "vdb_load: doc read failed: %s\n", strerror(read_err));
165			fclose(fp);
166			return VDB_DOC_READ_ERR;
167		}
168	}
169
170	if (fclose(fp) != 0) {
171		int close_err = errno;
172		fprintf(stderr, "vdb_load: close failed: %s\n", strerror(close_err));
173		return VDB_CLOSE_ERR;
174	}
175
176	return VDB_SUCCESS;
177}
178
179const char* vdb_error(VectorDBErrorCode err) {
180	switch (err) {
181		case VDB_SUCCESS:
182			return "Success.";
183		case VDB_OPEN_ERR:
184			return "Failed to open file.";
185		case VDB_CLOSE_ERR:
186			return "Failed to close file.";
187		case VDB_HEADER_WRITE_ERR:
188			return "Failed to write header.";
189		case VDB_HEADER_READ_ERR:
190			return "Failed to read header.";
191		case VDB_MAGIC_MISMATCH_ERR:
192			return "Header magic/version mismatch.";
193		case VDB_EMBED_MISMATCH_ERR:
194			return "Header embed/max_text mismatch.";
195		case VDB_COUNT_TOO_LARGE_ERR:
196			return "Header count too large.";
197		case VDB_DOC_WRITE_ERR:
198			return "Failed to write documents.";
199		case VDB_DOC_READ_ERR:
200			return "Failed to read documents.";
201		default:
202			return "Unknown error.";
203	}
204}