summaryrefslogtreecommitdiff
path: root/examples/redis-unstable/modules/vector-sets/w2v.c
diff options
context:
space:
mode:
Diffstat (limited to 'examples/redis-unstable/modules/vector-sets/w2v.c')
-rw-r--r--examples/redis-unstable/modules/vector-sets/w2v.c539
1 files changed, 539 insertions, 0 deletions
diff --git a/examples/redis-unstable/modules/vector-sets/w2v.c b/examples/redis-unstable/modules/vector-sets/w2v.c
new file mode 100644
index 0000000..bcf6338
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/w2v.c
@@ -0,0 +1,539 @@
+/*
+ * HNSW (Hierarchical Navigable Small World) Implementation
+ * Based on the paper by Yu. A. Malkov, D. A. Yashunin
+ *
+ * Copyright (c) 2009-Present, Redis Ltd.
+ * All rights reserved.
+ *
+ * Licensed under your choice of (a) the Redis Source Available License 2.0
+ * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
+ * GNU Affero General Public License v3 (AGPLv3).
+ * Originally authored by: Salvatore Sanfilippo
+ */
+
+#define _DEFAULT_SOURCE
+#define _USE_MATH_DEFINES
+#define _POSIX_C_SOURCE 200809L
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <strings.h>
+#include <sys/time.h>
+#include <time.h>
+#include <stdint.h>
+#include <pthread.h>
+#include <stdatomic.h>
+#include <math.h>
+
+#include "hnsw.h"
+
+/* Get current time in milliseconds */
+uint64_t ms_time(void) {
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000);
+}
+
+/* Implementation of the recall test with random vectors. */
+void test_recall(HNSW *index, int ef) {
+ const int num_test_vectors = 10000;
+ const int k = 100; // Number of nearest neighbors to find.
+ if (ef < k) ef = k;
+
+ // Add recall distribution counters (2% bins from 0-100%).
+ int recall_bins[50] = {0};
+
+ // Create array to store vectors for mixing.
+ int num_source_vectors = 1000; // Enough, since we mix them.
+ float **source_vectors = malloc(sizeof(float*) * num_source_vectors);
+ if (!source_vectors) {
+ printf("Failed to allocate memory for source vectors\n");
+ return;
+ }
+
+ // Allocate memory for each source vector.
+ for (int i = 0; i < num_source_vectors; i++) {
+ source_vectors[i] = malloc(sizeof(float) * 300);
+ if (!source_vectors[i]) {
+ printf("Failed to allocate memory for source vector %d\n", i);
+ // Clean up already allocated vectors.
+ for (int j = 0; j < i; j++) free(source_vectors[j]);
+ free(source_vectors);
+ return;
+ }
+ }
+
+ /* Populate source vectors from the index, we just scan the
+ * first N items. */
+ int source_count = 0;
+ hnswNode *current = index->head;
+ while (current && source_count < num_source_vectors) {
+ hnsw_get_node_vector(index, current, source_vectors[source_count]);
+ source_count++;
+ current = current->next;
+ }
+
+ if (source_count < num_source_vectors) {
+ printf("Warning: Only found %d nodes for source vectors\n",
+ source_count);
+ num_source_vectors = source_count;
+ }
+
+ // Allocate memory for test vector.
+ float *test_vector = malloc(sizeof(float) * 300);
+ if (!test_vector) {
+ printf("Failed to allocate memory for test vector\n");
+ for (int i = 0; i < num_source_vectors; i++) {
+ free(source_vectors[i]);
+ }
+ free(source_vectors);
+ return;
+ }
+
+ // Allocate memory for results.
+ hnswNode **hnsw_results = malloc(sizeof(hnswNode*) * ef);
+ hnswNode **linear_results = malloc(sizeof(hnswNode*) * ef);
+ float *hnsw_distances = malloc(sizeof(float) * ef);
+ float *linear_distances = malloc(sizeof(float) * ef);
+
+ if (!hnsw_results || !linear_results || !hnsw_distances || !linear_distances) {
+ printf("Failed to allocate memory for results\n");
+ if (hnsw_results) free(hnsw_results);
+ if (linear_results) free(linear_results);
+ if (hnsw_distances) free(hnsw_distances);
+ if (linear_distances) free(linear_distances);
+ for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]);
+ free(source_vectors);
+ free(test_vector);
+ return;
+ }
+
+ // Initialize random seed.
+ srand(time(NULL));
+
+ // Perform recall test.
+ printf("\nPerforming recall test with EF=%d on %d random vectors...\n",
+ ef, num_test_vectors);
+ double total_recall = 0.0;
+
+ for (int t = 0; t < num_test_vectors; t++) {
+ // Create a random vector by mixing 3 existing vectors.
+ float weights[3] = {0.0};
+ int src_indices[3] = {0};
+
+ // Generate random weights.
+ float weight_sum = 0.0;
+ for (int i = 0; i < 3; i++) {
+ weights[i] = (float)rand() / RAND_MAX;
+ weight_sum += weights[i];
+ src_indices[i] = rand() % num_source_vectors;
+ }
+
+ // Normalize weights.
+ for (int i = 0; i < 3; i++) weights[i] /= weight_sum;
+
+ // Mix vectors.
+ memset(test_vector, 0, sizeof(float) * 300);
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 300; j++) {
+ test_vector[j] +=
+ weights[i] * source_vectors[src_indices[i]][j];
+ }
+ }
+
+ // Perform HNSW search with the specified EF parameter.
+ int slot = hnsw_acquire_read_slot(index);
+ int hnsw_found = hnsw_search(index, test_vector, ef, hnsw_results, hnsw_distances, slot, 0);
+
+ // Perform linear search (ground truth).
+ int linear_found = hnsw_ground_truth_with_filter(index, test_vector, ef, linear_results, linear_distances, slot, 0, NULL, NULL);
+ hnsw_release_read_slot(index, slot);
+
+ // Calculate recall for this query (intersection size / k).
+ if (hnsw_found > k) hnsw_found = k;
+ if (linear_found > k) linear_found = k;
+ int intersection_count = 0;
+ for (int i = 0; i < linear_found; i++) {
+ for (int j = 0; j < hnsw_found; j++) {
+ if (linear_results[i] == hnsw_results[j]) {
+ intersection_count++;
+ break;
+ }
+ }
+ }
+
+ double recall = (double)intersection_count / linear_found;
+ total_recall += recall;
+
+ // Add to distribution bins (2% steps)
+ int bin_index = (int)(recall * 50);
+ if (bin_index >= 50) bin_index = 49; // Handle 100% recall case
+ recall_bins[bin_index]++;
+
+ // Show progress.
+ if ((t+1) % 1000 == 0 || t == num_test_vectors-1) {
+ printf("Processed %d/%d queries, current avg recall: %.2f%%\n",
+ t+1, num_test_vectors, (total_recall / (t+1)) * 100);
+ }
+ }
+
+ // Calculate and print final average recall.
+ double avg_recall = (total_recall / num_test_vectors) * 100;
+ printf("\nRecall Test Results:\n");
+ printf("Average recall@%d (EF=%d): %.2f%%\n", k, ef, avg_recall);
+
+ // Print recall distribution histogram.
+ printf("\nRecall Distribution (2%% bins):\n");
+ printf("================================\n");
+
+ // Find the maximum bin count for scaling.
+ int max_count = 0;
+ for (int i = 0; i < 50; i++) {
+ if (recall_bins[i] > max_count) max_count = recall_bins[i];
+ }
+
+ // Scale factor for histogram (max 50 chars wide)
+ const int max_bars = 50;
+ double scale = (max_count > max_bars) ? (double)max_bars / max_count : 1.0;
+
+ // Print the histogram.
+ for (int i = 0; i < 50; i++) {
+ int bar_len = (int)(recall_bins[i] * scale);
+ printf("%3d%%-%-3d%% | %-6d |", i*2, (i+1)*2, recall_bins[i]);
+ for (int j = 0; j < bar_len; j++) printf("#");
+ printf("\n");
+ }
+
+ // Cleanup.
+ free(hnsw_results);
+ free(linear_results);
+ free(hnsw_distances);
+ free(linear_distances);
+ free(test_vector);
+ for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]);
+ free(source_vectors);
+}
+
+/* Example usage in main() */
+int w2v_single_thread(int m_param, int quantization, uint64_t numele, int massdel, int self_recall, int recall_ef) {
+ /* Create index */
+ HNSW *index = hnsw_new(300, quantization, m_param);
+ float v[300];
+ uint16_t wlen;
+
+ FILE *fp = fopen("word2vec.bin","rb");
+ if (fp == NULL) {
+ perror("word2vec.bin file missing");
+ exit(1);
+ }
+ unsigned char header[8];
+ if (fread(header,8,1,fp) <= 0) { // Skip header
+ perror("Unexpected EOF");
+ exit(1);
+ }
+
+ uint64_t id = 0;
+ uint64_t start_time = ms_time();
+ char *word = NULL;
+ hnswNode *search_node = NULL;
+
+ while(id < numele) {
+ if (fread(&wlen,2,1,fp) == 0) break;
+ word = malloc(wlen+1);
+ if (fread(word,wlen,1,fp) <= 0) {
+ perror("unexpected EOF");
+ exit(1);
+ }
+ word[wlen] = 0;
+ if (fread(v,300*sizeof(float),1,fp) <= 0) {
+ perror("unexpected EOF");
+ exit(1);
+ }
+
+ // Plain API that acquires a write lock for the whole time.
+ hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200);
+
+ if (!strcmp(word,"banana")) search_node = added;
+ if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id);
+ }
+ uint64_t elapsed = ms_time() - start_time;
+ fclose(fp);
+
+ printf("%llu words added (%llu words/sec), last word: %s\n",
+ (unsigned long long)index->node_count,
+ (unsigned long long)id*1000/elapsed, word);
+
+ /* Search query */
+ if (search_node == NULL) search_node = index->head;
+ hnsw_get_node_vector(index,search_node,v);
+ hnswNode *neighbors[10];
+ float distances[10];
+
+ int found, j;
+ start_time = ms_time();
+ for (j = 0; j < 20000; j++)
+ found = hnsw_search(index, v, 10, neighbors, distances, 0, 0);
+ elapsed = ms_time() - start_time;
+ printf("%d searches performed (%llu searches/sec), nodes found: %d\n",
+ j, (unsigned long long)j*1000/elapsed, found);
+
+ if (found > 0) {
+ printf("Found %d neighbors:\n", found);
+ for (int i = 0; i < found; i++) {
+ printf("Node ID: %llu, distance: %f, word: %s\n",
+ (unsigned long long)neighbors[i]->id,
+ distances[i], (char*)neighbors[i]->value);
+ }
+ }
+
+ // Self-recall test (ability to find the node by its own vector).
+ if (self_recall) {
+ hnsw_print_stats(index);
+ hnsw_test_graph_recall(index,200,0);
+ }
+
+ // Recall test with random vectors.
+ if (recall_ef > 0) {
+ test_recall(index, recall_ef);
+ }
+
+ uint64_t connected_nodes;
+ int reciprocal_links;
+ hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
+
+ if (massdel) {
+ int remove_perc = 95;
+ printf("\nRemoving %d%% of nodes...\n", remove_perc);
+ uint64_t initial_nodes = index->node_count;
+
+ hnswNode *current = index->head;
+ while (current && index->node_count > initial_nodes*(100-remove_perc)/100) {
+ hnswNode *next = current->next;
+ hnsw_delete_node(index,current,free);
+ current = next;
+ // In order to don't remove only contiguous nodes, from time
+ // skip a node.
+ if (current && !(random() % remove_perc)) current = current->next;
+ }
+ printf("%llu nodes left\n", (unsigned long long)index->node_count);
+
+ // Test again.
+ hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
+ hnsw_test_graph_recall(index,200,0);
+ }
+
+ hnsw_free(index,free);
+ return 0;
+}
+
+struct threadContext {
+ pthread_mutex_t FileAccessMutex;
+ uint64_t numele;
+ _Atomic uint64_t SearchesDone;
+ _Atomic uint64_t id;
+ FILE *fp;
+ HNSW *index;
+ float *search_vector;
+};
+
+// Note that in practical terms inserting with many concurrent threads
+// may be *slower* and not faster, because there is a lot of
+// contention. So this is more a robustness test than anything else.
+//
+// The optimistic commit API goal is actually to exploit the ability to
+// add faster when there are many concurrent reads.
+void *threaded_insert(void *ctxptr) {
+ struct threadContext *ctx = ctxptr;
+ char *word;
+ float v[300];
+ uint16_t wlen;
+
+ while(1) {
+ pthread_mutex_lock(&ctx->FileAccessMutex);
+ if (fread(&wlen,2,1,ctx->fp) == 0) break;
+ pthread_mutex_unlock(&ctx->FileAccessMutex);
+ word = malloc(wlen+1);
+ if (fread(word,wlen,1,ctx->fp) <= 0) {
+ perror("Unexpected EOF");
+ exit(1);
+ }
+
+ word[wlen] = 0;
+ if (fread(v,300*sizeof(float),1,ctx->fp) <= 0) {
+ perror("Unexpected EOF");
+ exit(1);
+ }
+
+ // Check-and-set API that performs the costly scan for similar
+ // nodes concurrently with other read threads, and finally
+ // applies the check if the graph wasn't modified.
+ InsertContext *ic;
+ uint64_t next_id = ctx->id++;
+ ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, 200);
+ if (hnsw_try_commit_insert(ctx->index, ic, word) == NULL) {
+ // This time try locking since the start.
+ hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200);
+ }
+
+ if (next_id >= ctx->numele) break;
+ if (!((next_id+1) % 10000))
+ printf("%llu added\n", (unsigned long long)next_id+1);
+ }
+ return NULL;
+}
+
+void *threaded_search(void *ctxptr) {
+ struct threadContext *ctx = ctxptr;
+
+ /* Search query */
+ hnswNode *neighbors[10];
+ float distances[10];
+ int found = 0;
+ uint64_t last_id = 0;
+
+ while(ctx->id < 1000000) {
+ int slot = hnsw_acquire_read_slot(ctx->index);
+ found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0);
+ hnsw_release_read_slot(ctx->index,slot);
+ last_id = ++ctx->id;
+ }
+
+ if (found > 0 && last_id == 1000000) {
+ printf("Found %d neighbors:\n", found);
+ for (int i = 0; i < found; i++) {
+ printf("Node ID: %llu, distance: %f, word: %s\n",
+ (unsigned long long)neighbors[i]->id,
+ distances[i], (char*)neighbors[i]->value);
+ }
+ }
+ return NULL;
+}
+
+int w2v_multi_thread(int m_param, int numthreads, int quantization, uint64_t numele) {
+ /* Create index */
+ struct threadContext ctx;
+
+ ctx.index = hnsw_new(300, quantization, m_param);
+
+ ctx.fp = fopen("word2vec.bin","rb");
+ if (ctx.fp == NULL) {
+ perror("word2vec.bin file missing");
+ exit(1);
+ }
+
+ unsigned char header[8];
+ if (fread(header,8,1,ctx.fp) <= 0) { // Skip header
+ perror("Unexpected EOF");
+ exit(1);
+ }
+ pthread_mutex_init(&ctx.FileAccessMutex,NULL);
+
+ uint64_t start_time = ms_time();
+ ctx.id = 0;
+ ctx.numele = numele;
+ pthread_t threads[numthreads];
+ for (int j = 0; j < numthreads; j++)
+ pthread_create(&threads[j], NULL, threaded_insert, &ctx);
+
+ // Wait for all the threads to terminate adding items.
+ for (int j = 0; j < numthreads; j++)
+ pthread_join(threads[j],NULL);
+
+ uint64_t elapsed = ms_time() - start_time;
+ fclose(ctx.fp);
+
+ // Obtain the last word.
+ hnswNode *node = ctx.index->head;
+ char *word = node->value;
+
+ // We will search this last inserted word in the next test.
+ // Let's save its embedding.
+ ctx.search_vector = malloc(sizeof(float)*300);
+ hnsw_get_node_vector(ctx.index,node,ctx.search_vector);
+
+ printf("%llu words added (%llu words/sec), last word: %s\n",
+ (unsigned long long)ctx.index->node_count,
+ (unsigned long long)ctx.id*1000/elapsed, word);
+
+ /* Search query */
+ start_time = ms_time();
+ ctx.id = 0; // We will use this atomic field to stop at N queries done.
+
+ for (int j = 0; j < numthreads; j++)
+ pthread_create(&threads[j], NULL, threaded_search, &ctx);
+
+ // Wait for all the threads to terminate searching.
+ for (int j = 0; j < numthreads; j++)
+ pthread_join(threads[j],NULL);
+
+ elapsed = ms_time() - start_time;
+ printf("%llu searches performed (%llu searches/sec)\n",
+ (unsigned long long)ctx.id,
+ (unsigned long long)ctx.id*1000/elapsed);
+
+ hnsw_print_stats(ctx.index);
+ uint64_t connected_nodes;
+ int reciprocal_links;
+ hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links);
+ printf("%llu connected nodes. Links all reciprocal: %d\n",
+ (unsigned long long)connected_nodes, reciprocal_links);
+ hnsw_free(ctx.index,free);
+ return 0;
+}
+
+int main(int argc, char **argv) {
+ int quantization = HNSW_QUANT_NONE;
+ int numthreads = 0;
+ uint64_t numele = 20000;
+ int m_param = 0; // Default value (0 means use HNSW_DEFAULT_M)
+
+ /* This you can enable in single thread mode for testing: */
+ int massdel = 0; // If true, does the mass deletion test.
+ int self_recall = 0; // If true, does the self-recall test.
+ int recall_ef = 0; // If not 0, does the recall test with this EF value.
+
+ for (int j = 1; j < argc; j++) {
+ int moreargs = argc-j-1;
+
+ if (!strcasecmp(argv[j],"--quant")) {
+ quantization = HNSW_QUANT_Q8;
+ } else if (!strcasecmp(argv[j],"--bin")) {
+ quantization = HNSW_QUANT_BIN;
+ } else if (!strcasecmp(argv[j],"--mass-del")) {
+ massdel = 1;
+ } else if (!strcasecmp(argv[j],"--self-recall")) {
+ self_recall = 1;
+ } else if (moreargs >= 1 && !strcasecmp(argv[j],"--recall")) {
+ recall_ef = atoi(argv[j+1]);
+ j++;
+ } else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) {
+ numthreads = atoi(argv[j+1]);
+ j++;
+ } else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) {
+ numele = strtoll(argv[j+1],NULL,0);
+ j++;
+ if (numele < 1) numele = 1;
+ } else if (moreargs >= 1 && !strcasecmp(argv[j],"--m")) {
+ m_param = atoi(argv[j+1]);
+ j++;
+ } else if (!strcasecmp(argv[j],"--help")) {
+ printf("%s [--quant] [--bin] [--thread <count>] [--numele <count>] [--m <count>] [--mass-del] [--self-recall] [--recall <ef>]\n", argv[0]);
+ exit(0);
+ } else {
+ printf("Unrecognized option or wrong number of arguments: %s\n", argv[j]);
+ exit(1);
+ }
+ }
+
+ if (quantization == HNSW_QUANT_NONE) {
+ printf("You can enable quantization with --quant\n");
+ }
+
+ if (numthreads > 0) {
+ w2v_multi_thread(m_param, numthreads, quantization, numele);
+ } else {
+ printf("Single thread execution. Use --threads 4 for concurrent API\n");
+ w2v_single_thread(m_param, quantization, numele, massdel, self_recall, recall_ef);
+ }
+}