diff options
Diffstat (limited to 'vectordb.c')
| -rw-r--r-- | vectordb.c | 87 |
1 files changed, 56 insertions, 31 deletions
@@ -5,19 +5,9 @@ #include "llama.h" #include "vectordb.h" -#include "nonstd.h" - -#define VDB_MAGIC 0x31424456u /* "VDB1" */ -#define VDB_VERSION 1u - -typedef struct { - uint32_t magic; - uint32_t version; - uint32_t embed_size; - uint32_t max_text; - uint32_t count; -} VdbFileHeader; +// Returns cosine similarity in range [-1, 1] (approx). +// https://en.wikipedia.org/wiki/Cosine_similarity static float cosine_similarity(float *a, float *b, int n) { float dot = 0, norm_a = 0, norm_b = 0; for (int i = 0; i < n; i++) { @@ -29,10 +19,10 @@ static float cosine_similarity(float *a, float *b, int n) { } static void embed_text(struct llama_context *ctx, const char *text, float *out) { - llama_token tokens[512]; + llama_token tokens[VDB_TOKENS]; const struct llama_model *model = llama_get_model(ctx); const struct llama_vocab *vocab = llama_model_get_vocab(model); - int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, 512, true, true); + int n_tokens = llama_tokenize(vocab, text, strlen(text), tokens, VDB_TOKENS, true, true); if (n_tokens < 0) { return; } @@ -56,7 +46,7 @@ void vdb_free(VectorDB *db) { void vdb_add_document(VectorDB *db, const char *text) { if (db->count >= VDB_MAX_DOCS) { - log_message(stdout, LOG_INFO, "Vector database full"); + printf("Vector database full\n"); return; } @@ -64,7 +54,7 @@ void vdb_add_document(VectorDB *db, const char *text) { strncpy(doc->text, text, VDB_MAX_TEXT - 1); doc->text[VDB_MAX_TEXT - 1] = 0; - log_message(stdout, LOG_INFO, "Embedding doc %d...", db->count); + printf("Embedding doc %d...\n", db->count); embed_text(db->embed_ctx, text, doc->embedding); } @@ -96,10 +86,10 @@ void vdb_search(VectorDB *db, float *query, int top_k, int *results) { } } -int vdb_save(const VectorDB *db, const char *path) { +VectorDBErrorCode vdb_save(const VectorDB *db, const char *path) { FILE *fp = fopen(path, "wb"); if (!fp) { - return 1; + return VDB_OPEN_ERR; } VdbFileHeader header = { @@ -112,50 +102,54 @@ int vdb_save(const VectorDB *db, const char *path) { if (fwrite(&header, sizeof(header), 1, fp) != 1) { fclose(fp); - return 2; + return VDB_HEADER_WRITE_ERR; } if (db->count > 0) { size_t wrote = fwrite(db->docs, sizeof(VectorDoc), (size_t)db->count, fp); if (wrote != (size_t)db->count) { fclose(fp); - return 3; + return VDB_DOC_WRITE_ERR; } } if (fclose(fp) != 0) { - return 4; + return VDB_CLOSE_ERR; } - return 0; + return VDB_SUCCESS; } -int vdb_load(VectorDB *db, const char *path) { +VectorDBErrorCode vdb_load(VectorDB *db, const char *path) { struct llama_context *ctx = db->embed_ctx; FILE *fp = fopen(path, "rb"); if (!fp) { - return -1; + int open_err = errno; + fprintf(stderr, "vdb_load: open failed: %s\n", strerror(open_err)); + return VDB_OPEN_ERR; } VdbFileHeader header = {0}; if (fread(&header, sizeof(header), 1, fp) != 1) { + int read_err = errno; + fprintf(stderr, "vdb_load: header read failed: %s\n", strerror(read_err)); fclose(fp); - return -2; + return VDB_HEADER_READ_ERR; } if (header.magic != VDB_MAGIC || header.version != VDB_VERSION) { fclose(fp); - return -3; + return VDB_MAGIC_MISMATCH_ERR; } if (header.embed_size != VDB_EMBED_SIZE || header.max_text != VDB_MAX_TEXT) { fclose(fp); - return -4; + return VDB_EMBED_MISMATCH_ERR; } if (header.count > VDB_MAX_DOCS) { fclose(fp); - return -5; + return VDB_COUNT_TOO_LARGE_ERR; } memset(db, 0, sizeof(VectorDB)); @@ -165,14 +159,45 @@ int vdb_load(VectorDB *db, const char *path) { if (db->count > 0) { size_t read = fread(db->docs, sizeof(VectorDoc), (size_t)db->count, fp); if (read != (size_t)db->count) { + int read_err = errno; + fprintf(stderr, "vdb_load: doc read failed: %s\n", strerror(read_err)); fclose(fp); - return -6; + return VDB_DOC_READ_ERR; } } if (fclose(fp) != 0) { - return -7; + int close_err = errno; + fprintf(stderr, "vdb_load: close failed: %s\n", strerror(close_err)); + return VDB_CLOSE_ERR; } - return 0; + return VDB_SUCCESS; +} + +const char* vdb_error(VectorDBErrorCode err) { + switch (err) { + case VDB_SUCCESS: + return "Success."; + case VDB_OPEN_ERR: + return "Failed to open file."; + case VDB_CLOSE_ERR: + return "Failed to close file."; + case VDB_HEADER_WRITE_ERR: + return "Failed to write header."; + case VDB_HEADER_READ_ERR: + return "Failed to read header."; + case VDB_MAGIC_MISMATCH_ERR: + return "Header magic/version mismatch."; + case VDB_EMBED_MISMATCH_ERR: + return "Header embed/max_text mismatch."; + case VDB_COUNT_TOO_LARGE_ERR: + return "Header count too large."; + case VDB_DOC_WRITE_ERR: + return "Failed to write documents."; + case VDB_DOC_READ_ERR: + return "Failed to read documents."; + default: + return "Unknown error."; + } } |
