diff options
Diffstat (limited to 'examples/redis-unstable/modules/vector-sets/vset.c')
| -rw-r--r-- | examples/redis-unstable/modules/vector-sets/vset.c | 2587 |
1 files changed, 0 insertions, 2587 deletions
diff --git a/examples/redis-unstable/modules/vector-sets/vset.c b/examples/redis-unstable/modules/vector-sets/vset.c deleted file mode 100644 index 500f8e9..0000000 --- a/examples/redis-unstable/modules/vector-sets/vset.c +++ /dev/null | |||
| @@ -1,2587 +0,0 @@ | |||
| 1 | /* Redis implementation for vector sets. The data structure itself | ||
| 2 | * is implemented in hnsw.c. | ||
| 3 | * | ||
| 4 | * Copyright (c) 2009-Present, Redis Ltd. | ||
| 5 | * All rights reserved. | ||
| 6 | * | ||
| 7 | * Licensed under your choice of (a) the Redis Source Available License 2.0 | ||
| 8 | * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the | ||
| 9 | * GNU Affero General Public License v3 (AGPLv3). | ||
| 10 | * Originally authored by: Salvatore Sanfilippo. | ||
| 11 | * | ||
| 12 | * ======================== Understand threading model ========================= | ||
| 13 | * This code implements threaded operarations for two of the commands: | ||
| 14 | * | ||
| 15 | * 1. VSIM, by default. | ||
| 16 | * 2. VADD, if the CAS option is specified. | ||
| 17 | * | ||
| 18 | * Note that even if the second operation, VADD, is a write operation, only | ||
| 19 | * the neighbors collection for the new node is performed in a thread: then, | ||
| 20 | * the actual insert is performed in the reply callback VADD_CASReply(), | ||
| 21 | * which is executed in the main thread. | ||
| 22 | * | ||
| 23 | * Threaded operations need us to protect various operations with mutexes, | ||
| 24 | * even if a certain degree of protection is already provided by the HNSW | ||
| 25 | * library. Here are a few very important things about this implementation | ||
| 26 | * and the way locking is performed. | ||
| 27 | * | ||
| 28 | * 1. All the write operations are performed in the main Redis thread: | ||
| 29 | * this also include VADD_CASReply() callback, that is called by Redis | ||
| 30 | * internals only in the context of the main thread. However the HNSW | ||
| 31 | * library allows background threads in hnsw_search() (VSIM) to modify | ||
| 32 | * nodes metadata to speedup search (to understand if a node was already | ||
| 33 | * visited), but this only happens after acquiring a specific lock | ||
| 34 | * for a given "read slot". | ||
| 35 | * | ||
| 36 | * 2. We use a global lock for each Vector Set object, called "in_use". This | ||
| 37 | * lock is a read-write lock, and is acquired in read mode by all the | ||
| 38 | * threads that perform reads in the background. It is only acquired in | ||
| 39 | * write mode by vectorSetWaitAllBackgroundClients(): the function acquires | ||
| 40 | * the lock and immediately releases it, with the effect of waiting all the | ||
| 41 | * background threads still running from ending their execution. | ||
| 42 | * | ||
| 43 | * Note that no thread can be spawned, since we only call | ||
| 44 | * vectorSetWaitAllBackgroundClients() from the main Redis thread, that | ||
| 45 | * is also the only thread spawning other threads. | ||
| 46 | * | ||
| 47 | * vectorSetWaitAllBackgroundClients() is used in two ways: | ||
| 48 | * A) When we need to delete a vector set because of (DEL) or other | ||
| 49 | * operations destroying the object, we need to wait that all the | ||
| 50 | * background threads working with this object finished their work. | ||
| 51 | * B) When we modify the HNSW nodes bypassing the normal locking | ||
| 52 | * provided by the HNSW library. This only happens when we update | ||
| 53 | * an existing node attribute so far, in VSETATTR and when we call | ||
| 54 | * VADD to update a node with the SETATTR option. | ||
| 55 | * | ||
| 56 | * 3. Often during read operations performed by Redis commands in the | ||
| 57 | * main thread (VCARD, VEMB, VRANDMEMBER, ...) we don't acquire any | ||
| 58 | * lock at all. The commands run in the main Redis thread, we can only | ||
| 59 | * have, at the same time, background reads against the same data | ||
| 60 | * structure. Note that VSIM_thread() and VADD_thread() still modify the | ||
| 61 | * read slot metadata, that is node->visited_epoch[slot], but as long as | ||
| 62 | * our read commands running in the main thread don't need to use | ||
| 63 | * hnsw_search() or other HNSW functions using the visited epochs slots | ||
| 64 | * we are safe. | ||
| 65 | * | ||
| 66 | * 4. There is a race from the moment we create a thread, passing the | ||
| 67 | * vector set object, to the moment the thread can actually lock the | ||
| 68 | * result win the in_use_lock mutex: as the thread starts, in the meanwhile | ||
| 69 | * a DEL/expire could trigger and remove the object. For this reason | ||
| 70 | * we use an atomic counter that protects our object for this small | ||
| 71 | * time in vectorSetWaitAllBackgroundClients(). This prevents removal | ||
| 72 | * of objects that are about to be taken by threads. | ||
| 73 | * | ||
| 74 | * Note that other competing solutions could be used to fix the problem | ||
| 75 | * but have their set of issues, however they are worth documenting here | ||
| 76 | * and evaluating in the future: | ||
| 77 | * | ||
| 78 | * A. Using a conditional variable we could "wait" for the thread to | ||
| 79 | * acquire the lock. However this means waiting before returning | ||
| 80 | * to the event loop, and would make the command execution slower. | ||
| 81 | * B. We could use again an atomic variable, like we did, but this time | ||
| 82 | * as a refcount for the object, with a vsetAcquire() vsetRelease(). | ||
| 83 | * In this case, the command could retain the object in the main thread | ||
| 84 | * before starting the thread, and the thread, after the work is done, | ||
| 85 | * could release it. This way sometimes the object would be freed by | ||
| 86 | * the thread, and it's while now can be safe to do the kind of resource | ||
| 87 | * deallocation that vectorSetReleaseObject() does, given that the | ||
| 88 | * Redis Modules API is not always thread safe this solution may not | ||
| 89 | * be future-proof. However there is to evaluate it better in the | ||
| 90 | * future. | ||
| 91 | * C. We could use the "B" solution but instead of freeing the object | ||
| 92 | * in the thread, in this specific case we could just put it into a | ||
| 93 | * list and defer it for later freeing (for instance in the reply | ||
| 94 | * callback), so that the object is always freed in the main thread. | ||
| 95 | * This would require a list of objects to free. | ||
| 96 | * | ||
| 97 | * However the current solution only disadvantage is the potential busy | ||
| 98 | * loop, but this busy loop in practical terms will almost never do | ||
| 99 | * much: to trigger it, a number of circumnstances must happen: deleting | ||
| 100 | * Vector Set keys while using them, hitting the small window needed to | ||
| 101 | * start the thread and read-lock the mutex. | ||
| 102 | */ | ||
| 103 | |||
| 104 | #define _DEFAULT_SOURCE | ||
| 105 | #define _USE_MATH_DEFINES | ||
| 106 | #define _POSIX_C_SOURCE 200809L | ||
| 107 | |||
| 108 | #include "../../src/redismodule.h" | ||
| 109 | #include <stdio.h> | ||
| 110 | #include <stdlib.h> | ||
| 111 | #include <ctype.h> | ||
| 112 | #include <string.h> | ||
| 113 | #include <strings.h> | ||
| 114 | #include <stdint.h> | ||
| 115 | #include <math.h> | ||
| 116 | #include <pthread.h> | ||
| 117 | #include <stdatomic.h> | ||
| 118 | #include "hnsw.h" | ||
| 119 | #include "vset_config.h" | ||
| 120 | |||
| 121 | // We inline directly the expression implementation here so that building | ||
| 122 | // the module is trivial. | ||
| 123 | #include "expr.c" | ||
| 124 | |||
| 125 | static RedisModuleType *VectorSetType; | ||
| 126 | static uint64_t VectorSetTypeNextId = 0; | ||
| 127 | |||
| 128 | // Default EF value if not specified during creation. | ||
| 129 | #define VSET_DEFAULT_C_EF 200 | ||
| 130 | |||
| 131 | // Default EF value if not specified during search. | ||
| 132 | #define VSET_DEFAULT_SEARCH_EF 100 | ||
| 133 | |||
| 134 | // Default num elements returned by VSIM. | ||
| 135 | #define VSET_DEFAULT_COUNT 10 | ||
| 136 | |||
| 137 | /* ========================== Internal data structure ====================== */ | ||
| 138 | |||
| 139 | /* Our abstract data type needs a dual representation similar to Redis | ||
| 140 | * sorted set: the proximity graph, and also a element -> graph-node map | ||
| 141 | * that will allow us to perform deletions and other operations that have | ||
| 142 | * as input the element itself. */ | ||
| 143 | struct vsetObject { | ||
| 144 | HNSW *hnsw; // Proximity graph. | ||
| 145 | RedisModuleDict *dict; // Element -> node mapping. | ||
| 146 | float *proj_matrix; // Random projection matrix, NULL if no projection | ||
| 147 | uint32_t proj_input_size; // Input dimension after projection. | ||
| 148 | // Output dimension is implicit in | ||
| 149 | // hnsw->vector_dim. | ||
| 150 | pthread_rwlock_t in_use_lock; // Lock needed to destroy the object safely. | ||
| 151 | uint64_t id; // Unique ID used by threaded VADD to know the | ||
| 152 | // object is still the same. | ||
| 153 | uint64_t numattribs; // Number of nodes associated with an attribute. | ||
| 154 | atomic_int thread_creation_pending; // Number of threads that are currently | ||
| 155 | // pending to lock the object. | ||
| 156 | }; | ||
| 157 | |||
| 158 | /* Each node has two associated values: the associated string (the item | ||
| 159 | * in the set) and potentially a JSON string, that is, the attributes, used | ||
| 160 | * for hybrid search with the VSIM FILTER option. */ | ||
| 161 | struct vsetNodeVal { | ||
| 162 | RedisModuleString *item; | ||
| 163 | RedisModuleString *attrib; | ||
| 164 | }; | ||
| 165 | |||
| 166 | /* Count the number of set bits in an integer (population count/Hamming weight). | ||
| 167 | * This is a portable implementation that doesn't rely on compiler | ||
| 168 | * extensions. */ | ||
| 169 | static inline uint32_t bit_count(uint32_t n) { | ||
| 170 | uint32_t count = 0; | ||
| 171 | while (n) { | ||
| 172 | count += n & 1; | ||
| 173 | n >>= 1; | ||
| 174 | } | ||
| 175 | return count; | ||
| 176 | } | ||
| 177 | |||
| 178 | /* Create a Hadamard-based projection matrix for dimensionality reduction. | ||
| 179 | * Uses {-1, +1} entries with a pattern based on bit operations. | ||
| 180 | * The pattern is matrix[i][j] = (i & j) % 2 == 0 ? 1 : -1 | ||
| 181 | * Matrix is scaled by 1/sqrt(input_dim) for normalization. | ||
| 182 | * Returns NULL on allocation failure. | ||
| 183 | * | ||
| 184 | * Note that compared to other approaches (random gaussian weights), what | ||
| 185 | * we have here is deterministic, it means that our replicas will have | ||
| 186 | * the same set of weights. Also this approach seems to work much better | ||
| 187 | * in practice, and the distances between elements are better guaranteed. | ||
| 188 | * | ||
| 189 | * Note that we still save the projection matrix in the RDB file, because | ||
| 190 | * in the future we may change the weights generation, and we want everything | ||
| 191 | * to be backward compatible. */ | ||
| 192 | float *createProjectionMatrix(uint32_t input_dim, uint32_t output_dim) { | ||
| 193 | float *matrix = RedisModule_Alloc(sizeof(float) * input_dim * output_dim); | ||
| 194 | |||
| 195 | /* Scale factor to normalize the projection. */ | ||
| 196 | const float scale = 1.0f / sqrt(input_dim); | ||
| 197 | |||
| 198 | /* Fill the matrix using Hadamard pattern. */ | ||
| 199 | for (uint32_t i = 0; i < output_dim; i++) { | ||
| 200 | for (uint32_t j = 0; j < input_dim; j++) { | ||
| 201 | /* Calculate position in the flattened matrix. */ | ||
| 202 | uint32_t pos = i * input_dim + j; | ||
| 203 | |||
| 204 | /* Hadamard pattern: use bit operations to determine sign | ||
| 205 | * If the count of 1-bits in the bitwise AND of i and j is even, | ||
| 206 | * the value is 1, otherwise -1. */ | ||
| 207 | int value = (bit_count(i & j) % 2 == 0) ? 1 : -1; | ||
| 208 | |||
| 209 | /* Store the scaled value. */ | ||
| 210 | matrix[pos] = value * scale; | ||
| 211 | } | ||
| 212 | } | ||
| 213 | return matrix; | ||
| 214 | } | ||
| 215 | |||
| 216 | /* Apply random projection to input vector. Returns new allocated vector. */ | ||
| 217 | float *applyProjection(const float *input, const float *proj_matrix, | ||
| 218 | uint32_t input_dim, uint32_t output_dim) | ||
| 219 | { | ||
| 220 | float *output = RedisModule_Alloc(sizeof(float) * output_dim); | ||
| 221 | |||
| 222 | for (uint32_t i = 0; i < output_dim; i++) { | ||
| 223 | const float *row = &proj_matrix[i * input_dim]; | ||
| 224 | float sum = 0.0f; | ||
| 225 | for (uint32_t j = 0; j < input_dim; j++) { | ||
| 226 | sum += row[j] * input[j]; | ||
| 227 | } | ||
| 228 | output[i] = sum; | ||
| 229 | } | ||
| 230 | return output; | ||
| 231 | } | ||
| 232 | |||
| 233 | /* Create the vector as HNSW+Dictionary combined data structure. */ | ||
| 234 | struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type, uint32_t hnsw_M) { | ||
| 235 | struct vsetObject *o; | ||
| 236 | o = RedisModule_Alloc(sizeof(*o)); | ||
| 237 | |||
| 238 | o->id = VectorSetTypeNextId++; | ||
| 239 | o->hnsw = hnsw_new(dim,quant_type,hnsw_M); | ||
| 240 | if (!o->hnsw) { // May fail because of mutex creation. | ||
| 241 | RedisModule_Free(o); | ||
| 242 | return NULL; | ||
| 243 | } | ||
| 244 | |||
| 245 | o->dict = RedisModule_CreateDict(NULL); | ||
| 246 | o->proj_matrix = NULL; | ||
| 247 | o->proj_input_size = 0; | ||
| 248 | o->numattribs = 0; | ||
| 249 | o->thread_creation_pending = 0; | ||
| 250 | RedisModule_Assert(pthread_rwlock_init(&o->in_use_lock,NULL) == 0); | ||
| 251 | return o; | ||
| 252 | } | ||
| 253 | |||
| 254 | void vectorSetReleaseNodeValue(void *v) { | ||
| 255 | struct vsetNodeVal *nv = v; | ||
| 256 | RedisModule_FreeString(NULL,nv->item); | ||
| 257 | if (nv->attrib) RedisModule_FreeString(NULL,nv->attrib); | ||
| 258 | RedisModule_Free(nv); | ||
| 259 | } | ||
| 260 | |||
| 261 | /* Free the vector set object. */ | ||
| 262 | void vectorSetReleaseObject(struct vsetObject *o) { | ||
| 263 | if (!o) return; | ||
| 264 | if (o->hnsw) hnsw_free(o->hnsw,vectorSetReleaseNodeValue); | ||
| 265 | if (o->dict) RedisModule_FreeDict(NULL,o->dict); | ||
| 266 | if (o->proj_matrix) RedisModule_Free(o->proj_matrix); | ||
| 267 | pthread_rwlock_destroy(&o->in_use_lock); | ||
| 268 | RedisModule_Free(o); | ||
| 269 | } | ||
| 270 | |||
| 271 | /* Wait for all the threads performing operations on this | ||
| 272 | * index to terminate their work (locking for write will | ||
| 273 | * wait for all the other threads). | ||
| 274 | * | ||
| 275 | * if 'for_del' is set to 1, we also wait for all the pending threads | ||
| 276 | * that still didn't acquire the lock to finish their work. This | ||
| 277 | * is useful only if we are going to call this function to delete | ||
| 278 | * the object, and not if we want to just to modify it. */ | ||
| 279 | void vectorSetWaitAllBackgroundClients(struct vsetObject *vset, int for_del) { | ||
| 280 | if (for_del) { | ||
| 281 | // If we are going to destroy the object, after this call, let's | ||
| 282 | // wait for threads that are being created and still didn't had | ||
| 283 | // a chance to acquire the lock. | ||
| 284 | while (vset->thread_creation_pending > 0); | ||
| 285 | } | ||
| 286 | RedisModule_Assert(pthread_rwlock_wrlock(&vset->in_use_lock) == 0); | ||
| 287 | pthread_rwlock_unlock(&vset->in_use_lock); | ||
| 288 | } | ||
| 289 | |||
| 290 | /* Return a string representing the quantization type name of a vector set. */ | ||
| 291 | const char *vectorSetGetQuantName(struct vsetObject *o) { | ||
| 292 | switch(o->hnsw->quant_type) { | ||
| 293 | case HNSW_QUANT_NONE: return "f32"; | ||
| 294 | case HNSW_QUANT_Q8: return "int8"; | ||
| 295 | case HNSW_QUANT_BIN: return "bin"; | ||
| 296 | default: return "unknown"; | ||
| 297 | } | ||
| 298 | } | ||
| 299 | |||
| 300 | /* Insert the specified element into the Vector Set. | ||
| 301 | * If update is '1', the existing node will be updated. | ||
| 302 | * | ||
| 303 | * Returns 1 if the element was added, or 0 if the element was already there | ||
| 304 | * and was just updated. */ | ||
| 305 | int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, RedisModuleString *attrib, int update, int ef) | ||
| 306 | { | ||
| 307 | hnswNode *node = RedisModule_DictGet(o->dict,val,NULL); | ||
| 308 | if (node != NULL) { | ||
| 309 | if (update) { | ||
| 310 | /* Wait for clients in the background: background VSIM | ||
| 311 | * operations touch the nodes attributes we are going | ||
| 312 | * to touch. */ | ||
| 313 | vectorSetWaitAllBackgroundClients(o,0); | ||
| 314 | |||
| 315 | struct vsetNodeVal *nv = node->value; | ||
| 316 | /* Pass NULL as value-free function. We want to reuse | ||
| 317 | * the old value. */ | ||
| 318 | hnsw_delete_node(o->hnsw, node, NULL); | ||
| 319 | node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef); | ||
| 320 | RedisModule_Assert(node != NULL); | ||
| 321 | RedisModule_DictReplace(o->dict,val,node); | ||
| 322 | |||
| 323 | /* If attrib != NULL, the user wants that in case of an update we | ||
| 324 | * update the attribute as well (otherwise it remains as it was). | ||
| 325 | * Note that the order of operations is conceinved so that it | ||
| 326 | * works in case the old attrib and the new attrib pointer is the | ||
| 327 | * same. */ | ||
| 328 | if (attrib) { | ||
| 329 | // Empty attribute string means: unset the attribute during | ||
| 330 | // the update. | ||
| 331 | size_t attrlen; | ||
| 332 | RedisModule_StringPtrLen(attrib,&attrlen); | ||
| 333 | if (attrlen != 0) { | ||
| 334 | RedisModule_RetainString(NULL,attrib); | ||
| 335 | o->numattribs++; | ||
| 336 | } else { | ||
| 337 | attrib = NULL; | ||
| 338 | } | ||
| 339 | |||
| 340 | if (nv->attrib) { | ||
| 341 | o->numattribs--; | ||
| 342 | RedisModule_FreeString(NULL,nv->attrib); | ||
| 343 | } | ||
| 344 | nv->attrib = attrib; | ||
| 345 | } | ||
| 346 | } | ||
| 347 | return 0; | ||
| 348 | } | ||
| 349 | |||
| 350 | struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); | ||
| 351 | nv->item = val; | ||
| 352 | nv->attrib = attrib; | ||
| 353 | node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef); | ||
| 354 | if (node == NULL) { | ||
| 355 | // XXX Technically in Redis-land we don't have out of memory, as we | ||
| 356 | // crash on OOM. However the HNSW library may fail for error in the | ||
| 357 | // locking libc call. Probably impossible in practical terms. | ||
| 358 | RedisModule_Free(nv); | ||
| 359 | return 0; | ||
| 360 | } | ||
| 361 | if (attrib != NULL) o->numattribs++; | ||
| 362 | RedisModule_DictSet(o->dict,val,node); | ||
| 363 | RedisModule_RetainString(NULL,val); | ||
| 364 | if (attrib) RedisModule_RetainString(NULL,attrib); | ||
| 365 | return 1; | ||
| 366 | } | ||
| 367 | |||
| 368 | /* Parse vector from FP32 blob or VALUES format, with optional REDUCE. | ||
| 369 | * Format: [REDUCE dim] FP32|VALUES ... | ||
| 370 | * Returns allocated vector and sets dimension in *dim. | ||
| 371 | * If reduce_dim is not NULL, sets it to the requested reduction dimension. | ||
| 372 | * Returns NULL on parsing error. | ||
| 373 | * | ||
| 374 | * The function sets as a reference *consumed_args, so that the caller | ||
| 375 | * knows how many arguments we consumed in order to parse the input | ||
| 376 | * vector. Remaining arguments are often command options. */ | ||
| 377 | float *parseVector(RedisModuleString **argv, int argc, int start_idx, | ||
| 378 | size_t *dim, uint32_t *reduce_dim, int *consumed_args) | ||
| 379 | { | ||
| 380 | int consumed = 0; // Arguments consumed | ||
| 381 | |||
| 382 | /* Check for REDUCE option first. */ | ||
| 383 | if (reduce_dim) *reduce_dim = 0; | ||
| 384 | if (reduce_dim && argc > start_idx + 2 && | ||
| 385 | !strcasecmp(RedisModule_StringPtrLen(argv[start_idx],NULL),"REDUCE")) | ||
| 386 | { | ||
| 387 | long long rdim; | ||
| 388 | if (RedisModule_StringToLongLong(argv[start_idx+1],&rdim) | ||
| 389 | != REDISMODULE_OK || rdim <= 0) | ||
| 390 | { | ||
| 391 | return NULL; | ||
| 392 | } | ||
| 393 | if (reduce_dim) *reduce_dim = rdim; | ||
| 394 | start_idx += 2; // Skip REDUCE and its argument. | ||
| 395 | consumed += 2; | ||
| 396 | } | ||
| 397 | |||
| 398 | /* Now parse the vector format as before. */ | ||
| 399 | float *vec = NULL; | ||
| 400 | const char *vec_format = RedisModule_StringPtrLen(argv[start_idx],NULL); | ||
| 401 | |||
| 402 | if (!strcasecmp(vec_format,"FP32")) { | ||
| 403 | if (argc < start_idx + 2) return NULL; // Need FP32 + vector + value. | ||
| 404 | size_t vec_raw_len; | ||
| 405 | const char *blob = | ||
| 406 | RedisModule_StringPtrLen(argv[start_idx+1],&vec_raw_len); | ||
| 407 | |||
| 408 | // Must be 4 bytes per component. | ||
| 409 | if (vec_raw_len % 4 || vec_raw_len < 4) return NULL; | ||
| 410 | *dim = vec_raw_len/4; | ||
| 411 | |||
| 412 | vec = RedisModule_Alloc(vec_raw_len); | ||
| 413 | if (!vec) return NULL; | ||
| 414 | memcpy(vec,blob,vec_raw_len); | ||
| 415 | consumed += 2; | ||
| 416 | } else if (!strcasecmp(vec_format,"VALUES")) { | ||
| 417 | if (argc < start_idx + 2) return NULL; // Need at least the dimension. | ||
| 418 | long long vdim; // Vector dimension passed by the user. | ||
| 419 | if (RedisModule_StringToLongLong(argv[start_idx+1],&vdim) | ||
| 420 | != REDISMODULE_OK || vdim < 1) return NULL; | ||
| 421 | |||
| 422 | // Check that all the arguments are available. | ||
| 423 | if (argc < start_idx + 2 + vdim) return NULL; | ||
| 424 | |||
| 425 | *dim = vdim; | ||
| 426 | vec = RedisModule_Alloc(sizeof(float) * vdim); | ||
| 427 | if (!vec) return NULL; | ||
| 428 | |||
| 429 | for (int j = 0; j < vdim; j++) { | ||
| 430 | double val; | ||
| 431 | if (RedisModule_StringToDouble(argv[start_idx+2+j],&val) | ||
| 432 | != REDISMODULE_OK) | ||
| 433 | { | ||
| 434 | RedisModule_Free(vec); | ||
| 435 | return NULL; | ||
| 436 | } | ||
| 437 | vec[j] = val; | ||
| 438 | } | ||
| 439 | consumed += vdim + 2; | ||
| 440 | } else { | ||
| 441 | return NULL; // Unknown format. | ||
| 442 | } | ||
| 443 | |||
| 444 | if (consumed_args) *consumed_args = consumed; | ||
| 445 | return vec; | ||
| 446 | } | ||
| 447 | |||
| 448 | /* ========================== Commands implementation ======================= */ | ||
| 449 | |||
| 450 | /* VADD thread handling the "CAS" version of the command, that is | ||
| 451 | * performed blocking the client, accumulating here, in the thread, the | ||
| 452 | * set of potential candidates, and later inserting the element in the | ||
| 453 | * key (if it still exists, and if it is still the *same* vector set) | ||
| 454 | * in the Reply callback. */ | ||
| 455 | void *VADD_thread(void *arg) { | ||
| 456 | pthread_detach(pthread_self()); | ||
| 457 | |||
| 458 | void **targ = (void**)arg; | ||
| 459 | RedisModuleBlockedClient *bc = targ[0]; | ||
| 460 | struct vsetObject *vset = targ[1]; | ||
| 461 | float *vec = targ[3]; | ||
| 462 | int ef = (uint64_t)targ[6]; | ||
| 463 | |||
| 464 | /* Lock the object and signal that we are no longer pending | ||
| 465 | * the lock acquisition. */ | ||
| 466 | RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0); | ||
| 467 | vset->thread_creation_pending--; | ||
| 468 | |||
| 469 | /* Look for candidates... */ | ||
| 470 | InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, ef); | ||
| 471 | targ[5] = ic; // Pass the context to the reply callback. | ||
| 472 | |||
| 473 | /* Unblock the client so that our read reply will be invoked. */ | ||
| 474 | pthread_rwlock_unlock(&vset->in_use_lock); | ||
| 475 | RedisModule_BlockedClientMeasureTimeEnd(bc); | ||
| 476 | RedisModule_UnblockClient(bc,targ); // Use targ as privdata. | ||
| 477 | return NULL; | ||
| 478 | } | ||
| 479 | |||
| 480 | /* Reply callback for CAS variant of VADD. | ||
| 481 | * Note: this is called in the main thread, in the background thread | ||
| 482 | * we just do the read operation of gathering the neighbors. */ | ||
| 483 | int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 484 | (void)argc; | ||
| 485 | RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ | ||
| 486 | |||
| 487 | int retval = REDISMODULE_OK; | ||
| 488 | void **targ = (void**)RedisModule_GetBlockedClientPrivateData(ctx); | ||
| 489 | uint64_t vset_id = (unsigned long) targ[2]; | ||
| 490 | float *vec = targ[3]; | ||
| 491 | RedisModuleString *val = targ[4]; | ||
| 492 | InsertContext *ic = targ[5]; | ||
| 493 | int ef = (uint64_t)targ[6]; | ||
| 494 | RedisModuleString *attrib = targ[7]; | ||
| 495 | RedisModule_Free(targ); | ||
| 496 | |||
| 497 | /* Open the key: there are no guarantees it still exists, or contains | ||
| 498 | * a vector set, or even the SAME vector set. */ | ||
| 499 | RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1], | ||
| 500 | REDISMODULE_READ|REDISMODULE_WRITE); | ||
| 501 | int type = RedisModule_KeyType(key); | ||
| 502 | struct vsetObject *vset = NULL; | ||
| 503 | |||
| 504 | if (type != REDISMODULE_KEYTYPE_EMPTY && | ||
| 505 | RedisModule_ModuleTypeGetType(key) == VectorSetType) | ||
| 506 | { | ||
| 507 | vset = RedisModule_ModuleTypeGetValue(key); | ||
| 508 | // Same vector set? | ||
| 509 | if (vset->id != vset_id) vset = NULL; | ||
| 510 | |||
| 511 | /* Also, if the element was already inserted, we just pretend | ||
| 512 | * the other insert won. We don't even start a threaded VADD | ||
| 513 | * if this was an update, since the deletion of the element itself | ||
| 514 | * in order to perform the update would invalidate the CAS state. */ | ||
| 515 | if (vset && RedisModule_DictGet(vset->dict,val,NULL) != NULL) | ||
| 516 | vset = NULL; | ||
| 517 | } | ||
| 518 | |||
| 519 | if (vset == NULL) { | ||
| 520 | /* If the object does not match the start of the operation, we | ||
| 521 | * just pretend the VADD was performed BEFORE the key was deleted | ||
| 522 | * or replaced. We return success but don't do anything. */ | ||
| 523 | hnsw_free_insert_context(ic); | ||
| 524 | } else { | ||
| 525 | /* Otherwise try to insert the new element with the neighbors | ||
| 526 | * collected in background. If we fail, do it synchronously again | ||
| 527 | * from scratch. */ | ||
| 528 | |||
| 529 | // First: allocate the dual-ported value for the node. | ||
| 530 | struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); | ||
| 531 | nv->item = val; | ||
| 532 | nv->attrib = attrib; | ||
| 533 | |||
| 534 | /* Then: insert the node in the HNSW data structure. Note that | ||
| 535 | * 'ic' could be NULL in case hnsw_prepare_insert() failed because of | ||
| 536 | * locking failure (likely impossible in practical terms). */ | ||
| 537 | hnswNode *newnode; | ||
| 538 | if (ic == NULL || | ||
| 539 | (newnode = hnsw_try_commit_insert(vset->hnsw, ic, nv)) == NULL) | ||
| 540 | { | ||
| 541 | /* If we are here, the CAS insert failed. We need to insert | ||
| 542 | * again with full locking for neighbors selection and | ||
| 543 | * actual insertion. This time we can't fail: */ | ||
| 544 | newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, nv, ef); | ||
| 545 | RedisModule_Assert(newnode != NULL); | ||
| 546 | } | ||
| 547 | RedisModule_DictSet(vset->dict,val,newnode); | ||
| 548 | val = NULL; // Don't free it later. | ||
| 549 | attrib = NULL; // Don't free it later. | ||
| 550 | |||
| 551 | RedisModule_ReplicateVerbatim(ctx); | ||
| 552 | } | ||
| 553 | |||
| 554 | // Whatever happens is a success... :D | ||
| 555 | RedisModule_ReplyWithBool(ctx,1); | ||
| 556 | if (val) RedisModule_FreeString(ctx,val); // Not added? Free it. | ||
| 557 | if (attrib) RedisModule_FreeString(ctx,attrib); // Not added? Free it. | ||
| 558 | RedisModule_Free(vec); | ||
| 559 | return retval; | ||
| 560 | } | ||
| 561 | |||
| 562 | /* VADD key [REDUCE dim] FP32|VALUES vector value [CAS] [NOQUANT] [BIN] [Q8] | ||
| 563 | * [M count] */ | ||
| 564 | int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 565 | RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ | ||
| 566 | |||
| 567 | if (argc < 5) return RedisModule_WrongArity(ctx); | ||
| 568 | |||
| 569 | /* Parse vector with optional REDUCE */ | ||
| 570 | size_t dim = 0; | ||
| 571 | uint32_t reduce_dim = 0; | ||
| 572 | int consumed_args; | ||
| 573 | int cas = 0; // Threaded check-and-set style insert. | ||
| 574 | long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes. | ||
| 575 | long long hnsw_create_M = HNSW_DEFAULT_M; // HNSW creation default M value. | ||
| 576 | float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args); | ||
| 577 | RedisModuleString *attrib = NULL; // Attributes if passed via ATTRIB. | ||
| 578 | if (!vec) | ||
| 579 | return RedisModule_ReplyWithError(ctx,"ERR invalid vector specification"); | ||
| 580 | |||
| 581 | /* Missing element string at the end? */ | ||
| 582 | if (argc-2-consumed_args < 1) { | ||
| 583 | RedisModule_Free(vec); | ||
| 584 | return RedisModule_WrongArity(ctx); | ||
| 585 | } | ||
| 586 | |||
| 587 | /* Parse options after the element string. */ | ||
| 588 | uint32_t quant_type = HNSW_QUANT_Q8; // Default quantization type. | ||
| 589 | |||
| 590 | for (int j = 2 + consumed_args + 1; j < argc; j++) { | ||
| 591 | const char *opt = RedisModule_StringPtrLen(argv[j], NULL); | ||
| 592 | if (!strcasecmp(opt, "CAS")) { | ||
| 593 | cas = 1; | ||
| 594 | } else if (!strcasecmp(opt, "EF") && j+1 < argc) { | ||
| 595 | if (RedisModule_StringToLongLong(argv[j+1], &ef) | ||
| 596 | != REDISMODULE_OK || ef <= 0 || ef > 1000000) | ||
| 597 | { | ||
| 598 | RedisModule_Free(vec); | ||
| 599 | return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); | ||
| 600 | } | ||
| 601 | j++; // skip argument. | ||
| 602 | } else if (!strcasecmp(opt, "M") && j+1 < argc) { | ||
| 603 | if (RedisModule_StringToLongLong(argv[j+1], &hnsw_create_M) | ||
| 604 | != REDISMODULE_OK || hnsw_create_M < HNSW_MIN_M || | ||
| 605 | hnsw_create_M > HNSW_MAX_M) | ||
| 606 | { | ||
| 607 | RedisModule_Free(vec); | ||
| 608 | return RedisModule_ReplyWithError(ctx, "ERR invalid M"); | ||
| 609 | } | ||
| 610 | j++; // skip argument. | ||
| 611 | } else if (!strcasecmp(opt, "SETATTR") && j+1 < argc) { | ||
| 612 | attrib = argv[j+1]; | ||
| 613 | j++; // skip argument. | ||
| 614 | } else if (!strcasecmp(opt, "NOQUANT")) { | ||
| 615 | quant_type = HNSW_QUANT_NONE; | ||
| 616 | } else if (!strcasecmp(opt, "BIN")) { | ||
| 617 | quant_type = HNSW_QUANT_BIN; | ||
| 618 | } else if (!strcasecmp(opt, "Q8")) { | ||
| 619 | quant_type = HNSW_QUANT_Q8; | ||
| 620 | } else { | ||
| 621 | RedisModule_Free(vec); | ||
| 622 | return RedisModule_ReplyWithError(ctx,"ERR invalid option after element"); | ||
| 623 | } | ||
| 624 | } | ||
| 625 | |||
| 626 | /* Drop CAS if this is a replica and we are getting the command from the | ||
| 627 | * replication link: we want to add/delete items in the same order as | ||
| 628 | * the master, while with CAS the timing would be different. | ||
| 629 | * | ||
| 630 | * Also for Lua scripts and MULTI/EXEC, we want to run the command | ||
| 631 | * on the main thread. */ | ||
| 632 | if (RedisModule_GetContextFlags(ctx) & | ||
| 633 | (REDISMODULE_CTX_FLAGS_REPLICATED| | ||
| 634 | REDISMODULE_CTX_FLAGS_LUA| | ||
| 635 | REDISMODULE_CTX_FLAGS_MULTI)) | ||
| 636 | { | ||
| 637 | cas = 0; | ||
| 638 | } | ||
| 639 | |||
| 640 | if (VSGlobalConfig.forceSingleThreadExec) { | ||
| 641 | cas = 0; | ||
| 642 | } | ||
| 643 | |||
| 644 | /* Open/create key */ | ||
| 645 | RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1], | ||
| 646 | REDISMODULE_READ|REDISMODULE_WRITE); | ||
| 647 | int type = RedisModule_KeyType(key); | ||
| 648 | if (type != REDISMODULE_KEYTYPE_EMPTY && | ||
| 649 | RedisModule_ModuleTypeGetType(key) != VectorSetType) | ||
| 650 | { | ||
| 651 | RedisModule_Free(vec); | ||
| 652 | return RedisModule_ReplyWithError(ctx,REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 653 | } | ||
| 654 | |||
| 655 | /* Get the correct value argument based on format and REDUCE */ | ||
| 656 | RedisModuleString *val = argv[2 + consumed_args]; | ||
| 657 | |||
| 658 | /* Create or get existing vector set */ | ||
| 659 | struct vsetObject *vset; | ||
| 660 | if (type == REDISMODULE_KEYTYPE_EMPTY) { | ||
| 661 | cas = 0; /* Do synchronous insert at creation, otherwise the | ||
| 662 | * key would be left empty until the threaded part | ||
| 663 | * does not return. It's also pointless to try try | ||
| 664 | * doing threaded first element insertion. */ | ||
| 665 | vset = createVectorSetObject(reduce_dim ? reduce_dim : dim, quant_type, hnsw_create_M); | ||
| 666 | if (vset == NULL) { | ||
| 667 | // We can't fail for OOM in Redis, but the mutex initialization | ||
| 668 | // at least theoretically COULD fail. Likely this code path | ||
| 669 | // is not reachable in practical terms. | ||
| 670 | RedisModule_Free(vec); | ||
| 671 | return RedisModule_ReplyWithError(ctx, | ||
| 672 | "ERR unable to create a Vector Set: system resources issue?"); | ||
| 673 | } | ||
| 674 | |||
| 675 | /* Initialize projection if requested */ | ||
| 676 | if (reduce_dim) { | ||
| 677 | vset->proj_matrix = createProjectionMatrix(dim, reduce_dim); | ||
| 678 | vset->proj_input_size = dim; | ||
| 679 | |||
| 680 | /* Project the vector */ | ||
| 681 | float *projected = applyProjection(vec, vset->proj_matrix, | ||
| 682 | dim, reduce_dim); | ||
| 683 | RedisModule_Free(vec); | ||
| 684 | vec = projected; | ||
| 685 | } | ||
| 686 | RedisModule_ModuleTypeSetValue(key,VectorSetType,vset); | ||
| 687 | } else { | ||
| 688 | vset = RedisModule_ModuleTypeGetValue(key); | ||
| 689 | |||
| 690 | if (vset->hnsw->quant_type != quant_type) { | ||
| 691 | RedisModule_Free(vec); | ||
| 692 | return RedisModule_ReplyWithError(ctx, | ||
| 693 | "ERR asked quantization mismatch with existing vector set"); | ||
| 694 | } | ||
| 695 | |||
| 696 | if (vset->hnsw->M != hnsw_create_M) { | ||
| 697 | RedisModule_Free(vec); | ||
| 698 | return RedisModule_ReplyWithError(ctx, | ||
| 699 | "ERR asked M value mismatch with existing vector set"); | ||
| 700 | } | ||
| 701 | |||
| 702 | if ((vset->proj_matrix == NULL && vset->hnsw->vector_dim != dim) || | ||
| 703 | (vset->proj_matrix && vset->hnsw->vector_dim != reduce_dim)) | ||
| 704 | { | ||
| 705 | RedisModule_Free(vec); | ||
| 706 | return RedisModule_ReplyWithErrorFormat(ctx, | ||
| 707 | "ERR Vector dimension mismatch - got %d but set has %d", | ||
| 708 | (int)dim, (int)vset->hnsw->vector_dim); | ||
| 709 | } | ||
| 710 | |||
| 711 | /* Check REDUCE compatibility */ | ||
| 712 | if (reduce_dim) { | ||
| 713 | if (!vset->proj_matrix) { | ||
| 714 | RedisModule_Free(vec); | ||
| 715 | return RedisModule_ReplyWithError(ctx, | ||
| 716 | "ERR cannot add projection to existing set without projection"); | ||
| 717 | } | ||
| 718 | if (reduce_dim != vset->hnsw->vector_dim) { | ||
| 719 | RedisModule_Free(vec); | ||
| 720 | return RedisModule_ReplyWithError(ctx, | ||
| 721 | "ERR projection dimension mismatch with existing set"); | ||
| 722 | } | ||
| 723 | } | ||
| 724 | |||
| 725 | /* Apply projection if needed */ | ||
| 726 | if (vset->proj_matrix) { | ||
| 727 | /* Ensure input dimension matches the projection matrix's expected input dimension */ | ||
| 728 | if (dim != vset->proj_input_size) { | ||
| 729 | RedisModule_Free(vec); | ||
| 730 | return RedisModule_ReplyWithErrorFormat(ctx, | ||
| 731 | "ERR Input dimension mismatch for projection - got %d but projection expects %d", | ||
| 732 | (int)dim, (int)vset->proj_input_size); | ||
| 733 | } | ||
| 734 | |||
| 735 | float *projected = applyProjection(vec, vset->proj_matrix, | ||
| 736 | vset->proj_input_size, | ||
| 737 | vset->hnsw->vector_dim); | ||
| 738 | RedisModule_Free(vec); | ||
| 739 | vec = projected; | ||
| 740 | dim = vset->hnsw->vector_dim; | ||
| 741 | } | ||
| 742 | } | ||
| 743 | |||
| 744 | /* For existing keys don't do CAS updates. For how things work now, the | ||
| 745 | * CAS state would be invalidated by the deletion before adding back. */ | ||
| 746 | if (cas && RedisModule_DictGet(vset->dict,val,NULL) != NULL) | ||
| 747 | cas = 0; | ||
| 748 | |||
| 749 | /* Here depending on the CAS option we directly insert in a blocking | ||
| 750 | * way, or use a thread to do candidate neighbors selection and only | ||
| 751 | * later, in the reply callback, actually add the element. */ | ||
| 752 | if (cas) { | ||
| 753 | RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,VADD_CASReply,NULL,NULL,0); | ||
| 754 | pthread_t tid; | ||
| 755 | void **targ = RedisModule_Alloc(sizeof(void*)*8); | ||
| 756 | targ[0] = bc; | ||
| 757 | targ[1] = vset; | ||
| 758 | targ[2] = (void*)(unsigned long)vset->id; | ||
| 759 | targ[3] = vec; | ||
| 760 | targ[4] = val; | ||
| 761 | targ[5] = NULL; // Used later for insertion context. | ||
| 762 | targ[6] = (void*)(unsigned long)ef; | ||
| 763 | targ[7] = attrib; | ||
| 764 | RedisModule_RetainString(ctx,val); | ||
| 765 | if (attrib) RedisModule_RetainString(ctx,attrib); | ||
| 766 | RedisModule_BlockedClientMeasureTimeStart(bc); | ||
| 767 | vset->thread_creation_pending++; | ||
| 768 | if (pthread_create(&tid,NULL,VADD_thread,targ) != 0) { | ||
| 769 | vset->thread_creation_pending--; | ||
| 770 | RedisModule_AbortBlock(bc); | ||
| 771 | RedisModule_Free(targ); | ||
| 772 | RedisModule_FreeString(ctx,val); | ||
| 773 | if (attrib) RedisModule_FreeString(ctx,attrib); | ||
| 774 | |||
| 775 | // Fall back to synchronous insert, see later in the code. | ||
| 776 | } else { | ||
| 777 | return REDISMODULE_OK; | ||
| 778 | } | ||
| 779 | } | ||
| 780 | |||
| 781 | /* Insert vector synchronously: we reach this place even | ||
| 782 | * if cas was true but thread creation failed. */ | ||
| 783 | int added = vectorSetInsert(vset,vec,NULL,0,val,attrib,1,ef); | ||
| 784 | RedisModule_Free(vec); | ||
| 785 | |||
| 786 | RedisModule_ReplyWithBool(ctx,added); | ||
| 787 | if (added) RedisModule_ReplicateVerbatim(ctx); | ||
| 788 | return REDISMODULE_OK; | ||
| 789 | } | ||
| 790 | |||
| 791 | /* HNSW callback to filter items according to a predicate function | ||
| 792 | * (our FILTER expression in this case). */ | ||
| 793 | int vectorSetFilterCallback(void *value, void *privdata) { | ||
| 794 | exprstate *expr = privdata; | ||
| 795 | struct vsetNodeVal *nv = value; | ||
| 796 | if (nv->attrib == NULL) return 0; // No attributes? No match. | ||
| 797 | size_t json_len; | ||
| 798 | char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len); | ||
| 799 | return exprRun(expr,json,json_len); | ||
| 800 | } | ||
| 801 | |||
| 802 | /* Common path for the execution of the VSIM command both threaded and | ||
| 803 | * not threaded. Note that 'ctx' may be normal context of a thread safe | ||
| 804 | * context obtained from a blocked client. The locking that is specific | ||
| 805 | * to the vset object is handled by the caller, however the function | ||
| 806 | * handles the HNSW locking explicitly. */ | ||
| 807 | void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, | ||
| 808 | float *vec, unsigned long count, float epsilon, unsigned long withscores, | ||
| 809 | unsigned long withattribs, unsigned long ef, exprstate *filter_expr, | ||
| 810 | unsigned long filter_ef, int ground_truth) | ||
| 811 | { | ||
| 812 | /* In our scan, we can't just collect 'count' elements as | ||
| 813 | * if count is small we would explore the graph in an insufficient | ||
| 814 | * way to provide enough recall. | ||
| 815 | * | ||
| 816 | * If the user didn't asked for a specific exploration, we use | ||
| 817 | * VSET_DEFAULT_SEARCH_EF as minimum, or we match count if count | ||
| 818 | * is greater than that. Otherwise the minumim will be the specified | ||
| 819 | * EF argument. */ | ||
| 820 | if (ef == 0) ef = VSET_DEFAULT_SEARCH_EF; | ||
| 821 | if (count > ef) ef = count; | ||
| 822 | |||
| 823 | int slot = hnsw_acquire_read_slot(vset->hnsw); | ||
| 824 | if (ef > vset->hnsw->node_count) ef = vset->hnsw->node_count; | ||
| 825 | |||
| 826 | /* Perform search */ | ||
| 827 | hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef); | ||
| 828 | float *distances = RedisModule_Alloc(sizeof(float)*ef); | ||
| 829 | unsigned int found; | ||
| 830 | if (ground_truth) { | ||
| 831 | found = hnsw_ground_truth_with_filter(vset->hnsw, vec, ef, neighbors, | ||
| 832 | distances, slot, 0, | ||
| 833 | filter_expr ? vectorSetFilterCallback : NULL, | ||
| 834 | filter_expr); | ||
| 835 | } else { | ||
| 836 | if (filter_expr == NULL) { | ||
| 837 | found = hnsw_search(vset->hnsw, vec, ef, neighbors, | ||
| 838 | distances, slot, 0); | ||
| 839 | } else { | ||
| 840 | found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, | ||
| 841 | distances, slot, 0, vectorSetFilterCallback, | ||
| 842 | filter_expr, filter_ef); | ||
| 843 | } | ||
| 844 | } | ||
| 845 | |||
| 846 | /* Return results */ | ||
| 847 | int resp3 = RedisModule_GetContextFlags(ctx) & REDISMODULE_CTX_FLAGS_RESP3; | ||
| 848 | int reply_with_map = resp3 && (withscores || withattribs); | ||
| 849 | |||
| 850 | if (reply_with_map) | ||
| 851 | RedisModule_ReplyWithMap(ctx, REDISMODULE_POSTPONED_LEN); | ||
| 852 | else | ||
| 853 | RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN); | ||
| 854 | |||
| 855 | long long arraylen = 0; | ||
| 856 | for (unsigned int i = 0; i < found && i < count; i++) { | ||
| 857 | if (distances[i]/2 > epsilon) break; | ||
| 858 | struct vsetNodeVal *nv = neighbors[i]->value; | ||
| 859 | RedisModule_ReplyWithString(ctx, nv->item); | ||
| 860 | arraylen++; | ||
| 861 | |||
| 862 | /* If the user asked for multiple properties at the same time using | ||
| 863 | * the RESP3 protocol, we wrap the value of the map into an N-items | ||
| 864 | * array. Two for now, since we have just two properties that can be | ||
| 865 | * requested. | ||
| 866 | * | ||
| 867 | * So in the case of RESP2 we will just have the flat reply: | ||
| 868 | * item, score, attribute. For RESP3 instead item -> [score, attribute] | ||
| 869 | */ | ||
| 870 | if (resp3 && withscores && withattribs) | ||
| 871 | RedisModule_ReplyWithArray(ctx,2); | ||
| 872 | |||
| 873 | if (withscores) { | ||
| 874 | /* The similarity score is provided in a 0-1 range. */ | ||
| 875 | RedisModule_ReplyWithDouble(ctx, 1.0 - distances[i]/2.0); | ||
| 876 | } | ||
| 877 | if (withattribs) { | ||
| 878 | /* Return the attributes as well, if any. */ | ||
| 879 | if (nv->attrib) | ||
| 880 | RedisModule_ReplyWithString(ctx, nv->attrib); | ||
| 881 | else | ||
| 882 | RedisModule_ReplyWithNull(ctx); | ||
| 883 | } | ||
| 884 | } | ||
| 885 | hnsw_release_read_slot(vset->hnsw,slot); | ||
| 886 | |||
| 887 | if (reply_with_map) { | ||
| 888 | RedisModule_ReplySetMapLength(ctx, arraylen); | ||
| 889 | } else { | ||
| 890 | int items_per_ele = 1+withattribs+withscores; | ||
| 891 | RedisModule_ReplySetArrayLength(ctx, arraylen * items_per_ele); | ||
| 892 | } | ||
| 893 | |||
| 894 | RedisModule_Free(vec); | ||
| 895 | RedisModule_Free(neighbors); | ||
| 896 | RedisModule_Free(distances); | ||
| 897 | if (filter_expr) exprFree(filter_expr); | ||
| 898 | } | ||
| 899 | |||
| 900 | /* VSIM thread handling the blocked client request. */ | ||
| 901 | void *VSIM_thread(void *arg) { | ||
| 902 | pthread_detach(pthread_self()); | ||
| 903 | |||
| 904 | // Extract arguments. | ||
| 905 | void **targ = (void**)arg; | ||
| 906 | RedisModuleBlockedClient *bc = targ[0]; | ||
| 907 | struct vsetObject *vset = targ[1]; | ||
| 908 | float *vec = targ[2]; | ||
| 909 | unsigned long count = (unsigned long)targ[3]; | ||
| 910 | float epsilon = *((float*)targ[4]); | ||
| 911 | unsigned long withscores = (unsigned long)targ[5]; | ||
| 912 | unsigned long withattribs = (unsigned long)targ[6]; | ||
| 913 | unsigned long ef = (unsigned long)targ[7]; | ||
| 914 | exprstate *filter_expr = targ[8]; | ||
| 915 | unsigned long filter_ef = (unsigned long)targ[9]; | ||
| 916 | unsigned long ground_truth = (unsigned long)targ[10]; | ||
| 917 | RedisModule_Free(targ[4]); | ||
| 918 | RedisModule_Free(targ); | ||
| 919 | |||
| 920 | /* Lock the object and signal that we are no longer pending | ||
| 921 | * the lock acquisition. */ | ||
| 922 | RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0); | ||
| 923 | vset->thread_creation_pending--; | ||
| 924 | |||
| 925 | // Accumulate reply in a thread safe context: no contention. | ||
| 926 | RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc); | ||
| 927 | |||
| 928 | // Run the query. | ||
| 929 | VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth); | ||
| 930 | pthread_rwlock_unlock(&vset->in_use_lock); | ||
| 931 | |||
| 932 | // Cleanup. | ||
| 933 | RedisModule_FreeThreadSafeContext(ctx); | ||
| 934 | RedisModule_BlockedClientMeasureTimeEnd(bc); | ||
| 935 | RedisModule_UnblockClient(bc,NULL); | ||
| 936 | return NULL; | ||
| 937 | } | ||
| 938 | |||
| 939 | /* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */ | ||
| 940 | int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 941 | RedisModule_AutoMemory(ctx); | ||
| 942 | |||
| 943 | /* Basic argument check: need at least key and vector specification | ||
| 944 | * method. */ | ||
| 945 | if (argc < 4) return RedisModule_WrongArity(ctx); | ||
| 946 | |||
| 947 | /* Defaults */ | ||
| 948 | int withscores = 0; | ||
| 949 | int withattribs = 0; | ||
| 950 | long long count = VSET_DEFAULT_COUNT; /* New default value */ | ||
| 951 | long long ef = 0; /* Exploration factor (see HNSW paper) */ | ||
| 952 | double epsilon = 2.0; /* Max cosine distance */ | ||
| 953 | long long ground_truth = 0; /* Linear scan instead of HNSW search? */ | ||
| 954 | int no_thread = 0; /* NOTHREAD option: exec on main thread. */ | ||
| 955 | |||
| 956 | /* Things computed later. */ | ||
| 957 | long long filter_ef = 0; | ||
| 958 | exprstate *filter_expr = NULL; | ||
| 959 | |||
| 960 | /* Get key and vector type */ | ||
| 961 | RedisModuleString *key = argv[1]; | ||
| 962 | const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL); | ||
| 963 | |||
| 964 | /* Get vector set */ | ||
| 965 | RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); | ||
| 966 | int type = RedisModule_KeyType(keyptr); | ||
| 967 | if (type == REDISMODULE_KEYTYPE_EMPTY) | ||
| 968 | return RedisModule_ReplyWithEmptyArray(ctx); | ||
| 969 | |||
| 970 | if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) | ||
| 971 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 972 | |||
| 973 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); | ||
| 974 | |||
| 975 | /* Vector parsing stage */ | ||
| 976 | float *vec = NULL; | ||
| 977 | size_t dim = 0; | ||
| 978 | int vector_args = 0; /* Number of args consumed by vector specification */ | ||
| 979 | |||
| 980 | if (!strcasecmp(vectorType, "ELE")) { | ||
| 981 | /* Get vector from existing element */ | ||
| 982 | RedisModuleString *ele = argv[3]; | ||
| 983 | hnswNode *node = RedisModule_DictGet(vset->dict, ele, NULL); | ||
| 984 | if (!node) { | ||
| 985 | return RedisModule_ReplyWithError(ctx, "ERR element not found in set"); | ||
| 986 | } | ||
| 987 | vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim); | ||
| 988 | hnsw_get_node_vector(vset->hnsw,node,vec); | ||
| 989 | dim = vset->hnsw->vector_dim; | ||
| 990 | vector_args = 2; /* ELE + element name */ | ||
| 991 | } else { | ||
| 992 | /* Parse vector. */ | ||
| 993 | int consumed_args; | ||
| 994 | |||
| 995 | vec = parseVector(argv, argc, 2, &dim, NULL, &consumed_args); | ||
| 996 | if (!vec) { | ||
| 997 | return RedisModule_ReplyWithError(ctx, | ||
| 998 | "ERR invalid vector specification"); | ||
| 999 | } | ||
| 1000 | vector_args = consumed_args; | ||
| 1001 | |||
| 1002 | /* Apply projection if the set uses it, with the exception | ||
| 1003 | * of ELE type, that will already have the right dimension. */ | ||
| 1004 | if (vset->proj_matrix && dim != vset->hnsw->vector_dim) { | ||
| 1005 | /* Ensure input dimension matches the projection matrix's expected input dimension */ | ||
| 1006 | if (dim != vset->proj_input_size) { | ||
| 1007 | RedisModule_Free(vec); | ||
| 1008 | return RedisModule_ReplyWithErrorFormat(ctx, | ||
| 1009 | "ERR Input dimension mismatch for projection - got %d but projection expects %d", | ||
| 1010 | (int)dim, (int)vset->proj_input_size); | ||
| 1011 | } | ||
| 1012 | |||
| 1013 | float *projected = applyProjection(vec, vset->proj_matrix, | ||
| 1014 | vset->proj_input_size, | ||
| 1015 | vset->hnsw->vector_dim); | ||
| 1016 | RedisModule_Free(vec); | ||
| 1017 | vec = projected; | ||
| 1018 | dim = vset->hnsw->vector_dim; | ||
| 1019 | } | ||
| 1020 | |||
| 1021 | /* Count consumed arguments */ | ||
| 1022 | if (!strcasecmp(vectorType, "FP32")) { | ||
| 1023 | vector_args = 2; /* FP32 + vector blob */ | ||
| 1024 | } else if (!strcasecmp(vectorType, "VALUES")) { | ||
| 1025 | long long vdim; | ||
| 1026 | if (RedisModule_StringToLongLong(argv[3], &vdim) != REDISMODULE_OK) { | ||
| 1027 | RedisModule_Free(vec); | ||
| 1028 | return RedisModule_ReplyWithError(ctx, "ERR invalid vector dimension"); | ||
| 1029 | } | ||
| 1030 | vector_args = 2 + vdim; /* VALUES + dim + values */ | ||
| 1031 | } else { | ||
| 1032 | RedisModule_Free(vec); | ||
| 1033 | return RedisModule_ReplyWithError(ctx, | ||
| 1034 | "ERR vector type must be ELE, FP32 or VALUES"); | ||
| 1035 | } | ||
| 1036 | } | ||
| 1037 | |||
| 1038 | /* Check vector dimension matches set */ | ||
| 1039 | if (dim != vset->hnsw->vector_dim) { | ||
| 1040 | RedisModule_Free(vec); | ||
| 1041 | return RedisModule_ReplyWithErrorFormat(ctx, | ||
| 1042 | "ERR Vector dimension mismatch - got %d but set has %d", | ||
| 1043 | (int)dim, (int)vset->hnsw->vector_dim); | ||
| 1044 | } | ||
| 1045 | |||
| 1046 | /* Parse optional arguments - start after vector specification */ | ||
| 1047 | int j = 2 + vector_args; | ||
| 1048 | while (j < argc) { | ||
| 1049 | const char *opt = RedisModule_StringPtrLen(argv[j], NULL); | ||
| 1050 | if (!strcasecmp(opt, "WITHSCORES")) { | ||
| 1051 | withscores = 1; | ||
| 1052 | j++; | ||
| 1053 | } else if (!strcasecmp(opt, "WITHATTRIBS")) { | ||
| 1054 | withattribs = 1; | ||
| 1055 | j++; | ||
| 1056 | } else if (!strcasecmp(opt, "TRUTH")) { | ||
| 1057 | ground_truth = 1; | ||
| 1058 | j++; | ||
| 1059 | } else if (!strcasecmp(opt, "NOTHREAD")) { | ||
| 1060 | no_thread = 1; | ||
| 1061 | j++; | ||
| 1062 | } else if (!strcasecmp(opt, "COUNT") && j+1 < argc) { | ||
| 1063 | if (RedisModule_StringToLongLong(argv[j+1], &count) | ||
| 1064 | != REDISMODULE_OK || count <= 0) | ||
| 1065 | { | ||
| 1066 | RedisModule_Free(vec); | ||
| 1067 | return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT"); | ||
| 1068 | } | ||
| 1069 | j += 2; | ||
| 1070 | } else if (!strcasecmp(opt, "EPSILON") && j+1 < argc) { | ||
| 1071 | if (RedisModule_StringToDouble(argv[j+1], &epsilon) != | ||
| 1072 | REDISMODULE_OK || epsilon <= 0) | ||
| 1073 | { | ||
| 1074 | RedisModule_Free(vec); | ||
| 1075 | return RedisModule_ReplyWithError(ctx, "ERR invalid EPSILON"); | ||
| 1076 | } | ||
| 1077 | j += 2; | ||
| 1078 | } else if (!strcasecmp(opt, "EF") && j+1 < argc) { | ||
| 1079 | if (RedisModule_StringToLongLong(argv[j+1], &ef) != | ||
| 1080 | REDISMODULE_OK || ef <= 0 || ef > 1000000) | ||
| 1081 | { | ||
| 1082 | RedisModule_Free(vec); | ||
| 1083 | return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); | ||
| 1084 | } | ||
| 1085 | j += 2; | ||
| 1086 | } else if (!strcasecmp(opt, "FILTER-EF") && j+1 < argc) { | ||
| 1087 | if (RedisModule_StringToLongLong(argv[j+1], &filter_ef) != | ||
| 1088 | REDISMODULE_OK || filter_ef <= 0) | ||
| 1089 | { | ||
| 1090 | RedisModule_Free(vec); | ||
| 1091 | return RedisModule_ReplyWithError(ctx, "ERR invalid FILTER-EF"); | ||
| 1092 | } | ||
| 1093 | j += 2; | ||
| 1094 | } else if (!strcasecmp(opt, "FILTER") && j+1 < argc) { | ||
| 1095 | RedisModuleString *exprarg = argv[j+1]; | ||
| 1096 | size_t exprlen; | ||
| 1097 | char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen); | ||
| 1098 | int errpos; | ||
| 1099 | filter_expr = exprCompile(exprstr,&errpos); | ||
| 1100 | if (filter_expr == NULL) { | ||
| 1101 | if ((size_t)errpos >= exprlen) errpos = 0; | ||
| 1102 | RedisModule_Free(vec); | ||
| 1103 | return RedisModule_ReplyWithErrorFormat(ctx, | ||
| 1104 | "ERR syntax error in FILTER expression near: %s", | ||
| 1105 | exprstr+errpos); | ||
| 1106 | } | ||
| 1107 | j += 2; | ||
| 1108 | } else { | ||
| 1109 | RedisModule_Free(vec); | ||
| 1110 | return RedisModule_ReplyWithError(ctx, | ||
| 1111 | "ERR syntax error in VSIM command"); | ||
| 1112 | } | ||
| 1113 | } | ||
| 1114 | |||
| 1115 | int threaded_request = 1; // Run on a thread, by default. | ||
| 1116 | if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes. | ||
| 1117 | |||
| 1118 | /* Disable threaded for MULTI/EXEC and Lua, or if explicitly | ||
| 1119 | * requested by the user via the NOTHREAD option. */ | ||
| 1120 | if (no_thread || VSGlobalConfig.forceSingleThreadExec || | ||
| 1121 | (RedisModule_GetContextFlags(ctx) & | ||
| 1122 | (REDISMODULE_CTX_FLAGS_LUA | REDISMODULE_CTX_FLAGS_MULTI))) | ||
| 1123 | { | ||
| 1124 | threaded_request = 0; | ||
| 1125 | } | ||
| 1126 | |||
| 1127 | if (threaded_request) { | ||
| 1128 | /* Note: even if we create one thread per request, the underlying | ||
| 1129 | * HNSW library has a fixed number of slots for the threads, as it's | ||
| 1130 | * defined in HNSW_MAX_THREADS (beware that if you increase it, | ||
| 1131 | * every node will use more memory). This means that while this request | ||
| 1132 | * is threaded, and will NOT block Redis, it may end waiting for a | ||
| 1133 | * free slot if all the HNSW_MAX_THREADS slots are used. */ | ||
| 1134 | RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0); | ||
| 1135 | pthread_t tid; | ||
| 1136 | void **targ = RedisModule_Alloc(sizeof(void*)*11); | ||
| 1137 | targ[0] = bc; | ||
| 1138 | targ[1] = vset; | ||
| 1139 | targ[2] = vec; | ||
| 1140 | targ[3] = (void*)count; | ||
| 1141 | targ[4] = RedisModule_Alloc(sizeof(float)); | ||
| 1142 | *((float*)targ[4]) = epsilon; | ||
| 1143 | targ[5] = (void*)(unsigned long)withscores; | ||
| 1144 | targ[6] = (void*)(unsigned long)withattribs; | ||
| 1145 | targ[7] = (void*)(unsigned long)ef; | ||
| 1146 | targ[8] = (void*)filter_expr; | ||
| 1147 | targ[9] = (void*)(unsigned long)filter_ef; | ||
| 1148 | targ[10] = (void*)(unsigned long)ground_truth; | ||
| 1149 | RedisModule_BlockedClientMeasureTimeStart(bc); | ||
| 1150 | vset->thread_creation_pending++; | ||
| 1151 | if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) { | ||
| 1152 | vset->thread_creation_pending--; | ||
| 1153 | RedisModule_AbortBlock(bc); | ||
| 1154 | RedisModule_Free(targ[4]); | ||
| 1155 | RedisModule_Free(targ); | ||
| 1156 | VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth); | ||
| 1157 | } | ||
| 1158 | } else { | ||
| 1159 | VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth); | ||
| 1160 | } | ||
| 1161 | |||
| 1162 | return REDISMODULE_OK; | ||
| 1163 | } | ||
| 1164 | |||
| 1165 | /* VDIM <key>: return the dimension of vectors in the vector set. */ | ||
| 1166 | int VDIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1167 | RedisModule_AutoMemory(ctx); | ||
| 1168 | |||
| 1169 | if (argc != 2) return RedisModule_WrongArity(ctx); | ||
| 1170 | |||
| 1171 | RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); | ||
| 1172 | int type = RedisModule_KeyType(key); | ||
| 1173 | |||
| 1174 | if (type == REDISMODULE_KEYTYPE_EMPTY) | ||
| 1175 | return RedisModule_ReplyWithError(ctx, "ERR key does not exist"); | ||
| 1176 | |||
| 1177 | if (RedisModule_ModuleTypeGetType(key) != VectorSetType) | ||
| 1178 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1179 | |||
| 1180 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); | ||
| 1181 | return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); | ||
| 1182 | } | ||
| 1183 | |||
| 1184 | /* VCARD <key>: return cardinality (num of elements) of the vector set. */ | ||
| 1185 | int VCARD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1186 | RedisModule_AutoMemory(ctx); | ||
| 1187 | |||
| 1188 | if (argc != 2) return RedisModule_WrongArity(ctx); | ||
| 1189 | |||
| 1190 | RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); | ||
| 1191 | int type = RedisModule_KeyType(key); | ||
| 1192 | |||
| 1193 | if (type == REDISMODULE_KEYTYPE_EMPTY) | ||
| 1194 | return RedisModule_ReplyWithLongLong(ctx, 0); | ||
| 1195 | |||
| 1196 | if (RedisModule_ModuleTypeGetType(key) != VectorSetType) | ||
| 1197 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1198 | |||
| 1199 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); | ||
| 1200 | return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count); | ||
| 1201 | } | ||
| 1202 | |||
| 1203 | /* VREM key element | ||
| 1204 | * Remove an element from a vector set. | ||
| 1205 | * Returns 1 if the element was found and removed, 0 if not found. */ | ||
| 1206 | int VREM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1207 | RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ | ||
| 1208 | |||
| 1209 | if (argc != 3) return RedisModule_WrongArity(ctx); | ||
| 1210 | |||
| 1211 | /* Get key and value */ | ||
| 1212 | RedisModuleString *key = argv[1]; | ||
| 1213 | RedisModuleString *element = argv[2]; | ||
| 1214 | |||
| 1215 | /* Open key */ | ||
| 1216 | RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, | ||
| 1217 | REDISMODULE_READ|REDISMODULE_WRITE); | ||
| 1218 | int type = RedisModule_KeyType(keyptr); | ||
| 1219 | |||
| 1220 | /* Handle non-existing key or wrong type */ | ||
| 1221 | if (type == REDISMODULE_KEYTYPE_EMPTY) { | ||
| 1222 | return RedisModule_ReplyWithBool(ctx, 0); | ||
| 1223 | } | ||
| 1224 | if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { | ||
| 1225 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1226 | } | ||
| 1227 | |||
| 1228 | /* Get vector set from key */ | ||
| 1229 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); | ||
| 1230 | |||
| 1231 | /* Find the node for this element */ | ||
| 1232 | hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); | ||
| 1233 | if (!node) { | ||
| 1234 | return RedisModule_ReplyWithBool(ctx, 0); | ||
| 1235 | } | ||
| 1236 | |||
| 1237 | /* Remove from dictionary */ | ||
| 1238 | RedisModule_DictDel(vset->dict, element, NULL); | ||
| 1239 | |||
| 1240 | /* Remove from HNSW graph using the high-level API that handles | ||
| 1241 | * locking and cleanup. We pass RedisModule_FreeString as the value | ||
| 1242 | * free function since the strings were retained at insertion time. */ | ||
| 1243 | struct vsetNodeVal *nv = node->value; | ||
| 1244 | if (nv->attrib != NULL) vset->numattribs--; | ||
| 1245 | RedisModule_Assert(hnsw_delete_node(vset->hnsw, node, vectorSetReleaseNodeValue) == 1); | ||
| 1246 | |||
| 1247 | /* Destroy empty vector set. */ | ||
| 1248 | if (RedisModule_DictSize(vset->dict) == 0) { | ||
| 1249 | RedisModule_DeleteKey(keyptr); | ||
| 1250 | } | ||
| 1251 | |||
| 1252 | /* Reply and propagate the command */ | ||
| 1253 | RedisModule_ReplyWithBool(ctx, 1); | ||
| 1254 | RedisModule_ReplicateVerbatim(ctx); | ||
| 1255 | return REDISMODULE_OK; | ||
| 1256 | } | ||
| 1257 | |||
| 1258 | /* VEMB key element | ||
| 1259 | * Returns the embedding vector associated with an element, or NIL if not | ||
| 1260 | * found. The vector is returned in the same format it was added, but the | ||
| 1261 | * return value will have some lack of precision due to quantization and | ||
| 1262 | * normalization of vectors. Also, if items were added using REDUCE, the | ||
| 1263 | * reduced vector is returned instead. */ | ||
| 1264 | int VEMB_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1265 | RedisModule_AutoMemory(ctx); | ||
| 1266 | int raw_output = 0; // RAW option. | ||
| 1267 | |||
| 1268 | if (argc < 3) return RedisModule_WrongArity(ctx); | ||
| 1269 | |||
| 1270 | /* Parse arguments. */ | ||
| 1271 | for (int j = 3; j < argc; j++) { | ||
| 1272 | const char *opt = RedisModule_StringPtrLen(argv[j], NULL); | ||
| 1273 | if (!strcasecmp(opt,"raw")) { | ||
| 1274 | raw_output = 1; | ||
| 1275 | } else { | ||
| 1276 | return RedisModule_ReplyWithError(ctx,"ERR invalid option"); | ||
| 1277 | } | ||
| 1278 | } | ||
| 1279 | |||
| 1280 | /* Get key and element. */ | ||
| 1281 | RedisModuleString *key = argv[1]; | ||
| 1282 | RedisModuleString *element = argv[2]; | ||
| 1283 | |||
| 1284 | /* Open key. */ | ||
| 1285 | RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); | ||
| 1286 | int type = RedisModule_KeyType(keyptr); | ||
| 1287 | |||
| 1288 | /* Handle non-existing key and key of wrong type. */ | ||
| 1289 | if (type == REDISMODULE_KEYTYPE_EMPTY) { | ||
| 1290 | return RedisModule_ReplyWithNull(ctx); | ||
| 1291 | } else if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { | ||
| 1292 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1293 | } | ||
| 1294 | |||
| 1295 | /* Lookup the node about the specified element. */ | ||
| 1296 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); | ||
| 1297 | hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); | ||
| 1298 | if (!node) { | ||
| 1299 | return RedisModule_ReplyWithNull(ctx); | ||
| 1300 | } | ||
| 1301 | |||
| 1302 | if (raw_output) { | ||
| 1303 | int output_qrange = vset->hnsw->quant_type == HNSW_QUANT_Q8; | ||
| 1304 | RedisModule_ReplyWithArray(ctx, 3+output_qrange); | ||
| 1305 | RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset)); | ||
| 1306 | RedisModule_ReplyWithStringBuffer(ctx, node->vector, hnsw_quants_bytes(vset->hnsw)); | ||
| 1307 | RedisModule_ReplyWithDouble(ctx, node->l2); | ||
| 1308 | if (output_qrange) RedisModule_ReplyWithDouble(ctx, node->quants_range); | ||
| 1309 | } else { | ||
| 1310 | /* Get the vector associated with the node. */ | ||
| 1311 | float *vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim); | ||
| 1312 | hnsw_get_node_vector(vset->hnsw, node, vec); // May dequantize/denorm. | ||
| 1313 | |||
| 1314 | /* Return as array of doubles. */ | ||
| 1315 | RedisModule_ReplyWithArray(ctx, vset->hnsw->vector_dim); | ||
| 1316 | for (uint32_t i = 0; i < vset->hnsw->vector_dim; i++) | ||
| 1317 | RedisModule_ReplyWithDouble(ctx, vec[i]); | ||
| 1318 | RedisModule_Free(vec); | ||
| 1319 | } | ||
| 1320 | return REDISMODULE_OK; | ||
| 1321 | } | ||
| 1322 | |||
| 1323 | /* VSETATTR key element json | ||
| 1324 | * Set or remove the JSON attribute associated with an element. | ||
| 1325 | * Setting an empty string removes the attribute. | ||
| 1326 | * The command returns one if the attribute was actually updated or | ||
| 1327 | * zero if there is no key or element. */ | ||
| 1328 | int VSETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1329 | RedisModule_AutoMemory(ctx); | ||
| 1330 | |||
| 1331 | if (argc != 4) return RedisModule_WrongArity(ctx); | ||
| 1332 | |||
| 1333 | RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], | ||
| 1334 | REDISMODULE_READ|REDISMODULE_WRITE); | ||
| 1335 | int type = RedisModule_KeyType(key); | ||
| 1336 | |||
| 1337 | if (type == REDISMODULE_KEYTYPE_EMPTY) | ||
| 1338 | return RedisModule_ReplyWithBool(ctx, 0); | ||
| 1339 | |||
| 1340 | if (RedisModule_ModuleTypeGetType(key) != VectorSetType) | ||
| 1341 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1342 | |||
| 1343 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); | ||
| 1344 | hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL); | ||
| 1345 | if (!node) | ||
| 1346 | return RedisModule_ReplyWithBool(ctx, 0); | ||
| 1347 | |||
| 1348 | struct vsetNodeVal *nv = node->value; | ||
| 1349 | RedisModuleString *new_attr = argv[3]; | ||
| 1350 | |||
| 1351 | /* Background VSIM operations use the node attributes, so | ||
| 1352 | * wait for background operations before messing with them. */ | ||
| 1353 | vectorSetWaitAllBackgroundClients(vset,0); | ||
| 1354 | |||
| 1355 | /* Set or delete the attribute based on the fact it's an empty | ||
| 1356 | * string or not. */ | ||
| 1357 | size_t attrlen; | ||
| 1358 | RedisModule_StringPtrLen(new_attr, &attrlen); | ||
| 1359 | if (attrlen == 0) { | ||
| 1360 | // If we had an attribute before, decrease the count and free it. | ||
| 1361 | if (nv->attrib) { | ||
| 1362 | vset->numattribs--; | ||
| 1363 | RedisModule_FreeString(NULL, nv->attrib); | ||
| 1364 | nv->attrib = NULL; | ||
| 1365 | } | ||
| 1366 | } else { | ||
| 1367 | // If we didn't have an attribute before, increase the count. | ||
| 1368 | // Otherwise free the old one. | ||
| 1369 | if (nv->attrib) { | ||
| 1370 | RedisModule_FreeString(NULL, nv->attrib); | ||
| 1371 | } else { | ||
| 1372 | vset->numattribs++; | ||
| 1373 | } | ||
| 1374 | // Set new attribute. | ||
| 1375 | RedisModule_RetainString(NULL, new_attr); | ||
| 1376 | nv->attrib = new_attr; | ||
| 1377 | } | ||
| 1378 | |||
| 1379 | RedisModule_ReplyWithBool(ctx, 1); | ||
| 1380 | RedisModule_ReplicateVerbatim(ctx); | ||
| 1381 | return REDISMODULE_OK; | ||
| 1382 | } | ||
| 1383 | |||
| 1384 | /* VGETATTR key element | ||
| 1385 | * Get the JSON attribute associated with an element. | ||
| 1386 | * Returns NIL if the element has no attribute or doesn't exist. */ | ||
| 1387 | int VGETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1388 | RedisModule_AutoMemory(ctx); | ||
| 1389 | |||
| 1390 | if (argc != 3) return RedisModule_WrongArity(ctx); | ||
| 1391 | |||
| 1392 | RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); | ||
| 1393 | int type = RedisModule_KeyType(key); | ||
| 1394 | |||
| 1395 | if (type == REDISMODULE_KEYTYPE_EMPTY) | ||
| 1396 | return RedisModule_ReplyWithNull(ctx); | ||
| 1397 | |||
| 1398 | if (RedisModule_ModuleTypeGetType(key) != VectorSetType) | ||
| 1399 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1400 | |||
| 1401 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); | ||
| 1402 | hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL); | ||
| 1403 | if (!node) | ||
| 1404 | return RedisModule_ReplyWithNull(ctx); | ||
| 1405 | |||
| 1406 | struct vsetNodeVal *nv = node->value; | ||
| 1407 | if (!nv->attrib) | ||
| 1408 | return RedisModule_ReplyWithNull(ctx); | ||
| 1409 | |||
| 1410 | return RedisModule_ReplyWithString(ctx, nv->attrib); | ||
| 1411 | } | ||
| 1412 | |||
| 1413 | /* ============================== Reflection ================================ */ | ||
| 1414 | |||
| 1415 | /* VLINKS key element [WITHSCORES] | ||
| 1416 | * Returns the neighbors of an element at each layer in the HNSW graph. | ||
| 1417 | * Reply is an array of arrays, where each nested array represents one level | ||
| 1418 | * of neighbors, from highest level to level 0. If WITHSCORES is specified, | ||
| 1419 | * each neighbor is followed by its distance from the element. */ | ||
| 1420 | int VLINKS_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1421 | RedisModule_AutoMemory(ctx); | ||
| 1422 | |||
| 1423 | if (argc < 3 || argc > 4) return RedisModule_WrongArity(ctx); | ||
| 1424 | |||
| 1425 | RedisModuleString *key = argv[1]; | ||
| 1426 | RedisModuleString *element = argv[2]; | ||
| 1427 | |||
| 1428 | /* Parse WITHSCORES option. */ | ||
| 1429 | int withscores = 0; | ||
| 1430 | if (argc == 4) { | ||
| 1431 | const char *opt = RedisModule_StringPtrLen(argv[3], NULL); | ||
| 1432 | if (strcasecmp(opt, "WITHSCORES") != 0) { | ||
| 1433 | return RedisModule_WrongArity(ctx); | ||
| 1434 | } | ||
| 1435 | withscores = 1; | ||
| 1436 | } | ||
| 1437 | |||
| 1438 | RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); | ||
| 1439 | int type = RedisModule_KeyType(keyptr); | ||
| 1440 | |||
| 1441 | /* Handle non-existing key or wrong type. */ | ||
| 1442 | if (type == REDISMODULE_KEYTYPE_EMPTY) | ||
| 1443 | return RedisModule_ReplyWithNull(ctx); | ||
| 1444 | |||
| 1445 | if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) | ||
| 1446 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1447 | |||
| 1448 | /* Find the node for this element. */ | ||
| 1449 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); | ||
| 1450 | hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); | ||
| 1451 | if (!node) | ||
| 1452 | return RedisModule_ReplyWithNull(ctx); | ||
| 1453 | |||
| 1454 | /* Reply with array of arrays, one per level. */ | ||
| 1455 | RedisModule_ReplyWithArray(ctx, node->level + 1); | ||
| 1456 | |||
| 1457 | /* For each level, from highest to lowest: */ | ||
| 1458 | for (int i = node->level; i >= 0; i--) { | ||
| 1459 | /* Reply with array of neighbors at this level. */ | ||
| 1460 | if (withscores) | ||
| 1461 | RedisModule_ReplyWithMap(ctx,node->layers[i].num_links); | ||
| 1462 | else | ||
| 1463 | RedisModule_ReplyWithArray(ctx,node->layers[i].num_links); | ||
| 1464 | |||
| 1465 | /* Add each neighbor's element value to the array. */ | ||
| 1466 | for (uint32_t j = 0; j < node->layers[i].num_links; j++) { | ||
| 1467 | struct vsetNodeVal *nv = node->layers[i].links[j]->value; | ||
| 1468 | RedisModule_ReplyWithString(ctx, nv->item); | ||
| 1469 | if (withscores) { | ||
| 1470 | float distance = hnsw_distance(vset->hnsw, node, node->layers[i].links[j]); | ||
| 1471 | /* Convert distance to similarity score to match | ||
| 1472 | * VSIM behavior.*/ | ||
| 1473 | float similarity = 1.0 - distance/2.0; | ||
| 1474 | RedisModule_ReplyWithDouble(ctx, similarity); | ||
| 1475 | } | ||
| 1476 | } | ||
| 1477 | } | ||
| 1478 | return REDISMODULE_OK; | ||
| 1479 | } | ||
| 1480 | |||
| 1481 | /* VINFO key | ||
| 1482 | * Returns information about a vector set, both visible and hidden | ||
| 1483 | * features of the HNSW data structure. */ | ||
| 1484 | int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1485 | RedisModule_AutoMemory(ctx); | ||
| 1486 | |||
| 1487 | if (argc != 2) return RedisModule_WrongArity(ctx); | ||
| 1488 | |||
| 1489 | RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); | ||
| 1490 | int type = RedisModule_KeyType(key); | ||
| 1491 | |||
| 1492 | if (type == REDISMODULE_KEYTYPE_EMPTY) | ||
| 1493 | return RedisModule_ReplyWithNullArray(ctx); | ||
| 1494 | |||
| 1495 | if (RedisModule_ModuleTypeGetType(key) != VectorSetType) | ||
| 1496 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1497 | |||
| 1498 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); | ||
| 1499 | |||
| 1500 | /* Reply with hash */ | ||
| 1501 | RedisModule_ReplyWithMap(ctx, 9); | ||
| 1502 | |||
| 1503 | /* Quantization type */ | ||
| 1504 | RedisModule_ReplyWithSimpleString(ctx, "quant-type"); | ||
| 1505 | RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset)); | ||
| 1506 | |||
| 1507 | /* HNSW M value */ | ||
| 1508 | RedisModule_ReplyWithSimpleString(ctx, "hnsw-m"); | ||
| 1509 | RedisModule_ReplyWithLongLong(ctx, vset->hnsw->M); | ||
| 1510 | |||
| 1511 | /* Vector dimensionality. */ | ||
| 1512 | RedisModule_ReplyWithSimpleString(ctx, "vector-dim"); | ||
| 1513 | RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); | ||
| 1514 | |||
| 1515 | /* Original input dimension before projection. | ||
| 1516 | * This is zero for vector sets without a random projection matrix. */ | ||
| 1517 | RedisModule_ReplyWithSimpleString(ctx, "projection-input-dim"); | ||
| 1518 | RedisModule_ReplyWithLongLong(ctx, vset->proj_input_size); | ||
| 1519 | |||
| 1520 | /* Number of elements. */ | ||
| 1521 | RedisModule_ReplyWithSimpleString(ctx, "size"); | ||
| 1522 | RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count); | ||
| 1523 | |||
| 1524 | /* Max level of HNSW. */ | ||
| 1525 | RedisModule_ReplyWithSimpleString(ctx, "max-level"); | ||
| 1526 | RedisModule_ReplyWithLongLong(ctx, vset->hnsw->max_level); | ||
| 1527 | |||
| 1528 | /* Number of nodes with attributes. */ | ||
| 1529 | RedisModule_ReplyWithSimpleString(ctx, "attributes-count"); | ||
| 1530 | RedisModule_ReplyWithLongLong(ctx, vset->numattribs); | ||
| 1531 | |||
| 1532 | /* Vector set ID. */ | ||
| 1533 | RedisModule_ReplyWithSimpleString(ctx, "vset-uid"); | ||
| 1534 | RedisModule_ReplyWithLongLong(ctx, vset->id); | ||
| 1535 | |||
| 1536 | /* HNSW max node ID. */ | ||
| 1537 | RedisModule_ReplyWithSimpleString(ctx, "hnsw-max-node-uid"); | ||
| 1538 | RedisModule_ReplyWithLongLong(ctx, vset->hnsw->last_id); | ||
| 1539 | |||
| 1540 | return REDISMODULE_OK; | ||
| 1541 | } | ||
| 1542 | |||
| 1543 | /* VRANDMEMBER key [count] | ||
| 1544 | * Return random members from a vector set. | ||
| 1545 | * | ||
| 1546 | * Without count: returns a single random member. | ||
| 1547 | * With positive count: N unique random members (no duplicates). | ||
| 1548 | * With negative count: N random members (with possible duplicates). | ||
| 1549 | * | ||
| 1550 | * If the key doesn't exist, returns NULL if count is not given, or | ||
| 1551 | * an empty array if a count was given. */ | ||
| 1552 | int VRANDMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1553 | RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ | ||
| 1554 | |||
| 1555 | /* Check arguments. */ | ||
| 1556 | if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx); | ||
| 1557 | |||
| 1558 | /* Parse optional count argument. */ | ||
| 1559 | long long count = 1; /* Default is to return a single element. */ | ||
| 1560 | int with_count = (argc == 3); | ||
| 1561 | |||
| 1562 | if (with_count) { | ||
| 1563 | if (RedisModule_StringToLongLong(argv[2], &count) != REDISMODULE_OK) { | ||
| 1564 | return RedisModule_ReplyWithError(ctx, | ||
| 1565 | "ERR COUNT value is not an integer"); | ||
| 1566 | } | ||
| 1567 | /* Count = 0 is a special case, return empty array */ | ||
| 1568 | if (count == 0) { | ||
| 1569 | return RedisModule_ReplyWithEmptyArray(ctx); | ||
| 1570 | } | ||
| 1571 | } | ||
| 1572 | |||
| 1573 | /* Open key. */ | ||
| 1574 | RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); | ||
| 1575 | int type = RedisModule_KeyType(key); | ||
| 1576 | |||
| 1577 | /* Handle non-existing key. */ | ||
| 1578 | if (type == REDISMODULE_KEYTYPE_EMPTY) { | ||
| 1579 | if (!with_count) { | ||
| 1580 | return RedisModule_ReplyWithNull(ctx); | ||
| 1581 | } else { | ||
| 1582 | return RedisModule_ReplyWithEmptyArray(ctx); | ||
| 1583 | } | ||
| 1584 | } | ||
| 1585 | |||
| 1586 | /* Check key type. */ | ||
| 1587 | if (RedisModule_ModuleTypeGetType(key) != VectorSetType) { | ||
| 1588 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1589 | } | ||
| 1590 | |||
| 1591 | /* Get vector set from key. */ | ||
| 1592 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); | ||
| 1593 | uint64_t set_size = vset->hnsw->node_count; | ||
| 1594 | |||
| 1595 | /* No elements in the set? */ | ||
| 1596 | if (set_size == 0) { | ||
| 1597 | if (!with_count) { | ||
| 1598 | return RedisModule_ReplyWithNull(ctx); | ||
| 1599 | } else { | ||
| 1600 | return RedisModule_ReplyWithEmptyArray(ctx); | ||
| 1601 | } | ||
| 1602 | } | ||
| 1603 | |||
| 1604 | /* Case 1: No count specified: return a single element. */ | ||
| 1605 | if (!with_count) { | ||
| 1606 | hnswNode *random_node = hnsw_random_node(vset->hnsw, 0); | ||
| 1607 | if (random_node) { | ||
| 1608 | struct vsetNodeVal *nv = random_node->value; | ||
| 1609 | return RedisModule_ReplyWithString(ctx, nv->item); | ||
| 1610 | } else { | ||
| 1611 | return RedisModule_ReplyWithNull(ctx); | ||
| 1612 | } | ||
| 1613 | } | ||
| 1614 | |||
| 1615 | /* Case 2: COUNT option given, return an array of elements. */ | ||
| 1616 | int allow_duplicates = (count < 0); | ||
| 1617 | long long abs_count = (count < 0) ? -count : count; | ||
| 1618 | |||
| 1619 | /* Cap the count to the set size if we are not allowing duplicates. */ | ||
| 1620 | if (!allow_duplicates && abs_count > (long long)set_size) | ||
| 1621 | abs_count = set_size; | ||
| 1622 | |||
| 1623 | /* Prepare reply. */ | ||
| 1624 | RedisModule_ReplyWithArray(ctx, abs_count); | ||
| 1625 | |||
| 1626 | if (allow_duplicates) { | ||
| 1627 | /* Simple case: With duplicates, just pick random nodes | ||
| 1628 | * abs_count times. */ | ||
| 1629 | for (long long i = 0; i < abs_count; i++) { | ||
| 1630 | hnswNode *random_node = hnsw_random_node(vset->hnsw,0); | ||
| 1631 | struct vsetNodeVal *nv = random_node->value; | ||
| 1632 | RedisModule_ReplyWithString(ctx, nv->item); | ||
| 1633 | } | ||
| 1634 | } else { | ||
| 1635 | /* Case where count is positive: we need unique elements. | ||
| 1636 | * But, if the user asked for many elements, selecting so | ||
| 1637 | * many (> 20%) random nodes may be too expansive: we just start | ||
| 1638 | * from a random element and follow the next link. | ||
| 1639 | * | ||
| 1640 | * Otherwisem for the <= 20% case, a dictionary is used to | ||
| 1641 | * reject duplicates. */ | ||
| 1642 | int use_dict = (abs_count <= set_size * 0.2); | ||
| 1643 | |||
| 1644 | if (use_dict) { | ||
| 1645 | RedisModuleDict *returned = RedisModule_CreateDict(ctx); | ||
| 1646 | |||
| 1647 | long long returned_count = 0; | ||
| 1648 | while (returned_count < abs_count) { | ||
| 1649 | hnswNode *random_node = hnsw_random_node(vset->hnsw, 0); | ||
| 1650 | struct vsetNodeVal *nv = random_node->value; | ||
| 1651 | |||
| 1652 | /* Check if we've already returned this element. */ | ||
| 1653 | if (RedisModule_DictGet(returned, nv->item, NULL) == NULL) { | ||
| 1654 | /* Mark as returned and add to results. */ | ||
| 1655 | RedisModule_DictSet(returned, nv->item, (void*)1); | ||
| 1656 | RedisModule_ReplyWithString(ctx, nv->item); | ||
| 1657 | returned_count++; | ||
| 1658 | } | ||
| 1659 | } | ||
| 1660 | RedisModule_FreeDict(ctx, returned); | ||
| 1661 | } else { | ||
| 1662 | /* For large samples, get a random starting node and walk | ||
| 1663 | * the list. | ||
| 1664 | * | ||
| 1665 | * IMPORTANT: doing so does not really generate random | ||
| 1666 | * elements: it's just a linear scan, but we have no choices. | ||
| 1667 | * If we generate too many random elements, more and more would | ||
| 1668 | * fail the check of being novel (not yet collected in the set | ||
| 1669 | * to return) if the % of elements to emit is too large, we would | ||
| 1670 | * spend too much CPU. */ | ||
| 1671 | hnswNode *start_node = hnsw_random_node(vset->hnsw, 0); | ||
| 1672 | hnswNode *current = start_node; | ||
| 1673 | |||
| 1674 | long long returned_count = 0; | ||
| 1675 | while (returned_count < abs_count) { | ||
| 1676 | if (current == NULL) { | ||
| 1677 | /* Restart from head if we hit the end. */ | ||
| 1678 | current = vset->hnsw->head; | ||
| 1679 | } | ||
| 1680 | struct vsetNodeVal *nv = current->value; | ||
| 1681 | RedisModule_ReplyWithString(ctx, nv->item); | ||
| 1682 | returned_count++; | ||
| 1683 | current = current->next; | ||
| 1684 | } | ||
| 1685 | } | ||
| 1686 | } | ||
| 1687 | return REDISMODULE_OK; | ||
| 1688 | } | ||
| 1689 | |||
| 1690 | /* VISMEMBER key element | ||
| 1691 | * Check if an element exists in a vector set. | ||
| 1692 | * Returns 1 if the element exists, 0 if not. */ | ||
| 1693 | int VISMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1694 | RedisModule_AutoMemory(ctx); | ||
| 1695 | if (argc != 3) return RedisModule_WrongArity(ctx); | ||
| 1696 | |||
| 1697 | RedisModuleString *key = argv[1]; | ||
| 1698 | RedisModuleString *element = argv[2]; | ||
| 1699 | |||
| 1700 | /* Open key. */ | ||
| 1701 | RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); | ||
| 1702 | int type = RedisModule_KeyType(keyptr); | ||
| 1703 | |||
| 1704 | /* Handle non-existing key or wrong type. */ | ||
| 1705 | if (type == REDISMODULE_KEYTYPE_EMPTY) { | ||
| 1706 | /* An element of a non existing key does not exist, like | ||
| 1707 | * SISMEMBER & similar. */ | ||
| 1708 | return RedisModule_ReplyWithBool(ctx, 0); | ||
| 1709 | } | ||
| 1710 | if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { | ||
| 1711 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1712 | } | ||
| 1713 | |||
| 1714 | /* Get the object and test membership via the dictionary in constant | ||
| 1715 | * time (assuming a member of average size). */ | ||
| 1716 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); | ||
| 1717 | hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); | ||
| 1718 | return RedisModule_ReplyWithBool(ctx, node != NULL); | ||
| 1719 | } | ||
| 1720 | |||
| 1721 | /* Structure to represent a range boundary. */ | ||
| 1722 | struct vsetRangeOp { | ||
| 1723 | int incl; /* 1 if inclusive ([), 0 if exclusive ((). */ | ||
| 1724 | int min; /* 1 if this is "-" (minimum). */ | ||
| 1725 | int max; /* 1 if this is "+" (maximum). */ | ||
| 1726 | unsigned char *ele; /* The actual element, NULL if min/max. */ | ||
| 1727 | size_t ele_len; /* Length of the element. */ | ||
| 1728 | }; | ||
| 1729 | |||
| 1730 | /* Parse a range specification like "[foo" or "(bar" or "-" or "+". | ||
| 1731 | * Returns 1 on success, 0 on error. */ | ||
| 1732 | int vsetParseRangeOp(RedisModuleString *arg, struct vsetRangeOp *op) { | ||
| 1733 | size_t len; | ||
| 1734 | const char *str = RedisModule_StringPtrLen(arg, &len); | ||
| 1735 | |||
| 1736 | if (len == 0) return 0; | ||
| 1737 | |||
| 1738 | /* Initialize the structure. */ | ||
| 1739 | op->incl = 0; | ||
| 1740 | op->min = 0; | ||
| 1741 | op->max = 0; | ||
| 1742 | op->ele = NULL; | ||
| 1743 | op->ele_len = 0; | ||
| 1744 | |||
| 1745 | /* Check for special cases "-" and "+". */ | ||
| 1746 | if (len == 1 && str[0] == '-') { | ||
| 1747 | op->min = 1; | ||
| 1748 | return 1; | ||
| 1749 | } | ||
| 1750 | if (len == 1 && str[0] == '+') { | ||
| 1751 | op->max = 1; | ||
| 1752 | return 1; | ||
| 1753 | } | ||
| 1754 | |||
| 1755 | /* Otherwise, must start with ( or [. */ | ||
| 1756 | if (str[0] == '[') { | ||
| 1757 | op->incl = 1; | ||
| 1758 | } else if (str[0] == '(') { | ||
| 1759 | op->incl = 0; | ||
| 1760 | } else { | ||
| 1761 | return 0; /* Invalid format. */ | ||
| 1762 | } | ||
| 1763 | |||
| 1764 | /* Extract the string part after the bracket. */ | ||
| 1765 | if (len > 1) { | ||
| 1766 | op->ele = (unsigned char *)(str + 1); | ||
| 1767 | op->ele_len = len - 1; | ||
| 1768 | } else { | ||
| 1769 | return 0; /* Just a bracket with no string. */ | ||
| 1770 | } | ||
| 1771 | |||
| 1772 | return 1; | ||
| 1773 | } | ||
| 1774 | |||
| 1775 | /* Check if the current element is within the range defined by the end operator. | ||
| 1776 | * Returns 1 if the element is within range, 0 if it has passed the end. */ | ||
| 1777 | int vsetIsElementInRange(const void *ele, size_t ele_len, struct vsetRangeOp *end_op) { | ||
| 1778 | /* If end is "+", element is always in range. */ | ||
| 1779 | if (end_op->max) return 1; | ||
| 1780 | |||
| 1781 | /* Compare current element with end boundary. */ | ||
| 1782 | size_t minlen = ele_len < end_op->ele_len ? ele_len : end_op->ele_len; | ||
| 1783 | int cmp = memcmp(ele, end_op->ele, minlen); | ||
| 1784 | |||
| 1785 | if (cmp == 0) { | ||
| 1786 | /* If equal up to minlen, shorter string is smaller. */ | ||
| 1787 | if (ele_len < end_op->ele_len) { | ||
| 1788 | cmp = -1; | ||
| 1789 | } else if (ele_len > end_op->ele_len) { | ||
| 1790 | cmp = 1; | ||
| 1791 | } | ||
| 1792 | } | ||
| 1793 | |||
| 1794 | /* Check based on inclusive/exclusive. */ | ||
| 1795 | if (end_op->incl) { | ||
| 1796 | return cmp <= 0; /* Inclusive: element <= end. */ | ||
| 1797 | } else { | ||
| 1798 | return cmp < 0; /* Exclusive: element < end. */ | ||
| 1799 | } | ||
| 1800 | } | ||
| 1801 | |||
| 1802 | /* VRANGE key start end [count] | ||
| 1803 | * Returns elements in the lexicographical range [start, end] | ||
| 1804 | * | ||
| 1805 | * Elements must be specified in one of the following forms: | ||
| 1806 | * | ||
| 1807 | * [myelement | ||
| 1808 | * (myelement | ||
| 1809 | * + | ||
| 1810 | * - | ||
| 1811 | * | ||
| 1812 | * Elements starting with [ are inclusive, so "myelement" would be | ||
| 1813 | * returned if present in the set. Elements starting with ( are exclusive | ||
| 1814 | * ranges instead. The special - and + elements mean the minimum and maximum | ||
| 1815 | * possible element (inclusive), so "VRANGE key - +" will return everything | ||
| 1816 | * (depending on COUNT of course). The special - element can be used only | ||
| 1817 | * as starting element, the special + element only as ending element. */ | ||
| 1818 | int VRANGE_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 1819 | RedisModule_AutoMemory(ctx); | ||
| 1820 | |||
| 1821 | /* Check arguments. */ | ||
| 1822 | if (argc < 4 || argc > 5) return RedisModule_WrongArity(ctx); | ||
| 1823 | |||
| 1824 | /* Parse COUNT if provided. */ | ||
| 1825 | long long count = -1; /* Default: return all elements. */ | ||
| 1826 | if (argc == 5) { | ||
| 1827 | if (RedisModule_StringToLongLong(argv[4], &count) != REDISMODULE_OK) { | ||
| 1828 | return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT value"); | ||
| 1829 | } | ||
| 1830 | } | ||
| 1831 | |||
| 1832 | /* Parse range operators. */ | ||
| 1833 | struct vsetRangeOp start_op, end_op; | ||
| 1834 | if (!vsetParseRangeOp(argv[2], &start_op)) { | ||
| 1835 | return RedisModule_ReplyWithError(ctx, "ERR invalid start range format"); | ||
| 1836 | } | ||
| 1837 | if (!vsetParseRangeOp(argv[3], &end_op)) { | ||
| 1838 | return RedisModule_ReplyWithError(ctx, "ERR invalid end range format"); | ||
| 1839 | } | ||
| 1840 | |||
| 1841 | /* Validate: "-" can only be first arg, "+" can only be second. */ | ||
| 1842 | if (start_op.max || end_op.min) { | ||
| 1843 | return RedisModule_ReplyWithError(ctx, | ||
| 1844 | "ERR '-' can only be used as first argument, '+' only as second"); | ||
| 1845 | } | ||
| 1846 | |||
| 1847 | /* Open the key. */ | ||
| 1848 | RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); | ||
| 1849 | int type = RedisModule_KeyType(key); | ||
| 1850 | |||
| 1851 | if (type == REDISMODULE_KEYTYPE_EMPTY) { | ||
| 1852 | return RedisModule_ReplyWithEmptyArray(ctx); | ||
| 1853 | } | ||
| 1854 | |||
| 1855 | if (RedisModule_ModuleTypeGetType(key) != VectorSetType) { | ||
| 1856 | return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); | ||
| 1857 | } | ||
| 1858 | |||
| 1859 | struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); | ||
| 1860 | |||
| 1861 | /* Start the iterator. */ | ||
| 1862 | RedisModuleDictIter *iter; | ||
| 1863 | if (start_op.min) { | ||
| 1864 | /* Start from the beginning. */ | ||
| 1865 | iter = RedisModule_DictIteratorStartC(vset->dict, "^", NULL, 0); | ||
| 1866 | } else { | ||
| 1867 | /* Start from the specified element. */ | ||
| 1868 | const char *op = start_op.incl ? ">=" : ">"; | ||
| 1869 | iter = RedisModule_DictIteratorStartC(vset->dict, op, start_op.ele, start_op.ele_len); | ||
| 1870 | } | ||
| 1871 | |||
| 1872 | /* Collect results. */ | ||
| 1873 | RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN); | ||
| 1874 | long long returned = 0; | ||
| 1875 | |||
| 1876 | void *key_data; | ||
| 1877 | size_t key_len; | ||
| 1878 | while ((key_data = RedisModule_DictNextC(iter, &key_len, NULL)) != NULL) { | ||
| 1879 | /* Check if we've collected enough elements. */ | ||
| 1880 | if (count >= 0 && returned >= count) break; | ||
| 1881 | |||
| 1882 | /* Check if we've passed the end range. */ | ||
| 1883 | if (!vsetIsElementInRange(key_data, key_len, &end_op)) break; | ||
| 1884 | |||
| 1885 | /* Add this element to the result. */ | ||
| 1886 | RedisModule_ReplyWithStringBuffer(ctx, key_data, key_len); | ||
| 1887 | returned++; | ||
| 1888 | } | ||
| 1889 | |||
| 1890 | RedisModule_ReplySetArrayLength(ctx, returned); | ||
| 1891 | |||
| 1892 | /* Cleanup. */ | ||
| 1893 | RedisModule_DictIteratorStop(iter); | ||
| 1894 | |||
| 1895 | return REDISMODULE_OK; | ||
| 1896 | } | ||
| 1897 | |||
| 1898 | /* ============================== vset type methods ========================= */ | ||
| 1899 | |||
| 1900 | #define SAVE_FLAG_HAS_PROJMATRIX (1<<0) | ||
| 1901 | #define SAVE_FLAG_HAS_ATTRIBS (1<<1) | ||
| 1902 | |||
| 1903 | /* Save object to RDB */ | ||
| 1904 | void VectorSetRdbSave(RedisModuleIO *rdb, void *value) { | ||
| 1905 | struct vsetObject *vset = value; | ||
| 1906 | RedisModule_SaveUnsigned(rdb, vset->hnsw->vector_dim); | ||
| 1907 | RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count); | ||
| 1908 | |||
| 1909 | uint32_t hnsw_config = (vset->hnsw->quant_type & 0xff) | | ||
| 1910 | ((vset->hnsw->M & 0xffff) << 8); | ||
| 1911 | RedisModule_SaveUnsigned(rdb, hnsw_config); | ||
| 1912 | |||
| 1913 | uint32_t save_flags = 0; | ||
| 1914 | if (vset->proj_matrix) save_flags |= SAVE_FLAG_HAS_PROJMATRIX; | ||
| 1915 | if (vset->numattribs != 0) save_flags |= SAVE_FLAG_HAS_ATTRIBS; | ||
| 1916 | RedisModule_SaveUnsigned(rdb, save_flags); | ||
| 1917 | |||
| 1918 | /* Save projection matrix if present */ | ||
| 1919 | if (vset->proj_matrix) { | ||
| 1920 | uint32_t input_dim = vset->proj_input_size; | ||
| 1921 | uint32_t output_dim = vset->hnsw->vector_dim; | ||
| 1922 | RedisModule_SaveUnsigned(rdb, input_dim); | ||
| 1923 | // Output dim is the same as the first value saved | ||
| 1924 | // above, so we don't save it. | ||
| 1925 | |||
| 1926 | // Save projection matrix as binary blob | ||
| 1927 | size_t matrix_size = sizeof(float) * input_dim * output_dim; | ||
| 1928 | RedisModule_SaveStringBuffer(rdb, (const char *)vset->proj_matrix, matrix_size); | ||
| 1929 | } | ||
| 1930 | |||
| 1931 | hnswNode *node = vset->hnsw->head; | ||
| 1932 | while(node) { | ||
| 1933 | struct vsetNodeVal *nv = node->value; | ||
| 1934 | RedisModule_SaveString(rdb, nv->item); | ||
| 1935 | if (vset->numattribs) { | ||
| 1936 | if (nv->attrib) | ||
| 1937 | RedisModule_SaveString(rdb, nv->attrib); | ||
| 1938 | else | ||
| 1939 | RedisModule_SaveStringBuffer(rdb, "", 0); | ||
| 1940 | } | ||
| 1941 | hnswSerNode *sn = hnsw_serialize_node(vset->hnsw,node); | ||
| 1942 | RedisModule_SaveStringBuffer(rdb, (const char *)sn->vector, sn->vector_size); | ||
| 1943 | RedisModule_SaveUnsigned(rdb, sn->params_count); | ||
| 1944 | for (uint32_t j = 0; j < sn->params_count; j++) | ||
| 1945 | RedisModule_SaveUnsigned(rdb, sn->params[j]); | ||
| 1946 | hnsw_free_serialized_node(sn); | ||
| 1947 | node = node->next; | ||
| 1948 | } | ||
| 1949 | } | ||
| 1950 | |||
| 1951 | /* Load object from RDB. Recover from recoverable errors (read errors) | ||
| 1952 | * by performing cleanup. */ | ||
| 1953 | void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { | ||
| 1954 | if (encver != 0) return NULL; // Invalid version | ||
| 1955 | |||
| 1956 | uint32_t dim = RedisModule_LoadUnsigned(rdb); | ||
| 1957 | uint64_t elements = RedisModule_LoadUnsigned(rdb); | ||
| 1958 | uint32_t hnsw_config = RedisModule_LoadUnsigned(rdb); | ||
| 1959 | if (RedisModule_IsIOError(rdb)) return NULL; | ||
| 1960 | uint32_t quant_type = hnsw_config & 0xff; | ||
| 1961 | uint32_t hnsw_m = (hnsw_config >> 8) & 0xffff; | ||
| 1962 | |||
| 1963 | /* Check that the quantization type is correct. Otherwise | ||
| 1964 | * return ASAP signaling the error. */ | ||
| 1965 | if (quant_type != HNSW_QUANT_NONE && | ||
| 1966 | quant_type != HNSW_QUANT_Q8 && | ||
| 1967 | quant_type != HNSW_QUANT_BIN) return NULL; | ||
| 1968 | |||
| 1969 | if (hnsw_m == 0) hnsw_m = 16; // Default, useful for RDB files predating | ||
| 1970 | // this configuration parameter: it was fixed | ||
| 1971 | // to 16. | ||
| 1972 | struct vsetObject *vset = createVectorSetObject(dim,quant_type,hnsw_m); | ||
| 1973 | RedisModule_Assert(vset != NULL); | ||
| 1974 | |||
| 1975 | /* Load projection matrix if present */ | ||
| 1976 | uint32_t save_flags = RedisModule_LoadUnsigned(rdb); | ||
| 1977 | if (RedisModule_IsIOError(rdb)) goto ioerr; | ||
| 1978 | int has_projection = save_flags & SAVE_FLAG_HAS_PROJMATRIX; | ||
| 1979 | int has_attribs = save_flags & SAVE_FLAG_HAS_ATTRIBS; | ||
| 1980 | if (has_projection) { | ||
| 1981 | uint32_t input_dim = RedisModule_LoadUnsigned(rdb); | ||
| 1982 | if (RedisModule_IsIOError(rdb)) goto ioerr; | ||
| 1983 | uint32_t output_dim = dim; | ||
| 1984 | size_t matrix_size = sizeof(float) * input_dim * output_dim; | ||
| 1985 | |||
| 1986 | vset->proj_matrix = RedisModule_Alloc(matrix_size); | ||
| 1987 | vset->proj_input_size = input_dim; | ||
| 1988 | |||
| 1989 | // Load projection matrix as a binary blob | ||
| 1990 | char *matrix_blob = RedisModule_LoadStringBuffer(rdb, NULL); | ||
| 1991 | if (matrix_blob == NULL) goto ioerr; | ||
| 1992 | memcpy(vset->proj_matrix, matrix_blob, matrix_size); | ||
| 1993 | RedisModule_Free(matrix_blob); | ||
| 1994 | } | ||
| 1995 | |||
| 1996 | while(elements--) { | ||
| 1997 | // Load associated string element. | ||
| 1998 | RedisModuleString *ele = RedisModule_LoadString(rdb); | ||
| 1999 | if (RedisModule_IsIOError(rdb)) goto ioerr; | ||
| 2000 | RedisModuleString *attrib = NULL; | ||
| 2001 | if (has_attribs) { | ||
| 2002 | attrib = RedisModule_LoadString(rdb); | ||
| 2003 | if (RedisModule_IsIOError(rdb)) { | ||
| 2004 | RedisModule_FreeString(NULL,ele); | ||
| 2005 | goto ioerr; | ||
| 2006 | } | ||
| 2007 | size_t attrlen; | ||
| 2008 | RedisModule_StringPtrLen(attrib,&attrlen); | ||
| 2009 | if (attrlen == 0) { | ||
| 2010 | RedisModule_FreeString(NULL,attrib); | ||
| 2011 | attrib = NULL; | ||
| 2012 | } | ||
| 2013 | } | ||
| 2014 | size_t vector_len; | ||
| 2015 | void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len); | ||
| 2016 | if (RedisModule_IsIOError(rdb)) { | ||
| 2017 | RedisModule_FreeString(NULL,ele); | ||
| 2018 | if (attrib) RedisModule_FreeString(NULL,attrib); | ||
| 2019 | goto ioerr; | ||
| 2020 | } | ||
| 2021 | uint32_t vector_bytes = hnsw_quants_bytes(vset->hnsw); | ||
| 2022 | if (vector_len != vector_bytes) { | ||
| 2023 | RedisModule_LogIOError(rdb,"warning", | ||
| 2024 | "Mismatching vector dimension"); | ||
| 2025 | RedisModule_FreeString(NULL,ele); | ||
| 2026 | if (attrib) RedisModule_FreeString(NULL,attrib); | ||
| 2027 | RedisModule_Free(vector); | ||
| 2028 | goto ioerr; | ||
| 2029 | } | ||
| 2030 | |||
| 2031 | // Load node parameters back. | ||
| 2032 | uint32_t params_count = RedisModule_LoadUnsigned(rdb); | ||
| 2033 | if (RedisModule_IsIOError(rdb)) { | ||
| 2034 | RedisModule_FreeString(NULL,ele); | ||
| 2035 | if (attrib) RedisModule_FreeString(NULL,attrib); | ||
| 2036 | RedisModule_Free(vector); | ||
| 2037 | goto ioerr; | ||
| 2038 | } | ||
| 2039 | |||
| 2040 | uint64_t *params = RedisModule_Alloc(params_count*sizeof(uint64_t)); | ||
| 2041 | for (uint32_t j = 0; j < params_count; j++) { | ||
| 2042 | // Ignore loading errors here: handled at the end of the loop. | ||
| 2043 | params[j] = RedisModule_LoadUnsigned(rdb); | ||
| 2044 | } | ||
| 2045 | if (RedisModule_IsIOError(rdb)) { | ||
| 2046 | RedisModule_FreeString(NULL,ele); | ||
| 2047 | if (attrib) RedisModule_FreeString(NULL,attrib); | ||
| 2048 | RedisModule_Free(vector); | ||
| 2049 | RedisModule_Free(params); | ||
| 2050 | goto ioerr; | ||
| 2051 | } | ||
| 2052 | |||
| 2053 | struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); | ||
| 2054 | nv->item = ele; | ||
| 2055 | nv->attrib = attrib; | ||
| 2056 | hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, nv); | ||
| 2057 | if (node == NULL) { | ||
| 2058 | RedisModule_LogIOError(rdb,"warning", | ||
| 2059 | "Vector set node index loading error"); | ||
| 2060 | vectorSetReleaseNodeValue(nv); | ||
| 2061 | RedisModule_Free(vector); | ||
| 2062 | RedisModule_Free(params); | ||
| 2063 | goto ioerr; | ||
| 2064 | } | ||
| 2065 | if (nv->attrib) vset->numattribs++; | ||
| 2066 | RedisModule_DictSet(vset->dict,ele,node); | ||
| 2067 | RedisModule_Free(vector); | ||
| 2068 | RedisModule_Free(params); | ||
| 2069 | } | ||
| 2070 | |||
| 2071 | uint64_t salt[2]; | ||
| 2072 | RedisModule_GetRandomBytes((unsigned char*)salt,sizeof(salt)); | ||
| 2073 | if (!hnsw_deserialize_index(vset->hnsw, salt[0], salt[1])) goto ioerr; | ||
| 2074 | |||
| 2075 | return vset; | ||
| 2076 | |||
| 2077 | ioerr: | ||
| 2078 | /* We want to recover from I/O errors and free the partially allocated | ||
| 2079 | * data structure to support diskless replication. */ | ||
| 2080 | vectorSetReleaseObject(vset); | ||
| 2081 | return NULL; | ||
| 2082 | } | ||
| 2083 | |||
| 2084 | /* Calculate memory usage */ | ||
| 2085 | size_t VectorSetMemUsage(const void *value) { | ||
| 2086 | const struct vsetObject *vset = value; | ||
| 2087 | size_t size = sizeof(*vset); | ||
| 2088 | |||
| 2089 | /* Account for HNSW index base structure */ | ||
| 2090 | size += sizeof(HNSW); | ||
| 2091 | |||
| 2092 | /* Account for projection matrix if present */ | ||
| 2093 | if (vset->proj_matrix) { | ||
| 2094 | /* For the matrix size, we need the input dimension. We can get it | ||
| 2095 | * from the first node if the set is not empty. */ | ||
| 2096 | uint32_t input_dim = vset->proj_input_size; | ||
| 2097 | uint32_t output_dim = vset->hnsw->vector_dim; | ||
| 2098 | size += sizeof(float) * input_dim * output_dim; | ||
| 2099 | } | ||
| 2100 | |||
| 2101 | /* Account for each node's memory usage. */ | ||
| 2102 | hnswNode *node = vset->hnsw->head; | ||
| 2103 | if (node == NULL) return size; | ||
| 2104 | |||
| 2105 | /* Base node structure. */ | ||
| 2106 | size += sizeof(*node) * vset->hnsw->node_count; | ||
| 2107 | |||
| 2108 | /* Vector storage. */ | ||
| 2109 | uint64_t vec_storage = hnsw_quants_bytes(vset->hnsw); | ||
| 2110 | size += vec_storage * vset->hnsw->node_count; | ||
| 2111 | |||
| 2112 | /* Layers array. We use 1.33 as average nodes layers count. */ | ||
| 2113 | uint64_t layers_storage = sizeof(hnswNodeLayer) * vset->hnsw->node_count; | ||
| 2114 | layers_storage = layers_storage * 4 / 3; // 1.33 times. | ||
| 2115 | size += layers_storage; | ||
| 2116 | |||
| 2117 | /* All the nodes have layer 0 links. */ | ||
| 2118 | uint64_t level0_links = node->layers[0].max_links; | ||
| 2119 | uint64_t other_levels_links = level0_links/2; | ||
| 2120 | size += sizeof(hnswNode*) * level0_links * vset->hnsw->node_count; | ||
| 2121 | |||
| 2122 | /* Add the 0.33 remaining part, but upper layers have less links. */ | ||
| 2123 | size += (sizeof(hnswNode*) * other_levels_links * vset->hnsw->node_count)/3; | ||
| 2124 | |||
| 2125 | /* Associated string value and attributres. | ||
| 2126 | * Use Redis Module API to get string size, and guess that all the | ||
| 2127 | * elements have similar size as the first few. */ | ||
| 2128 | size_t items_scanned = 0, items_size = 0; | ||
| 2129 | size_t attribs_scanned = 0, attribs_size = 0; | ||
| 2130 | int scan_effort = 20; | ||
| 2131 | while(scan_effort > 0 && node) { | ||
| 2132 | struct vsetNodeVal *nv = node->value; | ||
| 2133 | items_size += RedisModule_MallocSizeString(nv->item); | ||
| 2134 | items_scanned++; | ||
| 2135 | if (nv->attrib) { | ||
| 2136 | attribs_size += RedisModule_MallocSizeString(nv->attrib); | ||
| 2137 | attribs_scanned++; | ||
| 2138 | } | ||
| 2139 | scan_effort--; | ||
| 2140 | node = node->next; | ||
| 2141 | } | ||
| 2142 | |||
| 2143 | /* Add the memory usage due to items. */ | ||
| 2144 | if (items_scanned) | ||
| 2145 | size += items_size / items_scanned * vset->hnsw->node_count; | ||
| 2146 | |||
| 2147 | /* Add memory usage due to attributres. */ | ||
| 2148 | if (attribs_scanned == 0) { | ||
| 2149 | /* We were not lucky enough to find a single attribute in the | ||
| 2150 | * first few items? Let's use a fixed arbitrary value. */ | ||
| 2151 | attribs_scanned = 1; | ||
| 2152 | attribs_size = 64; | ||
| 2153 | } | ||
| 2154 | size += attribs_size / attribs_scanned * vset->numattribs; | ||
| 2155 | |||
| 2156 | /* Account for dictionary overhead - this is an approximation. */ | ||
| 2157 | size += RedisModule_DictSize(vset->dict) * (sizeof(void*) * 2); | ||
| 2158 | |||
| 2159 | return size; | ||
| 2160 | } | ||
| 2161 | |||
| 2162 | /* Free the entire data structure */ | ||
| 2163 | void VectorSetFree(void *value) { | ||
| 2164 | struct vsetObject *vset = value; | ||
| 2165 | |||
| 2166 | vectorSetWaitAllBackgroundClients(vset,1); | ||
| 2167 | vectorSetReleaseObject(value); | ||
| 2168 | } | ||
| 2169 | |||
| 2170 | /* Add object digest to the digest context */ | ||
| 2171 | void VectorSetDigest(RedisModuleDigest *md, void *value) { | ||
| 2172 | struct vsetObject *vset = value; | ||
| 2173 | |||
| 2174 | /* Add consistent order-independent hash of all vectors */ | ||
| 2175 | hnswNode *node = vset->hnsw->head; | ||
| 2176 | |||
| 2177 | /* Hash the vector dimension and number of nodes. */ | ||
| 2178 | RedisModule_DigestAddLongLong(md, vset->hnsw->node_count); | ||
| 2179 | RedisModule_DigestAddLongLong(md, vset->hnsw->vector_dim); | ||
| 2180 | RedisModule_DigestEndSequence(md); | ||
| 2181 | |||
| 2182 | while(node) { | ||
| 2183 | struct vsetNodeVal *nv = node->value; | ||
| 2184 | /* Hash each vector component */ | ||
| 2185 | RedisModule_DigestAddStringBuffer(md, node->vector, hnsw_quants_bytes(vset->hnsw)); | ||
| 2186 | /* Hash the associated value */ | ||
| 2187 | size_t len; | ||
| 2188 | const char *str = RedisModule_StringPtrLen(nv->item, &len); | ||
| 2189 | RedisModule_DigestAddStringBuffer(md, (char*)str, len); | ||
| 2190 | if (nv->attrib) { | ||
| 2191 | str = RedisModule_StringPtrLen(nv->attrib, &len); | ||
| 2192 | RedisModule_DigestAddStringBuffer(md, (char*)str, len); | ||
| 2193 | } | ||
| 2194 | node = node->next; | ||
| 2195 | RedisModule_DigestEndSequence(md); | ||
| 2196 | } | ||
| 2197 | } | ||
| 2198 | |||
| 2199 | // int VectorSets_InitModuleConfig(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 2200 | int VectorSets_InitModuleConfig(RedisModuleCtx *ctx) { | ||
| 2201 | if (RegisterModuleConfig(ctx) == REDISMODULE_ERR) { | ||
| 2202 | RedisModule_Log(ctx, "warning", "Error registering module configuration"); | ||
| 2203 | return REDISMODULE_ERR; | ||
| 2204 | } | ||
| 2205 | // Load default values | ||
| 2206 | if (RedisModule_LoadDefaultConfigs(ctx) == REDISMODULE_ERR) { | ||
| 2207 | RedisModule_Log(ctx, "warning", "Error loading default module configuration"); | ||
| 2208 | return REDISMODULE_ERR; | ||
| 2209 | } else { | ||
| 2210 | RedisModule_Log(ctx, "verbose", "Successfully loaded default module configuration"); | ||
| 2211 | } | ||
| 2212 | if (RedisModule_LoadConfigs(ctx) == REDISMODULE_ERR) { | ||
| 2213 | RedisModule_Log(ctx, "warning", "Error loading user module configuration"); | ||
| 2214 | return REDISMODULE_ERR; | ||
| 2215 | } else { | ||
| 2216 | RedisModule_Log(ctx, "verbose", "Successfully loaded user module configuration"); | ||
| 2217 | } | ||
| 2218 | return REDISMODULE_OK; | ||
| 2219 | } | ||
| 2220 | |||
| 2221 | /* This function must be present on each Redis module. It is used in order to | ||
| 2222 | * register the commands into the Redis server. */ | ||
| 2223 | int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 2224 | REDISMODULE_NOT_USED(argv); | ||
| 2225 | REDISMODULE_NOT_USED(argc); | ||
| 2226 | |||
| 2227 | if (RedisModule_Init(ctx,"vectorset",1,REDISMODULE_APIVER_1) | ||
| 2228 | == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2229 | |||
| 2230 | if (VectorSets_InitModuleConfig(ctx) == REDISMODULE_ERR) { | ||
| 2231 | return REDISMODULE_ERR; | ||
| 2232 | } | ||
| 2233 | |||
| 2234 | RedisModule_SetModuleOptions(ctx, REDISMODULE_OPTIONS_HANDLE_IO_ERRORS|REDISMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD); | ||
| 2235 | |||
| 2236 | RedisModuleTypeMethods tm = { | ||
| 2237 | .version = REDISMODULE_TYPE_METHOD_VERSION, | ||
| 2238 | .rdb_load = VectorSetRdbLoad, | ||
| 2239 | .rdb_save = VectorSetRdbSave, | ||
| 2240 | .aof_rewrite = NULL, | ||
| 2241 | .mem_usage = VectorSetMemUsage, | ||
| 2242 | .free = VectorSetFree, | ||
| 2243 | .digest = VectorSetDigest | ||
| 2244 | }; | ||
| 2245 | |||
| 2246 | VectorSetType = RedisModule_CreateDataType(ctx,"vectorset",0,&tm); | ||
| 2247 | if (VectorSetType == NULL) return REDISMODULE_ERR; | ||
| 2248 | |||
| 2249 | // Register command VADD | ||
| 2250 | if (RedisModule_CreateCommand(ctx,"VADD", | ||
| 2251 | VADD_RedisCommand,"write deny-oom",1,1,1) == REDISMODULE_ERR) | ||
| 2252 | return REDISMODULE_ERR; | ||
| 2253 | |||
| 2254 | RedisModuleCommand *vadd_cmd = RedisModule_GetCommand(ctx, "VADD"); | ||
| 2255 | if (vadd_cmd == NULL) return REDISMODULE_ERR; | ||
| 2256 | |||
| 2257 | RedisModuleCommandArg vadd_args[] = { | ||
| 2258 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2259 | { .name = "reduce", .type = REDISMODULE_ARG_TYPE_BLOCK, .token = "REDUCE", .flags = REDISMODULE_CMD_ARG_OPTIONAL, | ||
| 2260 | .subargs = (RedisModuleCommandArg[]) { | ||
| 2261 | { .name = "dim", .type = REDISMODULE_ARG_TYPE_INTEGER }, | ||
| 2262 | { .name = NULL } | ||
| 2263 | } | ||
| 2264 | }, | ||
| 2265 | { .name = "format", .type = REDISMODULE_ARG_TYPE_ONEOF, .subargs = (RedisModuleCommandArg[]) { | ||
| 2266 | { .name = "fp32", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "FP32" }, | ||
| 2267 | { .name = "values", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "VALUES" }, | ||
| 2268 | { .name = NULL } | ||
| 2269 | } | ||
| 2270 | }, | ||
| 2271 | { .name = "vector", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2272 | { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2273 | { .name = "cas", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "CAS", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2274 | { .name = "quant_type", .type = REDISMODULE_ARG_TYPE_ONEOF, .flags = REDISMODULE_CMD_ARG_OPTIONAL, .subargs = (RedisModuleCommandArg[]) { | ||
| 2275 | { .name = "noquant", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "NOQUANT" }, | ||
| 2276 | { .name = "bin", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "BIN" }, | ||
| 2277 | { .name = "q8", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "Q8" }, | ||
| 2278 | { .name = NULL } | ||
| 2279 | } | ||
| 2280 | }, | ||
| 2281 | { .name = "build-exploration-factor", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2282 | { .name = "attributes", .type = REDISMODULE_ARG_TYPE_STRING, .token = "SETATTR", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2283 | { .name = "numlinks", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "M", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2284 | { .name = NULL } | ||
| 2285 | }; | ||
| 2286 | RedisModuleCommandInfo vadd_info = { | ||
| 2287 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2288 | .summary = "Add one or more elements to a vector set, or update its vector if it already exists", | ||
| 2289 | .since = "8.0.0", | ||
| 2290 | .arity = -5, | ||
| 2291 | .args = vadd_args, | ||
| 2292 | }; | ||
| 2293 | if (RedisModule_SetCommandInfo(vadd_cmd, &vadd_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2294 | |||
| 2295 | // Register command VREM | ||
| 2296 | if (RedisModule_CreateCommand(ctx,"VREM", | ||
| 2297 | VREM_RedisCommand,"write",1,1,1) == REDISMODULE_ERR) | ||
| 2298 | return REDISMODULE_ERR; | ||
| 2299 | |||
| 2300 | RedisModuleCommand *vrem_cmd = RedisModule_GetCommand(ctx, "VREM"); | ||
| 2301 | if (vrem_cmd == NULL) return REDISMODULE_ERR; | ||
| 2302 | |||
| 2303 | RedisModuleCommandArg vrem_args[] = { | ||
| 2304 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2305 | { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2306 | { .name = NULL } | ||
| 2307 | }; | ||
| 2308 | RedisModuleCommandInfo vrem_info = { | ||
| 2309 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2310 | .summary = "Remove an element from a vector set", | ||
| 2311 | .since = "8.0.0", | ||
| 2312 | .arity = 3, | ||
| 2313 | .args = vrem_args, | ||
| 2314 | }; | ||
| 2315 | if (RedisModule_SetCommandInfo(vrem_cmd, &vrem_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2316 | |||
| 2317 | // Register command VSIM | ||
| 2318 | if (RedisModule_CreateCommand(ctx,"VSIM", | ||
| 2319 | VSIM_RedisCommand,"readonly",1,1,1) == REDISMODULE_ERR) | ||
| 2320 | return REDISMODULE_ERR; | ||
| 2321 | |||
| 2322 | RedisModuleCommand *vsim_cmd = RedisModule_GetCommand(ctx, "VSIM"); | ||
| 2323 | if (vsim_cmd == NULL) return REDISMODULE_ERR; | ||
| 2324 | |||
| 2325 | RedisModuleCommandArg vsim_args[] = { | ||
| 2326 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2327 | { .name = "format", .type = REDISMODULE_ARG_TYPE_ONEOF, .subargs = (RedisModuleCommandArg[]) { | ||
| 2328 | { .name = "ele", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "ELE" }, | ||
| 2329 | { .name = "fp32", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "FP32" }, | ||
| 2330 | { .name = "values", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "VALUES" }, | ||
| 2331 | { .name = NULL } | ||
| 2332 | } | ||
| 2333 | }, | ||
| 2334 | { .name = "vector_or_element", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2335 | { .name = "withscores", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHSCORES", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2336 | { .name = "withattribs", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHATTRIBS", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2337 | { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "COUNT", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2338 | { .name = "max_distance", .type = REDISMODULE_ARG_TYPE_DOUBLE, .token = "EPSILON", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2339 | { .name = "search-exploration-factor", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2340 | { .name = "expression", .type = REDISMODULE_ARG_TYPE_STRING, .token = "FILTER", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2341 | { .name = "max-filtering-effort", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "FILTER-EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2342 | { .name = "truth", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "TRUTH", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2343 | { .name = "nothread", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "NOTHREAD", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2344 | { .name = NULL } | ||
| 2345 | }; | ||
| 2346 | RedisModuleCommandInfo vsim_info = { | ||
| 2347 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2348 | .summary = "Return elements by vector similarity", | ||
| 2349 | .since = "8.0.0", | ||
| 2350 | .arity = -4, | ||
| 2351 | .args = vsim_args, | ||
| 2352 | }; | ||
| 2353 | if (RedisModule_SetCommandInfo(vsim_cmd, &vsim_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2354 | |||
| 2355 | // Register command VDIM | ||
| 2356 | if (RedisModule_CreateCommand(ctx, "VDIM", | ||
| 2357 | VDIM_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2358 | return REDISMODULE_ERR; | ||
| 2359 | |||
| 2360 | RedisModuleCommand *vdim_cmd = RedisModule_GetCommand(ctx, "VDIM"); | ||
| 2361 | if (vdim_cmd == NULL) return REDISMODULE_ERR; | ||
| 2362 | |||
| 2363 | RedisModuleCommandArg vdim_args[] = { | ||
| 2364 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2365 | { .name = NULL } | ||
| 2366 | }; | ||
| 2367 | RedisModuleCommandInfo vdim_info = { | ||
| 2368 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2369 | .summary = "Return the dimension of vectors in the vector set", | ||
| 2370 | .since = "8.0.0", | ||
| 2371 | .arity = 2, | ||
| 2372 | .args = vdim_args, | ||
| 2373 | }; | ||
| 2374 | if (RedisModule_SetCommandInfo(vdim_cmd, &vdim_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2375 | |||
| 2376 | // Register command VCARD | ||
| 2377 | if (RedisModule_CreateCommand(ctx, "VCARD", | ||
| 2378 | VCARD_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2379 | return REDISMODULE_ERR; | ||
| 2380 | |||
| 2381 | RedisModuleCommand *vcard_cmd = RedisModule_GetCommand(ctx, "VCARD"); | ||
| 2382 | if (vcard_cmd == NULL) return REDISMODULE_ERR; | ||
| 2383 | |||
| 2384 | RedisModuleCommandArg vcard_args[] = { | ||
| 2385 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2386 | { .name = NULL } | ||
| 2387 | }; | ||
| 2388 | RedisModuleCommandInfo vcard_info = { | ||
| 2389 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2390 | .summary = "Return the number of elements in a vector set", | ||
| 2391 | .since = "8.0.0", | ||
| 2392 | .arity = 2, | ||
| 2393 | .args = vcard_args, | ||
| 2394 | }; | ||
| 2395 | if (RedisModule_SetCommandInfo(vcard_cmd, &vcard_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2396 | |||
| 2397 | // Register command VEMB | ||
| 2398 | if (RedisModule_CreateCommand(ctx, "VEMB", | ||
| 2399 | VEMB_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2400 | return REDISMODULE_ERR; | ||
| 2401 | |||
| 2402 | RedisModuleCommand *vemb_cmd = RedisModule_GetCommand(ctx, "VEMB"); | ||
| 2403 | if (vemb_cmd == NULL) return REDISMODULE_ERR; | ||
| 2404 | |||
| 2405 | RedisModuleCommandArg vemb_args[] = { | ||
| 2406 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2407 | { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2408 | { .name = "raw", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "RAW", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2409 | { .name = NULL } | ||
| 2410 | }; | ||
| 2411 | RedisModuleCommandInfo vemb_info = { | ||
| 2412 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2413 | .summary = "Return the vector associated with an element", | ||
| 2414 | .since = "8.0.0", | ||
| 2415 | .arity = -3, | ||
| 2416 | .args = vemb_args, | ||
| 2417 | }; | ||
| 2418 | if (RedisModule_SetCommandInfo(vemb_cmd, &vemb_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2419 | |||
| 2420 | // Register command VLINKS | ||
| 2421 | if (RedisModule_CreateCommand(ctx, "VLINKS", | ||
| 2422 | VLINKS_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2423 | return REDISMODULE_ERR; | ||
| 2424 | |||
| 2425 | RedisModuleCommand *vlinks_cmd = RedisModule_GetCommand(ctx, "VLINKS"); | ||
| 2426 | if (vlinks_cmd == NULL) return REDISMODULE_ERR; | ||
| 2427 | |||
| 2428 | RedisModuleCommandArg vlinks_args[] = { | ||
| 2429 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2430 | { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2431 | { .name = "withscores", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHSCORES", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2432 | { .name = NULL } | ||
| 2433 | }; | ||
| 2434 | RedisModuleCommandInfo vlinks_info = { | ||
| 2435 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2436 | .summary = "Return the neighbors of an element at each layer in the HNSW graph", | ||
| 2437 | .since = "8.0.0", | ||
| 2438 | .arity = -3, | ||
| 2439 | .args = vlinks_args, | ||
| 2440 | }; | ||
| 2441 | if (RedisModule_SetCommandInfo(vlinks_cmd, &vlinks_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2442 | |||
| 2443 | // Register command VINFO | ||
| 2444 | if (RedisModule_CreateCommand(ctx, "VINFO", | ||
| 2445 | VINFO_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2446 | return REDISMODULE_ERR; | ||
| 2447 | |||
| 2448 | RedisModuleCommand *vinfo_cmd = RedisModule_GetCommand(ctx, "VINFO"); | ||
| 2449 | if (vinfo_cmd == NULL) return REDISMODULE_ERR; | ||
| 2450 | |||
| 2451 | RedisModuleCommandArg vinfo_args[] = { | ||
| 2452 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2453 | { .name = NULL } | ||
| 2454 | }; | ||
| 2455 | RedisModuleCommandInfo vinfo_info = { | ||
| 2456 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2457 | .summary = "Return information about a vector set", | ||
| 2458 | .since = "8.0.0", | ||
| 2459 | .arity = 2, | ||
| 2460 | .args = vinfo_args, | ||
| 2461 | }; | ||
| 2462 | if (RedisModule_SetCommandInfo(vinfo_cmd, &vinfo_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2463 | |||
| 2464 | // Register command VSETATTR | ||
| 2465 | if (RedisModule_CreateCommand(ctx, "VSETATTR", | ||
| 2466 | VSETATTR_RedisCommand, "write fast", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2467 | return REDISMODULE_ERR; | ||
| 2468 | |||
| 2469 | RedisModuleCommand *vsetattr_cmd = RedisModule_GetCommand(ctx, "VSETATTR"); | ||
| 2470 | if (vsetattr_cmd == NULL) return REDISMODULE_ERR; | ||
| 2471 | |||
| 2472 | RedisModuleCommandArg vsetattr_args[] = { | ||
| 2473 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2474 | { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2475 | { .name = "json", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2476 | { .name = NULL } | ||
| 2477 | }; | ||
| 2478 | RedisModuleCommandInfo vsetattr_info = { | ||
| 2479 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2480 | .summary = "Associate or remove the JSON attributes of elements", | ||
| 2481 | .since = "8.0.0", | ||
| 2482 | .arity = 4, | ||
| 2483 | .args = vsetattr_args, | ||
| 2484 | }; | ||
| 2485 | if (RedisModule_SetCommandInfo(vsetattr_cmd, &vsetattr_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2486 | |||
| 2487 | // Register command VGETATTR | ||
| 2488 | if (RedisModule_CreateCommand(ctx, "VGETATTR", | ||
| 2489 | VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2490 | return REDISMODULE_ERR; | ||
| 2491 | |||
| 2492 | RedisModuleCommand *vgetattr_cmd = RedisModule_GetCommand(ctx, "VGETATTR"); | ||
| 2493 | if (vgetattr_cmd == NULL) return REDISMODULE_ERR; | ||
| 2494 | |||
| 2495 | RedisModuleCommandArg vgetattr_args[] = { | ||
| 2496 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2497 | { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2498 | { .name = NULL } | ||
| 2499 | }; | ||
| 2500 | RedisModuleCommandInfo vgetattr_info = { | ||
| 2501 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2502 | .summary = "Retrieve the JSON attributes of elements", | ||
| 2503 | .since = "8.0.0", | ||
| 2504 | .arity = 3, | ||
| 2505 | .args = vgetattr_args, | ||
| 2506 | }; | ||
| 2507 | if (RedisModule_SetCommandInfo(vgetattr_cmd, &vgetattr_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2508 | |||
| 2509 | // Register command VRANDMEMBER | ||
| 2510 | if (RedisModule_CreateCommand(ctx, "VRANDMEMBER", | ||
| 2511 | VRANDMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2512 | return REDISMODULE_ERR; | ||
| 2513 | |||
| 2514 | RedisModuleCommand *vrandmember_cmd = RedisModule_GetCommand(ctx, "VRANDMEMBER"); | ||
| 2515 | if (vrandmember_cmd == NULL) return REDISMODULE_ERR; | ||
| 2516 | |||
| 2517 | RedisModuleCommandArg vrandmember_args[] = { | ||
| 2518 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2519 | { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2520 | { .name = NULL } | ||
| 2521 | }; | ||
| 2522 | RedisModuleCommandInfo vrandmember_info = { | ||
| 2523 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2524 | .summary = "Return one or multiple random members from a vector set", | ||
| 2525 | .since = "8.0.0", | ||
| 2526 | .arity = -2, | ||
| 2527 | .args = vrandmember_args, | ||
| 2528 | }; | ||
| 2529 | if (RedisModule_SetCommandInfo(vrandmember_cmd, &vrandmember_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2530 | |||
| 2531 | // Register command VISMEMBER | ||
| 2532 | if (RedisModule_CreateCommand(ctx, "VISMEMBER", | ||
| 2533 | VISMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2534 | return REDISMODULE_ERR; | ||
| 2535 | |||
| 2536 | RedisModuleCommand *vismember_cmd = RedisModule_GetCommand(ctx, "VISMEMBER"); | ||
| 2537 | if (vismember_cmd == NULL) return REDISMODULE_ERR; | ||
| 2538 | |||
| 2539 | RedisModuleCommandArg vismember_args[] = { | ||
| 2540 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2541 | { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2542 | { .name = NULL } | ||
| 2543 | }; | ||
| 2544 | RedisModuleCommandInfo vismember_info = { | ||
| 2545 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2546 | .summary = "Check if an element exists in a vector set", | ||
| 2547 | .since = "8.2.0", | ||
| 2548 | .arity = 3, | ||
| 2549 | .args = vismember_args, | ||
| 2550 | }; | ||
| 2551 | if (RedisModule_SetCommandInfo(vismember_cmd, &vismember_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2552 | |||
| 2553 | // Register command VRANGE | ||
| 2554 | if (RedisModule_CreateCommand(ctx, "VRANGE", | ||
| 2555 | VRANGE_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) | ||
| 2556 | return REDISMODULE_ERR; | ||
| 2557 | |||
| 2558 | RedisModuleCommand *vrange_cmd = RedisModule_GetCommand(ctx, "VRANGE"); | ||
| 2559 | if (vrange_cmd == NULL) return REDISMODULE_ERR; | ||
| 2560 | |||
| 2561 | RedisModuleCommandArg vrange_args[] = { | ||
| 2562 | { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, | ||
| 2563 | { .name = "start", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2564 | { .name = "end", .type = REDISMODULE_ARG_TYPE_STRING }, | ||
| 2565 | { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .flags = REDISMODULE_CMD_ARG_OPTIONAL }, | ||
| 2566 | { .name = NULL } | ||
| 2567 | }; | ||
| 2568 | RedisModuleCommandInfo vrange_info = { | ||
| 2569 | .version = REDISMODULE_COMMAND_INFO_VERSION, | ||
| 2570 | .summary = "Return vector set elements in a lex range", | ||
| 2571 | .since = "8.4.0", | ||
| 2572 | .arity = -4, | ||
| 2573 | .args = vrange_args, | ||
| 2574 | }; | ||
| 2575 | if (RedisModule_SetCommandInfo(vrange_cmd, &vrange_info) == REDISMODULE_ERR) return REDISMODULE_ERR; | ||
| 2576 | |||
| 2577 | // Set the allocator for the HNSW library, so that memory tracking | ||
| 2578 | // is correct in Redis. | ||
| 2579 | hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc, | ||
| 2580 | RedisModule_Realloc); | ||
| 2581 | |||
| 2582 | return REDISMODULE_OK; | ||
| 2583 | } | ||
| 2584 | |||
| 2585 | int VectorSets_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { | ||
| 2586 | return RedisModule_OnLoad(ctx, argv, argc); | ||
| 2587 | } | ||
