aboutsummaryrefslogtreecommitdiff
path: root/examples/redis-unstable/modules/vector-sets/vset.c
diff options
context:
space:
mode:
Diffstat (limited to 'examples/redis-unstable/modules/vector-sets/vset.c')
-rw-r--r--examples/redis-unstable/modules/vector-sets/vset.c2587
1 files changed, 0 insertions, 2587 deletions
diff --git a/examples/redis-unstable/modules/vector-sets/vset.c b/examples/redis-unstable/modules/vector-sets/vset.c
deleted file mode 100644
index 500f8e9..0000000
--- a/examples/redis-unstable/modules/vector-sets/vset.c
+++ /dev/null
@@ -1,2587 +0,0 @@
1/* Redis implementation for vector sets. The data structure itself
2 * is implemented in hnsw.c.
3 *
4 * Copyright (c) 2009-Present, Redis Ltd.
5 * All rights reserved.
6 *
7 * Licensed under your choice of (a) the Redis Source Available License 2.0
8 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
9 * GNU Affero General Public License v3 (AGPLv3).
10 * Originally authored by: Salvatore Sanfilippo.
11 *
12 * ======================== Understand threading model =========================
13 * This code implements threaded operarations for two of the commands:
14 *
15 * 1. VSIM, by default.
16 * 2. VADD, if the CAS option is specified.
17 *
18 * Note that even if the second operation, VADD, is a write operation, only
19 * the neighbors collection for the new node is performed in a thread: then,
20 * the actual insert is performed in the reply callback VADD_CASReply(),
21 * which is executed in the main thread.
22 *
23 * Threaded operations need us to protect various operations with mutexes,
24 * even if a certain degree of protection is already provided by the HNSW
25 * library. Here are a few very important things about this implementation
26 * and the way locking is performed.
27 *
28 * 1. All the write operations are performed in the main Redis thread:
29 * this also include VADD_CASReply() callback, that is called by Redis
30 * internals only in the context of the main thread. However the HNSW
31 * library allows background threads in hnsw_search() (VSIM) to modify
32 * nodes metadata to speedup search (to understand if a node was already
33 * visited), but this only happens after acquiring a specific lock
34 * for a given "read slot".
35 *
36 * 2. We use a global lock for each Vector Set object, called "in_use". This
37 * lock is a read-write lock, and is acquired in read mode by all the
38 * threads that perform reads in the background. It is only acquired in
39 * write mode by vectorSetWaitAllBackgroundClients(): the function acquires
40 * the lock and immediately releases it, with the effect of waiting all the
41 * background threads still running from ending their execution.
42 *
43 * Note that no thread can be spawned, since we only call
44 * vectorSetWaitAllBackgroundClients() from the main Redis thread, that
45 * is also the only thread spawning other threads.
46 *
47 * vectorSetWaitAllBackgroundClients() is used in two ways:
48 * A) When we need to delete a vector set because of (DEL) or other
49 * operations destroying the object, we need to wait that all the
50 * background threads working with this object finished their work.
51 * B) When we modify the HNSW nodes bypassing the normal locking
52 * provided by the HNSW library. This only happens when we update
53 * an existing node attribute so far, in VSETATTR and when we call
54 * VADD to update a node with the SETATTR option.
55 *
56 * 3. Often during read operations performed by Redis commands in the
57 * main thread (VCARD, VEMB, VRANDMEMBER, ...) we don't acquire any
58 * lock at all. The commands run in the main Redis thread, we can only
59 * have, at the same time, background reads against the same data
60 * structure. Note that VSIM_thread() and VADD_thread() still modify the
61 * read slot metadata, that is node->visited_epoch[slot], but as long as
62 * our read commands running in the main thread don't need to use
63 * hnsw_search() or other HNSW functions using the visited epochs slots
64 * we are safe.
65 *
66 * 4. There is a race from the moment we create a thread, passing the
67 * vector set object, to the moment the thread can actually lock the
68 * result win the in_use_lock mutex: as the thread starts, in the meanwhile
69 * a DEL/expire could trigger and remove the object. For this reason
70 * we use an atomic counter that protects our object for this small
71 * time in vectorSetWaitAllBackgroundClients(). This prevents removal
72 * of objects that are about to be taken by threads.
73 *
74 * Note that other competing solutions could be used to fix the problem
75 * but have their set of issues, however they are worth documenting here
76 * and evaluating in the future:
77 *
78 * A. Using a conditional variable we could "wait" for the thread to
79 * acquire the lock. However this means waiting before returning
80 * to the event loop, and would make the command execution slower.
81 * B. We could use again an atomic variable, like we did, but this time
82 * as a refcount for the object, with a vsetAcquire() vsetRelease().
83 * In this case, the command could retain the object in the main thread
84 * before starting the thread, and the thread, after the work is done,
85 * could release it. This way sometimes the object would be freed by
86 * the thread, and it's while now can be safe to do the kind of resource
87 * deallocation that vectorSetReleaseObject() does, given that the
88 * Redis Modules API is not always thread safe this solution may not
89 * be future-proof. However there is to evaluate it better in the
90 * future.
91 * C. We could use the "B" solution but instead of freeing the object
92 * in the thread, in this specific case we could just put it into a
93 * list and defer it for later freeing (for instance in the reply
94 * callback), so that the object is always freed in the main thread.
95 * This would require a list of objects to free.
96 *
97 * However the current solution only disadvantage is the potential busy
98 * loop, but this busy loop in practical terms will almost never do
99 * much: to trigger it, a number of circumnstances must happen: deleting
100 * Vector Set keys while using them, hitting the small window needed to
101 * start the thread and read-lock the mutex.
102 */
103
104#define _DEFAULT_SOURCE
105#define _USE_MATH_DEFINES
106#define _POSIX_C_SOURCE 200809L
107
108#include "../../src/redismodule.h"
109#include <stdio.h>
110#include <stdlib.h>
111#include <ctype.h>
112#include <string.h>
113#include <strings.h>
114#include <stdint.h>
115#include <math.h>
116#include <pthread.h>
117#include <stdatomic.h>
118#include "hnsw.h"
119#include "vset_config.h"
120
121// We inline directly the expression implementation here so that building
122// the module is trivial.
123#include "expr.c"
124
125static RedisModuleType *VectorSetType;
126static uint64_t VectorSetTypeNextId = 0;
127
128// Default EF value if not specified during creation.
129#define VSET_DEFAULT_C_EF 200
130
131// Default EF value if not specified during search.
132#define VSET_DEFAULT_SEARCH_EF 100
133
134// Default num elements returned by VSIM.
135#define VSET_DEFAULT_COUNT 10
136
137/* ========================== Internal data structure ====================== */
138
139/* Our abstract data type needs a dual representation similar to Redis
140 * sorted set: the proximity graph, and also a element -> graph-node map
141 * that will allow us to perform deletions and other operations that have
142 * as input the element itself. */
143struct vsetObject {
144 HNSW *hnsw; // Proximity graph.
145 RedisModuleDict *dict; // Element -> node mapping.
146 float *proj_matrix; // Random projection matrix, NULL if no projection
147 uint32_t proj_input_size; // Input dimension after projection.
148 // Output dimension is implicit in
149 // hnsw->vector_dim.
150 pthread_rwlock_t in_use_lock; // Lock needed to destroy the object safely.
151 uint64_t id; // Unique ID used by threaded VADD to know the
152 // object is still the same.
153 uint64_t numattribs; // Number of nodes associated with an attribute.
154 atomic_int thread_creation_pending; // Number of threads that are currently
155 // pending to lock the object.
156};
157
158/* Each node has two associated values: the associated string (the item
159 * in the set) and potentially a JSON string, that is, the attributes, used
160 * for hybrid search with the VSIM FILTER option. */
161struct vsetNodeVal {
162 RedisModuleString *item;
163 RedisModuleString *attrib;
164};
165
166/* Count the number of set bits in an integer (population count/Hamming weight).
167 * This is a portable implementation that doesn't rely on compiler
168 * extensions. */
169static inline uint32_t bit_count(uint32_t n) {
170 uint32_t count = 0;
171 while (n) {
172 count += n & 1;
173 n >>= 1;
174 }
175 return count;
176}
177
178/* Create a Hadamard-based projection matrix for dimensionality reduction.
179 * Uses {-1, +1} entries with a pattern based on bit operations.
180 * The pattern is matrix[i][j] = (i & j) % 2 == 0 ? 1 : -1
181 * Matrix is scaled by 1/sqrt(input_dim) for normalization.
182 * Returns NULL on allocation failure.
183 *
184 * Note that compared to other approaches (random gaussian weights), what
185 * we have here is deterministic, it means that our replicas will have
186 * the same set of weights. Also this approach seems to work much better
187 * in practice, and the distances between elements are better guaranteed.
188 *
189 * Note that we still save the projection matrix in the RDB file, because
190 * in the future we may change the weights generation, and we want everything
191 * to be backward compatible. */
192float *createProjectionMatrix(uint32_t input_dim, uint32_t output_dim) {
193 float *matrix = RedisModule_Alloc(sizeof(float) * input_dim * output_dim);
194
195 /* Scale factor to normalize the projection. */
196 const float scale = 1.0f / sqrt(input_dim);
197
198 /* Fill the matrix using Hadamard pattern. */
199 for (uint32_t i = 0; i < output_dim; i++) {
200 for (uint32_t j = 0; j < input_dim; j++) {
201 /* Calculate position in the flattened matrix. */
202 uint32_t pos = i * input_dim + j;
203
204 /* Hadamard pattern: use bit operations to determine sign
205 * If the count of 1-bits in the bitwise AND of i and j is even,
206 * the value is 1, otherwise -1. */
207 int value = (bit_count(i & j) % 2 == 0) ? 1 : -1;
208
209 /* Store the scaled value. */
210 matrix[pos] = value * scale;
211 }
212 }
213 return matrix;
214}
215
216/* Apply random projection to input vector. Returns new allocated vector. */
217float *applyProjection(const float *input, const float *proj_matrix,
218 uint32_t input_dim, uint32_t output_dim)
219{
220 float *output = RedisModule_Alloc(sizeof(float) * output_dim);
221
222 for (uint32_t i = 0; i < output_dim; i++) {
223 const float *row = &proj_matrix[i * input_dim];
224 float sum = 0.0f;
225 for (uint32_t j = 0; j < input_dim; j++) {
226 sum += row[j] * input[j];
227 }
228 output[i] = sum;
229 }
230 return output;
231}
232
233/* Create the vector as HNSW+Dictionary combined data structure. */
234struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type, uint32_t hnsw_M) {
235 struct vsetObject *o;
236 o = RedisModule_Alloc(sizeof(*o));
237
238 o->id = VectorSetTypeNextId++;
239 o->hnsw = hnsw_new(dim,quant_type,hnsw_M);
240 if (!o->hnsw) { // May fail because of mutex creation.
241 RedisModule_Free(o);
242 return NULL;
243 }
244
245 o->dict = RedisModule_CreateDict(NULL);
246 o->proj_matrix = NULL;
247 o->proj_input_size = 0;
248 o->numattribs = 0;
249 o->thread_creation_pending = 0;
250 RedisModule_Assert(pthread_rwlock_init(&o->in_use_lock,NULL) == 0);
251 return o;
252}
253
254void vectorSetReleaseNodeValue(void *v) {
255 struct vsetNodeVal *nv = v;
256 RedisModule_FreeString(NULL,nv->item);
257 if (nv->attrib) RedisModule_FreeString(NULL,nv->attrib);
258 RedisModule_Free(nv);
259}
260
261/* Free the vector set object. */
262void vectorSetReleaseObject(struct vsetObject *o) {
263 if (!o) return;
264 if (o->hnsw) hnsw_free(o->hnsw,vectorSetReleaseNodeValue);
265 if (o->dict) RedisModule_FreeDict(NULL,o->dict);
266 if (o->proj_matrix) RedisModule_Free(o->proj_matrix);
267 pthread_rwlock_destroy(&o->in_use_lock);
268 RedisModule_Free(o);
269}
270
271/* Wait for all the threads performing operations on this
272 * index to terminate their work (locking for write will
273 * wait for all the other threads).
274 *
275 * if 'for_del' is set to 1, we also wait for all the pending threads
276 * that still didn't acquire the lock to finish their work. This
277 * is useful only if we are going to call this function to delete
278 * the object, and not if we want to just to modify it. */
279void vectorSetWaitAllBackgroundClients(struct vsetObject *vset, int for_del) {
280 if (for_del) {
281 // If we are going to destroy the object, after this call, let's
282 // wait for threads that are being created and still didn't had
283 // a chance to acquire the lock.
284 while (vset->thread_creation_pending > 0);
285 }
286 RedisModule_Assert(pthread_rwlock_wrlock(&vset->in_use_lock) == 0);
287 pthread_rwlock_unlock(&vset->in_use_lock);
288}
289
290/* Return a string representing the quantization type name of a vector set. */
291const char *vectorSetGetQuantName(struct vsetObject *o) {
292 switch(o->hnsw->quant_type) {
293 case HNSW_QUANT_NONE: return "f32";
294 case HNSW_QUANT_Q8: return "int8";
295 case HNSW_QUANT_BIN: return "bin";
296 default: return "unknown";
297 }
298}
299
300/* Insert the specified element into the Vector Set.
301 * If update is '1', the existing node will be updated.
302 *
303 * Returns 1 if the element was added, or 0 if the element was already there
304 * and was just updated. */
305int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, RedisModuleString *attrib, int update, int ef)
306{
307 hnswNode *node = RedisModule_DictGet(o->dict,val,NULL);
308 if (node != NULL) {
309 if (update) {
310 /* Wait for clients in the background: background VSIM
311 * operations touch the nodes attributes we are going
312 * to touch. */
313 vectorSetWaitAllBackgroundClients(o,0);
314
315 struct vsetNodeVal *nv = node->value;
316 /* Pass NULL as value-free function. We want to reuse
317 * the old value. */
318 hnsw_delete_node(o->hnsw, node, NULL);
319 node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef);
320 RedisModule_Assert(node != NULL);
321 RedisModule_DictReplace(o->dict,val,node);
322
323 /* If attrib != NULL, the user wants that in case of an update we
324 * update the attribute as well (otherwise it remains as it was).
325 * Note that the order of operations is conceinved so that it
326 * works in case the old attrib and the new attrib pointer is the
327 * same. */
328 if (attrib) {
329 // Empty attribute string means: unset the attribute during
330 // the update.
331 size_t attrlen;
332 RedisModule_StringPtrLen(attrib,&attrlen);
333 if (attrlen != 0) {
334 RedisModule_RetainString(NULL,attrib);
335 o->numattribs++;
336 } else {
337 attrib = NULL;
338 }
339
340 if (nv->attrib) {
341 o->numattribs--;
342 RedisModule_FreeString(NULL,nv->attrib);
343 }
344 nv->attrib = attrib;
345 }
346 }
347 return 0;
348 }
349
350 struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
351 nv->item = val;
352 nv->attrib = attrib;
353 node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef);
354 if (node == NULL) {
355 // XXX Technically in Redis-land we don't have out of memory, as we
356 // crash on OOM. However the HNSW library may fail for error in the
357 // locking libc call. Probably impossible in practical terms.
358 RedisModule_Free(nv);
359 return 0;
360 }
361 if (attrib != NULL) o->numattribs++;
362 RedisModule_DictSet(o->dict,val,node);
363 RedisModule_RetainString(NULL,val);
364 if (attrib) RedisModule_RetainString(NULL,attrib);
365 return 1;
366}
367
368/* Parse vector from FP32 blob or VALUES format, with optional REDUCE.
369 * Format: [REDUCE dim] FP32|VALUES ...
370 * Returns allocated vector and sets dimension in *dim.
371 * If reduce_dim is not NULL, sets it to the requested reduction dimension.
372 * Returns NULL on parsing error.
373 *
374 * The function sets as a reference *consumed_args, so that the caller
375 * knows how many arguments we consumed in order to parse the input
376 * vector. Remaining arguments are often command options. */
377float *parseVector(RedisModuleString **argv, int argc, int start_idx,
378 size_t *dim, uint32_t *reduce_dim, int *consumed_args)
379{
380 int consumed = 0; // Arguments consumed
381
382 /* Check for REDUCE option first. */
383 if (reduce_dim) *reduce_dim = 0;
384 if (reduce_dim && argc > start_idx + 2 &&
385 !strcasecmp(RedisModule_StringPtrLen(argv[start_idx],NULL),"REDUCE"))
386 {
387 long long rdim;
388 if (RedisModule_StringToLongLong(argv[start_idx+1],&rdim)
389 != REDISMODULE_OK || rdim <= 0)
390 {
391 return NULL;
392 }
393 if (reduce_dim) *reduce_dim = rdim;
394 start_idx += 2; // Skip REDUCE and its argument.
395 consumed += 2;
396 }
397
398 /* Now parse the vector format as before. */
399 float *vec = NULL;
400 const char *vec_format = RedisModule_StringPtrLen(argv[start_idx],NULL);
401
402 if (!strcasecmp(vec_format,"FP32")) {
403 if (argc < start_idx + 2) return NULL; // Need FP32 + vector + value.
404 size_t vec_raw_len;
405 const char *blob =
406 RedisModule_StringPtrLen(argv[start_idx+1],&vec_raw_len);
407
408 // Must be 4 bytes per component.
409 if (vec_raw_len % 4 || vec_raw_len < 4) return NULL;
410 *dim = vec_raw_len/4;
411
412 vec = RedisModule_Alloc(vec_raw_len);
413 if (!vec) return NULL;
414 memcpy(vec,blob,vec_raw_len);
415 consumed += 2;
416 } else if (!strcasecmp(vec_format,"VALUES")) {
417 if (argc < start_idx + 2) return NULL; // Need at least the dimension.
418 long long vdim; // Vector dimension passed by the user.
419 if (RedisModule_StringToLongLong(argv[start_idx+1],&vdim)
420 != REDISMODULE_OK || vdim < 1) return NULL;
421
422 // Check that all the arguments are available.
423 if (argc < start_idx + 2 + vdim) return NULL;
424
425 *dim = vdim;
426 vec = RedisModule_Alloc(sizeof(float) * vdim);
427 if (!vec) return NULL;
428
429 for (int j = 0; j < vdim; j++) {
430 double val;
431 if (RedisModule_StringToDouble(argv[start_idx+2+j],&val)
432 != REDISMODULE_OK)
433 {
434 RedisModule_Free(vec);
435 return NULL;
436 }
437 vec[j] = val;
438 }
439 consumed += vdim + 2;
440 } else {
441 return NULL; // Unknown format.
442 }
443
444 if (consumed_args) *consumed_args = consumed;
445 return vec;
446}
447
448/* ========================== Commands implementation ======================= */
449
450/* VADD thread handling the "CAS" version of the command, that is
451 * performed blocking the client, accumulating here, in the thread, the
452 * set of potential candidates, and later inserting the element in the
453 * key (if it still exists, and if it is still the *same* vector set)
454 * in the Reply callback. */
455void *VADD_thread(void *arg) {
456 pthread_detach(pthread_self());
457
458 void **targ = (void**)arg;
459 RedisModuleBlockedClient *bc = targ[0];
460 struct vsetObject *vset = targ[1];
461 float *vec = targ[3];
462 int ef = (uint64_t)targ[6];
463
464 /* Lock the object and signal that we are no longer pending
465 * the lock acquisition. */
466 RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0);
467 vset->thread_creation_pending--;
468
469 /* Look for candidates... */
470 InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, ef);
471 targ[5] = ic; // Pass the context to the reply callback.
472
473 /* Unblock the client so that our read reply will be invoked. */
474 pthread_rwlock_unlock(&vset->in_use_lock);
475 RedisModule_BlockedClientMeasureTimeEnd(bc);
476 RedisModule_UnblockClient(bc,targ); // Use targ as privdata.
477 return NULL;
478}
479
480/* Reply callback for CAS variant of VADD.
481 * Note: this is called in the main thread, in the background thread
482 * we just do the read operation of gathering the neighbors. */
483int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
484 (void)argc;
485 RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
486
487 int retval = REDISMODULE_OK;
488 void **targ = (void**)RedisModule_GetBlockedClientPrivateData(ctx);
489 uint64_t vset_id = (unsigned long) targ[2];
490 float *vec = targ[3];
491 RedisModuleString *val = targ[4];
492 InsertContext *ic = targ[5];
493 int ef = (uint64_t)targ[6];
494 RedisModuleString *attrib = targ[7];
495 RedisModule_Free(targ);
496
497 /* Open the key: there are no guarantees it still exists, or contains
498 * a vector set, or even the SAME vector set. */
499 RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1],
500 REDISMODULE_READ|REDISMODULE_WRITE);
501 int type = RedisModule_KeyType(key);
502 struct vsetObject *vset = NULL;
503
504 if (type != REDISMODULE_KEYTYPE_EMPTY &&
505 RedisModule_ModuleTypeGetType(key) == VectorSetType)
506 {
507 vset = RedisModule_ModuleTypeGetValue(key);
508 // Same vector set?
509 if (vset->id != vset_id) vset = NULL;
510
511 /* Also, if the element was already inserted, we just pretend
512 * the other insert won. We don't even start a threaded VADD
513 * if this was an update, since the deletion of the element itself
514 * in order to perform the update would invalidate the CAS state. */
515 if (vset && RedisModule_DictGet(vset->dict,val,NULL) != NULL)
516 vset = NULL;
517 }
518
519 if (vset == NULL) {
520 /* If the object does not match the start of the operation, we
521 * just pretend the VADD was performed BEFORE the key was deleted
522 * or replaced. We return success but don't do anything. */
523 hnsw_free_insert_context(ic);
524 } else {
525 /* Otherwise try to insert the new element with the neighbors
526 * collected in background. If we fail, do it synchronously again
527 * from scratch. */
528
529 // First: allocate the dual-ported value for the node.
530 struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
531 nv->item = val;
532 nv->attrib = attrib;
533
534 /* Then: insert the node in the HNSW data structure. Note that
535 * 'ic' could be NULL in case hnsw_prepare_insert() failed because of
536 * locking failure (likely impossible in practical terms). */
537 hnswNode *newnode;
538 if (ic == NULL ||
539 (newnode = hnsw_try_commit_insert(vset->hnsw, ic, nv)) == NULL)
540 {
541 /* If we are here, the CAS insert failed. We need to insert
542 * again with full locking for neighbors selection and
543 * actual insertion. This time we can't fail: */
544 newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, nv, ef);
545 RedisModule_Assert(newnode != NULL);
546 }
547 RedisModule_DictSet(vset->dict,val,newnode);
548 val = NULL; // Don't free it later.
549 attrib = NULL; // Don't free it later.
550
551 RedisModule_ReplicateVerbatim(ctx);
552 }
553
554 // Whatever happens is a success... :D
555 RedisModule_ReplyWithBool(ctx,1);
556 if (val) RedisModule_FreeString(ctx,val); // Not added? Free it.
557 if (attrib) RedisModule_FreeString(ctx,attrib); // Not added? Free it.
558 RedisModule_Free(vec);
559 return retval;
560}
561
562/* VADD key [REDUCE dim] FP32|VALUES vector value [CAS] [NOQUANT] [BIN] [Q8]
563 * [M count] */
564int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
565 RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
566
567 if (argc < 5) return RedisModule_WrongArity(ctx);
568
569 /* Parse vector with optional REDUCE */
570 size_t dim = 0;
571 uint32_t reduce_dim = 0;
572 int consumed_args;
573 int cas = 0; // Threaded check-and-set style insert.
574 long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes.
575 long long hnsw_create_M = HNSW_DEFAULT_M; // HNSW creation default M value.
576 float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args);
577 RedisModuleString *attrib = NULL; // Attributes if passed via ATTRIB.
578 if (!vec)
579 return RedisModule_ReplyWithError(ctx,"ERR invalid vector specification");
580
581 /* Missing element string at the end? */
582 if (argc-2-consumed_args < 1) {
583 RedisModule_Free(vec);
584 return RedisModule_WrongArity(ctx);
585 }
586
587 /* Parse options after the element string. */
588 uint32_t quant_type = HNSW_QUANT_Q8; // Default quantization type.
589
590 for (int j = 2 + consumed_args + 1; j < argc; j++) {
591 const char *opt = RedisModule_StringPtrLen(argv[j], NULL);
592 if (!strcasecmp(opt, "CAS")) {
593 cas = 1;
594 } else if (!strcasecmp(opt, "EF") && j+1 < argc) {
595 if (RedisModule_StringToLongLong(argv[j+1], &ef)
596 != REDISMODULE_OK || ef <= 0 || ef > 1000000)
597 {
598 RedisModule_Free(vec);
599 return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
600 }
601 j++; // skip argument.
602 } else if (!strcasecmp(opt, "M") && j+1 < argc) {
603 if (RedisModule_StringToLongLong(argv[j+1], &hnsw_create_M)
604 != REDISMODULE_OK || hnsw_create_M < HNSW_MIN_M ||
605 hnsw_create_M > HNSW_MAX_M)
606 {
607 RedisModule_Free(vec);
608 return RedisModule_ReplyWithError(ctx, "ERR invalid M");
609 }
610 j++; // skip argument.
611 } else if (!strcasecmp(opt, "SETATTR") && j+1 < argc) {
612 attrib = argv[j+1];
613 j++; // skip argument.
614 } else if (!strcasecmp(opt, "NOQUANT")) {
615 quant_type = HNSW_QUANT_NONE;
616 } else if (!strcasecmp(opt, "BIN")) {
617 quant_type = HNSW_QUANT_BIN;
618 } else if (!strcasecmp(opt, "Q8")) {
619 quant_type = HNSW_QUANT_Q8;
620 } else {
621 RedisModule_Free(vec);
622 return RedisModule_ReplyWithError(ctx,"ERR invalid option after element");
623 }
624 }
625
626 /* Drop CAS if this is a replica and we are getting the command from the
627 * replication link: we want to add/delete items in the same order as
628 * the master, while with CAS the timing would be different.
629 *
630 * Also for Lua scripts and MULTI/EXEC, we want to run the command
631 * on the main thread. */
632 if (RedisModule_GetContextFlags(ctx) &
633 (REDISMODULE_CTX_FLAGS_REPLICATED|
634 REDISMODULE_CTX_FLAGS_LUA|
635 REDISMODULE_CTX_FLAGS_MULTI))
636 {
637 cas = 0;
638 }
639
640 if (VSGlobalConfig.forceSingleThreadExec) {
641 cas = 0;
642 }
643
644 /* Open/create key */
645 RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1],
646 REDISMODULE_READ|REDISMODULE_WRITE);
647 int type = RedisModule_KeyType(key);
648 if (type != REDISMODULE_KEYTYPE_EMPTY &&
649 RedisModule_ModuleTypeGetType(key) != VectorSetType)
650 {
651 RedisModule_Free(vec);
652 return RedisModule_ReplyWithError(ctx,REDISMODULE_ERRORMSG_WRONGTYPE);
653 }
654
655 /* Get the correct value argument based on format and REDUCE */
656 RedisModuleString *val = argv[2 + consumed_args];
657
658 /* Create or get existing vector set */
659 struct vsetObject *vset;
660 if (type == REDISMODULE_KEYTYPE_EMPTY) {
661 cas = 0; /* Do synchronous insert at creation, otherwise the
662 * key would be left empty until the threaded part
663 * does not return. It's also pointless to try try
664 * doing threaded first element insertion. */
665 vset = createVectorSetObject(reduce_dim ? reduce_dim : dim, quant_type, hnsw_create_M);
666 if (vset == NULL) {
667 // We can't fail for OOM in Redis, but the mutex initialization
668 // at least theoretically COULD fail. Likely this code path
669 // is not reachable in practical terms.
670 RedisModule_Free(vec);
671 return RedisModule_ReplyWithError(ctx,
672 "ERR unable to create a Vector Set: system resources issue?");
673 }
674
675 /* Initialize projection if requested */
676 if (reduce_dim) {
677 vset->proj_matrix = createProjectionMatrix(dim, reduce_dim);
678 vset->proj_input_size = dim;
679
680 /* Project the vector */
681 float *projected = applyProjection(vec, vset->proj_matrix,
682 dim, reduce_dim);
683 RedisModule_Free(vec);
684 vec = projected;
685 }
686 RedisModule_ModuleTypeSetValue(key,VectorSetType,vset);
687 } else {
688 vset = RedisModule_ModuleTypeGetValue(key);
689
690 if (vset->hnsw->quant_type != quant_type) {
691 RedisModule_Free(vec);
692 return RedisModule_ReplyWithError(ctx,
693 "ERR asked quantization mismatch with existing vector set");
694 }
695
696 if (vset->hnsw->M != hnsw_create_M) {
697 RedisModule_Free(vec);
698 return RedisModule_ReplyWithError(ctx,
699 "ERR asked M value mismatch with existing vector set");
700 }
701
702 if ((vset->proj_matrix == NULL && vset->hnsw->vector_dim != dim) ||
703 (vset->proj_matrix && vset->hnsw->vector_dim != reduce_dim))
704 {
705 RedisModule_Free(vec);
706 return RedisModule_ReplyWithErrorFormat(ctx,
707 "ERR Vector dimension mismatch - got %d but set has %d",
708 (int)dim, (int)vset->hnsw->vector_dim);
709 }
710
711 /* Check REDUCE compatibility */
712 if (reduce_dim) {
713 if (!vset->proj_matrix) {
714 RedisModule_Free(vec);
715 return RedisModule_ReplyWithError(ctx,
716 "ERR cannot add projection to existing set without projection");
717 }
718 if (reduce_dim != vset->hnsw->vector_dim) {
719 RedisModule_Free(vec);
720 return RedisModule_ReplyWithError(ctx,
721 "ERR projection dimension mismatch with existing set");
722 }
723 }
724
725 /* Apply projection if needed */
726 if (vset->proj_matrix) {
727 /* Ensure input dimension matches the projection matrix's expected input dimension */
728 if (dim != vset->proj_input_size) {
729 RedisModule_Free(vec);
730 return RedisModule_ReplyWithErrorFormat(ctx,
731 "ERR Input dimension mismatch for projection - got %d but projection expects %d",
732 (int)dim, (int)vset->proj_input_size);
733 }
734
735 float *projected = applyProjection(vec, vset->proj_matrix,
736 vset->proj_input_size,
737 vset->hnsw->vector_dim);
738 RedisModule_Free(vec);
739 vec = projected;
740 dim = vset->hnsw->vector_dim;
741 }
742 }
743
744 /* For existing keys don't do CAS updates. For how things work now, the
745 * CAS state would be invalidated by the deletion before adding back. */
746 if (cas && RedisModule_DictGet(vset->dict,val,NULL) != NULL)
747 cas = 0;
748
749 /* Here depending on the CAS option we directly insert in a blocking
750 * way, or use a thread to do candidate neighbors selection and only
751 * later, in the reply callback, actually add the element. */
752 if (cas) {
753 RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,VADD_CASReply,NULL,NULL,0);
754 pthread_t tid;
755 void **targ = RedisModule_Alloc(sizeof(void*)*8);
756 targ[0] = bc;
757 targ[1] = vset;
758 targ[2] = (void*)(unsigned long)vset->id;
759 targ[3] = vec;
760 targ[4] = val;
761 targ[5] = NULL; // Used later for insertion context.
762 targ[6] = (void*)(unsigned long)ef;
763 targ[7] = attrib;
764 RedisModule_RetainString(ctx,val);
765 if (attrib) RedisModule_RetainString(ctx,attrib);
766 RedisModule_BlockedClientMeasureTimeStart(bc);
767 vset->thread_creation_pending++;
768 if (pthread_create(&tid,NULL,VADD_thread,targ) != 0) {
769 vset->thread_creation_pending--;
770 RedisModule_AbortBlock(bc);
771 RedisModule_Free(targ);
772 RedisModule_FreeString(ctx,val);
773 if (attrib) RedisModule_FreeString(ctx,attrib);
774
775 // Fall back to synchronous insert, see later in the code.
776 } else {
777 return REDISMODULE_OK;
778 }
779 }
780
781 /* Insert vector synchronously: we reach this place even
782 * if cas was true but thread creation failed. */
783 int added = vectorSetInsert(vset,vec,NULL,0,val,attrib,1,ef);
784 RedisModule_Free(vec);
785
786 RedisModule_ReplyWithBool(ctx,added);
787 if (added) RedisModule_ReplicateVerbatim(ctx);
788 return REDISMODULE_OK;
789}
790
791/* HNSW callback to filter items according to a predicate function
792 * (our FILTER expression in this case). */
793int vectorSetFilterCallback(void *value, void *privdata) {
794 exprstate *expr = privdata;
795 struct vsetNodeVal *nv = value;
796 if (nv->attrib == NULL) return 0; // No attributes? No match.
797 size_t json_len;
798 char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len);
799 return exprRun(expr,json,json_len);
800}
801
802/* Common path for the execution of the VSIM command both threaded and
803 * not threaded. Note that 'ctx' may be normal context of a thread safe
804 * context obtained from a blocked client. The locking that is specific
805 * to the vset object is handled by the caller, however the function
806 * handles the HNSW locking explicitly. */
807void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
808 float *vec, unsigned long count, float epsilon, unsigned long withscores,
809 unsigned long withattribs, unsigned long ef, exprstate *filter_expr,
810 unsigned long filter_ef, int ground_truth)
811{
812 /* In our scan, we can't just collect 'count' elements as
813 * if count is small we would explore the graph in an insufficient
814 * way to provide enough recall.
815 *
816 * If the user didn't asked for a specific exploration, we use
817 * VSET_DEFAULT_SEARCH_EF as minimum, or we match count if count
818 * is greater than that. Otherwise the minumim will be the specified
819 * EF argument. */
820 if (ef == 0) ef = VSET_DEFAULT_SEARCH_EF;
821 if (count > ef) ef = count;
822
823 int slot = hnsw_acquire_read_slot(vset->hnsw);
824 if (ef > vset->hnsw->node_count) ef = vset->hnsw->node_count;
825
826 /* Perform search */
827 hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef);
828 float *distances = RedisModule_Alloc(sizeof(float)*ef);
829 unsigned int found;
830 if (ground_truth) {
831 found = hnsw_ground_truth_with_filter(vset->hnsw, vec, ef, neighbors,
832 distances, slot, 0,
833 filter_expr ? vectorSetFilterCallback : NULL,
834 filter_expr);
835 } else {
836 if (filter_expr == NULL) {
837 found = hnsw_search(vset->hnsw, vec, ef, neighbors,
838 distances, slot, 0);
839 } else {
840 found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors,
841 distances, slot, 0, vectorSetFilterCallback,
842 filter_expr, filter_ef);
843 }
844 }
845
846 /* Return results */
847 int resp3 = RedisModule_GetContextFlags(ctx) & REDISMODULE_CTX_FLAGS_RESP3;
848 int reply_with_map = resp3 && (withscores || withattribs);
849
850 if (reply_with_map)
851 RedisModule_ReplyWithMap(ctx, REDISMODULE_POSTPONED_LEN);
852 else
853 RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN);
854
855 long long arraylen = 0;
856 for (unsigned int i = 0; i < found && i < count; i++) {
857 if (distances[i]/2 > epsilon) break;
858 struct vsetNodeVal *nv = neighbors[i]->value;
859 RedisModule_ReplyWithString(ctx, nv->item);
860 arraylen++;
861
862 /* If the user asked for multiple properties at the same time using
863 * the RESP3 protocol, we wrap the value of the map into an N-items
864 * array. Two for now, since we have just two properties that can be
865 * requested.
866 *
867 * So in the case of RESP2 we will just have the flat reply:
868 * item, score, attribute. For RESP3 instead item -> [score, attribute]
869 */
870 if (resp3 && withscores && withattribs)
871 RedisModule_ReplyWithArray(ctx,2);
872
873 if (withscores) {
874 /* The similarity score is provided in a 0-1 range. */
875 RedisModule_ReplyWithDouble(ctx, 1.0 - distances[i]/2.0);
876 }
877 if (withattribs) {
878 /* Return the attributes as well, if any. */
879 if (nv->attrib)
880 RedisModule_ReplyWithString(ctx, nv->attrib);
881 else
882 RedisModule_ReplyWithNull(ctx);
883 }
884 }
885 hnsw_release_read_slot(vset->hnsw,slot);
886
887 if (reply_with_map) {
888 RedisModule_ReplySetMapLength(ctx, arraylen);
889 } else {
890 int items_per_ele = 1+withattribs+withscores;
891 RedisModule_ReplySetArrayLength(ctx, arraylen * items_per_ele);
892 }
893
894 RedisModule_Free(vec);
895 RedisModule_Free(neighbors);
896 RedisModule_Free(distances);
897 if (filter_expr) exprFree(filter_expr);
898}
899
900/* VSIM thread handling the blocked client request. */
901void *VSIM_thread(void *arg) {
902 pthread_detach(pthread_self());
903
904 // Extract arguments.
905 void **targ = (void**)arg;
906 RedisModuleBlockedClient *bc = targ[0];
907 struct vsetObject *vset = targ[1];
908 float *vec = targ[2];
909 unsigned long count = (unsigned long)targ[3];
910 float epsilon = *((float*)targ[4]);
911 unsigned long withscores = (unsigned long)targ[5];
912 unsigned long withattribs = (unsigned long)targ[6];
913 unsigned long ef = (unsigned long)targ[7];
914 exprstate *filter_expr = targ[8];
915 unsigned long filter_ef = (unsigned long)targ[9];
916 unsigned long ground_truth = (unsigned long)targ[10];
917 RedisModule_Free(targ[4]);
918 RedisModule_Free(targ);
919
920 /* Lock the object and signal that we are no longer pending
921 * the lock acquisition. */
922 RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0);
923 vset->thread_creation_pending--;
924
925 // Accumulate reply in a thread safe context: no contention.
926 RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc);
927
928 // Run the query.
929 VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth);
930 pthread_rwlock_unlock(&vset->in_use_lock);
931
932 // Cleanup.
933 RedisModule_FreeThreadSafeContext(ctx);
934 RedisModule_BlockedClientMeasureTimeEnd(bc);
935 RedisModule_UnblockClient(bc,NULL);
936 return NULL;
937}
938
939/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */
940int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
941 RedisModule_AutoMemory(ctx);
942
943 /* Basic argument check: need at least key and vector specification
944 * method. */
945 if (argc < 4) return RedisModule_WrongArity(ctx);
946
947 /* Defaults */
948 int withscores = 0;
949 int withattribs = 0;
950 long long count = VSET_DEFAULT_COUNT; /* New default value */
951 long long ef = 0; /* Exploration factor (see HNSW paper) */
952 double epsilon = 2.0; /* Max cosine distance */
953 long long ground_truth = 0; /* Linear scan instead of HNSW search? */
954 int no_thread = 0; /* NOTHREAD option: exec on main thread. */
955
956 /* Things computed later. */
957 long long filter_ef = 0;
958 exprstate *filter_expr = NULL;
959
960 /* Get key and vector type */
961 RedisModuleString *key = argv[1];
962 const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL);
963
964 /* Get vector set */
965 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ);
966 int type = RedisModule_KeyType(keyptr);
967 if (type == REDISMODULE_KEYTYPE_EMPTY)
968 return RedisModule_ReplyWithEmptyArray(ctx);
969
970 if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType)
971 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
972
973 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
974
975 /* Vector parsing stage */
976 float *vec = NULL;
977 size_t dim = 0;
978 int vector_args = 0; /* Number of args consumed by vector specification */
979
980 if (!strcasecmp(vectorType, "ELE")) {
981 /* Get vector from existing element */
982 RedisModuleString *ele = argv[3];
983 hnswNode *node = RedisModule_DictGet(vset->dict, ele, NULL);
984 if (!node) {
985 return RedisModule_ReplyWithError(ctx, "ERR element not found in set");
986 }
987 vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim);
988 hnsw_get_node_vector(vset->hnsw,node,vec);
989 dim = vset->hnsw->vector_dim;
990 vector_args = 2; /* ELE + element name */
991 } else {
992 /* Parse vector. */
993 int consumed_args;
994
995 vec = parseVector(argv, argc, 2, &dim, NULL, &consumed_args);
996 if (!vec) {
997 return RedisModule_ReplyWithError(ctx,
998 "ERR invalid vector specification");
999 }
1000 vector_args = consumed_args;
1001
1002 /* Apply projection if the set uses it, with the exception
1003 * of ELE type, that will already have the right dimension. */
1004 if (vset->proj_matrix && dim != vset->hnsw->vector_dim) {
1005 /* Ensure input dimension matches the projection matrix's expected input dimension */
1006 if (dim != vset->proj_input_size) {
1007 RedisModule_Free(vec);
1008 return RedisModule_ReplyWithErrorFormat(ctx,
1009 "ERR Input dimension mismatch for projection - got %d but projection expects %d",
1010 (int)dim, (int)vset->proj_input_size);
1011 }
1012
1013 float *projected = applyProjection(vec, vset->proj_matrix,
1014 vset->proj_input_size,
1015 vset->hnsw->vector_dim);
1016 RedisModule_Free(vec);
1017 vec = projected;
1018 dim = vset->hnsw->vector_dim;
1019 }
1020
1021 /* Count consumed arguments */
1022 if (!strcasecmp(vectorType, "FP32")) {
1023 vector_args = 2; /* FP32 + vector blob */
1024 } else if (!strcasecmp(vectorType, "VALUES")) {
1025 long long vdim;
1026 if (RedisModule_StringToLongLong(argv[3], &vdim) != REDISMODULE_OK) {
1027 RedisModule_Free(vec);
1028 return RedisModule_ReplyWithError(ctx, "ERR invalid vector dimension");
1029 }
1030 vector_args = 2 + vdim; /* VALUES + dim + values */
1031 } else {
1032 RedisModule_Free(vec);
1033 return RedisModule_ReplyWithError(ctx,
1034 "ERR vector type must be ELE, FP32 or VALUES");
1035 }
1036 }
1037
1038 /* Check vector dimension matches set */
1039 if (dim != vset->hnsw->vector_dim) {
1040 RedisModule_Free(vec);
1041 return RedisModule_ReplyWithErrorFormat(ctx,
1042 "ERR Vector dimension mismatch - got %d but set has %d",
1043 (int)dim, (int)vset->hnsw->vector_dim);
1044 }
1045
1046 /* Parse optional arguments - start after vector specification */
1047 int j = 2 + vector_args;
1048 while (j < argc) {
1049 const char *opt = RedisModule_StringPtrLen(argv[j], NULL);
1050 if (!strcasecmp(opt, "WITHSCORES")) {
1051 withscores = 1;
1052 j++;
1053 } else if (!strcasecmp(opt, "WITHATTRIBS")) {
1054 withattribs = 1;
1055 j++;
1056 } else if (!strcasecmp(opt, "TRUTH")) {
1057 ground_truth = 1;
1058 j++;
1059 } else if (!strcasecmp(opt, "NOTHREAD")) {
1060 no_thread = 1;
1061 j++;
1062 } else if (!strcasecmp(opt, "COUNT") && j+1 < argc) {
1063 if (RedisModule_StringToLongLong(argv[j+1], &count)
1064 != REDISMODULE_OK || count <= 0)
1065 {
1066 RedisModule_Free(vec);
1067 return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT");
1068 }
1069 j += 2;
1070 } else if (!strcasecmp(opt, "EPSILON") && j+1 < argc) {
1071 if (RedisModule_StringToDouble(argv[j+1], &epsilon) !=
1072 REDISMODULE_OK || epsilon <= 0)
1073 {
1074 RedisModule_Free(vec);
1075 return RedisModule_ReplyWithError(ctx, "ERR invalid EPSILON");
1076 }
1077 j += 2;
1078 } else if (!strcasecmp(opt, "EF") && j+1 < argc) {
1079 if (RedisModule_StringToLongLong(argv[j+1], &ef) !=
1080 REDISMODULE_OK || ef <= 0 || ef > 1000000)
1081 {
1082 RedisModule_Free(vec);
1083 return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
1084 }
1085 j += 2;
1086 } else if (!strcasecmp(opt, "FILTER-EF") && j+1 < argc) {
1087 if (RedisModule_StringToLongLong(argv[j+1], &filter_ef) !=
1088 REDISMODULE_OK || filter_ef <= 0)
1089 {
1090 RedisModule_Free(vec);
1091 return RedisModule_ReplyWithError(ctx, "ERR invalid FILTER-EF");
1092 }
1093 j += 2;
1094 } else if (!strcasecmp(opt, "FILTER") && j+1 < argc) {
1095 RedisModuleString *exprarg = argv[j+1];
1096 size_t exprlen;
1097 char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen);
1098 int errpos;
1099 filter_expr = exprCompile(exprstr,&errpos);
1100 if (filter_expr == NULL) {
1101 if ((size_t)errpos >= exprlen) errpos = 0;
1102 RedisModule_Free(vec);
1103 return RedisModule_ReplyWithErrorFormat(ctx,
1104 "ERR syntax error in FILTER expression near: %s",
1105 exprstr+errpos);
1106 }
1107 j += 2;
1108 } else {
1109 RedisModule_Free(vec);
1110 return RedisModule_ReplyWithError(ctx,
1111 "ERR syntax error in VSIM command");
1112 }
1113 }
1114
1115 int threaded_request = 1; // Run on a thread, by default.
1116 if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes.
1117
1118 /* Disable threaded for MULTI/EXEC and Lua, or if explicitly
1119 * requested by the user via the NOTHREAD option. */
1120 if (no_thread || VSGlobalConfig.forceSingleThreadExec ||
1121 (RedisModule_GetContextFlags(ctx) &
1122 (REDISMODULE_CTX_FLAGS_LUA | REDISMODULE_CTX_FLAGS_MULTI)))
1123 {
1124 threaded_request = 0;
1125 }
1126
1127 if (threaded_request) {
1128 /* Note: even if we create one thread per request, the underlying
1129 * HNSW library has a fixed number of slots for the threads, as it's
1130 * defined in HNSW_MAX_THREADS (beware that if you increase it,
1131 * every node will use more memory). This means that while this request
1132 * is threaded, and will NOT block Redis, it may end waiting for a
1133 * free slot if all the HNSW_MAX_THREADS slots are used. */
1134 RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0);
1135 pthread_t tid;
1136 void **targ = RedisModule_Alloc(sizeof(void*)*11);
1137 targ[0] = bc;
1138 targ[1] = vset;
1139 targ[2] = vec;
1140 targ[3] = (void*)count;
1141 targ[4] = RedisModule_Alloc(sizeof(float));
1142 *((float*)targ[4]) = epsilon;
1143 targ[5] = (void*)(unsigned long)withscores;
1144 targ[6] = (void*)(unsigned long)withattribs;
1145 targ[7] = (void*)(unsigned long)ef;
1146 targ[8] = (void*)filter_expr;
1147 targ[9] = (void*)(unsigned long)filter_ef;
1148 targ[10] = (void*)(unsigned long)ground_truth;
1149 RedisModule_BlockedClientMeasureTimeStart(bc);
1150 vset->thread_creation_pending++;
1151 if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) {
1152 vset->thread_creation_pending--;
1153 RedisModule_AbortBlock(bc);
1154 RedisModule_Free(targ[4]);
1155 RedisModule_Free(targ);
1156 VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth);
1157 }
1158 } else {
1159 VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth);
1160 }
1161
1162 return REDISMODULE_OK;
1163}
1164
1165/* VDIM <key>: return the dimension of vectors in the vector set. */
1166int VDIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1167 RedisModule_AutoMemory(ctx);
1168
1169 if (argc != 2) return RedisModule_WrongArity(ctx);
1170
1171 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1172 int type = RedisModule_KeyType(key);
1173
1174 if (type == REDISMODULE_KEYTYPE_EMPTY)
1175 return RedisModule_ReplyWithError(ctx, "ERR key does not exist");
1176
1177 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1178 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1179
1180 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1181 return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim);
1182}
1183
1184/* VCARD <key>: return cardinality (num of elements) of the vector set. */
1185int VCARD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1186 RedisModule_AutoMemory(ctx);
1187
1188 if (argc != 2) return RedisModule_WrongArity(ctx);
1189
1190 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1191 int type = RedisModule_KeyType(key);
1192
1193 if (type == REDISMODULE_KEYTYPE_EMPTY)
1194 return RedisModule_ReplyWithLongLong(ctx, 0);
1195
1196 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1197 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1198
1199 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1200 return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count);
1201}
1202
1203/* VREM key element
1204 * Remove an element from a vector set.
1205 * Returns 1 if the element was found and removed, 0 if not found. */
1206int VREM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1207 RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
1208
1209 if (argc != 3) return RedisModule_WrongArity(ctx);
1210
1211 /* Get key and value */
1212 RedisModuleString *key = argv[1];
1213 RedisModuleString *element = argv[2];
1214
1215 /* Open key */
1216 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key,
1217 REDISMODULE_READ|REDISMODULE_WRITE);
1218 int type = RedisModule_KeyType(keyptr);
1219
1220 /* Handle non-existing key or wrong type */
1221 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1222 return RedisModule_ReplyWithBool(ctx, 0);
1223 }
1224 if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) {
1225 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1226 }
1227
1228 /* Get vector set from key */
1229 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
1230
1231 /* Find the node for this element */
1232 hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL);
1233 if (!node) {
1234 return RedisModule_ReplyWithBool(ctx, 0);
1235 }
1236
1237 /* Remove from dictionary */
1238 RedisModule_DictDel(vset->dict, element, NULL);
1239
1240 /* Remove from HNSW graph using the high-level API that handles
1241 * locking and cleanup. We pass RedisModule_FreeString as the value
1242 * free function since the strings were retained at insertion time. */
1243 struct vsetNodeVal *nv = node->value;
1244 if (nv->attrib != NULL) vset->numattribs--;
1245 RedisModule_Assert(hnsw_delete_node(vset->hnsw, node, vectorSetReleaseNodeValue) == 1);
1246
1247 /* Destroy empty vector set. */
1248 if (RedisModule_DictSize(vset->dict) == 0) {
1249 RedisModule_DeleteKey(keyptr);
1250 }
1251
1252 /* Reply and propagate the command */
1253 RedisModule_ReplyWithBool(ctx, 1);
1254 RedisModule_ReplicateVerbatim(ctx);
1255 return REDISMODULE_OK;
1256}
1257
1258/* VEMB key element
1259 * Returns the embedding vector associated with an element, or NIL if not
1260 * found. The vector is returned in the same format it was added, but the
1261 * return value will have some lack of precision due to quantization and
1262 * normalization of vectors. Also, if items were added using REDUCE, the
1263 * reduced vector is returned instead. */
1264int VEMB_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1265 RedisModule_AutoMemory(ctx);
1266 int raw_output = 0; // RAW option.
1267
1268 if (argc < 3) return RedisModule_WrongArity(ctx);
1269
1270 /* Parse arguments. */
1271 for (int j = 3; j < argc; j++) {
1272 const char *opt = RedisModule_StringPtrLen(argv[j], NULL);
1273 if (!strcasecmp(opt,"raw")) {
1274 raw_output = 1;
1275 } else {
1276 return RedisModule_ReplyWithError(ctx,"ERR invalid option");
1277 }
1278 }
1279
1280 /* Get key and element. */
1281 RedisModuleString *key = argv[1];
1282 RedisModuleString *element = argv[2];
1283
1284 /* Open key. */
1285 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ);
1286 int type = RedisModule_KeyType(keyptr);
1287
1288 /* Handle non-existing key and key of wrong type. */
1289 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1290 return RedisModule_ReplyWithNull(ctx);
1291 } else if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) {
1292 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1293 }
1294
1295 /* Lookup the node about the specified element. */
1296 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
1297 hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL);
1298 if (!node) {
1299 return RedisModule_ReplyWithNull(ctx);
1300 }
1301
1302 if (raw_output) {
1303 int output_qrange = vset->hnsw->quant_type == HNSW_QUANT_Q8;
1304 RedisModule_ReplyWithArray(ctx, 3+output_qrange);
1305 RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset));
1306 RedisModule_ReplyWithStringBuffer(ctx, node->vector, hnsw_quants_bytes(vset->hnsw));
1307 RedisModule_ReplyWithDouble(ctx, node->l2);
1308 if (output_qrange) RedisModule_ReplyWithDouble(ctx, node->quants_range);
1309 } else {
1310 /* Get the vector associated with the node. */
1311 float *vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim);
1312 hnsw_get_node_vector(vset->hnsw, node, vec); // May dequantize/denorm.
1313
1314 /* Return as array of doubles. */
1315 RedisModule_ReplyWithArray(ctx, vset->hnsw->vector_dim);
1316 for (uint32_t i = 0; i < vset->hnsw->vector_dim; i++)
1317 RedisModule_ReplyWithDouble(ctx, vec[i]);
1318 RedisModule_Free(vec);
1319 }
1320 return REDISMODULE_OK;
1321}
1322
1323/* VSETATTR key element json
1324 * Set or remove the JSON attribute associated with an element.
1325 * Setting an empty string removes the attribute.
1326 * The command returns one if the attribute was actually updated or
1327 * zero if there is no key or element. */
1328int VSETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1329 RedisModule_AutoMemory(ctx);
1330
1331 if (argc != 4) return RedisModule_WrongArity(ctx);
1332
1333 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1],
1334 REDISMODULE_READ|REDISMODULE_WRITE);
1335 int type = RedisModule_KeyType(key);
1336
1337 if (type == REDISMODULE_KEYTYPE_EMPTY)
1338 return RedisModule_ReplyWithBool(ctx, 0);
1339
1340 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1341 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1342
1343 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1344 hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL);
1345 if (!node)
1346 return RedisModule_ReplyWithBool(ctx, 0);
1347
1348 struct vsetNodeVal *nv = node->value;
1349 RedisModuleString *new_attr = argv[3];
1350
1351 /* Background VSIM operations use the node attributes, so
1352 * wait for background operations before messing with them. */
1353 vectorSetWaitAllBackgroundClients(vset,0);
1354
1355 /* Set or delete the attribute based on the fact it's an empty
1356 * string or not. */
1357 size_t attrlen;
1358 RedisModule_StringPtrLen(new_attr, &attrlen);
1359 if (attrlen == 0) {
1360 // If we had an attribute before, decrease the count and free it.
1361 if (nv->attrib) {
1362 vset->numattribs--;
1363 RedisModule_FreeString(NULL, nv->attrib);
1364 nv->attrib = NULL;
1365 }
1366 } else {
1367 // If we didn't have an attribute before, increase the count.
1368 // Otherwise free the old one.
1369 if (nv->attrib) {
1370 RedisModule_FreeString(NULL, nv->attrib);
1371 } else {
1372 vset->numattribs++;
1373 }
1374 // Set new attribute.
1375 RedisModule_RetainString(NULL, new_attr);
1376 nv->attrib = new_attr;
1377 }
1378
1379 RedisModule_ReplyWithBool(ctx, 1);
1380 RedisModule_ReplicateVerbatim(ctx);
1381 return REDISMODULE_OK;
1382}
1383
1384/* VGETATTR key element
1385 * Get the JSON attribute associated with an element.
1386 * Returns NIL if the element has no attribute or doesn't exist. */
1387int VGETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1388 RedisModule_AutoMemory(ctx);
1389
1390 if (argc != 3) return RedisModule_WrongArity(ctx);
1391
1392 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1393 int type = RedisModule_KeyType(key);
1394
1395 if (type == REDISMODULE_KEYTYPE_EMPTY)
1396 return RedisModule_ReplyWithNull(ctx);
1397
1398 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1399 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1400
1401 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1402 hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL);
1403 if (!node)
1404 return RedisModule_ReplyWithNull(ctx);
1405
1406 struct vsetNodeVal *nv = node->value;
1407 if (!nv->attrib)
1408 return RedisModule_ReplyWithNull(ctx);
1409
1410 return RedisModule_ReplyWithString(ctx, nv->attrib);
1411}
1412
1413/* ============================== Reflection ================================ */
1414
1415/* VLINKS key element [WITHSCORES]
1416 * Returns the neighbors of an element at each layer in the HNSW graph.
1417 * Reply is an array of arrays, where each nested array represents one level
1418 * of neighbors, from highest level to level 0. If WITHSCORES is specified,
1419 * each neighbor is followed by its distance from the element. */
1420int VLINKS_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1421 RedisModule_AutoMemory(ctx);
1422
1423 if (argc < 3 || argc > 4) return RedisModule_WrongArity(ctx);
1424
1425 RedisModuleString *key = argv[1];
1426 RedisModuleString *element = argv[2];
1427
1428 /* Parse WITHSCORES option. */
1429 int withscores = 0;
1430 if (argc == 4) {
1431 const char *opt = RedisModule_StringPtrLen(argv[3], NULL);
1432 if (strcasecmp(opt, "WITHSCORES") != 0) {
1433 return RedisModule_WrongArity(ctx);
1434 }
1435 withscores = 1;
1436 }
1437
1438 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ);
1439 int type = RedisModule_KeyType(keyptr);
1440
1441 /* Handle non-existing key or wrong type. */
1442 if (type == REDISMODULE_KEYTYPE_EMPTY)
1443 return RedisModule_ReplyWithNull(ctx);
1444
1445 if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType)
1446 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1447
1448 /* Find the node for this element. */
1449 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
1450 hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL);
1451 if (!node)
1452 return RedisModule_ReplyWithNull(ctx);
1453
1454 /* Reply with array of arrays, one per level. */
1455 RedisModule_ReplyWithArray(ctx, node->level + 1);
1456
1457 /* For each level, from highest to lowest: */
1458 for (int i = node->level; i >= 0; i--) {
1459 /* Reply with array of neighbors at this level. */
1460 if (withscores)
1461 RedisModule_ReplyWithMap(ctx,node->layers[i].num_links);
1462 else
1463 RedisModule_ReplyWithArray(ctx,node->layers[i].num_links);
1464
1465 /* Add each neighbor's element value to the array. */
1466 for (uint32_t j = 0; j < node->layers[i].num_links; j++) {
1467 struct vsetNodeVal *nv = node->layers[i].links[j]->value;
1468 RedisModule_ReplyWithString(ctx, nv->item);
1469 if (withscores) {
1470 float distance = hnsw_distance(vset->hnsw, node, node->layers[i].links[j]);
1471 /* Convert distance to similarity score to match
1472 * VSIM behavior.*/
1473 float similarity = 1.0 - distance/2.0;
1474 RedisModule_ReplyWithDouble(ctx, similarity);
1475 }
1476 }
1477 }
1478 return REDISMODULE_OK;
1479}
1480
1481/* VINFO key
1482 * Returns information about a vector set, both visible and hidden
1483 * features of the HNSW data structure. */
1484int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1485 RedisModule_AutoMemory(ctx);
1486
1487 if (argc != 2) return RedisModule_WrongArity(ctx);
1488
1489 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1490 int type = RedisModule_KeyType(key);
1491
1492 if (type == REDISMODULE_KEYTYPE_EMPTY)
1493 return RedisModule_ReplyWithNullArray(ctx);
1494
1495 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1496 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1497
1498 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1499
1500 /* Reply with hash */
1501 RedisModule_ReplyWithMap(ctx, 9);
1502
1503 /* Quantization type */
1504 RedisModule_ReplyWithSimpleString(ctx, "quant-type");
1505 RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset));
1506
1507 /* HNSW M value */
1508 RedisModule_ReplyWithSimpleString(ctx, "hnsw-m");
1509 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->M);
1510
1511 /* Vector dimensionality. */
1512 RedisModule_ReplyWithSimpleString(ctx, "vector-dim");
1513 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim);
1514
1515 /* Original input dimension before projection.
1516 * This is zero for vector sets without a random projection matrix. */
1517 RedisModule_ReplyWithSimpleString(ctx, "projection-input-dim");
1518 RedisModule_ReplyWithLongLong(ctx, vset->proj_input_size);
1519
1520 /* Number of elements. */
1521 RedisModule_ReplyWithSimpleString(ctx, "size");
1522 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count);
1523
1524 /* Max level of HNSW. */
1525 RedisModule_ReplyWithSimpleString(ctx, "max-level");
1526 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->max_level);
1527
1528 /* Number of nodes with attributes. */
1529 RedisModule_ReplyWithSimpleString(ctx, "attributes-count");
1530 RedisModule_ReplyWithLongLong(ctx, vset->numattribs);
1531
1532 /* Vector set ID. */
1533 RedisModule_ReplyWithSimpleString(ctx, "vset-uid");
1534 RedisModule_ReplyWithLongLong(ctx, vset->id);
1535
1536 /* HNSW max node ID. */
1537 RedisModule_ReplyWithSimpleString(ctx, "hnsw-max-node-uid");
1538 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->last_id);
1539
1540 return REDISMODULE_OK;
1541}
1542
1543/* VRANDMEMBER key [count]
1544 * Return random members from a vector set.
1545 *
1546 * Without count: returns a single random member.
1547 * With positive count: N unique random members (no duplicates).
1548 * With negative count: N random members (with possible duplicates).
1549 *
1550 * If the key doesn't exist, returns NULL if count is not given, or
1551 * an empty array if a count was given. */
1552int VRANDMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1553 RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
1554
1555 /* Check arguments. */
1556 if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx);
1557
1558 /* Parse optional count argument. */
1559 long long count = 1; /* Default is to return a single element. */
1560 int with_count = (argc == 3);
1561
1562 if (with_count) {
1563 if (RedisModule_StringToLongLong(argv[2], &count) != REDISMODULE_OK) {
1564 return RedisModule_ReplyWithError(ctx,
1565 "ERR COUNT value is not an integer");
1566 }
1567 /* Count = 0 is a special case, return empty array */
1568 if (count == 0) {
1569 return RedisModule_ReplyWithEmptyArray(ctx);
1570 }
1571 }
1572
1573 /* Open key. */
1574 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1575 int type = RedisModule_KeyType(key);
1576
1577 /* Handle non-existing key. */
1578 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1579 if (!with_count) {
1580 return RedisModule_ReplyWithNull(ctx);
1581 } else {
1582 return RedisModule_ReplyWithEmptyArray(ctx);
1583 }
1584 }
1585
1586 /* Check key type. */
1587 if (RedisModule_ModuleTypeGetType(key) != VectorSetType) {
1588 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1589 }
1590
1591 /* Get vector set from key. */
1592 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1593 uint64_t set_size = vset->hnsw->node_count;
1594
1595 /* No elements in the set? */
1596 if (set_size == 0) {
1597 if (!with_count) {
1598 return RedisModule_ReplyWithNull(ctx);
1599 } else {
1600 return RedisModule_ReplyWithEmptyArray(ctx);
1601 }
1602 }
1603
1604 /* Case 1: No count specified: return a single element. */
1605 if (!with_count) {
1606 hnswNode *random_node = hnsw_random_node(vset->hnsw, 0);
1607 if (random_node) {
1608 struct vsetNodeVal *nv = random_node->value;
1609 return RedisModule_ReplyWithString(ctx, nv->item);
1610 } else {
1611 return RedisModule_ReplyWithNull(ctx);
1612 }
1613 }
1614
1615 /* Case 2: COUNT option given, return an array of elements. */
1616 int allow_duplicates = (count < 0);
1617 long long abs_count = (count < 0) ? -count : count;
1618
1619 /* Cap the count to the set size if we are not allowing duplicates. */
1620 if (!allow_duplicates && abs_count > (long long)set_size)
1621 abs_count = set_size;
1622
1623 /* Prepare reply. */
1624 RedisModule_ReplyWithArray(ctx, abs_count);
1625
1626 if (allow_duplicates) {
1627 /* Simple case: With duplicates, just pick random nodes
1628 * abs_count times. */
1629 for (long long i = 0; i < abs_count; i++) {
1630 hnswNode *random_node = hnsw_random_node(vset->hnsw,0);
1631 struct vsetNodeVal *nv = random_node->value;
1632 RedisModule_ReplyWithString(ctx, nv->item);
1633 }
1634 } else {
1635 /* Case where count is positive: we need unique elements.
1636 * But, if the user asked for many elements, selecting so
1637 * many (> 20%) random nodes may be too expansive: we just start
1638 * from a random element and follow the next link.
1639 *
1640 * Otherwisem for the <= 20% case, a dictionary is used to
1641 * reject duplicates. */
1642 int use_dict = (abs_count <= set_size * 0.2);
1643
1644 if (use_dict) {
1645 RedisModuleDict *returned = RedisModule_CreateDict(ctx);
1646
1647 long long returned_count = 0;
1648 while (returned_count < abs_count) {
1649 hnswNode *random_node = hnsw_random_node(vset->hnsw, 0);
1650 struct vsetNodeVal *nv = random_node->value;
1651
1652 /* Check if we've already returned this element. */
1653 if (RedisModule_DictGet(returned, nv->item, NULL) == NULL) {
1654 /* Mark as returned and add to results. */
1655 RedisModule_DictSet(returned, nv->item, (void*)1);
1656 RedisModule_ReplyWithString(ctx, nv->item);
1657 returned_count++;
1658 }
1659 }
1660 RedisModule_FreeDict(ctx, returned);
1661 } else {
1662 /* For large samples, get a random starting node and walk
1663 * the list.
1664 *
1665 * IMPORTANT: doing so does not really generate random
1666 * elements: it's just a linear scan, but we have no choices.
1667 * If we generate too many random elements, more and more would
1668 * fail the check of being novel (not yet collected in the set
1669 * to return) if the % of elements to emit is too large, we would
1670 * spend too much CPU. */
1671 hnswNode *start_node = hnsw_random_node(vset->hnsw, 0);
1672 hnswNode *current = start_node;
1673
1674 long long returned_count = 0;
1675 while (returned_count < abs_count) {
1676 if (current == NULL) {
1677 /* Restart from head if we hit the end. */
1678 current = vset->hnsw->head;
1679 }
1680 struct vsetNodeVal *nv = current->value;
1681 RedisModule_ReplyWithString(ctx, nv->item);
1682 returned_count++;
1683 current = current->next;
1684 }
1685 }
1686 }
1687 return REDISMODULE_OK;
1688}
1689
1690/* VISMEMBER key element
1691 * Check if an element exists in a vector set.
1692 * Returns 1 if the element exists, 0 if not. */
1693int VISMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1694 RedisModule_AutoMemory(ctx);
1695 if (argc != 3) return RedisModule_WrongArity(ctx);
1696
1697 RedisModuleString *key = argv[1];
1698 RedisModuleString *element = argv[2];
1699
1700 /* Open key. */
1701 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ);
1702 int type = RedisModule_KeyType(keyptr);
1703
1704 /* Handle non-existing key or wrong type. */
1705 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1706 /* An element of a non existing key does not exist, like
1707 * SISMEMBER & similar. */
1708 return RedisModule_ReplyWithBool(ctx, 0);
1709 }
1710 if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) {
1711 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1712 }
1713
1714 /* Get the object and test membership via the dictionary in constant
1715 * time (assuming a member of average size). */
1716 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
1717 hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL);
1718 return RedisModule_ReplyWithBool(ctx, node != NULL);
1719}
1720
1721/* Structure to represent a range boundary. */
1722struct vsetRangeOp {
1723 int incl; /* 1 if inclusive ([), 0 if exclusive ((). */
1724 int min; /* 1 if this is "-" (minimum). */
1725 int max; /* 1 if this is "+" (maximum). */
1726 unsigned char *ele; /* The actual element, NULL if min/max. */
1727 size_t ele_len; /* Length of the element. */
1728};
1729
1730/* Parse a range specification like "[foo" or "(bar" or "-" or "+".
1731 * Returns 1 on success, 0 on error. */
1732int vsetParseRangeOp(RedisModuleString *arg, struct vsetRangeOp *op) {
1733 size_t len;
1734 const char *str = RedisModule_StringPtrLen(arg, &len);
1735
1736 if (len == 0) return 0;
1737
1738 /* Initialize the structure. */
1739 op->incl = 0;
1740 op->min = 0;
1741 op->max = 0;
1742 op->ele = NULL;
1743 op->ele_len = 0;
1744
1745 /* Check for special cases "-" and "+". */
1746 if (len == 1 && str[0] == '-') {
1747 op->min = 1;
1748 return 1;
1749 }
1750 if (len == 1 && str[0] == '+') {
1751 op->max = 1;
1752 return 1;
1753 }
1754
1755 /* Otherwise, must start with ( or [. */
1756 if (str[0] == '[') {
1757 op->incl = 1;
1758 } else if (str[0] == '(') {
1759 op->incl = 0;
1760 } else {
1761 return 0; /* Invalid format. */
1762 }
1763
1764 /* Extract the string part after the bracket. */
1765 if (len > 1) {
1766 op->ele = (unsigned char *)(str + 1);
1767 op->ele_len = len - 1;
1768 } else {
1769 return 0; /* Just a bracket with no string. */
1770 }
1771
1772 return 1;
1773}
1774
1775/* Check if the current element is within the range defined by the end operator.
1776 * Returns 1 if the element is within range, 0 if it has passed the end. */
1777int vsetIsElementInRange(const void *ele, size_t ele_len, struct vsetRangeOp *end_op) {
1778 /* If end is "+", element is always in range. */
1779 if (end_op->max) return 1;
1780
1781 /* Compare current element with end boundary. */
1782 size_t minlen = ele_len < end_op->ele_len ? ele_len : end_op->ele_len;
1783 int cmp = memcmp(ele, end_op->ele, minlen);
1784
1785 if (cmp == 0) {
1786 /* If equal up to minlen, shorter string is smaller. */
1787 if (ele_len < end_op->ele_len) {
1788 cmp = -1;
1789 } else if (ele_len > end_op->ele_len) {
1790 cmp = 1;
1791 }
1792 }
1793
1794 /* Check based on inclusive/exclusive. */
1795 if (end_op->incl) {
1796 return cmp <= 0; /* Inclusive: element <= end. */
1797 } else {
1798 return cmp < 0; /* Exclusive: element < end. */
1799 }
1800}
1801
1802/* VRANGE key start end [count]
1803 * Returns elements in the lexicographical range [start, end]
1804 *
1805 * Elements must be specified in one of the following forms:
1806 *
1807 * [myelement
1808 * (myelement
1809 * +
1810 * -
1811 *
1812 * Elements starting with [ are inclusive, so "myelement" would be
1813 * returned if present in the set. Elements starting with ( are exclusive
1814 * ranges instead. The special - and + elements mean the minimum and maximum
1815 * possible element (inclusive), so "VRANGE key - +" will return everything
1816 * (depending on COUNT of course). The special - element can be used only
1817 * as starting element, the special + element only as ending element. */
1818int VRANGE_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1819 RedisModule_AutoMemory(ctx);
1820
1821 /* Check arguments. */
1822 if (argc < 4 || argc > 5) return RedisModule_WrongArity(ctx);
1823
1824 /* Parse COUNT if provided. */
1825 long long count = -1; /* Default: return all elements. */
1826 if (argc == 5) {
1827 if (RedisModule_StringToLongLong(argv[4], &count) != REDISMODULE_OK) {
1828 return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT value");
1829 }
1830 }
1831
1832 /* Parse range operators. */
1833 struct vsetRangeOp start_op, end_op;
1834 if (!vsetParseRangeOp(argv[2], &start_op)) {
1835 return RedisModule_ReplyWithError(ctx, "ERR invalid start range format");
1836 }
1837 if (!vsetParseRangeOp(argv[3], &end_op)) {
1838 return RedisModule_ReplyWithError(ctx, "ERR invalid end range format");
1839 }
1840
1841 /* Validate: "-" can only be first arg, "+" can only be second. */
1842 if (start_op.max || end_op.min) {
1843 return RedisModule_ReplyWithError(ctx,
1844 "ERR '-' can only be used as first argument, '+' only as second");
1845 }
1846
1847 /* Open the key. */
1848 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1849 int type = RedisModule_KeyType(key);
1850
1851 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1852 return RedisModule_ReplyWithEmptyArray(ctx);
1853 }
1854
1855 if (RedisModule_ModuleTypeGetType(key) != VectorSetType) {
1856 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1857 }
1858
1859 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1860
1861 /* Start the iterator. */
1862 RedisModuleDictIter *iter;
1863 if (start_op.min) {
1864 /* Start from the beginning. */
1865 iter = RedisModule_DictIteratorStartC(vset->dict, "^", NULL, 0);
1866 } else {
1867 /* Start from the specified element. */
1868 const char *op = start_op.incl ? ">=" : ">";
1869 iter = RedisModule_DictIteratorStartC(vset->dict, op, start_op.ele, start_op.ele_len);
1870 }
1871
1872 /* Collect results. */
1873 RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN);
1874 long long returned = 0;
1875
1876 void *key_data;
1877 size_t key_len;
1878 while ((key_data = RedisModule_DictNextC(iter, &key_len, NULL)) != NULL) {
1879 /* Check if we've collected enough elements. */
1880 if (count >= 0 && returned >= count) break;
1881
1882 /* Check if we've passed the end range. */
1883 if (!vsetIsElementInRange(key_data, key_len, &end_op)) break;
1884
1885 /* Add this element to the result. */
1886 RedisModule_ReplyWithStringBuffer(ctx, key_data, key_len);
1887 returned++;
1888 }
1889
1890 RedisModule_ReplySetArrayLength(ctx, returned);
1891
1892 /* Cleanup. */
1893 RedisModule_DictIteratorStop(iter);
1894
1895 return REDISMODULE_OK;
1896}
1897
1898/* ============================== vset type methods ========================= */
1899
1900#define SAVE_FLAG_HAS_PROJMATRIX (1<<0)
1901#define SAVE_FLAG_HAS_ATTRIBS (1<<1)
1902
1903/* Save object to RDB */
1904void VectorSetRdbSave(RedisModuleIO *rdb, void *value) {
1905 struct vsetObject *vset = value;
1906 RedisModule_SaveUnsigned(rdb, vset->hnsw->vector_dim);
1907 RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count);
1908
1909 uint32_t hnsw_config = (vset->hnsw->quant_type & 0xff) |
1910 ((vset->hnsw->M & 0xffff) << 8);
1911 RedisModule_SaveUnsigned(rdb, hnsw_config);
1912
1913 uint32_t save_flags = 0;
1914 if (vset->proj_matrix) save_flags |= SAVE_FLAG_HAS_PROJMATRIX;
1915 if (vset->numattribs != 0) save_flags |= SAVE_FLAG_HAS_ATTRIBS;
1916 RedisModule_SaveUnsigned(rdb, save_flags);
1917
1918 /* Save projection matrix if present */
1919 if (vset->proj_matrix) {
1920 uint32_t input_dim = vset->proj_input_size;
1921 uint32_t output_dim = vset->hnsw->vector_dim;
1922 RedisModule_SaveUnsigned(rdb, input_dim);
1923 // Output dim is the same as the first value saved
1924 // above, so we don't save it.
1925
1926 // Save projection matrix as binary blob
1927 size_t matrix_size = sizeof(float) * input_dim * output_dim;
1928 RedisModule_SaveStringBuffer(rdb, (const char *)vset->proj_matrix, matrix_size);
1929 }
1930
1931 hnswNode *node = vset->hnsw->head;
1932 while(node) {
1933 struct vsetNodeVal *nv = node->value;
1934 RedisModule_SaveString(rdb, nv->item);
1935 if (vset->numattribs) {
1936 if (nv->attrib)
1937 RedisModule_SaveString(rdb, nv->attrib);
1938 else
1939 RedisModule_SaveStringBuffer(rdb, "", 0);
1940 }
1941 hnswSerNode *sn = hnsw_serialize_node(vset->hnsw,node);
1942 RedisModule_SaveStringBuffer(rdb, (const char *)sn->vector, sn->vector_size);
1943 RedisModule_SaveUnsigned(rdb, sn->params_count);
1944 for (uint32_t j = 0; j < sn->params_count; j++)
1945 RedisModule_SaveUnsigned(rdb, sn->params[j]);
1946 hnsw_free_serialized_node(sn);
1947 node = node->next;
1948 }
1949}
1950
1951/* Load object from RDB. Recover from recoverable errors (read errors)
1952 * by performing cleanup. */
1953void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) {
1954 if (encver != 0) return NULL; // Invalid version
1955
1956 uint32_t dim = RedisModule_LoadUnsigned(rdb);
1957 uint64_t elements = RedisModule_LoadUnsigned(rdb);
1958 uint32_t hnsw_config = RedisModule_LoadUnsigned(rdb);
1959 if (RedisModule_IsIOError(rdb)) return NULL;
1960 uint32_t quant_type = hnsw_config & 0xff;
1961 uint32_t hnsw_m = (hnsw_config >> 8) & 0xffff;
1962
1963 /* Check that the quantization type is correct. Otherwise
1964 * return ASAP signaling the error. */
1965 if (quant_type != HNSW_QUANT_NONE &&
1966 quant_type != HNSW_QUANT_Q8 &&
1967 quant_type != HNSW_QUANT_BIN) return NULL;
1968
1969 if (hnsw_m == 0) hnsw_m = 16; // Default, useful for RDB files predating
1970 // this configuration parameter: it was fixed
1971 // to 16.
1972 struct vsetObject *vset = createVectorSetObject(dim,quant_type,hnsw_m);
1973 RedisModule_Assert(vset != NULL);
1974
1975 /* Load projection matrix if present */
1976 uint32_t save_flags = RedisModule_LoadUnsigned(rdb);
1977 if (RedisModule_IsIOError(rdb)) goto ioerr;
1978 int has_projection = save_flags & SAVE_FLAG_HAS_PROJMATRIX;
1979 int has_attribs = save_flags & SAVE_FLAG_HAS_ATTRIBS;
1980 if (has_projection) {
1981 uint32_t input_dim = RedisModule_LoadUnsigned(rdb);
1982 if (RedisModule_IsIOError(rdb)) goto ioerr;
1983 uint32_t output_dim = dim;
1984 size_t matrix_size = sizeof(float) * input_dim * output_dim;
1985
1986 vset->proj_matrix = RedisModule_Alloc(matrix_size);
1987 vset->proj_input_size = input_dim;
1988
1989 // Load projection matrix as a binary blob
1990 char *matrix_blob = RedisModule_LoadStringBuffer(rdb, NULL);
1991 if (matrix_blob == NULL) goto ioerr;
1992 memcpy(vset->proj_matrix, matrix_blob, matrix_size);
1993 RedisModule_Free(matrix_blob);
1994 }
1995
1996 while(elements--) {
1997 // Load associated string element.
1998 RedisModuleString *ele = RedisModule_LoadString(rdb);
1999 if (RedisModule_IsIOError(rdb)) goto ioerr;
2000 RedisModuleString *attrib = NULL;
2001 if (has_attribs) {
2002 attrib = RedisModule_LoadString(rdb);
2003 if (RedisModule_IsIOError(rdb)) {
2004 RedisModule_FreeString(NULL,ele);
2005 goto ioerr;
2006 }
2007 size_t attrlen;
2008 RedisModule_StringPtrLen(attrib,&attrlen);
2009 if (attrlen == 0) {
2010 RedisModule_FreeString(NULL,attrib);
2011 attrib = NULL;
2012 }
2013 }
2014 size_t vector_len;
2015 void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len);
2016 if (RedisModule_IsIOError(rdb)) {
2017 RedisModule_FreeString(NULL,ele);
2018 if (attrib) RedisModule_FreeString(NULL,attrib);
2019 goto ioerr;
2020 }
2021 uint32_t vector_bytes = hnsw_quants_bytes(vset->hnsw);
2022 if (vector_len != vector_bytes) {
2023 RedisModule_LogIOError(rdb,"warning",
2024 "Mismatching vector dimension");
2025 RedisModule_FreeString(NULL,ele);
2026 if (attrib) RedisModule_FreeString(NULL,attrib);
2027 RedisModule_Free(vector);
2028 goto ioerr;
2029 }
2030
2031 // Load node parameters back.
2032 uint32_t params_count = RedisModule_LoadUnsigned(rdb);
2033 if (RedisModule_IsIOError(rdb)) {
2034 RedisModule_FreeString(NULL,ele);
2035 if (attrib) RedisModule_FreeString(NULL,attrib);
2036 RedisModule_Free(vector);
2037 goto ioerr;
2038 }
2039
2040 uint64_t *params = RedisModule_Alloc(params_count*sizeof(uint64_t));
2041 for (uint32_t j = 0; j < params_count; j++) {
2042 // Ignore loading errors here: handled at the end of the loop.
2043 params[j] = RedisModule_LoadUnsigned(rdb);
2044 }
2045 if (RedisModule_IsIOError(rdb)) {
2046 RedisModule_FreeString(NULL,ele);
2047 if (attrib) RedisModule_FreeString(NULL,attrib);
2048 RedisModule_Free(vector);
2049 RedisModule_Free(params);
2050 goto ioerr;
2051 }
2052
2053 struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
2054 nv->item = ele;
2055 nv->attrib = attrib;
2056 hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, nv);
2057 if (node == NULL) {
2058 RedisModule_LogIOError(rdb,"warning",
2059 "Vector set node index loading error");
2060 vectorSetReleaseNodeValue(nv);
2061 RedisModule_Free(vector);
2062 RedisModule_Free(params);
2063 goto ioerr;
2064 }
2065 if (nv->attrib) vset->numattribs++;
2066 RedisModule_DictSet(vset->dict,ele,node);
2067 RedisModule_Free(vector);
2068 RedisModule_Free(params);
2069 }
2070
2071 uint64_t salt[2];
2072 RedisModule_GetRandomBytes((unsigned char*)salt,sizeof(salt));
2073 if (!hnsw_deserialize_index(vset->hnsw, salt[0], salt[1])) goto ioerr;
2074
2075 return vset;
2076
2077ioerr:
2078 /* We want to recover from I/O errors and free the partially allocated
2079 * data structure to support diskless replication. */
2080 vectorSetReleaseObject(vset);
2081 return NULL;
2082}
2083
2084/* Calculate memory usage */
2085size_t VectorSetMemUsage(const void *value) {
2086 const struct vsetObject *vset = value;
2087 size_t size = sizeof(*vset);
2088
2089 /* Account for HNSW index base structure */
2090 size += sizeof(HNSW);
2091
2092 /* Account for projection matrix if present */
2093 if (vset->proj_matrix) {
2094 /* For the matrix size, we need the input dimension. We can get it
2095 * from the first node if the set is not empty. */
2096 uint32_t input_dim = vset->proj_input_size;
2097 uint32_t output_dim = vset->hnsw->vector_dim;
2098 size += sizeof(float) * input_dim * output_dim;
2099 }
2100
2101 /* Account for each node's memory usage. */
2102 hnswNode *node = vset->hnsw->head;
2103 if (node == NULL) return size;
2104
2105 /* Base node structure. */
2106 size += sizeof(*node) * vset->hnsw->node_count;
2107
2108 /* Vector storage. */
2109 uint64_t vec_storage = hnsw_quants_bytes(vset->hnsw);
2110 size += vec_storage * vset->hnsw->node_count;
2111
2112 /* Layers array. We use 1.33 as average nodes layers count. */
2113 uint64_t layers_storage = sizeof(hnswNodeLayer) * vset->hnsw->node_count;
2114 layers_storage = layers_storage * 4 / 3; // 1.33 times.
2115 size += layers_storage;
2116
2117 /* All the nodes have layer 0 links. */
2118 uint64_t level0_links = node->layers[0].max_links;
2119 uint64_t other_levels_links = level0_links/2;
2120 size += sizeof(hnswNode*) * level0_links * vset->hnsw->node_count;
2121
2122 /* Add the 0.33 remaining part, but upper layers have less links. */
2123 size += (sizeof(hnswNode*) * other_levels_links * vset->hnsw->node_count)/3;
2124
2125 /* Associated string value and attributres.
2126 * Use Redis Module API to get string size, and guess that all the
2127 * elements have similar size as the first few. */
2128 size_t items_scanned = 0, items_size = 0;
2129 size_t attribs_scanned = 0, attribs_size = 0;
2130 int scan_effort = 20;
2131 while(scan_effort > 0 && node) {
2132 struct vsetNodeVal *nv = node->value;
2133 items_size += RedisModule_MallocSizeString(nv->item);
2134 items_scanned++;
2135 if (nv->attrib) {
2136 attribs_size += RedisModule_MallocSizeString(nv->attrib);
2137 attribs_scanned++;
2138 }
2139 scan_effort--;
2140 node = node->next;
2141 }
2142
2143 /* Add the memory usage due to items. */
2144 if (items_scanned)
2145 size += items_size / items_scanned * vset->hnsw->node_count;
2146
2147 /* Add memory usage due to attributres. */
2148 if (attribs_scanned == 0) {
2149 /* We were not lucky enough to find a single attribute in the
2150 * first few items? Let's use a fixed arbitrary value. */
2151 attribs_scanned = 1;
2152 attribs_size = 64;
2153 }
2154 size += attribs_size / attribs_scanned * vset->numattribs;
2155
2156 /* Account for dictionary overhead - this is an approximation. */
2157 size += RedisModule_DictSize(vset->dict) * (sizeof(void*) * 2);
2158
2159 return size;
2160}
2161
2162/* Free the entire data structure */
2163void VectorSetFree(void *value) {
2164 struct vsetObject *vset = value;
2165
2166 vectorSetWaitAllBackgroundClients(vset,1);
2167 vectorSetReleaseObject(value);
2168}
2169
2170/* Add object digest to the digest context */
2171void VectorSetDigest(RedisModuleDigest *md, void *value) {
2172 struct vsetObject *vset = value;
2173
2174 /* Add consistent order-independent hash of all vectors */
2175 hnswNode *node = vset->hnsw->head;
2176
2177 /* Hash the vector dimension and number of nodes. */
2178 RedisModule_DigestAddLongLong(md, vset->hnsw->node_count);
2179 RedisModule_DigestAddLongLong(md, vset->hnsw->vector_dim);
2180 RedisModule_DigestEndSequence(md);
2181
2182 while(node) {
2183 struct vsetNodeVal *nv = node->value;
2184 /* Hash each vector component */
2185 RedisModule_DigestAddStringBuffer(md, node->vector, hnsw_quants_bytes(vset->hnsw));
2186 /* Hash the associated value */
2187 size_t len;
2188 const char *str = RedisModule_StringPtrLen(nv->item, &len);
2189 RedisModule_DigestAddStringBuffer(md, (char*)str, len);
2190 if (nv->attrib) {
2191 str = RedisModule_StringPtrLen(nv->attrib, &len);
2192 RedisModule_DigestAddStringBuffer(md, (char*)str, len);
2193 }
2194 node = node->next;
2195 RedisModule_DigestEndSequence(md);
2196 }
2197}
2198
2199// int VectorSets_InitModuleConfig(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
2200int VectorSets_InitModuleConfig(RedisModuleCtx *ctx) {
2201 if (RegisterModuleConfig(ctx) == REDISMODULE_ERR) {
2202 RedisModule_Log(ctx, "warning", "Error registering module configuration");
2203 return REDISMODULE_ERR;
2204 }
2205 // Load default values
2206 if (RedisModule_LoadDefaultConfigs(ctx) == REDISMODULE_ERR) {
2207 RedisModule_Log(ctx, "warning", "Error loading default module configuration");
2208 return REDISMODULE_ERR;
2209 } else {
2210 RedisModule_Log(ctx, "verbose", "Successfully loaded default module configuration");
2211 }
2212 if (RedisModule_LoadConfigs(ctx) == REDISMODULE_ERR) {
2213 RedisModule_Log(ctx, "warning", "Error loading user module configuration");
2214 return REDISMODULE_ERR;
2215 } else {
2216 RedisModule_Log(ctx, "verbose", "Successfully loaded user module configuration");
2217 }
2218 return REDISMODULE_OK;
2219}
2220
2221/* This function must be present on each Redis module. It is used in order to
2222 * register the commands into the Redis server. */
2223int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
2224 REDISMODULE_NOT_USED(argv);
2225 REDISMODULE_NOT_USED(argc);
2226
2227 if (RedisModule_Init(ctx,"vectorset",1,REDISMODULE_APIVER_1)
2228 == REDISMODULE_ERR) return REDISMODULE_ERR;
2229
2230 if (VectorSets_InitModuleConfig(ctx) == REDISMODULE_ERR) {
2231 return REDISMODULE_ERR;
2232 }
2233
2234 RedisModule_SetModuleOptions(ctx, REDISMODULE_OPTIONS_HANDLE_IO_ERRORS|REDISMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD);
2235
2236 RedisModuleTypeMethods tm = {
2237 .version = REDISMODULE_TYPE_METHOD_VERSION,
2238 .rdb_load = VectorSetRdbLoad,
2239 .rdb_save = VectorSetRdbSave,
2240 .aof_rewrite = NULL,
2241 .mem_usage = VectorSetMemUsage,
2242 .free = VectorSetFree,
2243 .digest = VectorSetDigest
2244 };
2245
2246 VectorSetType = RedisModule_CreateDataType(ctx,"vectorset",0,&tm);
2247 if (VectorSetType == NULL) return REDISMODULE_ERR;
2248
2249 // Register command VADD
2250 if (RedisModule_CreateCommand(ctx,"VADD",
2251 VADD_RedisCommand,"write deny-oom",1,1,1) == REDISMODULE_ERR)
2252 return REDISMODULE_ERR;
2253
2254 RedisModuleCommand *vadd_cmd = RedisModule_GetCommand(ctx, "VADD");
2255 if (vadd_cmd == NULL) return REDISMODULE_ERR;
2256
2257 RedisModuleCommandArg vadd_args[] = {
2258 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2259 { .name = "reduce", .type = REDISMODULE_ARG_TYPE_BLOCK, .token = "REDUCE", .flags = REDISMODULE_CMD_ARG_OPTIONAL,
2260 .subargs = (RedisModuleCommandArg[]) {
2261 { .name = "dim", .type = REDISMODULE_ARG_TYPE_INTEGER },
2262 { .name = NULL }
2263 }
2264 },
2265 { .name = "format", .type = REDISMODULE_ARG_TYPE_ONEOF, .subargs = (RedisModuleCommandArg[]) {
2266 { .name = "fp32", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "FP32" },
2267 { .name = "values", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "VALUES" },
2268 { .name = NULL }
2269 }
2270 },
2271 { .name = "vector", .type = REDISMODULE_ARG_TYPE_STRING },
2272 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2273 { .name = "cas", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "CAS", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2274 { .name = "quant_type", .type = REDISMODULE_ARG_TYPE_ONEOF, .flags = REDISMODULE_CMD_ARG_OPTIONAL, .subargs = (RedisModuleCommandArg[]) {
2275 { .name = "noquant", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "NOQUANT" },
2276 { .name = "bin", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "BIN" },
2277 { .name = "q8", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "Q8" },
2278 { .name = NULL }
2279 }
2280 },
2281 { .name = "build-exploration-factor", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2282 { .name = "attributes", .type = REDISMODULE_ARG_TYPE_STRING, .token = "SETATTR", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2283 { .name = "numlinks", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "M", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2284 { .name = NULL }
2285 };
2286 RedisModuleCommandInfo vadd_info = {
2287 .version = REDISMODULE_COMMAND_INFO_VERSION,
2288 .summary = "Add one or more elements to a vector set, or update its vector if it already exists",
2289 .since = "8.0.0",
2290 .arity = -5,
2291 .args = vadd_args,
2292 };
2293 if (RedisModule_SetCommandInfo(vadd_cmd, &vadd_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2294
2295 // Register command VREM
2296 if (RedisModule_CreateCommand(ctx,"VREM",
2297 VREM_RedisCommand,"write",1,1,1) == REDISMODULE_ERR)
2298 return REDISMODULE_ERR;
2299
2300 RedisModuleCommand *vrem_cmd = RedisModule_GetCommand(ctx, "VREM");
2301 if (vrem_cmd == NULL) return REDISMODULE_ERR;
2302
2303 RedisModuleCommandArg vrem_args[] = {
2304 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2305 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2306 { .name = NULL }
2307 };
2308 RedisModuleCommandInfo vrem_info = {
2309 .version = REDISMODULE_COMMAND_INFO_VERSION,
2310 .summary = "Remove an element from a vector set",
2311 .since = "8.0.0",
2312 .arity = 3,
2313 .args = vrem_args,
2314 };
2315 if (RedisModule_SetCommandInfo(vrem_cmd, &vrem_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2316
2317 // Register command VSIM
2318 if (RedisModule_CreateCommand(ctx,"VSIM",
2319 VSIM_RedisCommand,"readonly",1,1,1) == REDISMODULE_ERR)
2320 return REDISMODULE_ERR;
2321
2322 RedisModuleCommand *vsim_cmd = RedisModule_GetCommand(ctx, "VSIM");
2323 if (vsim_cmd == NULL) return REDISMODULE_ERR;
2324
2325 RedisModuleCommandArg vsim_args[] = {
2326 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2327 { .name = "format", .type = REDISMODULE_ARG_TYPE_ONEOF, .subargs = (RedisModuleCommandArg[]) {
2328 { .name = "ele", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "ELE" },
2329 { .name = "fp32", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "FP32" },
2330 { .name = "values", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "VALUES" },
2331 { .name = NULL }
2332 }
2333 },
2334 { .name = "vector_or_element", .type = REDISMODULE_ARG_TYPE_STRING },
2335 { .name = "withscores", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHSCORES", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2336 { .name = "withattribs", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHATTRIBS", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2337 { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "COUNT", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2338 { .name = "max_distance", .type = REDISMODULE_ARG_TYPE_DOUBLE, .token = "EPSILON", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2339 { .name = "search-exploration-factor", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2340 { .name = "expression", .type = REDISMODULE_ARG_TYPE_STRING, .token = "FILTER", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2341 { .name = "max-filtering-effort", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "FILTER-EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2342 { .name = "truth", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "TRUTH", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2343 { .name = "nothread", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "NOTHREAD", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2344 { .name = NULL }
2345 };
2346 RedisModuleCommandInfo vsim_info = {
2347 .version = REDISMODULE_COMMAND_INFO_VERSION,
2348 .summary = "Return elements by vector similarity",
2349 .since = "8.0.0",
2350 .arity = -4,
2351 .args = vsim_args,
2352 };
2353 if (RedisModule_SetCommandInfo(vsim_cmd, &vsim_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2354
2355 // Register command VDIM
2356 if (RedisModule_CreateCommand(ctx, "VDIM",
2357 VDIM_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2358 return REDISMODULE_ERR;
2359
2360 RedisModuleCommand *vdim_cmd = RedisModule_GetCommand(ctx, "VDIM");
2361 if (vdim_cmd == NULL) return REDISMODULE_ERR;
2362
2363 RedisModuleCommandArg vdim_args[] = {
2364 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2365 { .name = NULL }
2366 };
2367 RedisModuleCommandInfo vdim_info = {
2368 .version = REDISMODULE_COMMAND_INFO_VERSION,
2369 .summary = "Return the dimension of vectors in the vector set",
2370 .since = "8.0.0",
2371 .arity = 2,
2372 .args = vdim_args,
2373 };
2374 if (RedisModule_SetCommandInfo(vdim_cmd, &vdim_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2375
2376 // Register command VCARD
2377 if (RedisModule_CreateCommand(ctx, "VCARD",
2378 VCARD_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2379 return REDISMODULE_ERR;
2380
2381 RedisModuleCommand *vcard_cmd = RedisModule_GetCommand(ctx, "VCARD");
2382 if (vcard_cmd == NULL) return REDISMODULE_ERR;
2383
2384 RedisModuleCommandArg vcard_args[] = {
2385 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2386 { .name = NULL }
2387 };
2388 RedisModuleCommandInfo vcard_info = {
2389 .version = REDISMODULE_COMMAND_INFO_VERSION,
2390 .summary = "Return the number of elements in a vector set",
2391 .since = "8.0.0",
2392 .arity = 2,
2393 .args = vcard_args,
2394 };
2395 if (RedisModule_SetCommandInfo(vcard_cmd, &vcard_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2396
2397 // Register command VEMB
2398 if (RedisModule_CreateCommand(ctx, "VEMB",
2399 VEMB_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2400 return REDISMODULE_ERR;
2401
2402 RedisModuleCommand *vemb_cmd = RedisModule_GetCommand(ctx, "VEMB");
2403 if (vemb_cmd == NULL) return REDISMODULE_ERR;
2404
2405 RedisModuleCommandArg vemb_args[] = {
2406 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2407 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2408 { .name = "raw", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "RAW", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2409 { .name = NULL }
2410 };
2411 RedisModuleCommandInfo vemb_info = {
2412 .version = REDISMODULE_COMMAND_INFO_VERSION,
2413 .summary = "Return the vector associated with an element",
2414 .since = "8.0.0",
2415 .arity = -3,
2416 .args = vemb_args,
2417 };
2418 if (RedisModule_SetCommandInfo(vemb_cmd, &vemb_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2419
2420 // Register command VLINKS
2421 if (RedisModule_CreateCommand(ctx, "VLINKS",
2422 VLINKS_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2423 return REDISMODULE_ERR;
2424
2425 RedisModuleCommand *vlinks_cmd = RedisModule_GetCommand(ctx, "VLINKS");
2426 if (vlinks_cmd == NULL) return REDISMODULE_ERR;
2427
2428 RedisModuleCommandArg vlinks_args[] = {
2429 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2430 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2431 { .name = "withscores", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHSCORES", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2432 { .name = NULL }
2433 };
2434 RedisModuleCommandInfo vlinks_info = {
2435 .version = REDISMODULE_COMMAND_INFO_VERSION,
2436 .summary = "Return the neighbors of an element at each layer in the HNSW graph",
2437 .since = "8.0.0",
2438 .arity = -3,
2439 .args = vlinks_args,
2440 };
2441 if (RedisModule_SetCommandInfo(vlinks_cmd, &vlinks_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2442
2443 // Register command VINFO
2444 if (RedisModule_CreateCommand(ctx, "VINFO",
2445 VINFO_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2446 return REDISMODULE_ERR;
2447
2448 RedisModuleCommand *vinfo_cmd = RedisModule_GetCommand(ctx, "VINFO");
2449 if (vinfo_cmd == NULL) return REDISMODULE_ERR;
2450
2451 RedisModuleCommandArg vinfo_args[] = {
2452 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2453 { .name = NULL }
2454 };
2455 RedisModuleCommandInfo vinfo_info = {
2456 .version = REDISMODULE_COMMAND_INFO_VERSION,
2457 .summary = "Return information about a vector set",
2458 .since = "8.0.0",
2459 .arity = 2,
2460 .args = vinfo_args,
2461 };
2462 if (RedisModule_SetCommandInfo(vinfo_cmd, &vinfo_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2463
2464 // Register command VSETATTR
2465 if (RedisModule_CreateCommand(ctx, "VSETATTR",
2466 VSETATTR_RedisCommand, "write fast", 1, 1, 1) == REDISMODULE_ERR)
2467 return REDISMODULE_ERR;
2468
2469 RedisModuleCommand *vsetattr_cmd = RedisModule_GetCommand(ctx, "VSETATTR");
2470 if (vsetattr_cmd == NULL) return REDISMODULE_ERR;
2471
2472 RedisModuleCommandArg vsetattr_args[] = {
2473 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2474 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2475 { .name = "json", .type = REDISMODULE_ARG_TYPE_STRING },
2476 { .name = NULL }
2477 };
2478 RedisModuleCommandInfo vsetattr_info = {
2479 .version = REDISMODULE_COMMAND_INFO_VERSION,
2480 .summary = "Associate or remove the JSON attributes of elements",
2481 .since = "8.0.0",
2482 .arity = 4,
2483 .args = vsetattr_args,
2484 };
2485 if (RedisModule_SetCommandInfo(vsetattr_cmd, &vsetattr_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2486
2487 // Register command VGETATTR
2488 if (RedisModule_CreateCommand(ctx, "VGETATTR",
2489 VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2490 return REDISMODULE_ERR;
2491
2492 RedisModuleCommand *vgetattr_cmd = RedisModule_GetCommand(ctx, "VGETATTR");
2493 if (vgetattr_cmd == NULL) return REDISMODULE_ERR;
2494
2495 RedisModuleCommandArg vgetattr_args[] = {
2496 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2497 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2498 { .name = NULL }
2499 };
2500 RedisModuleCommandInfo vgetattr_info = {
2501 .version = REDISMODULE_COMMAND_INFO_VERSION,
2502 .summary = "Retrieve the JSON attributes of elements",
2503 .since = "8.0.0",
2504 .arity = 3,
2505 .args = vgetattr_args,
2506 };
2507 if (RedisModule_SetCommandInfo(vgetattr_cmd, &vgetattr_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2508
2509 // Register command VRANDMEMBER
2510 if (RedisModule_CreateCommand(ctx, "VRANDMEMBER",
2511 VRANDMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR)
2512 return REDISMODULE_ERR;
2513
2514 RedisModuleCommand *vrandmember_cmd = RedisModule_GetCommand(ctx, "VRANDMEMBER");
2515 if (vrandmember_cmd == NULL) return REDISMODULE_ERR;
2516
2517 RedisModuleCommandArg vrandmember_args[] = {
2518 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2519 { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2520 { .name = NULL }
2521 };
2522 RedisModuleCommandInfo vrandmember_info = {
2523 .version = REDISMODULE_COMMAND_INFO_VERSION,
2524 .summary = "Return one or multiple random members from a vector set",
2525 .since = "8.0.0",
2526 .arity = -2,
2527 .args = vrandmember_args,
2528 };
2529 if (RedisModule_SetCommandInfo(vrandmember_cmd, &vrandmember_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2530
2531 // Register command VISMEMBER
2532 if (RedisModule_CreateCommand(ctx, "VISMEMBER",
2533 VISMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR)
2534 return REDISMODULE_ERR;
2535
2536 RedisModuleCommand *vismember_cmd = RedisModule_GetCommand(ctx, "VISMEMBER");
2537 if (vismember_cmd == NULL) return REDISMODULE_ERR;
2538
2539 RedisModuleCommandArg vismember_args[] = {
2540 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2541 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2542 { .name = NULL }
2543 };
2544 RedisModuleCommandInfo vismember_info = {
2545 .version = REDISMODULE_COMMAND_INFO_VERSION,
2546 .summary = "Check if an element exists in a vector set",
2547 .since = "8.2.0",
2548 .arity = 3,
2549 .args = vismember_args,
2550 };
2551 if (RedisModule_SetCommandInfo(vismember_cmd, &vismember_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2552
2553 // Register command VRANGE
2554 if (RedisModule_CreateCommand(ctx, "VRANGE",
2555 VRANGE_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR)
2556 return REDISMODULE_ERR;
2557
2558 RedisModuleCommand *vrange_cmd = RedisModule_GetCommand(ctx, "VRANGE");
2559 if (vrange_cmd == NULL) return REDISMODULE_ERR;
2560
2561 RedisModuleCommandArg vrange_args[] = {
2562 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2563 { .name = "start", .type = REDISMODULE_ARG_TYPE_STRING },
2564 { .name = "end", .type = REDISMODULE_ARG_TYPE_STRING },
2565 { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2566 { .name = NULL }
2567 };
2568 RedisModuleCommandInfo vrange_info = {
2569 .version = REDISMODULE_COMMAND_INFO_VERSION,
2570 .summary = "Return vector set elements in a lex range",
2571 .since = "8.4.0",
2572 .arity = -4,
2573 .args = vrange_args,
2574 };
2575 if (RedisModule_SetCommandInfo(vrange_cmd, &vrange_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2576
2577 // Set the allocator for the HNSW library, so that memory tracking
2578 // is correct in Redis.
2579 hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc,
2580 RedisModule_Realloc);
2581
2582 return REDISMODULE_OK;
2583}
2584
2585int VectorSets_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
2586 return RedisModule_OnLoad(ctx, argv, argc);
2587}