summaryrefslogtreecommitdiff
path: root/vectordb.c
diff options
context:
space:
mode:
Diffstat (limited to 'vectordb.c')
-rw-r--r--vectordb.c87
1 files changed, 56 insertions, 31 deletions
diff --git a/vectordb.c b/vectordb.c
index b6fae64..3812ecb 100644
--- a/vectordb.c
+++ b/vectordb.c
@@ -5,19 +5,9 @@
#include "llama.h"
#include "vectordb.h"
-#include "nonstd.h"
-
-#define VDB_MAGIC 0x31424456u /* "VDB1" */
-#define VDB_VERSION 1u
-
-typedef struct {
- uint32_t magic;
- uint32_t version;
- uint32_t embed_size;
- uint32_t max_text;
- uint32_t count;
-} VdbFileHeader;
+// Returns cosine similarity in range [-1, 1] (approx).
+// https://en.wikipedia.org/wiki/Cosine_similarity
static float cosine_similarity(float *a, float *b, int n) {
float dot = 0, norm_a = 0, norm_b = 0;
for (int i = 0; i < n; i++) {
@@ -29,10 +19,10 @@ static float cosine_similarity(float *a, float *b, int n) {
}
static void embed_text(struct llama_context *ctx, const char *text, float *out) {
- llama_token tokens[512];
+ llama_token tokens[VDB_TOKENS];
const struct llama_model *model = llama_get_model(ctx);
const struct llama_vocab *vocab = llama_model_get_vocab(model);
- int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, 512, true, true);
+ int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, VDB_TOKENS, true, true);
if (n_tokens < 0) {
return;
}
@@ -56,7 +46,7 @@ void vdb_free(VectorDB *db) {
void vdb_add_document(VectorDB *db, const char *text) {
if (db->count >= VDB_MAX_DOCS) {
- log_message(stdout, LOG_INFO, "Vector database full");
+ printf("Vector database full\n");
return;
}
@@ -64,7 +54,7 @@ void vdb_add_document(VectorDB *db, const char *text) {
strncpy(doc->text, text, VDB_MAX_TEXT - 1);
doc->text[VDB_MAX_TEXT - 1] = 0;
- log_message(stdout, LOG_INFO, "Embedding doc %d...", db->count);
+ printf("Embedding doc %d...\n", db->count);
embed_text(db->embed_ctx, text, doc->embedding);
}
@@ -96,10 +86,10 @@ void vdb_search(VectorDB *db, float *query, int top_k, int *results) {
}
}
-int vdb_save(const VectorDB *db, const char *path) {
+VectorDBErrorCode vdb_save(const VectorDB *db, const char *path) {
FILE *fp = fopen(path, "wb");
if (!fp) {
- return 1;
+ return VDB_OPEN_ERR;
}
VdbFileHeader header = {
@@ -112,50 +102,54 @@ int vdb_save(const VectorDB *db, const char *path) {
if (fwrite(&header, sizeof(header), 1, fp) != 1) {
fclose(fp);
- return 2;
+ return VDB_HEADER_WRITE_ERR;
}
if (db->count > 0) {
size_t wrote = fwrite(db->docs, sizeof(VectorDoc), (size_t)db->count, fp);
if (wrote != (size_t)db->count) {
fclose(fp);
- return 3;
+ return VDB_DOC_WRITE_ERR;
}
}
if (fclose(fp) != 0) {
- return 4;
+ return VDB_CLOSE_ERR;
}
- return 0;
+ return VDB_SUCCESS;
}
-int vdb_load(VectorDB *db, const char *path) {
+VectorDBErrorCode vdb_load(VectorDB *db, const char *path) {
struct llama_context *ctx = db->embed_ctx;
FILE *fp = fopen(path, "rb");
if (!fp) {
- return -1;
+ int open_err = errno;
+ fprintf(stderr, "vdb_load: open failed: %s\n", strerror(open_err));
+ return VDB_OPEN_ERR;
}
VdbFileHeader header = {0};
if (fread(&header, sizeof(header), 1, fp) != 1) {
+ int read_err = errno;
+ fprintf(stderr, "vdb_load: header read failed: %s\n", strerror(read_err));
fclose(fp);
- return -2;
+ return VDB_HEADER_READ_ERR;
}
if (header.magic != VDB_MAGIC || header.version != VDB_VERSION) {
fclose(fp);
- return -3;
+ return VDB_MAGIC_MISMATCH_ERR;
}
if (header.embed_size != VDB_EMBED_SIZE || header.max_text != VDB_MAX_TEXT) {
fclose(fp);
- return -4;
+ return VDB_EMBED_MISMATCH_ERR;
}
if (header.count > VDB_MAX_DOCS) {
fclose(fp);
- return -5;
+ return VDB_COUNT_TOO_LARGE_ERR;
}
memset(db, 0, sizeof(VectorDB));
@@ -165,14 +159,45 @@ int vdb_load(VectorDB *db, const char *path) {
if (db->count > 0) {
size_t read = fread(db->docs, sizeof(VectorDoc), (size_t)db->count, fp);
if (read != (size_t)db->count) {
+ int read_err = errno;
+ fprintf(stderr, "vdb_load: doc read failed: %s\n", strerror(read_err));
fclose(fp);
- return -6;
+ return VDB_DOC_READ_ERR;
}
}
if (fclose(fp) != 0) {
- return -7;
+ int close_err = errno;
+ fprintf(stderr, "vdb_load: close failed: %s\n", strerror(close_err));
+ return VDB_CLOSE_ERR;
}
- return 0;
+ return VDB_SUCCESS;
+}
+
+const char* vdb_error(VectorDBErrorCode err) {
+ switch (err) {
+ case VDB_SUCCESS:
+ return "Success.";
+ case VDB_OPEN_ERR:
+ return "Failed to open file.";
+ case VDB_CLOSE_ERR:
+ return "Failed to close file.";
+ case VDB_HEADER_WRITE_ERR:
+ return "Failed to write header.";
+ case VDB_HEADER_READ_ERR:
+ return "Failed to read header.";
+ case VDB_MAGIC_MISMATCH_ERR:
+ return "Header magic/version mismatch.";
+ case VDB_EMBED_MISMATCH_ERR:
+ return "Header embed/max_text mismatch.";
+ case VDB_COUNT_TOO_LARGE_ERR:
+ return "Header count too large.";
+ case VDB_DOC_WRITE_ERR:
+ return "Failed to write documents.";
+ case VDB_DOC_READ_ERR:
+ return "Failed to read documents.";
+ default:
+ return "Unknown error.";
+ }
}