summaryrefslogtreecommitdiff
path: root/examples/redis-unstable/modules/vector-sets/w2v.c
diff options
context:
space:
mode:
Diffstat (limited to 'examples/redis-unstable/modules/vector-sets/w2v.c')
-rw-r--r--examples/redis-unstable/modules/vector-sets/w2v.c539
1 files changed, 539 insertions, 0 deletions
diff --git a/examples/redis-unstable/modules/vector-sets/w2v.c b/examples/redis-unstable/modules/vector-sets/w2v.c
new file mode 100644
index 0000000..bcf6338
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/w2v.c
@@ -0,0 +1,539 @@
1/*
2 * HNSW (Hierarchical Navigable Small World) Implementation
3 * Based on the paper by Yu. A. Malkov, D. A. Yashunin
4 *
5 * Copyright (c) 2009-Present, Redis Ltd.
6 * All rights reserved.
7 *
8 * Licensed under your choice of (a) the Redis Source Available License 2.0
9 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
10 * GNU Affero General Public License v3 (AGPLv3).
11 * Originally authored by: Salvatore Sanfilippo
12 */
13
14#define _DEFAULT_SOURCE
15#define _USE_MATH_DEFINES
16#define _POSIX_C_SOURCE 200809L
17
18#include <stdio.h>
19#include <stdlib.h>
20#include <string.h>
21#include <strings.h>
22#include <sys/time.h>
23#include <time.h>
24#include <stdint.h>
25#include <pthread.h>
26#include <stdatomic.h>
27#include <math.h>
28
29#include "hnsw.h"
30
31/* Get current time in milliseconds */
32uint64_t ms_time(void) {
33 struct timeval tv;
34 gettimeofday(&tv, NULL);
35 return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000);
36}
37
38/* Implementation of the recall test with random vectors. */
39void test_recall(HNSW *index, int ef) {
40 const int num_test_vectors = 10000;
41 const int k = 100; // Number of nearest neighbors to find.
42 if (ef < k) ef = k;
43
44 // Add recall distribution counters (2% bins from 0-100%).
45 int recall_bins[50] = {0};
46
47 // Create array to store vectors for mixing.
48 int num_source_vectors = 1000; // Enough, since we mix them.
49 float **source_vectors = malloc(sizeof(float*) * num_source_vectors);
50 if (!source_vectors) {
51 printf("Failed to allocate memory for source vectors\n");
52 return;
53 }
54
55 // Allocate memory for each source vector.
56 for (int i = 0; i < num_source_vectors; i++) {
57 source_vectors[i] = malloc(sizeof(float) * 300);
58 if (!source_vectors[i]) {
59 printf("Failed to allocate memory for source vector %d\n", i);
60 // Clean up already allocated vectors.
61 for (int j = 0; j < i; j++) free(source_vectors[j]);
62 free(source_vectors);
63 return;
64 }
65 }
66
67 /* Populate source vectors from the index, we just scan the
68 * first N items. */
69 int source_count = 0;
70 hnswNode *current = index->head;
71 while (current && source_count < num_source_vectors) {
72 hnsw_get_node_vector(index, current, source_vectors[source_count]);
73 source_count++;
74 current = current->next;
75 }
76
77 if (source_count < num_source_vectors) {
78 printf("Warning: Only found %d nodes for source vectors\n",
79 source_count);
80 num_source_vectors = source_count;
81 }
82
83 // Allocate memory for test vector.
84 float *test_vector = malloc(sizeof(float) * 300);
85 if (!test_vector) {
86 printf("Failed to allocate memory for test vector\n");
87 for (int i = 0; i < num_source_vectors; i++) {
88 free(source_vectors[i]);
89 }
90 free(source_vectors);
91 return;
92 }
93
94 // Allocate memory for results.
95 hnswNode **hnsw_results = malloc(sizeof(hnswNode*) * ef);
96 hnswNode **linear_results = malloc(sizeof(hnswNode*) * ef);
97 float *hnsw_distances = malloc(sizeof(float) * ef);
98 float *linear_distances = malloc(sizeof(float) * ef);
99
100 if (!hnsw_results || !linear_results || !hnsw_distances || !linear_distances) {
101 printf("Failed to allocate memory for results\n");
102 if (hnsw_results) free(hnsw_results);
103 if (linear_results) free(linear_results);
104 if (hnsw_distances) free(hnsw_distances);
105 if (linear_distances) free(linear_distances);
106 for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]);
107 free(source_vectors);
108 free(test_vector);
109 return;
110 }
111
112 // Initialize random seed.
113 srand(time(NULL));
114
115 // Perform recall test.
116 printf("\nPerforming recall test with EF=%d on %d random vectors...\n",
117 ef, num_test_vectors);
118 double total_recall = 0.0;
119
120 for (int t = 0; t < num_test_vectors; t++) {
121 // Create a random vector by mixing 3 existing vectors.
122 float weights[3] = {0.0};
123 int src_indices[3] = {0};
124
125 // Generate random weights.
126 float weight_sum = 0.0;
127 for (int i = 0; i < 3; i++) {
128 weights[i] = (float)rand() / RAND_MAX;
129 weight_sum += weights[i];
130 src_indices[i] = rand() % num_source_vectors;
131 }
132
133 // Normalize weights.
134 for (int i = 0; i < 3; i++) weights[i] /= weight_sum;
135
136 // Mix vectors.
137 memset(test_vector, 0, sizeof(float) * 300);
138 for (int i = 0; i < 3; i++) {
139 for (int j = 0; j < 300; j++) {
140 test_vector[j] +=
141 weights[i] * source_vectors[src_indices[i]][j];
142 }
143 }
144
145 // Perform HNSW search with the specified EF parameter.
146 int slot = hnsw_acquire_read_slot(index);
147 int hnsw_found = hnsw_search(index, test_vector, ef, hnsw_results, hnsw_distances, slot, 0);
148
149 // Perform linear search (ground truth).
150 int linear_found = hnsw_ground_truth_with_filter(index, test_vector, ef, linear_results, linear_distances, slot, 0, NULL, NULL);
151 hnsw_release_read_slot(index, slot);
152
153 // Calculate recall for this query (intersection size / k).
154 if (hnsw_found > k) hnsw_found = k;
155 if (linear_found > k) linear_found = k;
156 int intersection_count = 0;
157 for (int i = 0; i < linear_found; i++) {
158 for (int j = 0; j < hnsw_found; j++) {
159 if (linear_results[i] == hnsw_results[j]) {
160 intersection_count++;
161 break;
162 }
163 }
164 }
165
166 double recall = (double)intersection_count / linear_found;
167 total_recall += recall;
168
169 // Add to distribution bins (2% steps)
170 int bin_index = (int)(recall * 50);
171 if (bin_index >= 50) bin_index = 49; // Handle 100% recall case
172 recall_bins[bin_index]++;
173
174 // Show progress.
175 if ((t+1) % 1000 == 0 || t == num_test_vectors-1) {
176 printf("Processed %d/%d queries, current avg recall: %.2f%%\n",
177 t+1, num_test_vectors, (total_recall / (t+1)) * 100);
178 }
179 }
180
181 // Calculate and print final average recall.
182 double avg_recall = (total_recall / num_test_vectors) * 100;
183 printf("\nRecall Test Results:\n");
184 printf("Average recall@%d (EF=%d): %.2f%%\n", k, ef, avg_recall);
185
186 // Print recall distribution histogram.
187 printf("\nRecall Distribution (2%% bins):\n");
188 printf("================================\n");
189
190 // Find the maximum bin count for scaling.
191 int max_count = 0;
192 for (int i = 0; i < 50; i++) {
193 if (recall_bins[i] > max_count) max_count = recall_bins[i];
194 }
195
196 // Scale factor for histogram (max 50 chars wide)
197 const int max_bars = 50;
198 double scale = (max_count > max_bars) ? (double)max_bars / max_count : 1.0;
199
200 // Print the histogram.
201 for (int i = 0; i < 50; i++) {
202 int bar_len = (int)(recall_bins[i] * scale);
203 printf("%3d%%-%-3d%% | %-6d |", i*2, (i+1)*2, recall_bins[i]);
204 for (int j = 0; j < bar_len; j++) printf("#");
205 printf("\n");
206 }
207
208 // Cleanup.
209 free(hnsw_results);
210 free(linear_results);
211 free(hnsw_distances);
212 free(linear_distances);
213 free(test_vector);
214 for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]);
215 free(source_vectors);
216}
217
218/* Example usage in main() */
219int w2v_single_thread(int m_param, int quantization, uint64_t numele, int massdel, int self_recall, int recall_ef) {
220 /* Create index */
221 HNSW *index = hnsw_new(300, quantization, m_param);
222 float v[300];
223 uint16_t wlen;
224
225 FILE *fp = fopen("word2vec.bin","rb");
226 if (fp == NULL) {
227 perror("word2vec.bin file missing");
228 exit(1);
229 }
230 unsigned char header[8];
231 if (fread(header,8,1,fp) <= 0) { // Skip header
232 perror("Unexpected EOF");
233 exit(1);
234 }
235
236 uint64_t id = 0;
237 uint64_t start_time = ms_time();
238 char *word = NULL;
239 hnswNode *search_node = NULL;
240
241 while(id < numele) {
242 if (fread(&wlen,2,1,fp) == 0) break;
243 word = malloc(wlen+1);
244 if (fread(word,wlen,1,fp) <= 0) {
245 perror("unexpected EOF");
246 exit(1);
247 }
248 word[wlen] = 0;
249 if (fread(v,300*sizeof(float),1,fp) <= 0) {
250 perror("unexpected EOF");
251 exit(1);
252 }
253
254 // Plain API that acquires a write lock for the whole time.
255 hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200);
256
257 if (!strcmp(word,"banana")) search_node = added;
258 if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id);
259 }
260 uint64_t elapsed = ms_time() - start_time;
261 fclose(fp);
262
263 printf("%llu words added (%llu words/sec), last word: %s\n",
264 (unsigned long long)index->node_count,
265 (unsigned long long)id*1000/elapsed, word);
266
267 /* Search query */
268 if (search_node == NULL) search_node = index->head;
269 hnsw_get_node_vector(index,search_node,v);
270 hnswNode *neighbors[10];
271 float distances[10];
272
273 int found, j;
274 start_time = ms_time();
275 for (j = 0; j < 20000; j++)
276 found = hnsw_search(index, v, 10, neighbors, distances, 0, 0);
277 elapsed = ms_time() - start_time;
278 printf("%d searches performed (%llu searches/sec), nodes found: %d\n",
279 j, (unsigned long long)j*1000/elapsed, found);
280
281 if (found > 0) {
282 printf("Found %d neighbors:\n", found);
283 for (int i = 0; i < found; i++) {
284 printf("Node ID: %llu, distance: %f, word: %s\n",
285 (unsigned long long)neighbors[i]->id,
286 distances[i], (char*)neighbors[i]->value);
287 }
288 }
289
290 // Self-recall test (ability to find the node by its own vector).
291 if (self_recall) {
292 hnsw_print_stats(index);
293 hnsw_test_graph_recall(index,200,0);
294 }
295
296 // Recall test with random vectors.
297 if (recall_ef > 0) {
298 test_recall(index, recall_ef);
299 }
300
301 uint64_t connected_nodes;
302 int reciprocal_links;
303 hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
304
305 if (massdel) {
306 int remove_perc = 95;
307 printf("\nRemoving %d%% of nodes...\n", remove_perc);
308 uint64_t initial_nodes = index->node_count;
309
310 hnswNode *current = index->head;
311 while (current && index->node_count > initial_nodes*(100-remove_perc)/100) {
312 hnswNode *next = current->next;
313 hnsw_delete_node(index,current,free);
314 current = next;
315 // In order to don't remove only contiguous nodes, from time
316 // skip a node.
317 if (current && !(random() % remove_perc)) current = current->next;
318 }
319 printf("%llu nodes left\n", (unsigned long long)index->node_count);
320
321 // Test again.
322 hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
323 hnsw_test_graph_recall(index,200,0);
324 }
325
326 hnsw_free(index,free);
327 return 0;
328}
329
330struct threadContext {
331 pthread_mutex_t FileAccessMutex;
332 uint64_t numele;
333 _Atomic uint64_t SearchesDone;
334 _Atomic uint64_t id;
335 FILE *fp;
336 HNSW *index;
337 float *search_vector;
338};
339
340// Note that in practical terms inserting with many concurrent threads
341// may be *slower* and not faster, because there is a lot of
342// contention. So this is more a robustness test than anything else.
343//
344// The optimistic commit API goal is actually to exploit the ability to
345// add faster when there are many concurrent reads.
346void *threaded_insert(void *ctxptr) {
347 struct threadContext *ctx = ctxptr;
348 char *word;
349 float v[300];
350 uint16_t wlen;
351
352 while(1) {
353 pthread_mutex_lock(&ctx->FileAccessMutex);
354 if (fread(&wlen,2,1,ctx->fp) == 0) break;
355 pthread_mutex_unlock(&ctx->FileAccessMutex);
356 word = malloc(wlen+1);
357 if (fread(word,wlen,1,ctx->fp) <= 0) {
358 perror("Unexpected EOF");
359 exit(1);
360 }
361
362 word[wlen] = 0;
363 if (fread(v,300*sizeof(float),1,ctx->fp) <= 0) {
364 perror("Unexpected EOF");
365 exit(1);
366 }
367
368 // Check-and-set API that performs the costly scan for similar
369 // nodes concurrently with other read threads, and finally
370 // applies the check if the graph wasn't modified.
371 InsertContext *ic;
372 uint64_t next_id = ctx->id++;
373 ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, 200);
374 if (hnsw_try_commit_insert(ctx->index, ic, word) == NULL) {
375 // This time try locking since the start.
376 hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200);
377 }
378
379 if (next_id >= ctx->numele) break;
380 if (!((next_id+1) % 10000))
381 printf("%llu added\n", (unsigned long long)next_id+1);
382 }
383 return NULL;
384}
385
386void *threaded_search(void *ctxptr) {
387 struct threadContext *ctx = ctxptr;
388
389 /* Search query */
390 hnswNode *neighbors[10];
391 float distances[10];
392 int found = 0;
393 uint64_t last_id = 0;
394
395 while(ctx->id < 1000000) {
396 int slot = hnsw_acquire_read_slot(ctx->index);
397 found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0);
398 hnsw_release_read_slot(ctx->index,slot);
399 last_id = ++ctx->id;
400 }
401
402 if (found > 0 && last_id == 1000000) {
403 printf("Found %d neighbors:\n", found);
404 for (int i = 0; i < found; i++) {
405 printf("Node ID: %llu, distance: %f, word: %s\n",
406 (unsigned long long)neighbors[i]->id,
407 distances[i], (char*)neighbors[i]->value);
408 }
409 }
410 return NULL;
411}
412
413int w2v_multi_thread(int m_param, int numthreads, int quantization, uint64_t numele) {
414 /* Create index */
415 struct threadContext ctx;
416
417 ctx.index = hnsw_new(300, quantization, m_param);
418
419 ctx.fp = fopen("word2vec.bin","rb");
420 if (ctx.fp == NULL) {
421 perror("word2vec.bin file missing");
422 exit(1);
423 }
424
425 unsigned char header[8];
426 if (fread(header,8,1,ctx.fp) <= 0) { // Skip header
427 perror("Unexpected EOF");
428 exit(1);
429 }
430 pthread_mutex_init(&ctx.FileAccessMutex,NULL);
431
432 uint64_t start_time = ms_time();
433 ctx.id = 0;
434 ctx.numele = numele;
435 pthread_t threads[numthreads];
436 for (int j = 0; j < numthreads; j++)
437 pthread_create(&threads[j], NULL, threaded_insert, &ctx);
438
439 // Wait for all the threads to terminate adding items.
440 for (int j = 0; j < numthreads; j++)
441 pthread_join(threads[j],NULL);
442
443 uint64_t elapsed = ms_time() - start_time;
444 fclose(ctx.fp);
445
446 // Obtain the last word.
447 hnswNode *node = ctx.index->head;
448 char *word = node->value;
449
450 // We will search this last inserted word in the next test.
451 // Let's save its embedding.
452 ctx.search_vector = malloc(sizeof(float)*300);
453 hnsw_get_node_vector(ctx.index,node,ctx.search_vector);
454
455 printf("%llu words added (%llu words/sec), last word: %s\n",
456 (unsigned long long)ctx.index->node_count,
457 (unsigned long long)ctx.id*1000/elapsed, word);
458
459 /* Search query */
460 start_time = ms_time();
461 ctx.id = 0; // We will use this atomic field to stop at N queries done.
462
463 for (int j = 0; j < numthreads; j++)
464 pthread_create(&threads[j], NULL, threaded_search, &ctx);
465
466 // Wait for all the threads to terminate searching.
467 for (int j = 0; j < numthreads; j++)
468 pthread_join(threads[j],NULL);
469
470 elapsed = ms_time() - start_time;
471 printf("%llu searches performed (%llu searches/sec)\n",
472 (unsigned long long)ctx.id,
473 (unsigned long long)ctx.id*1000/elapsed);
474
475 hnsw_print_stats(ctx.index);
476 uint64_t connected_nodes;
477 int reciprocal_links;
478 hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links);
479 printf("%llu connected nodes. Links all reciprocal: %d\n",
480 (unsigned long long)connected_nodes, reciprocal_links);
481 hnsw_free(ctx.index,free);
482 return 0;
483}
484
485int main(int argc, char **argv) {
486 int quantization = HNSW_QUANT_NONE;
487 int numthreads = 0;
488 uint64_t numele = 20000;
489 int m_param = 0; // Default value (0 means use HNSW_DEFAULT_M)
490
491 /* This you can enable in single thread mode for testing: */
492 int massdel = 0; // If true, does the mass deletion test.
493 int self_recall = 0; // If true, does the self-recall test.
494 int recall_ef = 0; // If not 0, does the recall test with this EF value.
495
496 for (int j = 1; j < argc; j++) {
497 int moreargs = argc-j-1;
498
499 if (!strcasecmp(argv[j],"--quant")) {
500 quantization = HNSW_QUANT_Q8;
501 } else if (!strcasecmp(argv[j],"--bin")) {
502 quantization = HNSW_QUANT_BIN;
503 } else if (!strcasecmp(argv[j],"--mass-del")) {
504 massdel = 1;
505 } else if (!strcasecmp(argv[j],"--self-recall")) {
506 self_recall = 1;
507 } else if (moreargs >= 1 && !strcasecmp(argv[j],"--recall")) {
508 recall_ef = atoi(argv[j+1]);
509 j++;
510 } else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) {
511 numthreads = atoi(argv[j+1]);
512 j++;
513 } else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) {
514 numele = strtoll(argv[j+1],NULL,0);
515 j++;
516 if (numele < 1) numele = 1;
517 } else if (moreargs >= 1 && !strcasecmp(argv[j],"--m")) {
518 m_param = atoi(argv[j+1]);
519 j++;
520 } else if (!strcasecmp(argv[j],"--help")) {
521 printf("%s [--quant] [--bin] [--thread <count>] [--numele <count>] [--m <count>] [--mass-del] [--self-recall] [--recall <ef>]\n", argv[0]);
522 exit(0);
523 } else {
524 printf("Unrecognized option or wrong number of arguments: %s\n", argv[j]);
525 exit(1);
526 }
527 }
528
529 if (quantization == HNSW_QUANT_NONE) {
530 printf("You can enable quantization with --quant\n");
531 }
532
533 if (numthreads > 0) {
534 w2v_multi_thread(m_param, numthreads, quantization, numele);
535 } else {
536 printf("Single thread execution. Use --threads 4 for concurrent API\n");
537 w2v_single_thread(m_param, quantization, numele, massdel, self_recall, recall_ef);
538 }
539}