|
diff --git a/vectordb.c b/vectordb.c
|
| ... |
| 5 |
|
5 |
|
| 6 |
#include "llama.h" |
6 |
#include "llama.h" |
| 7 |
#include "vectordb.h" |
7 |
#include "vectordb.h" |
| 8 |
#include "nonstd.h" |
|
|
| 9 |
|
8 |
|
| 10 |
#define VDB_MAGIC 0x31424456u /* "VDB1" */ |
9 |
// Returns cosine similarity in range [-1, 1] (approx). |
| 11 |
#define VDB_VERSION 1u |
10 |
// https://en.wikipedia.org/wiki/Cosine_similarity |
| 12 |
|
|
|
| 13 |
typedef struct { |
|
|
| 14 |
uint32_t magic; |
|
|
| 15 |
uint32_t version; |
|
|
| 16 |
uint32_t embed_size; |
|
|
| 17 |
uint32_t max_text; |
|
|
| 18 |
uint32_t count; |
|
|
| 19 |
} VdbFileHeader; |
|
|
| 20 |
|
|
|
| 21 |
static float cosine_similarity(float *a, float *b, int n) { |
11 |
static float cosine_similarity(float *a, float *b, int n) { |
| 22 |
float dot = 0, norm_a = 0, norm_b = 0; |
12 |
float dot = 0, norm_a = 0, norm_b = 0; |
| 23 |
for (int i = 0; i < n; i++) { |
13 |
for (int i = 0; i < n; i++) { |
| ... |
| 29 |
} |
19 |
} |
| 30 |
|
20 |
|
| 31 |
static void embed_text(struct llama_context *ctx, const char *text, float *out) { |
21 |
static void embed_text(struct llama_context *ctx, const char *text, float *out) { |
| 32 |
llama_token tokens[512]; |
22 |
llama_token tokens[VDB_TOKENS]; |
| 33 |
const struct llama_model *model = llama_get_model(ctx); |
23 |
const struct llama_model *model = llama_get_model(ctx); |
| 34 |
const struct llama_vocab *vocab = llama_model_get_vocab(model); |
24 |
const struct llama_vocab *vocab = llama_model_get_vocab(model); |
| 35 |
int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, 512, true, true); |
25 |
int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, VDB_TOKENS, true, true); |
| 36 |
if (n_tokens < 0) { |
26 |
if (n_tokens < 0) { |
| 37 |
return; |
27 |
return; |
| 38 |
} |
28 |
} |
| ... |
| 56 |
|
46 |
|
| 57 |
void vdb_add_document(VectorDB *db, const char *text) { |
47 |
void vdb_add_document(VectorDB *db, const char *text) { |
| 58 |
if (db->count >= VDB_MAX_DOCS) { |
48 |
if (db->count >= VDB_MAX_DOCS) { |
| 59 |
log_message(stdout, LOG_INFO, "Vector database full"); |
49 |
printf("Vector database full\n"); |
| 60 |
return; |
50 |
return; |
| 61 |
} |
51 |
} |
| 62 |
|
52 |
|
| ... |
| 64 |
strncpy(doc->text, text, VDB_MAX_TEXT - 1); |
54 |
strncpy(doc->text, text, VDB_MAX_TEXT - 1); |
| 65 |
doc->text[VDB_MAX_TEXT - 1] = 0; |
55 |
doc->text[VDB_MAX_TEXT - 1] = 0; |
| 66 |
|
56 |
|
| 67 |
log_message(stdout, LOG_INFO, "Embedding doc %d...", db->count); |
57 |
printf("Embedding doc %d...\n", db->count); |
| 68 |
embed_text(db->embed_ctx, text, doc->embedding); |
58 |
embed_text(db->embed_ctx, text, doc->embedding); |
| 69 |
} |
59 |
} |
| 70 |
|
60 |
|
| ... |
| 96 |
} |
86 |
} |
| 97 |
} |
87 |
} |
| 98 |
|
88 |
|
| 99 |
int vdb_save(const VectorDB *db, const char *path) { |
89 |
VectorDBErrorCode vdb_save(const VectorDB *db, const char *path) { |
| 100 |
FILE *fp = fopen(path, "wb"); |
90 |
FILE *fp = fopen(path, "wb"); |
| 101 |
if (!fp) { |
91 |
if (!fp) { |
| 102 |
return 1; |
92 |
return VDB_OPEN_ERR; |
| 103 |
} |
93 |
} |
| 104 |
|
94 |
|
| 105 |
VdbFileHeader header = { |
95 |
VdbFileHeader header = { |
| ... |
| 112 |
|
102 |
|
| 113 |
if (fwrite(&header, sizeof(header), 1, fp) != 1) { |
103 |
if (fwrite(&header, sizeof(header), 1, fp) != 1) { |
| 114 |
fclose(fp); |
104 |
fclose(fp); |
| 115 |
return 2; |
105 |
return VDB_HEADER_WRITE_ERR; |
| 116 |
} |
106 |
} |
| 117 |
|
107 |
|
| 118 |
if (db->count > 0) { |
108 |
if (db->count > 0) { |
| 119 |
size_t wrote = fwrite(db->docs, sizeof(VectorDoc), (size_t)db->count, fp); |
109 |
size_t wrote = fwrite(db->docs, sizeof(VectorDoc), (size_t)db->count, fp); |
| 120 |
if (wrote != (size_t)db->count) { |
110 |
if (wrote != (size_t)db->count) { |
| 121 |
fclose(fp); |
111 |
fclose(fp); |
| 122 |
return 3; |
112 |
return VDB_DOC_WRITE_ERR; |
| 123 |
} |
113 |
} |
| 124 |
} |
114 |
} |
| 125 |
|
115 |
|
| 126 |
if (fclose(fp) != 0) { |
116 |
if (fclose(fp) != 0) { |
| 127 |
return 4; |
117 |
return VDB_CLOSE_ERR; |
| 128 |
} |
118 |
} |
| 129 |
|
119 |
|
| 130 |
return 0; |
120 |
return VDB_SUCCESS; |
| 131 |
} |
121 |
} |
| 132 |
|
122 |
|
| 133 |
int vdb_load(VectorDB *db, const char *path) { |
123 |
VectorDBErrorCode vdb_load(VectorDB *db, const char *path) { |
| 134 |
struct llama_context *ctx = db->embed_ctx; |
124 |
struct llama_context *ctx = db->embed_ctx; |
| 135 |
FILE *fp = fopen(path, "rb"); |
125 |
FILE *fp = fopen(path, "rb"); |
| 136 |
if (!fp) { |
126 |
if (!fp) { |
| 137 |
return -1; |
127 |
int open_err = errno; |
|
|
128 |
fprintf(stderr, "vdb_load: open failed: %s\n", strerror(open_err)); |
|
|
129 |
return VDB_OPEN_ERR; |
| 138 |
} |
130 |
} |
| 139 |
|
131 |
|
| 140 |
VdbFileHeader header = {0}; |
132 |
VdbFileHeader header = {0}; |
| 141 |
if (fread(&header, sizeof(header), 1, fp) != 1) { |
133 |
if (fread(&header, sizeof(header), 1, fp) != 1) { |
|
|
134 |
int read_err = errno; |
|
|
135 |
fprintf(stderr, "vdb_load: header read failed: %s\n", strerror(read_err)); |
| 142 |
fclose(fp); |
136 |
fclose(fp); |
| 143 |
return -2; |
137 |
return VDB_HEADER_READ_ERR; |
| 144 |
} |
138 |
} |
| 145 |
|
139 |
|
| 146 |
if (header.magic != VDB_MAGIC || header.version != VDB_VERSION) { |
140 |
if (header.magic != VDB_MAGIC || header.version != VDB_VERSION) { |
| 147 |
fclose(fp); |
141 |
fclose(fp); |
| 148 |
return -3; |
142 |
return VDB_MAGIC_MISMATCH_ERR; |
| 149 |
} |
143 |
} |
| 150 |
|
144 |
|
| 151 |
if (header.embed_size != VDB_EMBED_SIZE || header.max_text != VDB_MAX_TEXT) { |
145 |
if (header.embed_size != VDB_EMBED_SIZE || header.max_text != VDB_MAX_TEXT) { |
| 152 |
fclose(fp); |
146 |
fclose(fp); |
| 153 |
return -4; |
147 |
return VDB_EMBED_MISMATCH_ERR; |
| 154 |
} |
148 |
} |
| 155 |
|
149 |
|
| 156 |
if (header.count > VDB_MAX_DOCS) { |
150 |
if (header.count > VDB_MAX_DOCS) { |
| 157 |
fclose(fp); |
151 |
fclose(fp); |
| 158 |
return -5; |
152 |
return VDB_COUNT_TOO_LARGE_ERR; |
| 159 |
} |
153 |
} |
| 160 |
|
154 |
|
| 161 |
memset(db, 0, sizeof(VectorDB)); |
155 |
memset(db, 0, sizeof(VectorDB)); |
| ... |
| 165 |
if (db->count > 0) { |
159 |
if (db->count > 0) { |
| 166 |
size_t read = fread(db->docs, sizeof(VectorDoc), (size_t)db->count, fp); |
160 |
size_t read = fread(db->docs, sizeof(VectorDoc), (size_t)db->count, fp); |
| 167 |
if (read != (size_t)db->count) { |
161 |
if (read != (size_t)db->count) { |
|
|
162 |
int read_err = errno; |
|
|
163 |
fprintf(stderr, "vdb_load: doc read failed: %s\n", strerror(read_err)); |
| 168 |
fclose(fp); |
164 |
fclose(fp); |
| 169 |
return -6; |
165 |
return VDB_DOC_READ_ERR; |
| 170 |
} |
166 |
} |
| 171 |
} |
167 |
} |
| 172 |
|
168 |
|
| 173 |
if (fclose(fp) != 0) { |
169 |
if (fclose(fp) != 0) { |
| 174 |
return -7; |
170 |
int close_err = errno; |
|
|
171 |
fprintf(stderr, "vdb_load: close failed: %s\n", strerror(close_err)); |
|
|
172 |
return VDB_CLOSE_ERR; |
| 175 |
} |
173 |
} |
| 176 |
|
174 |
|
| 177 |
return 0; |
175 |
return VDB_SUCCESS; |
|
|
176 |
} |
|
|
177 |
|
|
|
178 |
const char* vdb_error(VectorDBErrorCode err) { |
|
|
179 |
switch (err) { |
|
|
180 |
case VDB_SUCCESS: |
|
|
181 |
return "Success."; |
|
|
182 |
case VDB_OPEN_ERR: |
|
|
183 |
return "Failed to open file."; |
|
|
184 |
case VDB_CLOSE_ERR: |
|
|
185 |
return "Failed to close file."; |
|
|
186 |
case VDB_HEADER_WRITE_ERR: |
|
|
187 |
return "Failed to write header."; |
|
|
188 |
case VDB_HEADER_READ_ERR: |
|
|
189 |
return "Failed to read header."; |
|
|
190 |
case VDB_MAGIC_MISMATCH_ERR: |
|
|
191 |
return "Header magic/version mismatch."; |
|
|
192 |
case VDB_EMBED_MISMATCH_ERR: |
|
|
193 |
return "Header embed/max_text mismatch."; |
|
|
194 |
case VDB_COUNT_TOO_LARGE_ERR: |
|
|
195 |
return "Header count too large."; |
|
|
196 |
case VDB_DOC_WRITE_ERR: |
|
|
197 |
return "Failed to write documents."; |
|
|
198 |
case VDB_DOC_READ_ERR: |
|
|
199 |
return "Failed to read documents."; |
|
|
200 |
default: |
|
|
201 |
return "Unknown error."; |
|
|
202 |
} |
| 178 |
} |
203 |
} |
|
diff --git a/vectordb.h b/vectordb.h
|
| ... |
| 3 |
|
3 |
|
| 4 |
#include "llama.h" |
4 |
#include "llama.h" |
| 5 |
|
5 |
|
| 6 |
#define VDB_MAX_DOCS 1000 |
6 |
#include <errno.h> |
| 7 |
#define VDB_EMBED_SIZE 768 |
7 |
|
| 8 |
#define VDB_MAX_TEXT 1024 |
8 |
#define VDB_MAX_DOCS 1000 |
|
|
9 |
#define VDB_EMBED_SIZE 768 |
|
|
10 |
#define VDB_MAX_TEXT 1024 |
|
|
11 |
|
|
|
12 |
#define VDB_MAGIC 0x31424456u /* "VDB1" */ |
|
|
13 |
#define VDB_VERSION 1u |
|
|
14 |
#define VDB_TOKENS 512 |
| 9 |
|
15 |
|
| 10 |
typedef struct { |
16 |
typedef struct { |
| 11 |
float embedding[VDB_EMBED_SIZE]; |
17 |
float embedding[VDB_EMBED_SIZE]; |
| ... |
| 18 |
struct llama_context *embed_ctx; |
24 |
struct llama_context *embed_ctx; |
| 19 |
} VectorDB; |
25 |
} VectorDB; |
| 20 |
|
26 |
|
|
|
27 |
typedef struct { |
|
|
28 |
uint32_t magic; |
|
|
29 |
uint32_t version; |
|
|
30 |
uint32_t embed_size; |
|
|
31 |
uint32_t max_text; |
|
|
32 |
uint32_t count; |
|
|
33 |
} VdbFileHeader; |
|
|
34 |
|
|
|
35 |
typedef enum { |
|
|
36 |
VDB_SUCCESS = 0, |
|
|
37 |
VDB_OPEN_ERR = 9001, |
|
|
38 |
VDB_CLOSE_ERR = 9002, |
|
|
39 |
VDB_HEADER_WRITE_ERR = 9003, |
|
|
40 |
VDB_HEADER_READ_ERR = 9004, |
|
|
41 |
VDB_MAGIC_MISMATCH_ERR = 9005, |
|
|
42 |
VDB_EMBED_MISMATCH_ERR = 9006, |
|
|
43 |
VDB_COUNT_TOO_LARGE_ERR = 9007, |
|
|
44 |
VDB_DOC_WRITE_ERR = 9008, |
|
|
45 |
VDB_DOC_READ_ERR = 9009, |
|
|
46 |
} VectorDBErrorCode; |
|
|
47 |
|
| 21 |
void vdb_init(VectorDB *db, struct llama_context *embed_ctx); |
48 |
void vdb_init(VectorDB *db, struct llama_context *embed_ctx); |
| 22 |
void vdb_free(VectorDB *db); |
49 |
void vdb_free(VectorDB *db); |
| 23 |
|
50 |
|
| ... |
| 26 |
void vdb_embed_query(VectorDB *db, const char *text, float *out_embedding); |
53 |
void vdb_embed_query(VectorDB *db, const char *text, float *out_embedding); |
| 27 |
void vdb_search(VectorDB *db, float *query_embedding, int top_k, int *results); |
54 |
void vdb_search(VectorDB *db, float *query_embedding, int top_k, int *results); |
| 28 |
|
55 |
|
| 29 |
int vdb_save(const VectorDB *db, const char *path); |
56 |
VectorDBErrorCode vdb_save(const VectorDB *db, const char *path); |
| 30 |
int vdb_load(VectorDB *db, const char *path); |
57 |
VectorDBErrorCode vdb_load(VectorDB *db, const char *path); |
|
|
58 |
|
|
|
59 |
const char* vdb_error(VectorDBErrorCode err); |
| 31 |
|
60 |
|
| 32 |
#endif |
61 |
#endif |