1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
|
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdint.h>
#include <errno.h>
#include "llama.h"
#include "vectordb.h"
// Returns cosine similarity in range [-1, 1] (approx).
// https://en.wikipedia.org/wiki/Cosine_similarity
static float cosine_similarity(float *a, float *b, int n) {
float dot = 0, norm_a = 0, norm_b = 0;
for (int i = 0; i < n; i++) {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
return dot / (sqrtf(norm_a) * sqrtf(norm_b) + 1e-8f);
}
static void embed_text(struct llama_context *ctx, const char *text, float *out) {
llama_token tokens[VDB_TOKENS];
const struct llama_model *model = llama_get_model(ctx);
const struct llama_vocab *vocab = llama_model_get_vocab(model);
int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, VDB_TOKENS, true, true);
if (n_tokens < 0) {
return;
}
struct llama_batch batch = llama_batch_get_one(tokens, n_tokens);
llama_decode(ctx, batch);
const float *emb = llama_get_embeddings(ctx);
memcpy(out, emb, sizeof(float) * VDB_EMBED_SIZE);
}
void vdb_init(VectorDB *db, struct llama_context *embed_ctx) {
memset(db, 0, sizeof(VectorDB));
db->embed_ctx = embed_ctx;
}
void vdb_free(VectorDB *db) {
(void)db;
}
void vdb_add_document(VectorDB *db, const char *text) {
if (db->count >= VDB_MAX_DOCS) {
printf("Vector database full\n");
return;
}
VectorDoc *doc = &db->docs[db->count++];
strncpy(doc->text, text, VDB_MAX_TEXT - 1);
doc->text[VDB_MAX_TEXT - 1] = 0;
printf("Embedding doc %d...\n", db->count);
embed_text(db->embed_ctx, text, doc->embedding);
}
void vdb_embed_query(VectorDB *db, const char *text, float *out_embedding) {
embed_text(db->embed_ctx, text, out_embedding);
}
void vdb_search(VectorDB *db, float *query, int top_k, int *results) {
float best_scores[top_k];
for (int i = 0; i < top_k; i++) {
best_scores[i] = -1.0f;
results[i] = -1;
}
for (int i = 0; i < db->count; i++) {
float score = cosine_similarity(query, db->docs[i].embedding, VDB_EMBED_SIZE);
for (int j = 0; j < top_k; j++) {
if (score > best_scores[j]) {
for (int k = top_k - 1; k > j; k--) {
best_scores[k] = best_scores[k - 1];
results[k] = results[k - 1];
}
best_scores[j] = score;
results[j] = i;
break;
}
}
}
}
VectorDBErrorCode vdb_save(const VectorDB *db, const char *path) {
FILE *fp = fopen(path, "wb");
if (!fp) {
return VDB_OPEN_ERR;
}
VdbFileHeader header = {
.magic = VDB_MAGIC,
.version = VDB_VERSION,
.embed_size = VDB_EMBED_SIZE,
.max_text = VDB_MAX_TEXT,
.count = (uint32_t)db->count,
};
if (fwrite(&header, sizeof(header), 1, fp) != 1) {
fclose(fp);
return VDB_HEADER_WRITE_ERR;
}
if (db->count > 0) {
size_t wrote = fwrite(db->docs, sizeof(VectorDoc), (size_t)db->count, fp);
if (wrote != (size_t)db->count) {
fclose(fp);
return VDB_DOC_WRITE_ERR;
}
}
if (fclose(fp) != 0) {
return VDB_CLOSE_ERR;
}
return VDB_SUCCESS;
}
VectorDBErrorCode vdb_load(VectorDB *db, const char *path) {
struct llama_context *ctx = db->embed_ctx;
FILE *fp = fopen(path, "rb");
if (!fp) {
int open_err = errno;
fprintf(stderr, "vdb_load: open failed: %s\n", strerror(open_err));
return VDB_OPEN_ERR;
}
VdbFileHeader header = {0};
if (fread(&header, sizeof(header), 1, fp) != 1) {
int read_err = errno;
fprintf(stderr, "vdb_load: header read failed: %s\n", strerror(read_err));
fclose(fp);
return VDB_HEADER_READ_ERR;
}
if (header.magic != VDB_MAGIC || header.version != VDB_VERSION) {
fclose(fp);
return VDB_MAGIC_MISMATCH_ERR;
}
if (header.embed_size != VDB_EMBED_SIZE || header.max_text != VDB_MAX_TEXT) {
fclose(fp);
return VDB_EMBED_MISMATCH_ERR;
}
if (header.count > VDB_MAX_DOCS) {
fclose(fp);
return VDB_COUNT_TOO_LARGE_ERR;
}
memset(db, 0, sizeof(VectorDB));
db->embed_ctx = ctx;
db->count = (int)header.count;
if (db->count > 0) {
size_t read = fread(db->docs, sizeof(VectorDoc), (size_t)db->count, fp);
if (read != (size_t)db->count) {
int read_err = errno;
fprintf(stderr, "vdb_load: doc read failed: %s\n", strerror(read_err));
fclose(fp);
return VDB_DOC_READ_ERR;
}
}
if (fclose(fp) != 0) {
int close_err = errno;
fprintf(stderr, "vdb_load: close failed: %s\n", strerror(close_err));
return VDB_CLOSE_ERR;
}
return VDB_SUCCESS;
}
const char* vdb_error(VectorDBErrorCode err) {
switch (err) {
case VDB_SUCCESS:
return "Success.";
case VDB_OPEN_ERR:
return "Failed to open file.";
case VDB_CLOSE_ERR:
return "Failed to close file.";
case VDB_HEADER_WRITE_ERR:
return "Failed to write header.";
case VDB_HEADER_READ_ERR:
return "Failed to read header.";
case VDB_MAGIC_MISMATCH_ERR:
return "Header magic/version mismatch.";
case VDB_EMBED_MISMATCH_ERR:
return "Header embed/max_text mismatch.";
case VDB_COUNT_TOO_LARGE_ERR:
return "Header count too large.";
case VDB_DOC_WRITE_ERR:
return "Failed to write documents.";
case VDB_DOC_READ_ERR:
return "Failed to read documents.";
default:
return "Unknown error.";
}
}
|