1#include <stdio.h>
2#include <string.h>
3#include <math.h>
4#include <stdint.h>
5#include <errno.h>
6
7#include "llama.h"
8#include "vectordb.h"
9
10// Returns cosine similarity in range [-1, 1] (approx).
11// https://en.wikipedia.org/wiki/Cosine_similarity
12static float cosine_similarity(float *a, float *b, int n) {
13 float dot = 0, norm_a = 0, norm_b = 0;
14 for (int i = 0; i < n; i++) {
15 dot += a[i] * b[i];
16 norm_a += a[i] * a[i];
17 norm_b += b[i] * b[i];
18 }
19 return dot / (sqrtf(norm_a) * sqrtf(norm_b) + 1e-8f);
20}
21
22static void embed_text(struct llama_context *ctx, const char *text, float *out) {
23 llama_token tokens[VDB_TOKENS];
24 const struct llama_model *model = llama_get_model(ctx);
25 const struct llama_vocab *vocab = llama_model_get_vocab(model);
26 int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, VDB_TOKENS, true, true);
27 if (n_tokens < 0) {
28 return;
29 }
30
31 struct llama_batch batch = llama_batch_get_one(tokens, n_tokens);
32 llama_decode(ctx, batch);
33
34 const float *emb = llama_get_embeddings(ctx);
35 memcpy(out, emb, sizeof(float) * VDB_EMBED_SIZE);
36
37}
38
39void vdb_init(VectorDB *db, struct llama_context *embed_ctx) {
40 memset(db, 0, sizeof(VectorDB));
41 db->embed_ctx = embed_ctx;
42}
43
44void vdb_free(VectorDB *db) {
45 (void)db;
46}
47
48void vdb_add_document(VectorDB *db, const char *text) {
49 if (db->count >= VDB_MAX_DOCS) {
50 printf("Vector database full\n");
51 return;
52 }
53
54 VectorDoc *doc = &db->docs[db->count++];
55 strncpy(doc->text, text, VDB_MAX_TEXT - 1);
56 doc->text[VDB_MAX_TEXT - 1] = 0;
57
58 printf("Embedding doc %d...\n", db->count);
59 embed_text(db->embed_ctx, text, doc->embedding);
60}
61
62void vdb_embed_query(VectorDB *db, const char *text, float *out_embedding) {
63 embed_text(db->embed_ctx, text, out_embedding);
64}
65
66void vdb_search(VectorDB *db, float *query, int top_k, int *results) {
67 float best_scores[top_k];
68 for (int i = 0; i < top_k; i++) {
69 best_scores[i] = -1.0f;
70 results[i] = -1;
71 }
72
73 for (int i = 0; i < db->count; i++) {
74 float score = cosine_similarity(query, db->docs[i].embedding, VDB_EMBED_SIZE);
75
76 for (int j = 0; j < top_k; j++) {
77 if (score > best_scores[j]) {
78 for (int k = top_k - 1; k > j; k--) {
79 best_scores[k] = best_scores[k - 1];
80 results[k] = results[k - 1];
81 }
82 best_scores[j] = score;
83 results[j] = i;
84 break;
85 }
86 }
87 }
88}
89
90VectorDBErrorCode vdb_save(const VectorDB *db, const char *path) {
91 FILE *fp = fopen(path, "wb");
92 if (!fp) {
93 return VDB_OPEN_ERR;
94 }
95
96 VdbFileHeader header = {
97 .magic = VDB_MAGIC,
98 .version = VDB_VERSION,
99 .embed_size = VDB_EMBED_SIZE,
100 .max_text = VDB_MAX_TEXT,
101 .count = (uint32_t)db->count,
102 };
103
104 if (fwrite(&header, sizeof(header), 1, fp) != 1) {
105 fclose(fp);
106 return VDB_HEADER_WRITE_ERR;
107 }
108
109 if (db->count > 0) {
110 size_t wrote = fwrite(db->docs, sizeof(VectorDoc), (size_t)db->count, fp);
111 if (wrote != (size_t)db->count) {
112 fclose(fp);
113 return VDB_DOC_WRITE_ERR;
114 }
115 }
116
117 if (fclose(fp) != 0) {
118 return VDB_CLOSE_ERR;
119 }
120
121 return VDB_SUCCESS;
122}
123
124VectorDBErrorCode vdb_load(VectorDB *db, const char *path) {
125 struct llama_context *ctx = db->embed_ctx;
126 FILE *fp = fopen(path, "rb");
127 if (!fp) {
128 int open_err = errno;
129 fprintf(stderr, "vdb_load: open failed: %s\n", strerror(open_err));
130 return VDB_OPEN_ERR;
131 }
132
133 VdbFileHeader header = {0};
134 if (fread(&header, sizeof(header), 1, fp) != 1) {
135 int read_err = errno;
136 fprintf(stderr, "vdb_load: header read failed: %s\n", strerror(read_err));
137 fclose(fp);
138 return VDB_HEADER_READ_ERR;
139 }
140
141 if (header.magic != VDB_MAGIC || header.version != VDB_VERSION) {
142 fclose(fp);
143 return VDB_MAGIC_MISMATCH_ERR;
144 }
145
146 if (header.embed_size != VDB_EMBED_SIZE || header.max_text != VDB_MAX_TEXT) {
147 fclose(fp);
148 return VDB_EMBED_MISMATCH_ERR;
149 }
150
151 if (header.count > VDB_MAX_DOCS) {
152 fclose(fp);
153 return VDB_COUNT_TOO_LARGE_ERR;
154 }
155
156 memset(db, 0, sizeof(VectorDB));
157 db->embed_ctx = ctx;
158 db->count = (int)header.count;
159
160 if (db->count > 0) {
161 size_t read = fread(db->docs, sizeof(VectorDoc), (size_t)db->count, fp);
162 if (read != (size_t)db->count) {
163 int read_err = errno;
164 fprintf(stderr, "vdb_load: doc read failed: %s\n", strerror(read_err));
165 fclose(fp);
166 return VDB_DOC_READ_ERR;
167 }
168 }
169
170 if (fclose(fp) != 0) {
171 int close_err = errno;
172 fprintf(stderr, "vdb_load: close failed: %s\n", strerror(close_err));
173 return VDB_CLOSE_ERR;
174 }
175
176 return VDB_SUCCESS;
177}
178
179const char* vdb_error(VectorDBErrorCode err) {
180 switch (err) {
181 case VDB_SUCCESS:
182 return "Success.";
183 case VDB_OPEN_ERR:
184 return "Failed to open file.";
185 case VDB_CLOSE_ERR:
186 return "Failed to close file.";
187 case VDB_HEADER_WRITE_ERR:
188 return "Failed to write header.";
189 case VDB_HEADER_READ_ERR:
190 return "Failed to read header.";
191 case VDB_MAGIC_MISMATCH_ERR:
192 return "Header magic/version mismatch.";
193 case VDB_EMBED_MISMATCH_ERR:
194 return "Header embed/max_text mismatch.";
195 case VDB_COUNT_TOO_LARGE_ERR:
196 return "Header count too large.";
197 case VDB_DOC_WRITE_ERR:
198 return "Failed to write documents.";
199 case VDB_DOC_READ_ERR:
200 return "Failed to read documents.";
201 default:
202 return "Unknown error.";
203 }
204}