summaryrefslogtreecommitdiff
path: root/examples/redis-unstable/modules
diff options
context:
space:
mode:
Diffstat (limited to 'examples/redis-unstable/modules')
-rw-r--r--examples/redis-unstable/modules/Makefile90
-rw-r--r--examples/redis-unstable/modules/common.mk51
-rw-r--r--examples/redis-unstable/modules/redisbloom/Makefile6
-rw-r--r--examples/redis-unstable/modules/redisearch/Makefile7
-rw-r--r--examples/redis-unstable/modules/redisjson/Makefile11
-rw-r--r--examples/redis-unstable/modules/redistimeseries/Makefile6
-rw-r--r--examples/redis-unstable/modules/vector-sets/.gitignore11
-rw-r--r--examples/redis-unstable/modules/vector-sets/Makefile87
-rw-r--r--examples/redis-unstable/modules/vector-sets/README.md733
-rw-r--r--examples/redis-unstable/modules/vector-sets/commands.json446
-rw-r--r--examples/redis-unstable/modules/vector-sets/examples/cli-tool/.gitignore1
-rw-r--r--examples/redis-unstable/modules/vector-sets/examples/cli-tool/README.md44
-rwxr-xr-xexamples/redis-unstable/modules/vector-sets/examples/cli-tool/cli.py160
-rw-r--r--examples/redis-unstable/modules/vector-sets/examples/glove-100/README3
-rw-r--r--examples/redis-unstable/modules/vector-sets/examples/glove-100/insert.py56
-rw-r--r--examples/redis-unstable/modules/vector-sets/examples/glove-100/recall.py87
-rw-r--r--examples/redis-unstable/modules/vector-sets/examples/movies/.gitignore2
-rw-r--r--examples/redis-unstable/modules/vector-sets/examples/movies/README30
-rw-r--r--examples/redis-unstable/modules/vector-sets/examples/movies/insert.py57
-rw-r--r--examples/redis-unstable/modules/vector-sets/expr.c959
-rw-r--r--examples/redis-unstable/modules/vector-sets/fastjson.c441
-rw-r--r--examples/redis-unstable/modules/vector-sets/fastjson_test.c406
-rw-r--r--examples/redis-unstable/modules/vector-sets/hnsw.c2999
-rw-r--r--examples/redis-unstable/modules/vector-sets/hnsw.h189
-rw-r--r--examples/redis-unstable/modules/vector-sets/mixer.h106
-rwxr-xr-xexamples/redis-unstable/modules/vector-sets/test.py294
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/basic_commands.py21
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/basic_similarity.py35
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py156
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/concurrent_vsim_and_del.py48
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/debug_digest.py39
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/deletion.py173
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/dimension_validation.py67
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/epsilon.py77
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/evict_empty.py27
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/filter_expr.py242
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/filter_int.py668
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/large_scale.py56
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/memory_usage.py36
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/node_update.py85
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/persistence.py86
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/reduce.py71
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/replication.py92
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/threading_config.py249
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/vadd_cas.py98
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/vemb.py41
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/vismember.py47
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/vrand-ping-pong.py35
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/vrandmember.py55
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/vrange.py113
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/vsim_limit_efsearch.py32
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/with.py214
-rw-r--r--examples/redis-unstable/modules/vector-sets/vset.c2587
-rw-r--r--examples/redis-unstable/modules/vector-sets/vset_config.c51
-rw-r--r--examples/redis-unstable/modules/vector-sets/vset_config.h24
-rw-r--r--examples/redis-unstable/modules/vector-sets/w2v.c539
56 files changed, 13346 insertions, 0 deletions
diff --git a/examples/redis-unstable/modules/Makefile b/examples/redis-unstable/modules/Makefile
new file mode 100644
index 0000000..11fcdc3
--- /dev/null
+++ b/examples/redis-unstable/modules/Makefile
@@ -0,0 +1,90 @@
1
2SUBDIRS = redisjson redistimeseries redisbloom redisearch
3
4define submake
5 for dir in $(SUBDIRS); do $(MAKE) -C $$dir $(1); done
6endef
7
8all: prepare_source
9 $(call submake,$@)
10
11get_source:
12 $(call submake,$@)
13
14prepare_source: get_source handle-werrors setup_environment
15
16clean:
17 $(call submake,$@)
18
19distclean: clean_environment
20 $(call submake,$@)
21
22pristine:
23 $(call submake,$@)
24
25install:
26 $(call submake,$@)
27
28setup_environment: install-rust handle-werrors
29
30clean_environment: uninstall-rust
31
32# Keep all of the Rust stuff in one place
33install-rust:
34ifeq ($(INSTALL_RUST_TOOLCHAIN),yes)
35 @RUST_VERSION=1.92.0; \
36 ARCH="$$(uname -m)"; \
37 if ldd --version 2>&1 | grep -q musl; then LIBC_TYPE="musl"; else LIBC_TYPE="gnu"; fi; \
38 echo "Detected architecture: $${ARCH} and libc: $${LIBC_TYPE}"; \
39 case "$${ARCH}" in \
40 'x86_64') \
41 if [ "$${LIBC_TYPE}" = "musl" ]; then \
42 RUST_INSTALLER="rust-$${RUST_VERSION}-x86_64-unknown-linux-musl"; \
43 RUST_SHA256="e9b977ef480504d3beef0a095b470945af79e8496bcbb0e4a9d4601d6d72dd91"; \
44 else \
45 RUST_INSTALLER="rust-$${RUST_VERSION}-x86_64-unknown-linux-gnu"; \
46 RUST_SHA256="d2ccef59dd9f7439f2c694948069f789a044dc1addcc0803613232af8f88ee0c"; \
47 fi ;; \
48 'aarch64') \
49 if [ "$${LIBC_TYPE}" = "musl" ]; then \
50 RUST_INSTALLER="rust-$${RUST_VERSION}-aarch64-unknown-linux-musl"; \
51 RUST_SHA256="0c35fca319d01c3d55345d44dbe66c987858c45e262a77c4db679e9869b0af73"; \
52 else \
53 RUST_INSTALLER="rust-$${RUST_VERSION}-aarch64-unknown-linux-gnu"; \
54 RUST_SHA256="3e383f8b4fca710d0600d0c1de97b78281672be2cda6575ecbe1c183a12e3822"; \
55 fi ;; \
56 *) echo >&2 "Unsupported architecture: '$${ARCH}'"; exit 1 ;; \
57 esac; \
58 echo "Downloading and installing Rust standalone installer: $${RUST_INSTALLER}"; \
59 wget --quiet -O $${RUST_INSTALLER}.tar.xz https://static.rust-lang.org/dist/$${RUST_INSTALLER}.tar.xz; \
60 echo "$${RUST_SHA256} $${RUST_INSTALLER}.tar.xz" | sha256sum -c --quiet || { echo "Rust standalone installer checksum failed!"; exit 1; }; \
61 tar -xf $${RUST_INSTALLER}.tar.xz; \
62 (cd $${RUST_INSTALLER} && ./install.sh); \
63 rm -rf $${RUST_INSTALLER}
64endif
65
66uninstall-rust:
67ifeq ($(INSTALL_RUST_TOOLCHAIN),yes)
68 @if [ -x "/usr/local/lib/rustlib/uninstall.sh" ]; then \
69 echo "Uninstalling Rust using uninstall.sh script"; \
70 rm -rf ~/.cargo; \
71 /usr/local/lib/rustlib/uninstall.sh; \
72 else \
73 echo "WARNING: Rust toolchain not found or uninstall script is missing."; \
74 fi
75endif
76
77handle-werrors: get_source
78ifeq ($(DISABLE_WERRORS),yes)
79 @echo "Disabling -Werror for all modules"
80 @for dir in $(SUBDIRS); do \
81 echo "Processing $$dir"; \
82 find $$dir/src -type f \
83 \( -name "Makefile" \
84 -o -name "*.mk" \
85 -o -name "CMakeLists.txt" \) \
86 -exec sed -i 's/-Werror//g' {} +; \
87 done
88endif
89
90.PHONY: all clean distclean install $(SUBDIRS) setup_environment clean_environment install-rust uninstall-rust handle-werrors
diff --git a/examples/redis-unstable/modules/common.mk b/examples/redis-unstable/modules/common.mk
new file mode 100644
index 0000000..7e6a975
--- /dev/null
+++ b/examples/redis-unstable/modules/common.mk
@@ -0,0 +1,51 @@
1PREFIX ?= /usr/local
2INSTALL_DIR ?= $(DESTDIR)$(PREFIX)/lib/redis/modules
3INSTALL ?= install
4
5# This logic *partially* follows the current module build system. It is a bit awkward and
6# should be changed if/when the modules' build process is refactored.
7
8ARCH_MAP_x86_64 := x64
9ARCH_MAP_i386 := x86
10ARCH_MAP_i686 := x86
11ARCH_MAP_aarch64 := arm64v8
12ARCH_MAP_arm64 := arm64v8
13
14OS := $(shell uname -s | tr '[:upper:]' '[:lower:]')
15ARCH := $(ARCH_MAP_$(shell uname -m))
16ifeq ($(ARCH),)
17 $(error Unrecognized CPU architecture $(shell uname -m))
18endif
19
20FULL_VARIANT := $(OS)-$(ARCH)-release
21
22# Common rules for all modules, based on per-module configuration
23
24all: $(TARGET_MODULE)
25
26$(TARGET_MODULE): get_source
27 $(MAKE) -C $(SRC_DIR)
28 cp ${TARGET_MODULE} ./
29
30get_source: $(SRC_DIR)/.prepared
31
32$(SRC_DIR)/.prepared:
33 mkdir -p $(SRC_DIR)
34 git clone --recursive --depth 1 --branch $(MODULE_VERSION) $(MODULE_REPO) $(SRC_DIR)
35 touch $@
36
37clean:
38 -$(MAKE) -C $(SRC_DIR) clean
39 -rm -f ./*.so
40
41distclean: clean
42 -$(MAKE) -C $(SRC_DIR) distclean
43
44pristine:
45 -rm -rf $(SRC_DIR)
46
47install: $(TARGET_MODULE)
48 mkdir -p $(INSTALL_DIR)
49 $(INSTALL) -m 0755 -D $(TARGET_MODULE) $(INSTALL_DIR)
50
51.PHONY: all clean distclean pristine install
diff --git a/examples/redis-unstable/modules/redisbloom/Makefile b/examples/redis-unstable/modules/redisbloom/Makefile
new file mode 100644
index 0000000..32bab9a
--- /dev/null
+++ b/examples/redis-unstable/modules/redisbloom/Makefile
@@ -0,0 +1,6 @@
1SRC_DIR = src
2MODULE_VERSION = v8.5.90
3MODULE_REPO = https://github.com/redisbloom/redisbloom
4TARGET_MODULE = $(SRC_DIR)/bin/$(FULL_VARIANT)/redisbloom.so
5
6include ../common.mk
diff --git a/examples/redis-unstable/modules/redisearch/Makefile b/examples/redis-unstable/modules/redisearch/Makefile
new file mode 100644
index 0000000..1672d74
--- /dev/null
+++ b/examples/redis-unstable/modules/redisearch/Makefile
@@ -0,0 +1,7 @@
1SRC_DIR = src
2MODULE_VERSION = v8.5.90
3MODULE_REPO = https://github.com/redisearch/redisearch
4TARGET_MODULE = $(SRC_DIR)/bin/$(FULL_VARIANT)/search-community/redisearch.so
5
6include ../common.mk
7
diff --git a/examples/redis-unstable/modules/redisjson/Makefile b/examples/redis-unstable/modules/redisjson/Makefile
new file mode 100644
index 0000000..a3c49c5
--- /dev/null
+++ b/examples/redis-unstable/modules/redisjson/Makefile
@@ -0,0 +1,11 @@
1SRC_DIR = src
2MODULE_VERSION = v8.5.90
3MODULE_REPO = https://github.com/redisjson/redisjson
4TARGET_MODULE = $(SRC_DIR)/bin/$(FULL_VARIANT)/rejson.so
5
6include ../common.mk
7
8$(SRC_DIR)/.cargo_fetched:
9 cd $(SRC_DIR) && cargo fetch
10
11get_source: $(SRC_DIR)/.cargo_fetched
diff --git a/examples/redis-unstable/modules/redistimeseries/Makefile b/examples/redis-unstable/modules/redistimeseries/Makefile
new file mode 100644
index 0000000..fac3c3c
--- /dev/null
+++ b/examples/redis-unstable/modules/redistimeseries/Makefile
@@ -0,0 +1,6 @@
1SRC_DIR = src
2MODULE_VERSION = v8.5.90
3MODULE_REPO = https://github.com/redistimeseries/redistimeseries
4TARGET_MODULE = $(SRC_DIR)/bin/$(FULL_VARIANT)/redistimeseries.so
5
6include ../common.mk
diff --git a/examples/redis-unstable/modules/vector-sets/.gitignore b/examples/redis-unstable/modules/vector-sets/.gitignore
new file mode 100644
index 0000000..c72b1b8
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/.gitignore
@@ -0,0 +1,11 @@
1__pycache__
2misc
3*.so
4*.xo
5*.o
6.DS_Store
7w2v
8word2vec.bin
9TODO
10*.txt
11*.rdb
diff --git a/examples/redis-unstable/modules/vector-sets/Makefile b/examples/redis-unstable/modules/vector-sets/Makefile
new file mode 100644
index 0000000..f8c05c9
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/Makefile
@@ -0,0 +1,87 @@
1# Compiler settings
2CC = cc
3
4ifdef SANITIZER
5ifeq ($(SANITIZER),address)
6 SAN=-fsanitize=address
7else
8ifeq ($(SANITIZER),undefined)
9 SAN=-fsanitize=undefined
10else
11ifeq ($(SANITIZER),thread)
12 SAN=-fsanitize=thread
13else
14 $(error "unknown sanitizer=${SANITIZER}")
15endif
16endif
17endif
18endif
19
20CFLAGS = -O2 -Wall -Wextra -g $(SAN) -std=c11
21LDFLAGS = -lm $(SAN)
22
23# Detect OS
24uname_S := $(shell sh -c 'uname -s 2>/dev/null || echo not')
25uname_M := $(shell sh -c 'uname -m 2>/dev/null || echo not')
26
27# Shared library compile flags for linux / osx
28ifeq ($(uname_S),Linux)
29 SHOBJ_CFLAGS ?= -W -Wall -fno-common -g -ggdb -std=c11 -O2
30 SHOBJ_LDFLAGS ?= -shared
31ifneq (,$(findstring armv,$(uname_M)))
32 SHOBJ_LDFLAGS += -latomic
33endif
34ifneq (,$(findstring aarch64,$(uname_M)))
35 SHOBJ_LDFLAGS += -latomic
36endif
37else
38 SHOBJ_CFLAGS ?= -W -Wall -dynamic -fno-common -g -ggdb -std=c11 -O3
39 SHOBJ_LDFLAGS ?= -bundle -undefined dynamic_lookup
40endif
41
42# OS X 11.x doesn't have /usr/lib/libSystem.dylib and needs an explicit setting.
43ifeq ($(uname_S),Darwin)
44ifeq ("$(wildcard /usr/lib/libSystem.dylib)","")
45LIBS = -L /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lsystem
46endif
47endif
48
49.SUFFIXES: .c .so .xo .o
50
51all: vset.so
52
53.c.xo:
54 $(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@
55
56vset.xo: ../../src/redismodule.h expr.c
57
58vset.so: vset.xo hnsw.xo vset_config.xo
59 $(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) $(SAN) -lc
60
61# Example sources / objects
62SRCS = hnsw.c w2v.c vset_config.c
63OBJS = $(SRCS:.c=.o)
64
65TARGET = w2v
66MODULE = vset.so
67
68# Default target
69all: $(TARGET) $(MODULE)
70
71# Example linking rule
72$(TARGET): $(OBJS)
73 $(CC) $(OBJS) $(LDFLAGS) -o $(TARGET)
74
75# Compilation rule for object files
76%.o: %.c
77 $(CC) $(CFLAGS) -c $< -o $@
78
79expr-test: expr.c fastjson.c fastjson_test.c
80 $(CC) $(CFLAGS) expr.c -o expr-test -DTEST_MAIN -lm
81
82# Clean rule
83clean:
84 rm -f $(TARGET) $(OBJS) *.xo *.so
85
86# Declare phony targets
87.PHONY: all clean
diff --git a/examples/redis-unstable/modules/vector-sets/README.md b/examples/redis-unstable/modules/vector-sets/README.md
new file mode 100644
index 0000000..be87ff3
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/README.md
@@ -0,0 +1,733 @@
1**IMPORTANT:** *Please note that this is a merged module, it's part of the Redis binary now, and you don't need to build it and load it into Redis. Compiling Redis version 8 or greater will result into having the Vector Sets commands available. However, you could compile this module as a shared library in order to load it in older versions of Redis.*
2
3This module implements Vector Sets for Redis, a new Redis data type similar
4to Sorted Sets but having string elements associated to a vector instead of
5a score. The fundamental goal of Vector Sets is to make possible adding items,
6and later get a subset of the added items that are the most similar to a
7specified vector (often a learned embedding), or the most similar to the vector
8of an element that is already part of the Vector Set.
9
10Moreover, Vector sets implement optional filtered search capabilities: it is possible to associate attributes to all or to a subset of elements in the set, and then, using the `FILTER` option of the `VSIM` command, to ask for items similar to a given vector but also passing a filter specified as a simple mathematical expression (Like `".year > 1950"` or similar). This means that **you can have vector similarity and scalar filters at the same time**.
11
12## Installation
13
14**WARNING:** If you are running **Redis 8.0 RC1 or greater** you don't need to install anything, just compile Redis, and the Vector Sets commands will be part of the default install. Otherwise to test Vector Sets with older Redis versions follow the following instructions.
15
16Build with:
17
18 make
19
20Then load the module with the following command line, or by inserting the needed directives in the `redis.conf` file.
21
22 ./redis-server --loadmodule vset.so
23
24To run tests, I suggest using this:
25
26 ./redis-server --save "" --enable-debug-command yes
27
28The execute the tests with:
29
30 ./test.py
31
32## Reference of available commands
33
34**VADD: add items into a vector set**
35
36 VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT | Q8 | BIN]
37 [EF build-exploration-factor] [SETATTR <attributes>] [M <numlinks>]
38
39Add a new element into the vector set specified by the key.
40The vector can be provided as FP32 blob of values, or as floating point
41numbers as strings, prefixed by the number of elements (3 in the example):
42
43 VADD mykey VALUES 3 0.1 1.2 0.5 my-element
44
45Meaning of the options:
46
47`REDUCE` implements random projection, in order to reduce the
48dimensionality of the vector. The projection matrix is saved and reloaded
49along with the vector set. **Please note that** the `REDUCE` option must be passed immediately before the vector, like in `REDUCE 50 VALUES ...`.
50
51`CAS` performs the operation partially using threads, in a
52check-and-set style. The neighbor candidates collection, which is slow, is
53performed in the background, while the command is executed in the main thread.
54
55`NOQUANT` forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default.
56
57`BIN` forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality.
58
59`Q8` forces the vector to use signed 8 bit quantization. This is the default, and the option only exists in order to make sure to check at insertion time if the vector set is of the same format.
60
61`EF` plays a role in the effort made to find good candidates when connecting the new node to the existing HNSW graph. The default is 200. Using a larger value, may help to have a better recall. To improve the recall it is also possible to increase `EF` during `VSIM` searches.
62
63`SETATTR` associates attributes to the newly created entry or update the entry attributes (if it already exists). It is the same as calling the `VSETATTR` attribute separately, so please check the documentation of that command in the filtered search section of this documentation.
64
65`M` defaults to 16 and is the HNSW famous `M` parameters. It is the maximum number of connections that each node of the graph have with other nodes: more connections mean more memory, but a better ability to explore the graph. Nodes at layer zero (every node exists at least at layer zero) have `M*2` connections, while the other layers only have `M` connections. This means that, for instance, an `M` of 64 will use at least 1024 bytes of memory for each node! That is, `64 links * 2 times * 8 bytes pointers`, and even more, since on average each node has something like 1.33 layers (but the other layers have just `M` connections, instead of `M*2`). If you don't have a recall quality problem, the default is fine, and uses a limited amount of memory.
66
67**VSIM: return elements by vector similarity**
68
69 VSIM key [ELE|FP32|VALUES] <vector or element> [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON delta] [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD]
70
71The command returns similar vectors, for simplicity (and verbosity) in the following example, instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements having a vector similar to a given element already in the sorted set:
72
73 > VSIM word_embeddings ELE apple
74 1) "apple"
75 2) "apples"
76 3) "pear"
77 4) "fruit"
78 5) "berry"
79 6) "pears"
80 7) "strawberry"
81 8) "peach"
82 9) "potato"
83 10) "grape"
84
85It is possible to specify a `COUNT` and also to get the similarity score (from 1 to 0, where 1 is identical, 0 is opposite vector) between the query and the returned items.
86
87 > VSIM word_embeddings ELE apple WITHSCORES COUNT 3
88 1) "apple"
89 2) "0.9998867657923256"
90 3) "apples"
91 4) "0.8598527610301971"
92 5) "pear"
93 6) "0.8226882219314575"
94
95It is also possible to specify a `EPSILON`, that is a floating point number between 0 and 1 in order to only return elements that have a distance that is no further than the specified one. In vector sets, the returned elements have a similarity score (when compared to the query vector) that is between 1 and 0, where 1 means identical, 0 opposite vectors. If for instance the `EPSILON` option is specified with an argument of 0.2, it means that we will get only elements that have a similarity of 0.8 or better (a distance < 0.2). This is useful when a large `COUNT` is specified, yet we don't want elements that are too far away our query vector.
96
97The `EF` argument is the exploration factor: the higher it is, the slower the command becomes, but the better the index is explored to find nodes that are near to our query. Sensible values are from 50 to 1000.
98
99The `TRUTH` option forces the command to perform a linear scan of all the entries inside the set, without using the graph search inside the HNSW, so it returns the best matching elements (the perfect result set) that can be used in order to easily calculate the recall. Of course the linear scan is `O(N)`, so it is much slower than the `log(N)` (considering a small `COUNT`) provided by the HNSW index.
100
101The `NOTHREAD` option forces the command to execute the search on the data structure in the main thread. Normally `VSIM` spawns a thread instead. This may be useful for benchmarking purposes, or when we work with extremely small vector sets and don't want to pay the cost of spawning a thread. It is possible that in the future this option will be automatically used by Redis when we detect small vector sets. Note that this option blocks the server for all the time needed to complete the command, so it is a source of potential latency issues: if you are in doubt, never use it.
102
103The `WITHSCORES` option returns, for each returned element, a floating point number representing how near the element is from the query, as a similarity between 0 and 1, where 0 means the vectors are opposite, and 1 means they are pointing exactly in the same direction (maximum similarity).
104
105The `WITHATTRIBS` option returns, for each element, the JSON attribute associated with the element, or NULL for the elements missing an attribute.
106
107For `FILTER` and `FILTER-EF` options, please check the filtered search section of this documentation.
108
109Note that when `WITHSCORES` and `WITHATTRIBS` are provided at the same time, the RESP2 reply guarantees that the returned elements are always in the sequence *ele*,*score*,*attribs*, while RESP3 replies will be in the form *ele > score|attrib* when just one is provided, or *ele -> [score,attrib]* when both are provided, that is, when both options are used and RESP3 is used the score and attribute will be a two-items array associated to the element key.
110
111**VDIM: return the dimension of the vectors inside the vector set**
112
113 VDIM keyname
114
115Example:
116
117 > VDIM word_embeddings
118 (integer) 300
119
120Note that in the case of vectors that were populated using the `REDUCE`
121option, for random projection, the vector set will report the size of
122the projected (reduced) dimension. Yet the user should perform all the
123queries using full-size vectors.
124
125**VCARD: return the number of elements in a vector set**
126
127 VCARD key
128
129Example:
130
131 > VCARD word_embeddings
132 (integer) 3000000
133
134
135**VREM: remove elements from vector set**
136
137 VREM key element
138
139Example:
140
141 > VADD vset VALUES 3 1 0 1 bar
142 (integer) 1
143 > VREM vset bar
144 (integer) 1
145 > VREM vset bar
146 (integer) 0
147
148VREM does not perform thumstone / logical deletion, but will actually reclaim
149the memory from the vector set, so it is save to add and remove elements
150in a vector set in the context of long running applications that continuously
151update the same index.
152
153**VEMB: return the approximated vector of an element**
154
155 VEMB key element
156
157Example:
158
159 > VEMB word_embeddings SQL
160 1) "0.18208661675453186"
161 2) "0.08535309880971909"
162 3) "0.1365649551153183"
163 4) "-0.16501599550247192"
164 5) "0.14225517213344574"
165 ... 295 more elements ...
166
167Because vector sets perform insertion time normalization and optional
168quantization, the returned vector could be approximated. `VEMB` will take
169care to de-quantized and de-normalize the vector before returning it.
170
171It is possible to ask VEMB to return raw data, that is, the internal representation used by the vector: fp32, int8, or a bitmap for binary quantization. This behavior is triggered by the `RAW` option of of VEMB:
172
173 VEMB word_embedding apple RAW
174
175In this case the return value of the command is an array of three or more elements:
1761. The name of the quantization used, that is one of: "fp32", "bin", "q8".
1772. The a string blob containing the raw data, 4 bytes fp32 floats for fp32, a bitmap for binary quants, or int8 bytes array for q8 quants.
1783. A float representing the l2 of the vector before normalization. You need to multiply by this vector if you want to de-normalize the value for any reason.
179
180For q8 quantization, an additional elements is also returned: the quantization
181range, so the integers from -127 to 127 represent (normalized) components
182in the range `-range`, `+range`.
183
184**VISMEMBER: test if a given element already exists**
185
186This command will return 1 (or true) if the specified element is already in the vector set, otherwise 0 (or false) is returned.
187
188 VISMEMBER key element
189
190As with other existence check Redis commands, if the key does not exist it is considered as if it was empty, thus the element is reported as non existing.
191
192**VRANGE: return elements in a lexicographical range
193
194 VRANGE key start end count
195
196The `VRANGE` command has many different use cases, but its main goal is to
197provide a stateless iterator for the elements inside a vector set: that is,
198it allows to retrieve all the elements inside a vector set in small amounts
199for each call, without an explicit cursor, and with guarantees about what
200the user will miss in case the vector set is changing (elements added and/or
201removed) during the iteration.
202
203The command usage is straightforward:
204
205```
206> VRANGE word_embeddings_int8 [Redis + 10
207 1) "Redis"
208 2) "Rediscover"
209 3) "Rediscover_Ashland"
210 4) "Rediscover_Northern_Ireland"
211 5) "Rediscovered"
212 6) "Rediscovered_Bookshop"
213 7) "Rediscovering"
214 8) "Rediscovering_God"
215 9) "Rediscovering_Lost"
21610) "Rediscovers"
217```
218
219The above command returns 10 (or less, if less are available in the specified range) elements from "Redis" (inclusive) to the maximum possible element. The comparison is performed byte by byte, as `memcmp()` would do, in this way the elements have a total order. The start and end range can be either a string, prefixed by `[` or `(` (the prefix is mandatory) to tell the command if the range is inclusive or exclusive, or can be the special symbols `-` and `+` that means the maximum and minimum element.
220
221So for instance if I want to iterate all the elements, ten elements for each call, I'll proceed as such:
222
223```
224> VRANGE mykey - + 10
225 1) "a"
226 2) "a-league"
227 3) "a."
228 4) "a.d."
229 5) "a.k.a."
230 6) "a.m."
231 7) "a1"
232 8) "a2"
233 9) "a3"
23410) "a7"
235```
236
237This will give me the first 10 elements. Then I want the next ten elements
238starting from the last element in the previous result, but *excluding* it,
239so the next range will use the `(` prefix with the last element of the
240previous call, that was `"a7"`:
241
242```
243> VRANGE mykey (a7 + 10
244 1) "a930913"
245 2) "aa"
246 3) "aaa"
247 4) "aaron"
248 5) "ab"
249 6) "aba"
250 7) "abandon"
251 8) "abandoned"
252 9) "abandoning"
25310) "abandonment"
254```
255
256And so forth.
257
258The command count is mandatory, however a negative count means to return all the elements in the set. This means that `VRANGE mykey - + -1` will return every element. Of course, iterating like that means that it is possible to block the server for a long time.
259
260The command time complexity is O(1) to seek to the element (considering the element would be of reasonable size), since we use a Radix Tree in the underlying implementation, plus the time to yield "M" elements. So if M is small, each call is just executed in constant time. However the iteration of a total set (via multiple calls) of N elements is O(N). Basically: this command, with a small count, will never produce latency issues in the Redis server.
261
262In case the elements are changing continuously as the set is iterated, the guarantees are very simple: each range will produce exactly the elements that were present in the range in the moment the `VRANGE` command was executed. In other words, an iteration performed in this way is *guaranteed* to return all the elements that stayed within the vector set from the start to the end of the iteration. Elements removed or added in the meantime may be returned or not depending on the moment they were added or removed.
263
264**VLINKS: introspection command that shows neighbors for a node**
265
266 VLINKS key element [WITHSCORES]
267
268The command reports the neighbors for each level.
269
270**VINFO: introspection command that shows info about a vector set**
271
272 VINFO key
273
274Example:
275
276 > VINFO word_embeddings
277 1) quant-type
278 2) int8
279 3) vector-dim
280 4) (integer) 300
281 5) size
282 6) (integer) 3000000
283 7) max-level
284 8) (integer) 12
285 9) vset-uid
286 10) (integer) 1
287 11) hnsw-max-node-uid
288 12) (integer) 3000000
289
290**VSETATTR: associate or remove the JSON attributes of elements**
291
292 VSETATTR key element "{... json ...}"
293
294Each element of a vector set can be optionally associated with a JSON string
295in order to use the `FILTER` option of `VSIM` to filter elements by scalars
296(see the filtered search section for more information). This command can set,
297update (if already set) or delete (if you set to an empty string) the
298associated JSON attributes of an element.
299
300The command returns 0 if the element or the key don't exist, without
301raising an error, otherwise 1 is returned, and the element attributes
302are set or updated.
303
304**VGETATTR: retrieve the JSON attributes of elements**
305
306 VGETATTR key element
307
308The command returns the JSON attribute associated with an element, or
309null if there is no element associated, or no element at all, or no key.
310
311**VRANDMEMBER: return random members from a vector set**
312
313 VRANDMEMBER key [count]
314
315Return one or more random elements from a vector set.
316
317The semantics of this command are similar to Redis's native SRANDMEMBER command:
318
319- When called without count, returns a single random element from the set, as a single string (no array reply).
320- When called with a positive count, returns up to count distinct random elements (no duplicates).
321- When called with a negative count, returns count random elements, potentially with duplicates.
322- If the count value is larger than the set size (and positive), only the entire set is returned.
323
324If the key doesn't exist, returns a Null reply if count is not given, or an empty array if a count is provided.
325
326Examples:
327
328 > VADD vset VALUES 3 1 0 0 elem1
329 (integer) 1
330 > VADD vset VALUES 3 0 1 0 elem2
331 (integer) 1
332 > VADD vset VALUES 3 0 0 1 elem3
333 (integer) 1
334
335 # Return a single random element
336 > VRANDMEMBER vset
337 "elem2"
338
339 # Return 2 distinct random elements
340 > VRANDMEMBER vset 2
341 1) "elem1"
342 2) "elem3"
343
344 # Return 3 random elements with possible duplicates
345 > VRANDMEMBER vset -3
346 1) "elem2"
347 2) "elem2"
348 3) "elem1"
349
350 # Return more elements than in the set (returns all elements)
351 > VRANDMEMBER vset 10
352 1) "elem1"
353 2) "elem2"
354 3) "elem3"
355
356 # When key doesn't exist
357 > VRANDMEMBER nonexistent
358 (nil)
359 > VRANDMEMBER nonexistent 3
360 (empty array)
361
362This command is particularly useful for:
363
3641. Selecting random samples from a vector set for testing or training.
3652. Performance testing by retrieving random elements for subsequent similarity searches.
366
367When the user asks for unique elements (positev count) the implementation optimizes for two scenarios:
368- For small sample sizes (less than 20% of the set size), it uses a dictionary to avoid duplicates, and performs a real random walk inside the graph.
369- For large sample sizes (more than 20% of the set size), it starts from a random node and sequentially traverses the internal list, providing faster performances but not really "random" elements.
370
371The command has `O(N)` worst-case time complexity when requesting many unique elements (it uses linear scanning), or `O(M*log(N))` complexity when the users asks for `M` random elements in a sorted set of `N` elements, with `M` much smaller than `N`.
372
373# Filtered search
374
375Each element of the vector set can be associated with a set of attributes specified as a JSON blob:
376
377 > VADD vset VALUES 3 1 1 1 a SETATTR '{"year": 1950}'
378 (integer) 1
379 > VADD vset VALUES 3 -1 -1 -1 b SETATTR '{"year": 1951}'
380 (integer) 1
381
382Specifying an attribute with the `SETATTR` option of `VADD` is exactly equivalent to adding an element and then setting (or updating, if already set) the attributes JSON string. Also the symmetrical `VGETATTR` command returns the attribute associated to a given element.
383
384 > VADD vset VALUES 3 0 1 0 c
385 (integer) 1
386 > VSETATTR vset c '{"year": 1952}'
387 (integer) 1
388 > VGETATTR vset c
389 "{\"year\": 1952}"
390
391At this point, I may use the FILTER option of VSIM to only ask for the subset of elements that are verified by my expression:
392
393 > VSIM vset VALUES 3 0 0 0 FILTER '.year > 1950'
394 1) "c"
395 2) "b"
396
397The items will be returned again in order of similarity (most similar first), but only the items with the year field matching the expression is returned.
398
399The expressions are similar to what you would write inside the `if` statement of JavaScript or other familiar programming languages: you can use `and`, `or`, the obvious math operators like `+`, `-`, `/`, `>=`, `<`, ... and so forth (see the expressions section for more info). The selectors of the JSON object attributes start with a dot followed by the name of the key inside the JSON objects.
400
401Elements with invalid JSON or not having a given specified field **are considered as not matching** the expression, but will not generate any error at runtime.
402
403## FILTER expressions capabilities
404
405FILTER expressions allow you to perform complex filtering on vector similarity results using a JavaScript-like syntax. The expression is evaluated against each element's JSON attributes, with only elements that satisfy the expression being included in the results.
406
407### Expression Syntax
408
409Expressions support the following operators and capabilities:
410
4111. **Arithmetic operators**: `+`, `-`, `*`, `/`, `%` (modulo), `**` (exponentiation)
4122. **Comparison operators**: `>`, `>=`, `<`, `<=`, `==`, `!=`
4133. **Logical operators**: `and`/`&&`, `or`/`||`, `!`/`not`
4144. **Containment operator**: `in`
4155. **Parentheses** for grouping: `(...)`
416
417### Selector Notation
418
419Attributes are accessed using dot notation:
420
421- `.year` references the "year" attribute
422- `.movie.year` would **NOT** reference the "year" field inside a "movie" object, only keys that are at the first level of the JSON object are accessible.
423
424### JSON and expressions data types
425
426Expressions can work with:
427
428- Numbers (dobule precision floats)
429- Strings (enclosed in single or double quotes)
430- Booleans (no native type: they are represented as 1 for true, 0 for false)
431- Arrays (for use with the `in` operator: `value in [1, 2, 3]`)
432
433JSON attributes are converted in this way:
434
435- Numbers will be converted to numbers.
436- Strings to strings.
437- Booleans to 0 or 1 number.
438- Arrays to tuples (for "in" operator), but only if composed of just numbers and strings.
439
440Any other type is ignored, and accessig it will make the expression evaluate to false.
441
442### The IN operator
443
444The `IN` operator works in two ways, it can test for membership in an array, like in:
445
446 5 in [1, 2, 3]
447 "foo" in [1, "foo", "bar"]
448
449But can also check for substrings, in case the A and B operators are both strings.
450
451 "foo" in "barfoobar" # Will evaluate to true
452 "zap" in "foobar" # Will evaluate to false
453
454### Examples
455
456```
457# Find items from the 1980s
458VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.year >= 1980 and .year < 1990'
459
460# Find action movies with high ratings
461VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.genre == "action" and .rating > 8.0'
462
463# Find movies directed by either Spielberg or Nolan
464VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.director in ["Spielberg", "Nolan"]'
465
466# Complex condition with numerical operations
467VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '(.year - 2000) ** 2 < 100 and .rating / 2 > 4'
468```
469
470### Error Handling
471
472Elements with any of the following conditions are considered not matching:
473- Missing the queried JSON attribute
474- Having invalid JSON in their attributes
475- Having a JSON value that cannot be converted to the expected type
476
477This behavior allows you to safely filter on optional attributes without generating errors.
478
479### FILTER effort
480
481The `FILTER-EF` option controls the maximum effort spent when filtering vector search results.
482
483When performing vector similarity search with filtering, Vector Sets perform the standard similarity search as they apply the filter expression to each node. Since many results might be filtered out, Vector Sets may need to examine a lot more candidates than the requested `COUNT` to ensure sufficient matching results are returned. Actually, if the elements matching the filter are very rare or if there are less than elements matching than the specified count, this would trigger a full scan of the HNSW graph.
484
485For this reason, by default, the maximum effort is limited to a reasonable amount of nodes explored.
486
487### Modifying the FILTER effort
488
4891. By default, Vector Sets will explore up to `COUNT * 100` candidates to find matching results.
4902. You can control this exploration with the `FILTER-EF` parameter.
4913. A higher `FILTER-EF` value increases the chances of finding all relevant matches at the cost of increased processing time.
4924. A `FILTER-EF` of zero will explore as many nodes as needed in order to actually return the number of elements specified by `COUNT`.
4935. Even when a high `FILTER-EF` value is specified **the implementation will do a lot less work** if the elements passing the filter are very common, because of the early stop conditions of the HNSW implementation (once the specified amount of elements is reached and the quality check of the other candidates trigger an early stop).
494
495```
496VSIM key [ELE|FP32|VALUES] <vector or element> COUNT 10 FILTER '.year > 2000' FILTER-EF 500
497```
498
499In this example, Vector Sets will examine up to 500 potential nodes. Of course if count is reached before exploring 500 nodes, and the quality checks show that it is not possible to make progresses on similarity, the search is ended sooner.
500
501### Performance Considerations
502
503- If you have highly selective filters (few items match), use a higher `FILTER-EF`, or just design your application in order to handle a result set that is smaller than the requested count. Note that anyway the additional elements may be too distant than the query vector.
504- For less selective filters, the default should be sufficient.
505- Very selective filters with low `FILTER-EF` values may return fewer items than requested.
506- Extremely high values may impact performance without significantly improving results.
507
508The optimal `FILTER-EF` value depends on:
5091. The selectivity of your filter.
5102. The distribution of your data.
5113. The required recall quality.
512
513A good practice is to start with the default and increase if needed when you observe fewer results than expected.
514
515### Testing a larg-ish data set
516
517To really see how things work at scale, you can [download](https://antirez.com/word2vec_with_attribs.rdb) the following dataset:
518
519 wget https://antirez.com/word2vec_with_attribs.rdb
520
521It contains the 3 million words in Word2Vec having as attribute a JSON with just the length of the word. Because of the length distribution of words in large amounts of texts, where longer words become less and less common, this is ideal to check how filtering behaves with a filter verifying as true with less and less elements in a vector set.
522
523For instance:
524
525 > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 6"
526 1) "pastas"
527 2) "rotini"
528 3) "gnocci"
529 4) "panino"
530 5) "salads"
531 6) "breads"
532 7) "salame"
533 8) "sauces"
534 9) "cheese"
535 10) "fritti"
536
537This will easily retrieve the desired amount of items (`COUNT` is 10 by default) since there are many items of length 6. However:
538
539 > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33"
540 1) "skinless_boneless_chicken_breasts"
541 2) "boneless_skinless_chicken_breasts"
542 3) "Boneless_skinless_chicken_breasts"
543
544This time even if we asked for 10 items, we only get 3, since the default filter effort will be `10*100 = 1000`. We can tune this giving the effort in an explicit way, with the risk of our query being slower, of course:
545
546 > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33" FILTER-EF 10000
547 1) "skinless_boneless_chicken_breasts"
548 2) "boneless_skinless_chicken_breasts"
549 3) "Boneless_skinless_chicken_breasts"
550 4) "mozzarella_feta_provolone_cheddar"
551 5) "Greatfood.com_R_www.greatfood.com"
552 6) "Pepperidge_Farm_Goldfish_crackers"
553 7) "Prosecuted_Mobsters_Rebuilt_Dying"
554 8) "Crispy_Snacker_Sandwiches_Popcorn"
555 9) "risultati_delle_partite_disputate"
556 10) "Peppermint_Mocha_Twist_Gingersnap"
557
558This time we get all the ten items, even if the last one will be quite far from our query vector. We encourage to experiment with this test dataset in order to understand better the dynamics of the implementation and the natural tradeoffs of filtered search.
559
560**Keep in mind** that by default, Redis Vector Sets will try to avoid a likely very useless huge scan of the HNSW graph, and will be more happy to return few or no elements at all, since this is almost always what the user actually wants in the context of retrieving *similar* items to the query.
561
562# Single Instance Scalability and Latency
563
564Vector Sets implement a threading model that allows Redis to handle many concurrent requests: by default `VSIM` is always threaded, and `VADD` is not (but can be partially threaded using the `CAS` option). This section explains how the threading and locking mechanisms work, and what to expect in terms of performance.
565
566## Threading Model
567
568- The `VSIM` command runs in a separate thread by default, allowing Redis to continue serving other commands.
569- A maximum of 32 threads can run concurrently (defined by `HNSW_MAX_THREADS`).
570- When this limit is reached, additional `VSIM` requests are queued - Redis remains responsive, no latency event is generated.
571- The `VADD` command with the `CAS` option also leverages threading for the computation-heavy candidate search phase, but the insertion itself is performed in the main thread. `VADD` always runs in a sub-millisecond time, so this is not a source of latency, but having too many hundreds of writes per second can be challenging to handle with a single instance. Please, look at the next section about multiple instances scalability.
572- Commands run within Lua scripts, MULTI/EXEC blocks, or from replication are executed in the main thread to ensure consistency.
573
574```
575> VSIM vset VALUES 3 1 1 1 FILTER '.year > 2000' # This runs in a thread.
576> VADD vset VALUES 3 1 1 1 element CAS # Candidate search runs in a thread.
577```
578
579## Locking Mechanism
580
581Vector Sets use a read/write locking mechanism to coordinate access:
582
583- Reads (`VSIM`, `VEMB`, etc.) acquire a read lock, allowing multiple concurrent reads.
584- Writes (`VADD`, `VREM`, etc.) acquire a write lock, temporarily blocking all reads.
585- When a write lock is requested while reads are in progress, the write operation waits for all reads to complete.
586- Once a write lock is granted, all reads are blocked until the write completes.
587- Each thread has a dedicated slot for tracking visited nodes during graph traversal, avoiding contention. This improves performances but limits the maximum number of concurrent threads, since each node has a memory cost proportional to the number of slots.
588
589## DEL latency
590
591Deleting a very large vector set (millions of elements) can cause latency spikes, as deletion rebuilds connections between nodes. This may change in the future.
592The deletion latency is most noticeable when using `DEL` on a key containing a large vector set or when the key expires.
593
594## Performance Characteristics
595
596- Search operations (`VSIM`) scale almost linearly with the number of CPU cores available, up to the thread limit. You can expect a Vector Set composed of million of items associated with components of dimension 300, with the default int8 quantization, to deliver around 50k VSIM operations per second in a single host.
597- Insertion operations (`VADD`) are more computationally expensive than searches, and can't be threaded: expect much lower throughput, in the range of a few thousands inserts per second.
598- Binary quantization offers significantly faster search performance at the cost of some recall quality, while int8 quantization, the default, seems to have very small impacts on recall quality, while it significantly improves performances and space efficiency.
599- The `EF` parameter has a major impact on both search quality and performance - higher values mean better recall but slower searches.
600- Graph traversal time scales logarithmically with the number of elements, making Vector Sets efficient even with millions of vectors
601
602## Loading / Saving performances
603
604Vector Sets are able to serialize on disk the graph structure as it is in memory, so loading back the data does not need to rebuild the HNSW graph. This means that Redis can load millions of items per minute. For instance 3 million items with 300 components vectors can be loaded back into memory into around 15 seconds.
605
606# Scaling vector sets to multiple instances
607
608The fundamental way vector sets can be scaled to very large data sets
609and to many Redis instances is that a given very large set of vectors
610can be partitioned into N different Redis keys, that can also live into
611different Redis instances.
612
613For instance, I could add my elements into `key0`, `key1`, `key2`, by hashing
614the item in some way, like doing `crc32(item)%3`, effectively splitting
615the dataset into three different parts. However once I want all the vectors
616of my dataset near to a given query vector, I could simply perform the
617`VSIM` command against all the three keys, merging the results by
618score (so the commands must be called using the `WITHSCORES` option) on
619the client side: once the union of the results are ordered by the
620similarity score, the query is equivalent to having a single key `key1+2+3`
621containing all the items.
622
623There are a few interesting facts to note about this pattern:
624
6251. It is possible to have a logical sorted set that is as big as the sum of all the Redis instances we are using.
6262. Deletion operations remain simple, we can hash the key and select the key where our item belongs.
6273. However, even if I use 10 different Redis instances, I'm not going to reach 10x the **read** operations per second, compared to using a single server: for each logical query, I need to query all the instances. Yet, smaller graphs are faster to navigate, so there is some win even from the point of view of CPU usage.
6284. Insertions, so **write** queries, will be scaled linearly: I can add N items against N instances at the same time, splitting the insertion load evenly. This is very important since vector sets, being based on HNSW data structures, are slower to add items than to query similar items, by a very big factor.
6295. While it cannot guarantee always the best results, with proper timeout management this system may be considered *highly available*, since if a subset of N instances are reachable, I'll be still be able to return similar items to my query vector.
630
631Notably, this pattern can be implemented in a way that avoids paying the sum of the round trip time with all the servers: it is possible to send the queries at the same time to all the instances, so that latency will be equal the slower reply out of of the N servers queries.
632
633# Optimizing memory usage
634
635Vector Sets, or better, HNSWs, the underlying data structure used by Vector Sets, combined with the features provided by the Vector Sets themselves (quantization, random projection, filtering, ...) form an implementation that has a non-trivial space of parameters that can be tuned. Despite to the complexity of the implementation and of vector similarity problems, here there is a list of simple ideas that can drive the user to pick the best settings:
636
637* 8 bit quantization (the default) is almost always a win. It reduces the memory usage of vectors by a factor of 4, yet the performance penalty in terms of recall is minimal. It also reduces insertion and search time by around 2 times or more.
638* Binary quantization is much more extreme: it makes vector sets a lot faster, but increases the recall error in a sensible way, for instance from 95% to 80% if all the parameters remain the same. Yet, the speedup is really big, and the memory usage of vectors, compaerd to full precision vectors, 32 times smaller.
639* Vectors memory usage are not the only responsible for Vector Set high memory usage per entry: nodes contain, on average `M*2 + M*0.33` pointers, where M is by default 16 (but can be tuned in `VADD`, see the `M` option). Also each node has the string item and the optional JSON attributes: those should be as small as possible in order to avoid contributing more to the memory usage.
640* The `M` parameter should be increased to 32 or more only when a near perfect recall is really needed.
641* It is possible to gain space (less memory usage) sacrificing time (more CPU time) by using a low `M` (the default of 16, for instance) and a high `EF` (the effort parameter of `VSIM`) in order to scan the graph more deeply.
642* When memory usage is seriosu concern, and there is the suspect the vectors we are storing don't contain as much information - at least for our use case - to justify the number of components they feature, random projection (the `REDUCE` option of `VADD`) could be tested to see if dimensionality reduction is possible with acceptable precision loss.
643
644## Random projection tradeoffs
645
646Sometimes learned vectors are not as information dense as we could guess, that
647is there are components having similar meanings in the space, and components
648having values that don't really represent features that matter in our use case.
649
650At the same time, certain vectors are very big, 1024 components or more. In this cases, it is possible to use the random projection feature of Redis Vector Sets in order to reduce both space (less RAM used) and space (more operstions per second). The feature is accessible via the `REDUCE` option of the `VADD` command. However, keep in mind that you need to test how much reduction impacts the performances of your vectors in term of recall and quality of the results you get back.
651
652## What is a random projection?
653
654The concept of Random Projection is relatively simple to grasp. For instance, a projection that turns a 100 components vector into a 10 components vector will perform a different linear transformation between the 100 components and each of the target 10 components. Please note that *each of the target components* will get some random amount of all the 100 original components. It is mathematically proved that this process results in a vector space where elements still have similar distances among them, but still some information will get lost.
655
656## Examples of projections and loss of precision
657
658To show you a bit of a extreme case, let's take Word2Vec 3 million items and compress them from 300 to 100, 50 and 25 components vectors. Then, we check the recall compared to the ground truth against each of the vector sets produced in this way (using different `REDUCE` parameters of `VADD`). This is the result, obtained asking for the top 10 elements.
659
660```
661----------------------------------------------------------------------
662Key Average Recall % Std Dev
663----------------------------------------------------------------------
664word_embeddings_int8 95.98 12.14
665 ^ This is the same key used for ground truth, but without TRUTH option
666word_embeddings_reduced_100 40.20 20.13
667word_embeddings_reduced_50 24.42 16.89
668word_embeddings_reduced_25 14.31 9.99
669```
670
671Here the dimensionality reduction we are using is quite extreme: from 300 to 100 means that 66.6% of the original information is lost. The recall drops from 96% to 40%, down to 24% and 14% for even more extreme dimension reduction.
672
673Reducing the dimension of vectors that are already relatively small, like the above example, of 300 components, will provide only relatively small memory savings, especially because by default Vector Sets use `int8` quantization, that will use only one byte per component:
674
675```
676> MEMORY USAGE word_embeddings_int8
677(integer) 3107002888
678> MEMORY USAGE word_embeddings_reduced_100
679(integer) 2507122888
680```
681
682Of course going, for example, from 2048 component vectors to 1024 would provide a much more sensible memory saving, even with the `int8` quantization used by Vector Sets, assuming the recall loss is acceptable. Other than the memory saving, there is also the reduction in CPU time, translating to more operations per second.
683
684Another thing to note is that, with certain embedding models, binary quantization (that offers a 8x reduction of memory usage compared to 8 bit quants, and a very big speedup in computation) performs much better than reducing the dimension of vectors of the same amount via random projections:
685
686```
687word_embeddings_bin 35.48 19.78
688```
689
690Here in the same test did above: we have a 35% recall which is not too far than the 40% obtained with a random projection from 300 to 100 components. However, while the first technique reduces the size by 3 times, the size reduced of binary quantization is by 8 times.
691
692```
693> memory usage word_embeddings_bin
694(integer) 2327002888
695```
696
697In this specific case the key uses JSON attributes and has a graph connection overhead that is much bigger than the 300 bits each vector takes, but, as already said, for big vectors (1024 components, for instance) or for lower values of `M` (see `VADD`, the `M` parameter connects the level of connectivity, so it changes the amount of pointers used per node) the memory saving is much stronger.
698
699# Vector Sets troubleshooting and understandability
700
701## Debugging poor recall or unexpected results
702
703Vector graphs and similarity queries pose many challenges mainly due to the following three problems:
704
7051. The error due to the approximated nature of Vector Sets is hard to evaluate.
7062. The error added by the quantization is often depends on the exact vector space (the embedding we are using **and** how far apart the elements we represent into such embeddings are).
7073. We live in the illusion that learned embeddings capture the best similarity possible among elements, which is obviously not always true, and highly application dependent.
708
709The only way to debug such problems, is the ability to inspect step by step what is happening inside our application, and the structure of the HNSW graph itself. To do so, we suggest to consider the following tools:
710
7111. The `TRUTH` option of the `VSIM` command is able to return the ground truth of the most similar elements, without using the HNSW graph, but doing a linear scan.
7122. The `VLINKS` command allows to explore the graph to see if the connections among nodes make sense, and to investigate why a given node may be more isolated than expected. Such command can also be used in a different way, when we want very fast "similar items" without paying the HNSW traversal time. It exploits the fact that we have a direct reference from each element in our vector set to each node in our HNSW graph.
7133. The `WITHSCORES` option, in the supported commands, return a value that is directly related to the *cosine similarity* between the query and the items vectors, the interval of the similarity is simply rescaled from the -1, 1 original range to 0, 1, otherwise the metric is identical.
714
715## Clients, latency and bandwidth usage
716
717During Vector Sets testing, we discovered that often clients introduce considerable latecy and CPU usage (in the client side, not in Redis) for two main reasons:
718
7191. Often the serialization to `VALUES ... list of floats ...` can be very slow.
7202. The vector payload of floats represented as strings is very large, resulting in high bandwidth usage and latency, compared to other Redis commands.
721
722Switching from `VALUES` to `FP32` as a method for transmitting vectors may easily provide 10-20x speedups.
723
724# Implementation details
725
726Vector sets are based on the `hnsw.c` implementation of the HNSW data structure with extensions for speed and functionality.
727
728The main features are:
729
730* Proper nodes deletion with relinking.
731* 8 bits and binary quantization.
732* Threaded queries.
733* Filtered search with predicate callback.
diff --git a/examples/redis-unstable/modules/vector-sets/commands.json b/examples/redis-unstable/modules/vector-sets/commands.json
new file mode 100644
index 0000000..5f42f87
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/commands.json
@@ -0,0 +1,446 @@
1{
2 "VADD": {
3 "summary": "Add one or more elements to a vector set, or update its vector if it already exists",
4 "complexity": "O(log(N)) for each element added, where N is the number of elements in the vector set.",
5 "group": "vector_set",
6 "since": "8.0.0",
7 "arity": -5,
8 "function": "vaddCommand",
9 "arguments": [
10 {
11 "name": "key",
12 "type": "key"
13 },
14 {
15 "token": "REDUCE",
16 "name": "reduce",
17 "type": "block",
18 "optional": true,
19 "arguments": [
20 {
21 "name": "dim",
22 "type": "integer"
23 }
24 ]
25 },
26 {
27 "name": "format",
28 "type": "oneof",
29 "arguments": [
30 {
31 "name": "fp32",
32 "type": "pure-token",
33 "token": "FP32"
34 },
35 {
36 "name": "values",
37 "type": "pure-token",
38 "token": "VALUES"
39 }
40 ]
41 },
42 {
43 "name": "vector",
44 "type": "string"
45 },
46 {
47 "name": "element",
48 "type": "string"
49 },
50 {
51 "token": "CAS",
52 "name": "cas",
53 "type": "pure-token",
54 "optional": true
55 },
56 {
57 "name": "quant_type",
58 "type": "oneof",
59 "optional": true,
60 "arguments": [
61 {
62 "name": "noquant",
63 "type": "pure-token",
64 "token": "NOQUANT"
65 },
66 {
67 "name": "bin",
68 "type": "pure-token",
69 "token": "BIN"
70 },
71 {
72 "name": "q8",
73 "type": "pure-token",
74 "token": "Q8"
75 }
76 ]
77 },
78 {
79 "token": "EF",
80 "name": "build-exploration-factor",
81 "type": "integer",
82 "optional": true
83 },
84 {
85 "token": "SETATTR",
86 "name": "attributes",
87 "type": "string",
88 "optional": true
89 },
90 {
91 "token": "M",
92 "name": "numlinks",
93 "type": "integer",
94 "optional": true
95 }
96 ],
97 "command_flags": [
98 "WRITE",
99 "DENYOOM"
100 ]
101 },
102 "VREM": {
103 "summary": "Remove an element from a vector set",
104 "complexity": "O(log(N)) for each element removed, where N is the number of elements in the vector set.",
105 "group": "vector_set",
106 "since": "8.0.0",
107 "arity": 3,
108 "function": "vremCommand",
109 "command_flags": [
110 "WRITE"
111 ],
112 "arguments": [
113 {
114 "name": "key",
115 "type": "key"
116 },
117 {
118 "name": "element",
119 "type": "string"
120 }
121 ]
122 },
123 "VSIM": {
124 "summary": "Return elements by vector similarity",
125 "complexity": "O(log(N)) where N is the number of elements in the vector set.",
126 "group": "vector_set",
127 "since": "8.0.0",
128 "arity": -4,
129 "function": "vsimCommand",
130 "command_flags": [
131 "READONLY"
132 ],
133 "arguments": [
134 {
135 "name": "key",
136 "type": "key"
137 },
138 {
139 "name": "format",
140 "type": "oneof",
141 "arguments": [
142 {
143 "name": "ele",
144 "type": "pure-token",
145 "token": "ELE"
146 },
147 {
148 "name": "fp32",
149 "type": "pure-token",
150 "token": "FP32"
151 },
152 {
153 "name": "values",
154 "type": "pure-token",
155 "token": "VALUES"
156 }
157 ]
158 },
159 {
160 "name": "vector_or_element",
161 "type": "string"
162 },
163 {
164 "token": "WITHSCORES",
165 "name": "withscores",
166 "type": "pure-token",
167 "optional": true
168 },
169 {
170 "token": "WITHATTRIBS",
171 "name": "withattribs",
172 "type": "pure-token",
173 "optional": true
174 },
175 {
176 "token": "COUNT",
177 "name": "count",
178 "type": "integer",
179 "optional": true
180 },
181 {
182 "token": "EPSILON",
183 "name": "max_distance",
184 "type": "double",
185 "optional": true
186 },
187 {
188 "token": "EF",
189 "name": "search-exploration-factor",
190 "type": "integer",
191 "optional": true
192 },
193 {
194 "token": "FILTER",
195 "name": "expression",
196 "type": "string",
197 "optional": true
198 },
199 {
200 "token": "FILTER-EF",
201 "name": "max-filtering-effort",
202 "type": "integer",
203 "optional": true
204 },
205 {
206 "token": "TRUTH",
207 "name": "truth",
208 "type": "pure-token",
209 "optional": true
210 },
211 {
212 "token": "NOTHREAD",
213 "name": "nothread",
214 "type": "pure-token",
215 "optional": true
216 }
217 ]
218 },
219 "VDIM": {
220 "summary": "Return the dimension of vectors in the vector set",
221 "complexity": "O(1)",
222 "group": "vector_set",
223 "since": "8.0.0",
224 "arity": 2,
225 "function": "vdimCommand",
226 "command_flags": [
227 "READONLY",
228 "FAST"
229 ],
230 "arguments": [
231 {
232 "name": "key",
233 "type": "key"
234 }
235 ]
236 },
237 "VCARD": {
238 "summary": "Return the number of elements in a vector set",
239 "complexity": "O(1)",
240 "group": "vector_set",
241 "since": "8.0.0",
242 "arity": 2,
243 "function": "vcardCommand",
244 "command_flags": [
245 "READONLY",
246 "FAST"
247 ],
248 "arguments": [
249 {
250 "name": "key",
251 "type": "key"
252 }
253 ]
254 },
255 "VEMB": {
256 "summary": "Return the vector associated with an element",
257 "complexity": "O(1)",
258 "group": "vector_set",
259 "since": "8.0.0",
260 "arity": -3,
261 "function": "vembCommand",
262 "command_flags": [
263 "READONLY"
264 ],
265 "arguments": [
266 {
267 "name": "key",
268 "type": "key"
269 },
270 {
271 "name": "element",
272 "type": "string"
273 },
274 {
275 "token": "RAW",
276 "name": "raw",
277 "type": "pure-token",
278 "optional": true
279 }
280 ]
281 },
282 "VLINKS": {
283 "summary": "Return the neighbors of an element at each layer in the HNSW graph",
284 "complexity": "O(1)",
285 "group": "vector_set",
286 "since": "8.0.0",
287 "arity": -3,
288 "function": "vlinksCommand",
289 "command_flags": [
290 "READONLY"
291 ],
292 "arguments": [
293 {
294 "name": "key",
295 "type": "key"
296 },
297 {
298 "name": "element",
299 "type": "string"
300 },
301 {
302 "token": "WITHSCORES",
303 "name": "withscores",
304 "type": "pure-token",
305 "optional": true
306 }
307 ]
308 },
309 "VINFO": {
310 "summary": "Return information about a vector set",
311 "complexity": "O(1)",
312 "group": "vector_set",
313 "since": "8.0.0",
314 "arity": 2,
315 "function": "vinfoCommand",
316 "command_flags": [
317 "READONLY",
318 "FAST"
319 ],
320 "arguments": [
321 {
322 "name": "key",
323 "type": "key"
324 }
325 ]
326 },
327 "VSETATTR": {
328 "summary": "Associate or remove the JSON attributes of elements",
329 "complexity": "O(1)",
330 "group": "vector_set",
331 "since": "8.0.0",
332 "arity": 4,
333 "function": "vsetattrCommand",
334 "command_flags": [
335 "WRITE"
336 ],
337 "arguments": [
338 {
339 "name": "key",
340 "type": "key"
341 },
342 {
343 "name": "element",
344 "type": "string"
345 },
346 {
347 "name": "json",
348 "type": "string"
349 }
350 ]
351 },
352 "VGETATTR": {
353 "summary": "Retrieve the JSON attributes of elements",
354 "complexity": "O(1)",
355 "group": "vector_set",
356 "since": "8.0.0",
357 "arity": 3,
358 "function": "vgetattrCommand",
359 "command_flags": [
360 "READONLY"
361 ],
362 "arguments": [
363 {
364 "name": "key",
365 "type": "key"
366 },
367 {
368 "name": "element",
369 "type": "string"
370 }
371 ]
372 },
373 "VRANDMEMBER": {
374 "summary": "Return one or multiple random members from a vector set",
375 "complexity": "O(N) where N is the absolute value of the count argument.",
376 "group": "vector_set",
377 "since": "8.0.0",
378 "arity": -2,
379 "function": "vrandmemberCommand",
380 "command_flags": [
381 "READONLY"
382 ],
383 "arguments": [
384 {
385 "name": "key",
386 "type": "key"
387 },
388 {
389 "name": "count",
390 "type": "integer",
391 "optional": true
392 }
393 ]
394 },
395 "VISMEMBER": {
396 "summary": "Check if an element exists in a vector set",
397 "complexity": "O(1)",
398 "group": "vector_set",
399 "since": "8.2.0",
400 "arity": 3,
401 "function": "vismemberCommand",
402 "command_flags": [
403 "READONLY"
404 ],
405 "arguments": [
406 {
407 "name": "key",
408 "type": "key"
409 },
410 {
411 "name": "element",
412 "type": "string"
413 }
414 ]
415 },
416 "VRANGE": {
417 "summary": "Return elements in a lexicographical range",
418 "complexity": "O(log(K)+M) where K is the number of elements in the start prefix, and M is the number of elements returned. In practical terms, the command is just O(M)",
419 "group": "vector_set",
420 "since": "8.4.0",
421 "arity": -4,
422 "function": "vrangeCommand",
423 "command_flags": [
424 "READONLY"
425 ],
426 "arguments": [
427 {
428 "name": "key",
429 "type": "key"
430 },
431 {
432 "name": "start",
433 "type": "string"
434 },
435 {
436 "name": "end",
437 "type": "string"
438 },
439 {
440 "name": "count",
441 "type": "integer",
442 "optional": true
443 }
444 ]
445 }
446}
diff --git a/examples/redis-unstable/modules/vector-sets/examples/cli-tool/.gitignore b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/.gitignore
new file mode 100644
index 0000000..5ceb386
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/.gitignore
@@ -0,0 +1 @@
venv
diff --git a/examples/redis-unstable/modules/vector-sets/examples/cli-tool/README.md b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/README.md
new file mode 100644
index 0000000..ad21744
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/README.md
@@ -0,0 +1,44 @@
1This tool is similar to redis-cli (but very basic) but allows
2to specify arguments that are expanded as vectors by calling
3ollama to get the embedding.
4
5Whatever is passed as !"foo bar" gets expanded into
6 VALUES ... embedding ...
7
8You must have ollama running with the mxbai-emb-large model
9already installed for this to work.
10
11Example:
12
13 redis> KEYS *
14 1) food_items
15 2) glove_embeddings_bin
16 3) many_movies_mxbai-embed-large_BIN
17 4) many_movies_mxbai-embed-large_NOQUANT
18 5) word_embeddings
19 6) word_embeddings_bin
20 7) glove_embeddings_fp32
21
22 redis> VSIM food_items !"drinks with fruit"
23 1) (Fruit)Juices,Lemonade,100ml,50 cal,210 kJ
24 2) (Fruit)Juices,Limeade,100ml,128 cal,538 kJ
25 3) CannedFruit,Canned Fruit Cocktail,100g,81 cal,340 kJ
26 4) (Fruit)Juices,Energy-Drink,100ml,87 cal,365 kJ
27 5) Fruits,Lime,100g,30 cal,126 kJ
28 6) (Fruit)Juices,Coconut Water,100ml,19 cal,80 kJ
29 7) Fruits,Lemon,100g,29 cal,122 kJ
30 8) (Fruit)Juices,Clamato,100ml,60 cal,252 kJ
31 9) Fruits,Fruit salad,100g,50 cal,210 kJ
32 10) (Fruit)Juices,Capri-Sun,100ml,41 cal,172 kJ
33
34 redis> vsim food_items !"barilla"
35 1) Pasta&Noodles,Spirelli,100g,367 cal,1541 kJ
36 2) Pasta&Noodles,Farfalle,100g,358 cal,1504 kJ
37 3) Pasta&Noodles,Capellini,100g,353 cal,1483 kJ
38 4) Pasta&Noodles,Spaetzle,100g,368 cal,1546 kJ
39 5) Pasta&Noodles,Cappelletti,100g,164 cal,689 kJ
40 6) Pasta&Noodles,Penne,100g,351 cal,1474 kJ
41 7) Pasta&Noodles,Shells,100g,353 cal,1483 kJ
42 8) Pasta&Noodles,Linguine,100g,357 cal,1499 kJ
43 9) Pasta&Noodles,Rotini,100g,353 cal,1483 kJ
44 10) Pasta&Noodles,Rigatoni,100g,353 cal,1483 kJ
diff --git a/examples/redis-unstable/modules/vector-sets/examples/cli-tool/cli.py b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/cli.py
new file mode 100755
index 0000000..614a021
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/cli.py
@@ -0,0 +1,160 @@
1#
2# Copyright (c) 2009-Present, Redis Ltd.
3# All rights reserved.
4#
5# Licensed under your choice of (a) the Redis Source Available License 2.0
6# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
7# GNU Affero General Public License v3 (AGPLv3).
8#
9
10#!/usr/bin/env python3
11import argparse
12import redis
13import requests
14import re
15import shlex
16from prompt_toolkit import PromptSession
17from prompt_toolkit.history import InMemoryHistory
18
19# Default Ollama embeddings URL (can be overridden with --ollama-url)
20OLLAMA_URL = "http://localhost:11434/api/embeddings"
21
22def get_embedding(text):
23 """Get embedding from local Ollama API"""
24 url = OLLAMA_URL
25 payload = {
26 "model": "mxbai-embed-large",
27 "prompt": text
28 }
29 try:
30 response = requests.post(url, json=payload)
31 response.raise_for_status()
32 return response.json()['embedding']
33 except requests.exceptions.RequestException as e:
34 raise Exception(f"Failed to get embedding: {str(e)}")
35
36def process_embedding_patterns(text):
37 """Process !"text" and !!"text" patterns in the command"""
38
39 def replace_with_embedding(match):
40 text = match.group(1)
41 embedding = get_embedding(text)
42 return f"VALUES {len(embedding)} {' '.join(map(str, embedding))}"
43
44 def replace_with_embedding_and_text(match):
45 text = match.group(1)
46 embedding = get_embedding(text)
47 # Return both the embedding values and the original text as next argument
48 return f'VALUES {len(embedding)} {" ".join(map(str, embedding))} "{text}"'
49
50 # First handle !!"text" pattern (must be done before !"text")
51 text = re.sub(r'!!"([^"]*)"', replace_with_embedding_and_text, text)
52 # Then handle !"text" pattern
53 text = re.sub(r'!"([^"]*)"', replace_with_embedding, text)
54 return text
55
56def parse_command(command):
57 """Parse command respecting quoted strings"""
58 try:
59 # Use shlex to properly handle quoted strings
60 return shlex.split(command)
61 except ValueError as e:
62 raise Exception(f"Invalid command syntax: {str(e)}")
63
64def format_response(response):
65 """Format the response to match Redis protocol style"""
66 if response is None:
67 return "(nil)"
68 elif isinstance(response, bool):
69 return "+OK" if response else "(error) Operation failed"
70 elif isinstance(response, (list, set)):
71 if not response:
72 return "(empty list or set)"
73 return "\n".join(f"{i+1}) {item}" for i, item in enumerate(response))
74 elif isinstance(response, int):
75 return f"(integer) {response}"
76 else:
77 return str(response)
78
79def main():
80 global OLLAMA_URL
81
82 parser = argparse.ArgumentParser(prog="cli.py", add_help=False)
83 parser.add_argument("--ollama-url", dest="ollama_url",
84 help="Ollama embeddings API URL (default: {OLLAMA_URL})",
85 default=OLLAMA_URL)
86 args, _ = parser.parse_known_args()
87 OLLAMA_URL = args.ollama_url
88
89 # Default connection to localhost:6379
90 r = redis.Redis(host='localhost', port=6379, decode_responses=True)
91
92 try:
93 # Test connection
94 r.ping()
95 print("Connected to Redis. Type your commands (CTRL+D to exit):")
96 print("Special syntax:")
97 print(" !\"text\" - Replace with embedding")
98 print(" !!\"text\" - Replace with embedding and append text as value")
99 print(" \"text\" - Quote strings containing spaces")
100 except redis.ConnectionError:
101 print("Error: Could not connect to Redis server")
102 return
103
104 # Setup prompt session with history
105 session = PromptSession(history=InMemoryHistory())
106
107 # Main loop
108 while True:
109 try:
110 # Read input with line editing support
111 command = session.prompt("redis> ")
112
113 # Skip empty commands
114 if not command.strip():
115 continue
116
117 # Process any embedding patterns before parsing
118 try:
119 processed_command = process_embedding_patterns(command)
120 except Exception as e:
121 print(f"(error) Embedding processing failed: {str(e)}")
122 continue
123
124 # Parse the command respecting quoted strings
125 try:
126 parts = parse_command(processed_command)
127 except Exception as e:
128 print(f"(error) {str(e)}")
129 continue
130
131 if not parts:
132 continue
133
134 cmd = parts[0].lower()
135 args = parts[1:]
136
137 # Execute command
138 try:
139 method = getattr(r, cmd, None)
140 if method is not None:
141 result = method(*args)
142 else:
143 # Use execute_command for unknown commands
144 result = r.execute_command(cmd, *args)
145 print(format_response(result))
146 except AttributeError:
147 print(f"(error) Unknown command '{cmd}'")
148
149 except EOFError:
150 print("\nGoodbye!")
151 break
152 except KeyboardInterrupt:
153 continue # Allow Ctrl+C to clear current line
154 except redis.RedisError as e:
155 print(f"(error) {str(e)}")
156 except Exception as e:
157 print(f"(error) {str(e)}")
158
159if __name__ == "__main__":
160 main()
diff --git a/examples/redis-unstable/modules/vector-sets/examples/glove-100/README b/examples/redis-unstable/modules/vector-sets/examples/glove-100/README
new file mode 100644
index 0000000..e8bb6dd
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/glove-100/README
@@ -0,0 +1,3 @@
1wget http://ann-benchmarks.com/glove-100-angular.hdf5
2python insert.py
3python recall.py (use --k <count> optionally, default top-10)
diff --git a/examples/redis-unstable/modules/vector-sets/examples/glove-100/insert.py b/examples/redis-unstable/modules/vector-sets/examples/glove-100/insert.py
new file mode 100644
index 0000000..595c77f
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/glove-100/insert.py
@@ -0,0 +1,56 @@
1#
2# Copyright (c) 2009-Present, Redis Ltd.
3# All rights reserved.
4#
5# Licensed under your choice of (a) the Redis Source Available License 2.0
6# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
7# GNU Affero General Public License v3 (AGPLv3).
8#
9
10import h5py
11import redis
12from tqdm import tqdm
13
14# Initialize Redis connection
15redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8')
16
17def add_to_redis(index, embedding):
18 """Add embedding to Redis using VADD command"""
19 args = ["VADD", "glove_embeddings", "VALUES", "100"] # 100 is vector dimension
20 args.extend(map(str, embedding))
21 args.append(f"{index}") # Using index as identifier since we don't have words
22 args.append("EF")
23 args.append("200")
24 # args.append("NOQUANT")
25 # args.append("BIN")
26 redis_client.execute_command(*args)
27
28def main():
29 with h5py.File('glove-100-angular.hdf5', 'r') as f:
30 # Get the train dataset
31 train_vectors = f['train']
32 total_vectors = train_vectors.shape[0]
33
34 print(f"Starting to process {total_vectors} vectors...")
35
36 # Process in batches to avoid memory issues
37 batch_size = 1000
38
39 for i in tqdm(range(0, total_vectors, batch_size)):
40 batch_end = min(i + batch_size, total_vectors)
41 batch = train_vectors[i:batch_end]
42
43 for j, vector in enumerate(batch):
44 try:
45 current_index = i + j
46 add_to_redis(current_index, vector)
47
48 except Exception as e:
49 print(f"Error processing vector {current_index}: {str(e)}")
50 continue
51
52 if (i + batch_size) % 10000 == 0:
53 print(f"Processed {i + batch_size} vectors")
54
55if __name__ == "__main__":
56 main()
diff --git a/examples/redis-unstable/modules/vector-sets/examples/glove-100/recall.py b/examples/redis-unstable/modules/vector-sets/examples/glove-100/recall.py
new file mode 100644
index 0000000..3064aed
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/glove-100/recall.py
@@ -0,0 +1,87 @@
1#
2# Copyright (c) 2009-Present, Redis Ltd.
3# All rights reserved.
4#
5# Licensed under your choice of (a) the Redis Source Available License 2.0
6# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
7# GNU Affero General Public License v3 (AGPLv3).
8#
9
10import h5py
11import redis
12import numpy as np
13from tqdm import tqdm
14import argparse
15
16# Initialize Redis connection
17redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8')
18
19def get_redis_neighbors(query_vector, k):
20 """Get nearest neighbors using Redis VSIM command"""
21 args = ["VSIM", "glove_embeddings_bin", "VALUES", "100"]
22 args.extend(map(str, query_vector))
23 args.extend(["COUNT", str(k)])
24 args.extend(["EF", 100])
25 if False:
26 print(args)
27 exit(1)
28 results = redis_client.execute_command(*args)
29 return [int(res) for res in results]
30
31def calculate_recall(ground_truth, predicted, k):
32 """Calculate recall@k"""
33 relevant = set(ground_truth[:k])
34 retrieved = set(predicted[:k])
35 return len(relevant.intersection(retrieved)) / len(relevant)
36
37def main():
38 parser = argparse.ArgumentParser(description='Evaluate Redis VSIM recall')
39 parser.add_argument('--k', type=int, default=10, help='Number of neighbors to evaluate (default: 10)')
40 parser.add_argument('--batch', type=int, default=100, help='Progress update frequency (default: 100)')
41 args = parser.parse_args()
42
43 k = args.k
44 batch_size = args.batch
45
46 with h5py.File('glove-100-angular.hdf5', 'r') as f:
47 test_vectors = f['test'][:]
48 ground_truth_neighbors = f['neighbors'][:]
49
50 num_queries = len(test_vectors)
51 recalls = []
52
53 print(f"Evaluating recall@{k} for {num_queries} test queries...")
54
55 for i in tqdm(range(num_queries)):
56 try:
57 # Get Redis results
58 redis_neighbors = get_redis_neighbors(test_vectors[i], k)
59
60 # Get ground truth for this query
61 true_neighbors = ground_truth_neighbors[i]
62
63 # Calculate recall
64 recall = calculate_recall(true_neighbors, redis_neighbors, k)
65 recalls.append(recall)
66
67 if (i + 1) % batch_size == 0:
68 current_avg_recall = np.mean(recalls)
69 print(f"Current average recall@{k} after {i+1} queries: {current_avg_recall:.4f}")
70
71 except Exception as e:
72 print(f"Error processing query {i}: {str(e)}")
73 continue
74
75 final_recall = np.mean(recalls)
76 print("\nFinal Results:")
77 print(f"Average recall@{k}: {final_recall:.4f}")
78 print(f"Total queries evaluated: {len(recalls)}")
79
80 # Save detailed results
81 with open(f'recall_evaluation_results_k{k}.txt', 'w') as f:
82 f.write(f"Average recall@{k}: {final_recall:.4f}\n")
83 f.write(f"Total queries evaluated: {len(recalls)}\n")
84 f.write(f"Individual query recalls: {recalls}\n")
85
86if __name__ == "__main__":
87 main()
diff --git a/examples/redis-unstable/modules/vector-sets/examples/movies/.gitignore b/examples/redis-unstable/modules/vector-sets/examples/movies/.gitignore
new file mode 100644
index 0000000..e736c6a
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/movies/.gitignore
@@ -0,0 +1,2 @@
1mpst_full_data.csv
2partition.json
diff --git a/examples/redis-unstable/modules/vector-sets/examples/movies/README b/examples/redis-unstable/modules/vector-sets/examples/movies/README
new file mode 100644
index 0000000..3931a6d
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/movies/README
@@ -0,0 +1,30 @@
1This example maps long form movies plots to movies titles.
2It will create fp32 and binary vectors (the two extremes).
3
41. Install ollama, and install the embedding model "mxbai-embed-large"
52. Download mpst_full_data.csv from https://www.kaggle.com/datasets/cryptexcode/mpst-movie-plot-synopses-with-tags
63. python insert.py
7
8127.0.0.1:6379> VSIM many_movies_mxbai-embed-large_NOQUANT ELE "The Matrix"
9 1) "The Matrix"
10 2) "The Matrix Reloaded"
11 3) "The Matrix Revolutions"
12 4) "Commando"
13 5) "Avatar"
14 6) "Forbidden Planet"
15 7) "Terminator Salvation"
16 8) "Mandroid"
17 9) "The Omega Code"
1810) "Coherence"
19
20127.0.0.1:6379> VSIM many_movies_mxbai-embed-large_BIN ELE "The Matrix"
21 1) "The Matrix"
22 2) "The Matrix Reloaded"
23 3) "The Matrix Revolutions"
24 4) "The Omega Code"
25 5) "Forbidden Planet"
26 6) "Avatar"
27 7) "John Carter"
28 8) "System Shock 2"
29 9) "Coherence"
3010) "Tomorrowland"
diff --git a/examples/redis-unstable/modules/vector-sets/examples/movies/insert.py b/examples/redis-unstable/modules/vector-sets/examples/movies/insert.py
new file mode 100644
index 0000000..dad4da3
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/examples/movies/insert.py
@@ -0,0 +1,57 @@
1#
2# Copyright (c) 2009-Present, Redis Ltd.
3# All rights reserved.
4#
5# Licensed under your choice of (a) the Redis Source Available License 2.0
6# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
7# GNU Affero General Public License v3 (AGPLv3).
8#
9
10import csv
11import requests
12import redis
13
14ModelName="mxbai-embed-large"
15
16# Initialize Redis connection, setting encoding to utf-8
17redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8')
18
19def get_embedding(text):
20 """Get embedding from local API"""
21 url = "http://localhost:11434/api/embeddings"
22 payload = {
23 "model": ModelName,
24 "prompt": "Represent this movie plot and genre: "+text
25 }
26 response = requests.post(url, json=payload)
27 return response.json()['embedding']
28
29def add_to_redis(title, embedding, quant_type):
30 """Add embedding to Redis using VADD command"""
31 args = ["VADD", "many_movies_"+ModelName+"_"+quant_type, "VALUES", str(len(embedding))]
32 args.extend(map(str, embedding))
33 args.append(title)
34 args.append(quant_type)
35 redis_client.execute_command(*args)
36
37def main():
38 with open('mpst_full_data.csv', 'r', encoding='utf-8') as file:
39 reader = csv.DictReader(file)
40
41 for movie in reader:
42 try:
43 text_to_embed = f"{movie['title']} {movie['plot_synopsis']} {movie['tags']}"
44
45 print(f"Getting embedding for: {movie['title']}")
46 embedding = get_embedding(text_to_embed)
47
48 add_to_redis(movie['title'], embedding, "BIN")
49 add_to_redis(movie['title'], embedding, "NOQUANT")
50 print(f"Successfully processed: {movie['title']}")
51
52 except Exception as e:
53 print(f"Error processing {movie['title']}: {str(e)}")
54 continue
55
56if __name__ == "__main__":
57 main()
diff --git a/examples/redis-unstable/modules/vector-sets/expr.c b/examples/redis-unstable/modules/vector-sets/expr.c
new file mode 100644
index 0000000..4f3a1cc
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/expr.c
@@ -0,0 +1,959 @@
1/* Filtering of objects based on simple expressions.
2 * This powers the FILTER option of Vector Sets, but it is otherwise
3 * general code to be used when we want to tell if a given object (with fields)
4 * passes or fails a given test for scalars, strings, ...
5 *
6 * Copyright (c) 2009-Present, Redis Ltd.
7 * All rights reserved.
8 *
9 * Licensed under your choice of (a) the Redis Source Available License 2.0
10 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
11 * GNU Affero General Public License v3 (AGPLv3).
12 * Originally authored by: Salvatore Sanfilippo.
13 */
14
15#ifdef TEST_MAIN
16#define RedisModule_Alloc malloc
17#define RedisModule_Realloc realloc
18#define RedisModule_Free free
19#define RedisModule_Strdup strdup
20#define RedisModule_Assert assert
21#define _DEFAULT_SOURCE
22#define _USE_MATH_DEFINES
23#include <assert.h>
24#include <math.h>
25#endif
26
27#include <stdio.h>
28#include <stdlib.h>
29#include <ctype.h>
30#include <math.h>
31#include <string.h>
32
33#define EXPR_TOKEN_EOF 0
34#define EXPR_TOKEN_NUM 1
35#define EXPR_TOKEN_STR 2
36#define EXPR_TOKEN_TUPLE 3
37#define EXPR_TOKEN_SELECTOR 4
38#define EXPR_TOKEN_OP 5
39#define EXPR_TOKEN_NULL 6
40
41#define EXPR_OP_OPAREN 0 /* ( */
42#define EXPR_OP_CPAREN 1 /* ) */
43#define EXPR_OP_NOT 2 /* ! */
44#define EXPR_OP_POW 3 /* ** */
45#define EXPR_OP_MULT 4 /* * */
46#define EXPR_OP_DIV 5 /* / */
47#define EXPR_OP_MOD 6 /* % */
48#define EXPR_OP_SUM 7 /* + */
49#define EXPR_OP_DIFF 8 /* - */
50#define EXPR_OP_GT 9 /* > */
51#define EXPR_OP_GTE 10 /* >= */
52#define EXPR_OP_LT 11 /* < */
53#define EXPR_OP_LTE 12 /* <= */
54#define EXPR_OP_EQ 13 /* == */
55#define EXPR_OP_NEQ 14 /* != */
56#define EXPR_OP_IN 15 /* in */
57#define EXPR_OP_AND 16 /* and */
58#define EXPR_OP_OR 17 /* or */
59
60/* This structure represents a token in our expression. It's either
61 * literals like 4, "foo", or operators like "+", "-", "and", or
62 * json selectors, that start with a dot: ".age", ".properties.somearray[1]" */
63typedef struct exprtoken {
64 int refcount; // Reference counting for memory reclaiming.
65 int token_type; // Token type of the just parsed token.
66 int offset; // Chars offset in expression.
67 union {
68 double num; // Value for EXPR_TOKEN_NUM.
69 struct {
70 char *start; // String pointer for EXPR_TOKEN_STR / SELECTOR.
71 size_t len; // String len for EXPR_TOKEN_STR / SELECTOR.
72 char *heapstr; // True if we have a private allocation for this
73 // string. When possible, it just references to the
74 // string expression we compiled, exprstate->expr.
75 } str;
76 int opcode; // Opcode ID for EXPR_TOKEN_OP.
77 struct {
78 struct exprtoken **ele;
79 size_t len;
80 } tuple; // Tuples are like [1, 2, 3] for "in" operator.
81 };
82} exprtoken;
83
84/* Simple stack of expr tokens. This is used both to represent the stack
85 * of values and the stack of operands during VM execution. */
86typedef struct exprstack {
87 exprtoken **items;
88 int numitems;
89 int allocsize;
90} exprstack;
91
92typedef struct exprstate {
93 char *expr; /* Expression string to compile. Note that
94 * expression token strings point directly to this
95 * string. */
96 char *p; // Current position inside 'expr', while parsing.
97
98 // Virtual machine state.
99 exprstack values_stack;
100 exprstack ops_stack; // Operator stack used during compilation.
101 exprstack tokens; // Expression processed into a sequence of tokens.
102 exprstack program; // Expression compiled into opcodes and values.
103} exprstate;
104
105/* Valid operators. */
106struct {
107 char *opname;
108 int oplen;
109 int opcode;
110 int precedence;
111 int arity;
112} ExprOptable[] = {
113 {"(", 1, EXPR_OP_OPAREN, 7, 0},
114 {")", 1, EXPR_OP_CPAREN, 7, 0},
115 {"!", 1, EXPR_OP_NOT, 6, 1},
116 {"not", 3, EXPR_OP_NOT, 6, 1},
117 {"**", 2, EXPR_OP_POW, 5, 2},
118 {"*", 1, EXPR_OP_MULT, 4, 2},
119 {"/", 1, EXPR_OP_DIV, 4, 2},
120 {"%", 1, EXPR_OP_MOD, 4, 2},
121 {"+", 1, EXPR_OP_SUM, 3, 2},
122 {"-", 1, EXPR_OP_DIFF, 3, 2},
123 {">", 1, EXPR_OP_GT, 2, 2},
124 {">=", 2, EXPR_OP_GTE, 2, 2},
125 {"<", 1, EXPR_OP_LT, 2, 2},
126 {"<=", 2, EXPR_OP_LTE, 2, 2},
127 {"==", 2, EXPR_OP_EQ, 2, 2},
128 {"!=", 2, EXPR_OP_NEQ, 2, 2},
129 {"in", 2, EXPR_OP_IN, 2, 2},
130 {"and", 3, EXPR_OP_AND, 1, 2},
131 {"&&", 2, EXPR_OP_AND, 1, 2},
132 {"or", 2, EXPR_OP_OR, 0, 2},
133 {"||", 2, EXPR_OP_OR, 0, 2},
134 {NULL, 0, 0, 0, 0} // Terminator.
135};
136
137#define EXPR_OP_SPECIALCHARS "+-*%/!()<>=|&"
138#define EXPR_SELECTOR_SPECIALCHARS "_-"
139
140/* ================================ Expr token ============================== */
141
142/* Return an heap allocated token of the specified type, setting the
143 * reference count to 1. */
144exprtoken *exprNewToken(int type) {
145 exprtoken *t = RedisModule_Alloc(sizeof(exprtoken));
146 memset(t,0,sizeof(*t));
147 t->token_type = type;
148 t->refcount = 1;
149 return t;
150}
151
152/* Generic free token function, can be used to free stack allocated
153 * objects (in this case the pointer itself will not be freed) or
154 * heap allocated objects. See the wrappers below. */
155void exprTokenRelease(exprtoken *t) {
156 if (t == NULL) return;
157
158 RedisModule_Assert(t->refcount > 0); // Catch double free & more.
159 t->refcount--;
160 if (t->refcount > 0) return;
161
162 // We reached refcount 0: free the object.
163 if (t->token_type == EXPR_TOKEN_STR) {
164 if (t->str.heapstr != NULL) RedisModule_Free(t->str.heapstr);
165 } else if (t->token_type == EXPR_TOKEN_TUPLE) {
166 for (size_t j = 0; j < t->tuple.len; j++)
167 exprTokenRelease(t->tuple.ele[j]);
168 if (t->tuple.ele) RedisModule_Free(t->tuple.ele);
169 }
170 RedisModule_Free(t);
171}
172
173void exprTokenRetain(exprtoken *t) {
174 t->refcount++;
175}
176
177/* ============================== Stack handling ============================ */
178
179#include <stdlib.h>
180#include <string.h>
181
182#define EXPR_STACK_INITIAL_SIZE 16
183
184/* Initialize a new expression stack. */
185void exprStackInit(exprstack *stack) {
186 stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE);
187 stack->numitems = 0;
188 stack->allocsize = EXPR_STACK_INITIAL_SIZE;
189}
190
191/* Push a token pointer onto the stack. Does not increment the refcount
192 * of the token: it is up to the caller doing this. */
193void exprStackPush(exprstack *stack, exprtoken *token) {
194 /* Check if we need to grow the stack. */
195 if (stack->numitems == stack->allocsize) {
196 size_t newsize = stack->allocsize * 2;
197 exprtoken **newitems =
198 RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize);
199 stack->items = newitems;
200 stack->allocsize = newsize;
201 }
202 stack->items[stack->numitems] = token;
203 stack->numitems++;
204}
205
206/* Pop a token pointer from the stack. Return NULL if the stack is
207 * empty. Does NOT recrement the refcount of the token, it's up to the
208 * caller to do so, as the new owner of the reference. */
209exprtoken *exprStackPop(exprstack *stack) {
210 if (stack->numitems == 0) return NULL;
211 stack->numitems--;
212 return stack->items[stack->numitems];
213}
214
215/* Just return the last element pushed, without consuming it nor altering
216 * the reference count. */
217exprtoken *exprStackPeek(exprstack *stack) {
218 if (stack->numitems == 0) return NULL;
219 return stack->items[stack->numitems-1];
220}
221
222/* Free the stack structure state, including the items it contains, that are
223 * assumed to be heap allocated. The passed pointer itself is not freed. */
224void exprStackFree(exprstack *stack) {
225 for (int j = 0; j < stack->numitems; j++)
226 exprTokenRelease(stack->items[j]);
227 RedisModule_Free(stack->items);
228}
229
230/* Just reset the stack removing all the items, but leaving it in a state
231 * that makes it still usable for new elements. */
232void exprStackReset(exprstack *stack) {
233 for (int j = 0; j < stack->numitems; j++)
234 exprTokenRelease(stack->items[j]);
235 stack->numitems = 0;
236}
237
238/* =========================== Expression compilation ======================= */
239
240void exprConsumeSpaces(exprstate *es) {
241 while(es->p[0] && isspace(es->p[0])) es->p++;
242}
243
244/* Parse an operator or a literal (just "null" currently).
245 * When parsing operators, the function will try to match the longest match
246 * in the operators table. */
247exprtoken *exprParseOperatorOrLiteral(exprstate *es) {
248 exprtoken *t = exprNewToken(EXPR_TOKEN_OP);
249 char *start = es->p;
250
251 while(es->p[0] &&
252 (isalpha(es->p[0]) ||
253 strchr(EXPR_OP_SPECIALCHARS,es->p[0]) != NULL))
254 {
255 es->p++;
256 }
257
258 int matchlen = es->p - start;
259 int bestlen = 0;
260 int j;
261
262 // Check if it's a literal.
263 if (matchlen == 4 && !memcmp("null",start,4)) {
264 t->token_type = EXPR_TOKEN_NULL;
265 return t;
266 }
267
268 // Find the longest matching operator.
269 for (j = 0; ExprOptable[j].opname != NULL; j++) {
270 if (ExprOptable[j].oplen > matchlen) continue;
271 if (memcmp(ExprOptable[j].opname, start, ExprOptable[j].oplen) != 0)
272 {
273 continue;
274 }
275 if (ExprOptable[j].oplen > bestlen) {
276 t->opcode = ExprOptable[j].opcode;
277 bestlen = ExprOptable[j].oplen;
278 }
279 }
280 if (bestlen == 0) {
281 exprTokenRelease(t);
282 return NULL;
283 } else {
284 es->p = start + bestlen;
285 }
286 return t;
287}
288
289// Valid selector charset.
290static int is_selector_char(int c) {
291 return (isalpha(c) ||
292 isdigit(c) ||
293 strchr(EXPR_SELECTOR_SPECIALCHARS,c) != NULL);
294}
295
296/* Parse selectors, they start with a dot and can have alphanumerical
297 * or few special chars. */
298exprtoken *exprParseSelector(exprstate *es) {
299 exprtoken *t = exprNewToken(EXPR_TOKEN_SELECTOR);
300 es->p++; // Skip dot.
301 char *start = es->p;
302
303 while(es->p[0] && is_selector_char(es->p[0])) es->p++;
304 int matchlen = es->p - start;
305 t->str.start = start;
306 t->str.len = matchlen;
307 return t;
308}
309
310exprtoken *exprParseNumber(exprstate *es) {
311 exprtoken *t = exprNewToken(EXPR_TOKEN_NUM);
312 char num[256];
313 int idx = 0;
314 while(isdigit(es->p[0]) || es->p[0] == '.' || es->p[0] == 'e' ||
315 es->p[0] == 'E' || (idx == 0 && es->p[0] == '-'))
316 {
317 if (idx >= (int)sizeof(num)-1) {
318 exprTokenRelease(t);
319 return NULL;
320 }
321 num[idx++] = es->p[0];
322 es->p++;
323 }
324 num[idx] = 0;
325
326 char *endptr;
327 t->num = strtod(num, &endptr);
328 if (*endptr != '\0') {
329 exprTokenRelease(t);
330 return NULL;
331 }
332 return t;
333}
334
335exprtoken *exprParseString(exprstate *es) {
336 char quote = es->p[0]; /* Store the quote type (' or "). */
337 es->p++; /* Skip opening quote. */
338
339 exprtoken *t = exprNewToken(EXPR_TOKEN_STR);
340 t->str.start = es->p;
341
342 while(es->p[0] != '\0') {
343 if (es->p[0] == '\\' && es->p[1] != '\0') {
344 es->p += 2; // Skip escaped char.
345 continue;
346 }
347 if (es->p[0] == quote) {
348 t->str.len = es->p - t->str.start;
349 es->p++; // Skip closing quote.
350 return t;
351 }
352 es->p++;
353 }
354 /* If we reach here, string was not terminated. */
355 exprTokenRelease(t);
356 return NULL;
357}
358
359/* Parse a tuple of the form [1, "foo", 42]. No nested tuples are
360 * supported. This type is useful mostly to be used with the "IN"
361 * operator. */
362exprtoken *exprParseTuple(exprstate *es) {
363 exprtoken *t = exprNewToken(EXPR_TOKEN_TUPLE);
364 t->tuple.ele = NULL;
365 t->tuple.len = 0;
366 es->p++; /* Skip opening '['. */
367
368 size_t allocated = 0;
369 while(1) {
370 exprConsumeSpaces(es);
371
372 /* Check for empty tuple or end. */
373 if (es->p[0] == ']') {
374 es->p++;
375 break;
376 }
377
378 /* Grow tuple array if needed. */
379 if (t->tuple.len == allocated) {
380 size_t newsize = allocated == 0 ? 4 : allocated * 2;
381 exprtoken **newele = RedisModule_Realloc(t->tuple.ele,
382 sizeof(exprtoken*) * newsize);
383 t->tuple.ele = newele;
384 allocated = newsize;
385 }
386
387 /* Parse tuple element. */
388 exprtoken *ele = NULL;
389 if (isdigit(es->p[0]) || es->p[0] == '-') {
390 ele = exprParseNumber(es);
391 } else if (es->p[0] == '"' || es->p[0] == '\'') {
392 ele = exprParseString(es);
393 } else {
394 exprTokenRelease(t);
395 return NULL;
396 }
397
398 /* Error parsing number/string? */
399 if (ele == NULL) {
400 exprTokenRelease(t);
401 return NULL;
402 }
403
404 /* Store element if no error was detected. */
405 t->tuple.ele[t->tuple.len] = ele;
406 t->tuple.len++;
407
408 /* Check for next element. */
409 exprConsumeSpaces(es);
410 if (es->p[0] == ']') {
411 es->p++;
412 break;
413 }
414 if (es->p[0] != ',') {
415 exprTokenRelease(t);
416 return NULL;
417 }
418 es->p++; /* Skip comma. */
419 }
420 return t;
421}
422
423/* Deallocate the object returned by exprCompile(). */
424void exprFree(exprstate *es) {
425 if (es == NULL) return;
426
427 /* Free the original expression string. */
428 if (es->expr) RedisModule_Free(es->expr);
429
430 /* Free all stacks. */
431 exprStackFree(&es->values_stack);
432 exprStackFree(&es->ops_stack);
433 exprStackFree(&es->tokens);
434 exprStackFree(&es->program);
435
436 /* Free the state object itself. */
437 RedisModule_Free(es);
438}
439
440/* Split the provided expression into a stack of tokens. Returns
441 * 0 on success, 1 on error. */
442int exprTokenize(exprstate *es, int *errpos) {
443 /* Main parsing loop. */
444 while(1) {
445 exprConsumeSpaces(es);
446
447 /* Set a flag to see if we can consider the - part of the
448 * number, or an operator. */
449 int minus_is_number = 0; // By default is an operator.
450
451 exprtoken *last = exprStackPeek(&es->tokens);
452 if (last == NULL) {
453 /* If we are at the start of an expression, the minus is
454 * considered a number. */
455 minus_is_number = 1;
456 } else if (last->token_type == EXPR_TOKEN_OP &&
457 last->opcode != EXPR_OP_CPAREN)
458 {
459 /* Also, if the previous token was an operator, the minus
460 * is considered a number, unless the previous operator is
461 * a closing parens. In such case it's like (...) -5, or alike
462 * and we want to emit an operator. */
463 minus_is_number = 1;
464 }
465
466 /* Parse based on the current character. */
467 exprtoken *current = NULL;
468 if (*es->p == '\0') {
469 current = exprNewToken(EXPR_TOKEN_EOF);
470 } else if (isdigit(*es->p) ||
471 (minus_is_number && *es->p == '-' && isdigit(es->p[1])))
472 {
473 current = exprParseNumber(es);
474 } else if (*es->p == '"' || *es->p == '\'') {
475 current = exprParseString(es);
476 } else if (*es->p == '.' && is_selector_char(es->p[1])) {
477 current = exprParseSelector(es);
478 } else if (*es->p == '[') {
479 current = exprParseTuple(es);
480 } else if (isalpha(*es->p) || strchr(EXPR_OP_SPECIALCHARS, *es->p)) {
481 current = exprParseOperatorOrLiteral(es);
482 }
483
484 if (current == NULL) {
485 if (errpos) *errpos = es->p - es->expr;
486 return 1; // Syntax Error.
487 }
488
489 /* Push the current token to tokens stack. */
490 exprStackPush(&es->tokens, current);
491 if (current->token_type == EXPR_TOKEN_EOF) break;
492 }
493 return 0;
494}
495
496/* Helper function to get operator precedence from the operator table. */
497int exprGetOpPrecedence(int opcode) {
498 for (int i = 0; ExprOptable[i].opname != NULL; i++) {
499 if (ExprOptable[i].opcode == opcode)
500 return ExprOptable[i].precedence;
501 }
502 return -1;
503}
504
505/* Helper function to get operator arity from the operator table. */
506int exprGetOpArity(int opcode) {
507 for (int i = 0; ExprOptable[i].opname != NULL; i++) {
508 if (ExprOptable[i].opcode == opcode)
509 return ExprOptable[i].arity;
510 }
511 return -1;
512}
513
514/* Process an operator during compilation. Returns 0 on success, 1 on error.
515 * This function will retain a reference of the operator 'op' in case it
516 * is pushed on the operators stack. */
517int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *errpos) {
518 if (op->opcode == EXPR_OP_OPAREN) {
519 // This is just a marker for us. Do nothing.
520 exprStackPush(&es->ops_stack, op);
521 exprTokenRetain(op);
522 return 0;
523 }
524
525 if (op->opcode == EXPR_OP_CPAREN) {
526 /* Process operators until we find the matching opening parenthesis. */
527 while (1) {
528 exprtoken *top_op = exprStackPop(&es->ops_stack);
529 if (top_op == NULL) {
530 if (errpos) *errpos = op->offset;
531 return 1;
532 }
533
534 if (top_op->opcode == EXPR_OP_OPAREN) {
535 /* Open parethesis found. Our work finished. */
536 exprTokenRelease(top_op);
537 return 0;
538 }
539
540 int arity = exprGetOpArity(top_op->opcode);
541 if (*stack_items < arity) {
542 exprTokenRelease(top_op);
543 if (errpos) *errpos = top_op->offset;
544 return 1;
545 }
546
547 /* Move the operator on the program stack. */
548 exprStackPush(&es->program, top_op);
549 *stack_items = *stack_items - arity + 1;
550 }
551 }
552
553 int curr_prec = exprGetOpPrecedence(op->opcode);
554
555 /* Process operators with higher or equal precedence. */
556 while (1) {
557 exprtoken *top_op = exprStackPeek(&es->ops_stack);
558 if (top_op == NULL || top_op->opcode == EXPR_OP_OPAREN) break;
559
560 int top_prec = exprGetOpPrecedence(top_op->opcode);
561 if (top_prec < curr_prec) break;
562 /* Special case for **: only pop if precedence is strictly higher
563 * so that the operator is right associative, that is:
564 * 2 ** 3 ** 2 is evaluated as 2 ** (3 ** 2) == 512 instead
565 * of (2 ** 3) ** 2 == 64. */
566 if (op->opcode == EXPR_OP_POW && top_prec <= curr_prec) break;
567
568 /* Pop and add to program. */
569 top_op = exprStackPop(&es->ops_stack);
570 int arity = exprGetOpArity(top_op->opcode);
571 if (*stack_items < arity) {
572 exprTokenRelease(top_op);
573 if (errpos) *errpos = top_op->offset;
574 return 1;
575 }
576
577 /* Move to the program stack. */
578 exprStackPush(&es->program, top_op);
579 *stack_items = *stack_items - arity + 1;
580 }
581
582 /* Push current operator. */
583 exprStackPush(&es->ops_stack, op);
584 exprTokenRetain(op);
585 return 0;
586}
587
588/* Compile the expression into a set of push-value and exec-operator
589 * that exprRun() can execute. The function returns an expstate object
590 * that can be used for execution of the program. On error, NULL
591 * is returned, and optionally the position of the error into the
592 * expression is returned by reference. */
593exprstate *exprCompile(char *expr, int *errpos) {
594 /* Initialize expression state. */
595 exprstate *es = RedisModule_Alloc(sizeof(exprstate));
596 es->expr = RedisModule_Strdup(expr);
597 es->p = es->expr;
598
599 /* Initialize all stacks. */
600 exprStackInit(&es->values_stack);
601 exprStackInit(&es->ops_stack);
602 exprStackInit(&es->tokens);
603 exprStackInit(&es->program);
604
605 /* Tokenization. */
606 if (exprTokenize(es, errpos)) {
607 exprFree(es);
608 return NULL;
609 }
610
611 /* Compile the expression into a sequence of operations. */
612 int stack_items = 0; // Track # of items that would be on the stack
613 // during execution. This way we can detect arity
614 // issues at compile time.
615
616 /* Process each token. */
617 for (int i = 0; i < es->tokens.numitems; i++) {
618 exprtoken *token = es->tokens.items[i];
619
620 if (token->token_type == EXPR_TOKEN_EOF) break;
621
622 /* Handle values (numbers, strings, selectors, null). */
623 if (token->token_type == EXPR_TOKEN_NUM ||
624 token->token_type == EXPR_TOKEN_STR ||
625 token->token_type == EXPR_TOKEN_TUPLE ||
626 token->token_type == EXPR_TOKEN_SELECTOR ||
627 token->token_type == EXPR_TOKEN_NULL)
628 {
629 exprStackPush(&es->program, token);
630 exprTokenRetain(token);
631 stack_items++;
632 continue;
633 }
634
635 /* Handle operators. */
636 if (token->token_type == EXPR_TOKEN_OP) {
637 if (exprProcessOperator(es, token, &stack_items, errpos)) {
638 exprFree(es);
639 return NULL;
640 }
641 continue;
642 }
643 }
644
645 /* Process remaining operators on the stack. */
646 while (es->ops_stack.numitems > 0) {
647 exprtoken *op = exprStackPop(&es->ops_stack);
648 if (op->opcode == EXPR_OP_OPAREN) {
649 if (errpos) *errpos = op->offset;
650 exprTokenRelease(op);
651 exprFree(es);
652 return NULL;
653 }
654
655 int arity = exprGetOpArity(op->opcode);
656 if (stack_items < arity) {
657 if (errpos) *errpos = op->offset;
658 exprTokenRelease(op);
659 exprFree(es);
660 return NULL;
661 }
662
663 exprStackPush(&es->program, op);
664 stack_items = stack_items - arity + 1;
665 }
666
667 /* Verify that exactly one value would remain on the stack after
668 * execution. We could also check that such value is a number, but this
669 * would make the code more complex without much gains. */
670 if (stack_items != 1) {
671 if (errpos) {
672 /* Point to the last token's offset for error reporting. */
673 exprtoken *last = es->tokens.items[es->tokens.numitems - 1];
674 *errpos = last->offset;
675 }
676 exprFree(es);
677 return NULL;
678 }
679 return es;
680}
681
682/* ============================ Expression execution ======================== */
683
684/* Convert a token to its numeric value. For strings we attempt to parse them
685 * as numbers, returning 0 if conversion fails. */
686double exprTokenToNum(exprtoken *t) {
687 char buf[256];
688 if (t->token_type == EXPR_TOKEN_NUM) {
689 return t->num;
690 } else if (t->token_type == EXPR_TOKEN_STR && t->str.len < sizeof(buf)) {
691 memcpy(buf, t->str.start, t->str.len);
692 buf[t->str.len] = '\0';
693 char *endptr;
694 double val = strtod(buf, &endptr);
695 return *endptr == '\0' ? val : 0;
696 } else {
697 return 0;
698 }
699}
700
701/* Convert object to true/false (0 or 1) */
702double exprTokenToBool(exprtoken *t) {
703 if (t->token_type == EXPR_TOKEN_NUM) {
704 return t->num != 0;
705 } else if (t->token_type == EXPR_TOKEN_STR && t->str.len == 0) {
706 return 0; // Empty string are false, like in Javascript.
707 } else if (t->token_type == EXPR_TOKEN_NULL) {
708 return 0; // Null is surely more false than true...
709 } else {
710 return 1; // Every non numerical type is true.
711 }
712}
713
714/* Compare two tokens. Returns true if they are equal. */
715int exprTokensEqual(exprtoken *a, exprtoken *b) {
716 // If both are strings, do string comparison.
717 if (a->token_type == EXPR_TOKEN_STR && b->token_type == EXPR_TOKEN_STR) {
718 return a->str.len == b->str.len &&
719 memcmp(a->str.start, b->str.start, a->str.len) == 0;
720 }
721
722 // If both are numbers, do numeric comparison.
723 if (a->token_type == EXPR_TOKEN_NUM && b->token_type == EXPR_TOKEN_NUM) {
724 return a->num == b->num;
725 }
726
727 /* If one of the two is null, the expression is true only if
728 * both are null. */
729 if (a->token_type == EXPR_TOKEN_NULL || b->token_type == EXPR_TOKEN_NULL) {
730 return a->token_type == b->token_type;
731 }
732
733 // Mixed types - convert to numbers and compare.
734 return exprTokenToNum(a) == exprTokenToNum(b);
735}
736
737/* Return true if the string a is a substring of b. */
738int exprTokensStringIn(exprtoken *a, exprtoken *b) {
739 RedisModule_Assert(a->token_type == EXPR_TOKEN_STR &&
740 b->token_type == EXPR_TOKEN_STR);
741 if (a->str.len > b->str.len) return 0; // A is bigger, can't be a substring.
742 for (size_t i = 0; i <= b->str.len - a->str.len; i++) {
743 if (memcmp(b->str.start+i,a->str.start,a->str.len) == 0) return 1;
744 }
745 return 0;
746}
747
748#include "fastjson.c" // JSON parser implementation used by exprRun().
749
750/* Execute the compiled expression program. Returns 1 if the final stack value
751 * evaluates to true, 0 otherwise. Also returns 0 if any selector callback
752 * fails. */
753int exprRun(exprstate *es, char *json, size_t json_len) {
754 exprStackReset(&es->values_stack);
755
756 // Execute each instruction in the program.
757 for (int i = 0; i < es->program.numitems; i++) {
758 exprtoken *t = es->program.items[i];
759
760 // Handle selectors by calling the callback.
761 if (t->token_type == EXPR_TOKEN_SELECTOR) {
762 exprtoken *obj = NULL;
763 if (t->str.len > 0)
764 obj = jsonExtractField(json,json_len,t->str.start,t->str.len);
765
766 // Selector not found or JSON object not convertible to
767 // expression tokens. Evaluate the expression to false.
768 if (obj == NULL) return 0;
769 exprStackPush(&es->values_stack, obj);
770 continue;
771 }
772
773 // Push non-operator values directly onto the stack.
774 if (t->token_type != EXPR_TOKEN_OP) {
775 exprStackPush(&es->values_stack, t);
776 exprTokenRetain(t);
777 continue;
778 }
779
780 // Handle operators.
781 exprtoken *result = exprNewToken(EXPR_TOKEN_NUM);
782
783 // Pop operands - we know we have enough from compile-time checks.
784 exprtoken *b = exprStackPop(&es->values_stack);
785 exprtoken *a = NULL;
786 if (exprGetOpArity(t->opcode) == 2) {
787 a = exprStackPop(&es->values_stack);
788 }
789
790 switch(t->opcode) {
791 case EXPR_OP_NOT:
792 result->num = exprTokenToBool(b) == 0 ? 1 : 0;
793 break;
794 case EXPR_OP_POW: {
795 double base = exprTokenToNum(a);
796 double exp = exprTokenToNum(b);
797 result->num = pow(base, exp);
798 break;
799 }
800 case EXPR_OP_MULT:
801 result->num = exprTokenToNum(a) * exprTokenToNum(b);
802 break;
803 case EXPR_OP_DIV:
804 result->num = exprTokenToNum(a) / exprTokenToNum(b);
805 break;
806 case EXPR_OP_MOD: {
807 double va = exprTokenToNum(a);
808 double vb = exprTokenToNum(b);
809 result->num = fmod(va, vb);
810 break;
811 }
812 case EXPR_OP_SUM:
813 result->num = exprTokenToNum(a) + exprTokenToNum(b);
814 break;
815 case EXPR_OP_DIFF:
816 result->num = exprTokenToNum(a) - exprTokenToNum(b);
817 break;
818 case EXPR_OP_GT:
819 result->num = exprTokenToNum(a) > exprTokenToNum(b) ? 1 : 0;
820 break;
821 case EXPR_OP_GTE:
822 result->num = exprTokenToNum(a) >= exprTokenToNum(b) ? 1 : 0;
823 break;
824 case EXPR_OP_LT:
825 result->num = exprTokenToNum(a) < exprTokenToNum(b) ? 1 : 0;
826 break;
827 case EXPR_OP_LTE:
828 result->num = exprTokenToNum(a) <= exprTokenToNum(b) ? 1 : 0;
829 break;
830 case EXPR_OP_EQ:
831 result->num = exprTokensEqual(a, b) ? 1 : 0;
832 break;
833 case EXPR_OP_NEQ:
834 result->num = !exprTokensEqual(a, b) ? 1 : 0;
835 break;
836 case EXPR_OP_IN: {
837 /* For 'in' operator, b must be a tuple, and we check for
838 * membership. Otherwise both a and b must be strings, and
839 * in this case we check if a is a substring of b. */
840 result->num = 0; // Default to false.
841 if (b->token_type == EXPR_TOKEN_TUPLE) {
842 for (size_t j = 0; j < b->tuple.len; j++) {
843 if (exprTokensEqual(a, b->tuple.ele[j])) {
844 result->num = 1; // Found a match.
845 break;
846 }
847 }
848 } else if (a->token_type == EXPR_TOKEN_STR &&
849 b->token_type == EXPR_TOKEN_STR)
850 {
851 result->num = exprTokensStringIn(a,b);
852 }
853 break;
854 }
855 case EXPR_OP_AND:
856 result->num =
857 exprTokenToBool(a) != 0 && exprTokenToBool(b) != 0 ? 1 : 0;
858 break;
859 case EXPR_OP_OR:
860 result->num =
861 exprTokenToBool(a) != 0 || exprTokenToBool(b) != 0 ? 1 : 0;
862 break;
863 default:
864 // Do nothing: we don't want runtime errors.
865 break;
866 }
867
868 // Free operands and push result.
869 if (a) exprTokenRelease(a);
870 exprTokenRelease(b);
871 exprStackPush(&es->values_stack, result);
872 }
873
874 // Get final result from stack.
875 exprtoken *final = exprStackPop(&es->values_stack);
876 if (final == NULL) return 0;
877
878 // Convert result to boolean.
879 int retval = exprTokenToBool(final);
880 exprTokenRelease(final);
881 return retval;
882}
883
884/* ============================ Simple test main ============================ */
885
886#ifdef TEST_MAIN
887#include "fastjson_test.c"
888
889void exprPrintToken(exprtoken *t) {
890 switch(t->token_type) {
891 case EXPR_TOKEN_EOF:
892 printf("EOF");
893 break;
894 case EXPR_TOKEN_NUM:
895 printf("NUM:%g", t->num);
896 break;
897 case EXPR_TOKEN_STR:
898 printf("STR:\"%.*s\"", (int)t->str.len, t->str.start);
899 break;
900 case EXPR_TOKEN_SELECTOR:
901 printf("SEL:%.*s", (int)t->str.len, t->str.start);
902 break;
903 case EXPR_TOKEN_OP:
904 printf("OP:");
905 for (int i = 0; ExprOptable[i].opname != NULL; i++) {
906 if (ExprOptable[i].opcode == t->opcode) {
907 printf("%s", ExprOptable[i].opname);
908 break;
909 }
910 }
911 break;
912 default:
913 printf("UNKNOWN");
914 break;
915 }
916}
917
918void exprPrintStack(exprstack *stack, const char *name) {
919 printf("%s (%d items):", name, stack->numitems);
920 for (int j = 0; j < stack->numitems; j++) {
921 printf(" ");
922 exprPrintToken(stack->items[j]);
923 }
924 printf("\n");
925}
926
927int main(int argc, char **argv) {
928 /* Check for JSON parser test mode. */
929 if (argc >= 2 && strcmp(argv[1], "--test-json-parser") == 0) {
930 run_fastjson_test();
931 return 0;
932 }
933
934 char *testexpr = "(5+2)*3 and .year > 1980 and 'foo' == 'foo'";
935 char *testjson = "{\"year\": 1984, \"name\": \"The Matrix\"}";
936 if (argc >= 2) testexpr = argv[1];
937 if (argc >= 3) testjson = argv[2];
938
939 printf("Compiling expression: %s\n", testexpr);
940
941 int errpos = 0;
942 exprstate *es = exprCompile(testexpr,&errpos);
943 if (es == NULL) {
944 printf("Compilation failed near \"...%s\"\n", testexpr+errpos);
945 return 1;
946 }
947
948 exprPrintStack(&es->tokens, "Tokens");
949 exprPrintStack(&es->program, "Program");
950 printf("Running against object: %s\n", testjson);
951 int result = exprRun(es,testjson,strlen(testjson));
952 printf("Result1: %s\n", result ? "True" : "False");
953 result = exprRun(es,testjson,strlen(testjson));
954 printf("Result2: %s\n", result ? "True" : "False");
955
956 exprFree(es);
957 return 0;
958}
959#endif
diff --git a/examples/redis-unstable/modules/vector-sets/fastjson.c b/examples/redis-unstable/modules/vector-sets/fastjson.c
new file mode 100644
index 0000000..78926e2
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/fastjson.c
@@ -0,0 +1,441 @@
1/* Ultra‑lightweight top‑level JSON field extractor.
2 * Return the element directly as an expr.c token.
3 * This code is directly included inside expr.c.
4 *
5 * Copyright (c) 2025-Present, Redis Ltd.
6 * All rights reserved.
7 *
8 * Licensed under your choice of the Redis Source Available License 2.0
9 * (RSALv2) or the Server Side Public License v1 (SSPLv1).
10 *
11 * Originally authored by: Salvatore Sanfilippo.
12 *
13 * ------------------------------------------------------------------
14 *
15 * DESIGN GOALS:
16 *
17 * 1. Zero heap allocations while seeking the requested key.
18 * 2. A single parse (and therefore a single allocation, if needed)
19 * when the key finally matches.
20 * 3. Same subset‑of‑JSON coverage needed by expr.c:
21 * - Strings (escapes: \" \\ \n \r \t).
22 * - Numbers (double).
23 * - Booleans.
24 * - Null.
25 * - Flat arrays of the above primitives.
26 *
27 * Any other value (nested object, unicode escape, etc.) returns NULL.
28 * Should be very easy to extend it in case in the future we want
29 * more for the FILTER option of VSIM.
30 * 4. No global state, so this file can be #included directly in expr.c.
31 *
32 * The only API expr.c uses directly is:
33 *
34 * exprtoken *jsonExtractField(const char *json, size_t json_len,
35 * const char *field, size_t field_len);
36 * ------------------------------------------------------------------ */
37
38#include <ctype.h>
39#include <string.h>
40
41// Forward declarations.
42static int jsonSkipValue(const char **p, const char *end);
43static exprtoken *jsonParseValueToken(const char **p, const char *end);
44
45/* Similar to ctype.h isdigit() but covers the whole JSON number charset,
46 * including exp form. */
47static int jsonIsNumberChar(int c) {
48 return isdigit(c) || c=='-' || c=='+' || c=='.' || c=='e' || c=='E';
49}
50
51/* ========================== Fast skipping of JSON =========================
52 * The helpers here are designed to skip values without performing any
53 * allocation. This way, for the use case of this JSON parser, we are able
54 * to easily (and with good speed) skip fields and values we are not
55 * interested in. Then, later in the code, when we find the field we want
56 * to obtain, we finally call the functions that turn a given JSON value
57 * associated to a field into our of our expressions token.
58 * ========================================================================== */
59
60/* Advance *p consuming all the spaces. */
61static inline void jsonSkipWhiteSpaces(const char **p, const char *end) {
62 while (*p < end && isspace((unsigned char)**p)) (*p)++;
63}
64
65/* Advance *p past a JSON string. Returns 1 on success, 0 on error. */
66static int jsonSkipString(const char **p, const char *end) {
67 if (*p >= end || **p != '"') return 0;
68 (*p)++; /* Skip opening quote. */
69 while (*p < end) {
70 if (**p == '\\') {
71 (*p) += 2;
72 continue;
73 }
74 if (**p == '"') {
75 (*p)++; /* Skip closing quote. */
76 return 1;
77 }
78 (*p)++;
79 }
80 return 0; /* unterminated */
81}
82
83/* Skip an array or object generically using depth counter.
84 * Opener and closer tells the function how the aggregated
85 * data type starts/stops, basically [] or {}. */
86static int jsonSkipBracketed(const char **p, const char *end,
87 char opener, char closer) {
88 int depth = 1;
89 (*p)++; /* Skip opener. */
90
91 /* Loop until we reach the end of the input or find the matching
92 * closer (depth becomes 0). */
93 while (*p < end && depth > 0) {
94 char c = **p;
95
96 if (c == '"') {
97 // Found a string, delegate skipping to jsonSkipString().
98 if (!jsonSkipString(p, end)) {
99 return 0; // String skipping failed (e.g., unterminated)
100 }
101 /* jsonSkipString() advances *p past the closing quote.
102 * Continue the loop to process the character *after* the string. */
103 continue;
104 }
105
106 /* If it's not a string, check if it affects the depth for the
107 * specific brackets we are currently tracking. */
108 if (c == opener) {
109 depth++;
110 } else if (c == closer) {
111 depth--;
112 }
113
114 /* Always advance the pointer for any non-string character.
115 * This handles commas, colons, whitespace, numbers, literals,
116 * and even nested brackets of a *different* type than the
117 * one we are currently skipping (e.g. skipping a { inside []). */
118 (*p)++;
119 }
120
121 /* Return 1 (true) if we successfully found the matching closer,
122 * otherwise there is a parse error and we return 0. */
123 return depth == 0;
124}
125
126/* Skip a single JSON literal (true, null, ...) starting at *p.
127 * Returns 1 on success, 0 on failure. */
128static int jsonSkipLiteral(const char **p, const char *end, const char *lit) {
129 size_t l = strlen(lit);
130 if (*p + l > end) return 0;
131 if (strncmp(*p, lit, l) == 0) { *p += l; return 1; }
132 return 0;
133}
134
135/* Skip number, don't check that number format is correct, just consume
136 * number-alike characters.
137 *
138 * Note: More robust number skipping might check validity,
139 * but for skipping, just consuming plausible characters is enough. */
140static int jsonSkipNumber(const char **p, const char *end) {
141 const char *num_start = *p;
142 while (*p < end && jsonIsNumberChar(**p)) (*p)++;
143 return *p > num_start; // Any progress made? Otherwise no number found.
144}
145
146/* Skip any JSON value. 1 = success, 0 = error. */
147static int jsonSkipValue(const char **p, const char *end) {
148 jsonSkipWhiteSpaces(p, end);
149 if (*p >= end) return 0;
150 switch (**p) {
151 case '"': return jsonSkipString(p, end);
152 case '{': return jsonSkipBracketed(p, end, '{', '}');
153 case '[': return jsonSkipBracketed(p, end, '[', ']');
154 case 't': return jsonSkipLiteral(p, end, "true");
155 case 'f': return jsonSkipLiteral(p, end, "false");
156 case 'n': return jsonSkipLiteral(p, end, "null");
157 default: return jsonSkipNumber(p, end);
158 }
159}
160
161/* =========================== JSON to exprtoken ============================
162 * The functions below convert a given json value to the equivalent
163 * expression token structure.
164 * ========================================================================== */
165
166static exprtoken *jsonParseStringToken(const char **p, const char *end) {
167 if (*p >= end || **p != '"') return NULL;
168 const char *start = ++(*p);
169 int esc = 0; size_t len = 0; int has_esc = 0;
170 const char *q = *p;
171 while (q < end) {
172 if (esc) { esc = 0; q++; len++; has_esc = 1; continue; }
173 if (*q == '\\') { esc = 1; q++; continue; }
174 if (*q == '"') break;
175 q++; len++;
176 }
177 if (q >= end || *q != '"') return NULL; // Unterminated string
178 exprtoken *t = exprNewToken(EXPR_TOKEN_STR);
179
180 if (!has_esc) {
181 // No escapes, we can point directly into the original JSON string.
182 t->str.start = (char*)start; t->str.len = len; t->str.heapstr = NULL;
183 } else {
184 // Escapes present, need to allocate and copy/process escapes.
185 char *dst = RedisModule_Alloc(len + 1);
186
187 t->str.start = t->str.heapstr = dst; t->str.len = len;
188 const char *r = start; esc = 0;
189 while (r < q) {
190 if (esc) {
191 switch (*r) {
192 // Supported escapes from Goal 3.
193 case 'n': *dst='\n'; break;
194 case 'r': *dst='\r'; break;
195 case 't': *dst='\t'; break;
196 case '\\': *dst='\\'; break;
197 case '"': *dst='\"'; break;
198 // Escapes (like \uXXXX, \b, \f) are not supported for now,
199 // we just copy them verbatim.
200 default: *dst=*r; break;
201 }
202 dst++; esc = 0; r++; continue;
203 }
204 if (*r == '\\') { esc = 1; r++; continue; }
205 *dst++ = *r++;
206 }
207 *dst = '\0'; // Null-terminate the allocated string.
208 }
209 *p = q + 1; // Advance the main pointer past the closing quote.
210 return t;
211}
212
213static exprtoken *jsonParseNumberToken(const char **p, const char *end) {
214 // Use a buffer to extract the number literal for parsing with strtod().
215 char buf[256]; int idx = 0;
216 const char *start = *p; // For strtod partial failures check.
217
218 // Copy potential number characters to buffer.
219 while (*p < end && idx < (int)sizeof(buf)-1 && jsonIsNumberChar(**p)) {
220 buf[idx++] = **p;
221 (*p)++;
222 }
223 buf[idx]='\0'; // Null-terminate buffer.
224
225 if (idx==0) return NULL; // No number characters found.
226
227 char *ep; // End pointer for strtod validation.
228 double v = strtod(buf, &ep);
229
230 /* Check if strtod() consumed the entire buffer content.
231 * If not, the number format was invalid. */
232 if (*ep!='\0') {
233 // strtod() failed; rewind p to the start and return NULL
234 *p = start;
235 return NULL;
236 }
237
238 // If strtod() succeeded, create and return the token..
239 exprtoken *t = exprNewToken(EXPR_TOKEN_NUM);
240 t->num = v;
241 return t;
242}
243
244static exprtoken *jsonParseLiteralToken(const char **p, const char *end, const char *lit, int type, double num) {
245 size_t l = strlen(lit);
246
247 // Ensure we don't read past 'end'.
248 if ((*p + l) > end) return NULL;
249
250 if (strncmp(*p, lit, l) != 0) return NULL; // Literal doesn't match.
251
252 // Check that the character *after* the literal is a valid JSON delimiter
253 // (whitespace, comma, closing bracket/brace, or end of input)
254 // This prevents matching "trueblabla" as "true".
255 if ((*p + l) < end) {
256 char next_char = *(*p + l);
257 if (!isspace((unsigned char)next_char) && next_char!=',' &&
258 next_char!=']' && next_char!='}') {
259 return NULL; // Invalid character following literal.
260 }
261 }
262
263 // Literal matched and is correctly terminated.
264 *p += l;
265 exprtoken *t = exprNewToken(type);
266 t->num = num;
267 return t;
268}
269
270static exprtoken *jsonParseArrayToken(const char **p, const char *end) {
271 if (*p >= end || **p != '[') return NULL;
272 (*p)++; // Skip '['.
273 jsonSkipWhiteSpaces(p,end);
274
275 exprtoken *t = exprNewToken(EXPR_TOKEN_TUPLE);
276 t->tuple.len = 0; t->tuple.ele = NULL; size_t alloc = 0;
277
278 // Handle empty array [].
279 if (*p < end && **p == ']') {
280 (*p)++; // Skip ']'.
281 return t;
282 }
283
284 // Parse array elements.
285 while (1) {
286 exprtoken *ele = jsonParseValueToken(p,end);
287 if (!ele) {
288 exprTokenRelease(t); // Clean up partially built array token.
289 return NULL;
290 }
291
292 // Grow allocated space for elements if needed.
293 if (t->tuple.len == alloc) {
294 size_t newsize = alloc ? alloc * 2 : 4;
295 // Check for potential overflow if newsize becomes huge.
296 if (newsize < alloc) {
297 exprTokenRelease(ele);
298 exprTokenRelease(t);
299 return NULL;
300 }
301 exprtoken **newele = RedisModule_Realloc(t->tuple.ele,
302 sizeof(exprtoken*)*newsize);
303 t->tuple.ele = newele;
304 alloc = newsize;
305 }
306 t->tuple.ele[t->tuple.len++] = ele; // Add element.
307
308 jsonSkipWhiteSpaces(p,end);
309 if (*p>=end) {
310 // Unterminated array. Note that this check is crucial because
311 // previous value parsed may seek 'p' to 'end'.
312 exprTokenRelease(t);
313 return NULL;
314 }
315
316 // Check for comma (more elements) or closing bracket.
317 if (**p == ',') {
318 (*p)++; // Skip ','
319 jsonSkipWhiteSpaces(p,end); // Skip whitespace before next element
320 continue; // Parse next element
321 } else if (**p == ']') {
322 (*p)++; // Skip ']'
323 return t; // End of array
324 } else {
325 // Unexpected character (not ',' or ']')
326 exprTokenRelease(t);
327 return NULL;
328 }
329 }
330}
331
332/* Turn a JSON value into an expr token. */
333static exprtoken *jsonParseValueToken(const char **p, const char *end) {
334 jsonSkipWhiteSpaces(p,end);
335 if (*p >= end) return NULL;
336
337 switch (**p) {
338 case '"': return jsonParseStringToken(p,end);
339 case '[': return jsonParseArrayToken(p,end);
340 case '{': return NULL; // No nested elements support for now.
341 case 't': return jsonParseLiteralToken(p,end,"true",EXPR_TOKEN_NUM,1);
342 case 'f': return jsonParseLiteralToken(p,end,"false",EXPR_TOKEN_NUM,0);
343 case 'n': return jsonParseLiteralToken(p,end,"null",EXPR_TOKEN_NULL,0);
344 default:
345 // Check if it starts like a number.
346 if (isdigit((unsigned char)**p) || **p=='-' || **p=='+') {
347 return jsonParseNumberToken(p,end);
348 }
349 // Anything else is an unsupported type or malformed JSON.
350 return NULL;
351 }
352}
353
354/* ============================== Fast key seeking ========================== */
355
356/* Finds the start of the value for a given field key within a JSON object.
357 * Returns pointer to the first char of the value, or NULL if not found/error.
358 * This function does not perform any allocation and is optimized to seek
359 * the specified *toplevel* filed as fast as possible. */
360static const char *jsonSeekField(const char *json, const char *end,
361 const char *field, size_t flen) {
362 const char *p = json;
363 jsonSkipWhiteSpaces(&p,end);
364 if (p >= end || *p != '{') return NULL; // Must start with '{'.
365 p++; // skip '{'.
366
367 while (1) {
368 jsonSkipWhiteSpaces(&p,end);
369 if (p >= end) return NULL; // Reached end within object.
370
371 if (*p == '}') return NULL; // End of object, field not found.
372
373 // Expecting a key (string).
374 if (*p != '"') return NULL; // Key must be a string.
375
376 // --- Key Matching using jsonSkipString ---
377 const char *key_start = p + 1; // Start of key content.
378 const char *key_end_p = p; // Will later contain the end.
379
380 // Use jsonSkipString() to find the end.
381 if (!jsonSkipString(&key_end_p, end)) {
382 // Unterminated / invalid key string.
383 return NULL;
384 }
385
386 // Calculate the length of the key's content.
387 size_t klen = (key_end_p - 1) - key_start;
388
389 /* Perform the comparison using the raw key content.
390 * WARNING: This uses memcmp(), so we don't handle escaped chars
391 * within the key matching against unescaped chars in 'field'. */
392 int match = klen == flen && !memcmp(key_start, field, flen);
393
394 // Update the main pointer 'p' to be after the key string.
395 p = key_end_p;
396
397 // Now we expect to find a ":" followed by a value.
398 jsonSkipWhiteSpaces(&p,end);
399 if (p>=end || *p!=':') return NULL; // Expect ':' after key
400 p++; // Skip ':'.
401
402 // Seek value.
403 jsonSkipWhiteSpaces(&p,end);
404 if (p>=end) return NULL; // Expect value after ':'
405
406 if (match) {
407 // Found the matching key, p now points to the start of the value.
408 return p;
409 } else {
410 // Key didn't match, skip the corresponding value.
411 if (!jsonSkipValue(&p,end)) return NULL; // Syntax error.
412 }
413
414
415 // Look for comma or a closing brace.
416 jsonSkipWhiteSpaces(&p,end);
417 if (p>=end) return NULL; // Reached end after value.
418
419 if (*p == ',') {
420 p++; // Skip comma, continue loop to find next key.
421 continue;
422 } else if (*p == '}') {
423 return NULL; // Reached end of object, field not found.
424 }
425 return NULL; // Malformed JSON (unexpected char after value).
426 }
427}
428
429/* This is the only real API that this file conceptually exports (it is
430 * inlined, actually). */
431exprtoken *jsonExtractField(const char *json, size_t json_len,
432 const char *field, size_t field_len)
433{
434 const char *end = json + json_len;
435 const char *valptr = jsonSeekField(json,end,field,field_len);
436 if (!valptr) return NULL;
437
438 /* Key found, valptr points to the start of the value.
439 * Convert it into an expression token object. */
440 return jsonParseValueToken(&valptr,end);
441}
diff --git a/examples/redis-unstable/modules/vector-sets/fastjson_test.c b/examples/redis-unstable/modules/vector-sets/fastjson_test.c
new file mode 100644
index 0000000..1ea76a9
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/fastjson_test.c
@@ -0,0 +1,406 @@
1/* fastjson_test.c - Stress test for fastjson.c
2 *
3 * This performs boundary and corruption tests to ensure
4 * the JSON parser handles edge cases without accessing
5 * memory outside the bounds of the input.
6 */
7
8#include <stdio.h>
9#include <stdlib.h>
10#include <string.h>
11#include <unistd.h>
12#include <signal.h>
13#include <time.h>
14#include <sys/mman.h>
15#include <sys/types.h>
16#include <fcntl.h>
17#include <errno.h>
18#include <setjmp.h>
19
20/* Page size constant - typically 4096 or 16k bytes (Apple Silicon).
21 * We use 16k so that it will work on both, but not with Linux huge pages. */
22#define PAGE_SIZE 4096*4
23#define MAX_JSON_SIZE (PAGE_SIZE - 128) /* Keep some margin */
24#define MAX_FIELD_SIZE 64
25#define NUM_TEST_ITERATIONS 100000
26#define NUM_CORRUPTION_TESTS 10000
27#define NUM_BOUNDARY_TESTS 10000
28
29/* Test state tracking */
30static char *safe_page = NULL; /* Start of readable/writable page */
31static char *unsafe_page = NULL; /* Start of inaccessible guard page */
32static int boundary_violation = 0; /* Flag for boundary violations */
33static jmp_buf jmpbuf; /* For signal handling */
34static int tests_passed = 0;
35static int tests_failed = 0;
36static int corruptions_passed = 0;
37static int boundary_tests_passed = 0;
38
39/* Test metadata for tracking */
40typedef struct {
41 char *json;
42 size_t json_len;
43 char field[MAX_FIELD_SIZE];
44 size_t field_len;
45 int expected_result;
46} test_case_t;
47
48/* Forward declarations for test JSON generation */
49char *generate_random_json(size_t *len, char *field, size_t *field_len, int *has_field);
50void corrupt_json(char *json, size_t len);
51void setup_test_memory(void);
52void cleanup_test_memory(void);
53void run_normal_tests(void);
54void run_corruption_tests(void);
55void run_boundary_tests(void);
56void print_test_summary(void);
57
58/* Signal handler for segmentation violations */
59static void sigsegv_handler(int sig) {
60 boundary_violation = 1;
61 printf("Boundary violation detected! Caught signal %d\n", sig);
62 longjmp(jmpbuf, 1);
63}
64
65/* Wrapper for jsonExtractField to check for boundary violations */
66exprtoken *safe_extract_field(const char *json, size_t json_len,
67 const char *field, size_t field_len) {
68 boundary_violation = 0;
69
70 if (setjmp(jmpbuf) == 0) {
71 return jsonExtractField(json, json_len, field, field_len);
72 } else {
73 return NULL; /* Return NULL if boundary violation occurred */
74 }
75}
76
77/* Setup two adjacent memory pages - one readable/writable, one inaccessible */
78void setup_test_memory(void) {
79 /* Request a page of memory, with specific alignment. We rely on the
80 * fact that hopefully the page after that will cause a segfault if
81 * accessed. */
82 void *region = mmap(NULL, PAGE_SIZE,
83 PROT_READ | PROT_WRITE,
84 MAP_PRIVATE | MAP_ANONYMOUS,
85 -1, 0);
86
87 if (region == MAP_FAILED) {
88 perror("mmap failed");
89 exit(EXIT_FAILURE);
90 }
91
92 safe_page = (char*)region;
93 unsafe_page = safe_page + PAGE_SIZE;
94 // Uncomment to make sure it crashes :D
95 // printf("%d\n", unsafe_page[5]);
96
97 /* Set up signal handlers for memory access violations */
98 struct sigaction sa;
99 sa.sa_handler = sigsegv_handler;
100 sigemptyset(&sa.sa_mask);
101 sa.sa_flags = 0;
102
103 sigaction(SIGSEGV, &sa, NULL);
104 sigaction(SIGBUS, &sa, NULL);
105}
106
107void cleanup_test_memory(void) {
108 if (safe_page != NULL) {
109 munmap(safe_page, PAGE_SIZE);
110 safe_page = NULL;
111 unsafe_page = NULL;
112 }
113}
114
115/* Generate random strings with proper escaping for JSON */
116void generate_random_string(char *buffer, size_t max_len) {
117 static const char charset[] = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
118 size_t len = 1 + rand() % (max_len - 2); /* Ensure at least 1 char */
119
120 for (size_t i = 0; i < len; i++) {
121 buffer[i] = charset[rand() % (sizeof(charset) - 1)];
122 }
123 buffer[len] = '\0';
124}
125
126/* Generate random numbers as strings */
127void generate_random_number(char *buffer, size_t max_len) {
128 double num = (double)rand() / RAND_MAX * 1000.0;
129
130 /* Occasionally make it negative or add decimal places */
131 if (rand() % 5 == 0) num = -num;
132 if (rand() % 3 != 0) num += (double)(rand() % 100) / 100.0;
133
134 snprintf(buffer, max_len, "%.6g", num);
135}
136
137/* Generate a random field name */
138void generate_random_field(char *field, size_t *field_len) {
139 generate_random_string(field, MAX_FIELD_SIZE / 2);
140 *field_len = strlen(field);
141}
142
143/* Generate a random JSON object with fields */
144char *generate_random_json(size_t *len, char *field, size_t *field_len, int *has_field) {
145 char *json = malloc(MAX_JSON_SIZE);
146 if (json == NULL) {
147 perror("malloc");
148 exit(EXIT_FAILURE);
149 }
150
151 char buffer[MAX_JSON_SIZE / 4]; /* Buffer for generating values */
152 int pos = 0;
153 int num_fields = 1 + rand() % 10; /* Random number of fields */
154 int target_field_index = rand() % num_fields; /* Which field to return */
155
156 /* Start the JSON object */
157 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "{");
158
159 /* Generate random field/value pairs */
160 for (int i = 0; i < num_fields; i++) {
161 /* Add a comma if not the first field */
162 if (i > 0) {
163 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, ", ");
164 }
165
166 /* Generate a field name */
167 if (i == target_field_index) {
168 /* This is our target field - save it for the caller */
169 generate_random_field(field, field_len);
170 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "\"%s\": ", field);
171 *has_field = 1;
172 /* Sometimes change the last char so that it will not match. */
173 if (rand() % 2) {
174 *has_field = 0;
175 field[*field_len-1] = '!';
176 }
177 } else {
178 generate_random_string(buffer, MAX_FIELD_SIZE / 4);
179 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "\"%s\": ", buffer);
180 }
181
182 /* Generate a random value type */
183 int value_type = rand() % 5;
184 switch (value_type) {
185 case 0: /* String */
186 generate_random_string(buffer, MAX_JSON_SIZE / 8);
187 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "\"%s\"", buffer);
188 break;
189
190 case 1: /* Number */
191 generate_random_number(buffer, MAX_JSON_SIZE / 8);
192 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "%s", buffer);
193 break;
194
195 case 2: /* Boolean: true */
196 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "true");
197 break;
198
199 case 3: /* Boolean: false */
200 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "false");
201 break;
202
203 case 4: /* Null */
204 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "null");
205 break;
206
207 case 5: /* Array (simple) */
208 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "[");
209 int array_items = 1 + rand() % 5;
210 for (int j = 0; j < array_items; j++) {
211 if (j > 0) pos += snprintf(json + pos, MAX_JSON_SIZE - pos, ", ");
212
213 /* Array items - either number or string */
214 if (rand() % 2) {
215 generate_random_number(buffer, MAX_JSON_SIZE / 16);
216 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "%s", buffer);
217 } else {
218 generate_random_string(buffer, MAX_JSON_SIZE / 16);
219 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "\"%s\"", buffer);
220 }
221 }
222 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "]");
223 break;
224 }
225 }
226
227 /* Close the JSON object */
228 pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "}");
229 *len = pos;
230
231 return json;
232}
233
234/* Corrupt JSON by replacing random characters */
235void corrupt_json(char *json, size_t len) {
236 if (len < 2) return; /* Too short to corrupt safely */
237
238 /* Corrupt 1-3 characters */
239 int num_corruptions = 1 + rand() % 3;
240 for (int i = 0; i < num_corruptions; i++) {
241 size_t pos = rand() % len;
242 char corruption = " \t\n{}[]\":,0123456789abcdefXYZ"[rand() % 30];
243 json[pos] = corruption;
244 }
245}
246
247/* Run standard parser tests with generated valid JSON */
248void run_normal_tests(void) {
249 printf("Running normal JSON extraction tests...\n");
250
251 for (int i = 0; i < NUM_TEST_ITERATIONS; i++) {
252 char field[MAX_FIELD_SIZE] = {0};
253 size_t field_len = 0;
254 size_t json_len = 0;
255 int has_field = 0;
256
257 /* Generate random JSON */
258 char *json = generate_random_json(&json_len, field, &field_len, &has_field);
259
260 /* Use valid field to test parser */
261 exprtoken *token = safe_extract_field(json, json_len, field, field_len);
262
263 /* Check if we got a token as expected */
264 if (has_field && token != NULL) {
265 exprTokenRelease(token);
266 tests_passed++;
267 } else if (!has_field && token == NULL) {
268 tests_passed++;
269 } else {
270 tests_failed++;
271 }
272
273 /* Test with a non-existent field */
274 char nonexistent_field[MAX_FIELD_SIZE] = "nonexistent_field";
275 token = safe_extract_field(json, json_len, nonexistent_field, strlen(nonexistent_field));
276
277 if (token == NULL) {
278 tests_passed++;
279 } else {
280 exprTokenRelease(token);
281 tests_failed++;
282 }
283
284 free(json);
285 }
286}
287
288/* Run tests with corrupted JSON */
289void run_corruption_tests(void) {
290 printf("Running JSON corruption tests...\n");
291
292 for (int i = 0; i < NUM_CORRUPTION_TESTS; i++) {
293 char field[MAX_FIELD_SIZE] = {0};
294 size_t field_len = 0;
295 size_t json_len = 0;
296 int has_field = 0;
297
298 /* Generate random JSON */
299 char *json = generate_random_json(&json_len, field, &field_len, &has_field);
300
301 /* Make a copy and corrupt it */
302 char *corrupted = malloc(json_len + 1);
303 if (!corrupted) {
304 perror("malloc");
305 free(json);
306 exit(EXIT_FAILURE);
307 }
308
309 memcpy(corrupted, json, json_len + 1);
310 corrupt_json(corrupted, json_len);
311
312 /* Test with corrupted JSON */
313 exprtoken *token = safe_extract_field(corrupted, json_len, field, field_len);
314
315 /* We're just testing that it doesn't crash or access invalid memory */
316 if (boundary_violation) {
317 printf("Boundary violation with corrupted JSON!\n");
318 tests_failed++;
319 } else {
320 if (token != NULL) {
321 exprTokenRelease(token);
322 }
323 corruptions_passed++;
324 }
325
326 free(corrupted);
327 free(json);
328 }
329}
330
331/* Run tests at memory boundaries */
332void run_boundary_tests(void) {
333 printf("Running memory boundary tests...\n");
334
335 for (int i = 0; i < NUM_BOUNDARY_TESTS; i++) {
336 char field[MAX_FIELD_SIZE] = {0};
337 size_t field_len = 0;
338 size_t json_len = 0;
339 int has_field = 0;
340
341 /* Generate random JSON */
342 char *temp_json = generate_random_json(&json_len, field, &field_len, &has_field);
343
344 /* Truncate the JSON to a random length */
345 size_t truncated_len = 1 + rand() % json_len;
346
347 /* Place at the edge of the safe page */
348 size_t offset = PAGE_SIZE - truncated_len;
349 memcpy(safe_page + offset, temp_json, truncated_len);
350
351 /* Test parsing with non-existent field (forcing it to scan to end) */
352 char nonexistent_field[MAX_FIELD_SIZE] = "nonexistent_field";
353 exprtoken *token = safe_extract_field(safe_page + offset, truncated_len,
354 nonexistent_field, strlen(nonexistent_field));
355
356 /* We're just testing that it doesn't access memory beyond the boundary */
357 if (boundary_violation) {
358 printf("Boundary violation at edge of memory page!\n");
359 tests_failed++;
360 } else {
361 if (token != NULL) {
362 exprTokenRelease(token);
363 }
364 boundary_tests_passed++;
365 }
366
367 free(temp_json);
368 }
369}
370
371/* Print summary of test results */
372void print_test_summary(void) {
373 printf("\n===== FASTJSON PARSER TEST SUMMARY =====\n");
374 printf("Normal tests passed: %d/%d\n", tests_passed, NUM_TEST_ITERATIONS * 2);
375 printf("Corruption tests passed: %d/%d\n", corruptions_passed, NUM_CORRUPTION_TESTS);
376 printf("Boundary tests passed: %d/%d\n", boundary_tests_passed, NUM_BOUNDARY_TESTS);
377 printf("Failed tests: %d\n", tests_failed);
378
379 if (tests_failed == 0) {
380 printf("\nALL TESTS PASSED! The JSON parser appears to be robust.\n");
381 } else {
382 printf("\nSome tests FAILED. The JSON parser may be vulnerable.\n");
383 }
384}
385
386/* Entry point for fastjson parser test */
387void run_fastjson_test(void) {
388 printf("Starting fastjson parser stress test...\n");
389
390 /* Seed the random number generator */
391 srand(time(NULL));
392
393 /* Setup test memory environment */
394 setup_test_memory();
395
396 /* Run the various test phases */
397 run_normal_tests();
398 run_corruption_tests();
399 run_boundary_tests();
400
401 /* Print summary */
402 print_test_summary();
403
404 /* Cleanup */
405 cleanup_test_memory();
406}
diff --git a/examples/redis-unstable/modules/vector-sets/hnsw.c b/examples/redis-unstable/modules/vector-sets/hnsw.c
new file mode 100644
index 0000000..2b4ebc0
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/hnsw.c
@@ -0,0 +1,2999 @@
1/* HNSW (Hierarchical Navigable Small World) Implementation.
2 *
3 * Based on the paper by Yu. A. Malkov, D. A. Yashunin.
4 *
5 * Many details of this implementation, not covered in the paper, were
6 * obtained simulating different workloads and checking the connection
7 * quality of the graph.
8 *
9 * Notably, this implementation:
10 *
11 * 1. Only uses bi-directional links, implementing strategies in order to
12 * link new nodes even when candidates are full, and our new node would
13 * be not close enough to replace old links in candidate.
14 *
15 * 2. We normalize on-insert, making cosine similarity and dot product the
16 * same. This means we can't use euclidean distance or alike here.
17 * Together with quantization, this provides an important speedup that
18 * makes HNSW more practical.
19 *
20 * 3. The quantization used is int8. And it is performed per-vector, so the
21 * "range" (max abs value) is also stored alongside with the quantized data.
22 *
23 * 4. This library implements true elements deletion, not just marking the
24 * element as deleted, but removing it (we can do it since our links are
25 * bidirectional), and reliking the nodes orphaned of one link among
26 * them.
27 *
28 * Copyright (c) 2009-Present, Redis Ltd.
29 * All rights reserved.
30 *
31 * Licensed under your choice of (a) the Redis Source Available License 2.0
32 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
33 * GNU Affero General Public License v3 (AGPLv3).
34 * Originally authored by: Salvatore Sanfilippo.
35 */
36
37#define _DEFAULT_SOURCE
38#define _POSIX_C_SOURCE 200809L
39
40#include <stdio.h>
41#include <stdlib.h>
42#include <string.h>
43#include <math.h>
44#include <stdint.h>
45#include <float.h> /* for INFINITY if not in math.h */
46#include <assert.h>
47#include "hnsw.h"
48#include "mixer.h"
49
50/* Check if we can compile SIMD code with function attributes */
51#if defined (__x86_64__) && ((defined(__GNUC__) && __GNUC__ >= 5) || (defined(__clang__) && __clang_major__ >= 4))
52#if defined(__has_attribute) && __has_attribute(target)
53#define HAVE_AVX2
54#define HAVE_AVX512
55#endif
56#endif
57
58#if defined (HAVE_AVX2)
59#define ATTRIBUTE_TARGET_AVX2 __attribute__((target("avx2,fma")))
60#define VSET_USE_AVX2 (__builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma"))
61#else
62#define ATTRIBUTE_TARGET_AVX2
63#define VSET_USE_AVX2 0
64#endif
65
66#if defined (HAVE_AVX512)
67#define ATTRIBUTE_TARGET_AVX512 __attribute__((target("avx512f,fma")))
68#define VSET_USE_AVX512 (__builtin_cpu_supports("avx512f"))
69#else
70#define ATTRIBUTE_TARGET_AVX512
71#define VSET_USE_AVX512 0
72#endif
73
74/* Include SIMD headers when supported */
75#if defined(HAVE_AVX2) || defined(HAVE_AVX512)
76#include <immintrin.h>
77#endif
78
79#if 0
80#define debugmsg printf
81#else
82#define debugmsg if(0) printf
83#endif
84
85#ifndef INFINITY
86#define INFINITY (1.0/0.0)
87#endif
88
89#define MIN(a,b) ((a) < (b) ? (a) : (b))
90
91/* Algorithm parameters. */
92
93#define HNSW_P 0.25 /* Probability of level increase. */
94#define HNSW_MAX_LEVEL 16 /* Max level nodes can reach. */
95#define HNSW_EF_C 200 /* Default size of dynamic candidate list while
96 * inserting a new node, in case 0 is passed to
97 * the 'ef' argument while inserting. This is also
98 * used when deleting nodes for the search step
99 * needed sometimes to reconnect nodes that remain
100 * orphaned of one link. */
101
102static void (*hfree)(void *p) = free;
103static void *(*hmalloc)(size_t s) = malloc;
104static void *(*hrealloc)(void *old, size_t s) = realloc;
105
106void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t),
107 void *(*realloc_ptr)(void*, size_t))
108{
109 hfree = free_ptr;
110 hmalloc = malloc_ptr;
111 hrealloc = realloc_ptr;
112}
113
114// Get a warning if you use the libc allocator functions for mistake.
115#define malloc use_hmalloc_instead
116#define realloc use_hrealloc_instead
117#define free use_hfree_instead
118
119/* ============================== Prototypes ================================ */
120void hnsw_cursor_element_deleted(HNSW *index, hnswNode *deleted);
121
122/* ============================ Priority queue ================================
123 * We need a priority queue to take an ordered list of candidates. Right now
124 * it is implemented as a linear array, since it is relatively small.
125 *
126 * You may find it to be odd that we take the best element (smaller distance)
127 * at the end of the array, but this way popping from the pqueue is O(1), as
128 * we need to just decrement the count, and this is a very used operation
129 * in a critical code path. This makes the priority queue implementation a
130 * bit more complex in the insertion, but for good reasons. */
131
132/* Maximum number of candidates we'll ever need (cit. Bill Gates). */
133#define HNSW_MAX_CANDIDATES 256
134
135typedef struct {
136 hnswNode *node;
137 float distance;
138} pqitem;
139
140typedef struct {
141 pqitem *items; /* Array of items. */
142 uint32_t count; /* Current number of items. */
143 uint32_t cap; /* Maximum capacity. */
144} pqueue;
145
146/* The HNSW algorithms access the pqueue conceptually from nearest (index 0)
147 * to farthest (larger indexes) node, so the following macros are used to
148 * access the pqueue in this fashion, even if the internal order is
149 * actually reversed. */
150#define pq_get_node(q,i) ((q)->items[(q)->count-(i+1)].node)
151#define pq_get_distance(q,i) ((q)->items[(q)->count-(i+1)].distance)
152
153/* Create a new priority queue with given capacity. Adding to the
154 * pqueue only retains 'capacity' elements with the shortest distance. */
155pqueue *pq_new(uint32_t capacity) {
156 pqueue *pq = hmalloc(sizeof(*pq));
157 if (!pq) return NULL;
158
159 pq->items = hmalloc(sizeof(pqitem) * capacity);
160 if (!pq->items) {
161 hfree(pq);
162 return NULL;
163 }
164
165 pq->count = 0;
166 pq->cap = capacity;
167 return pq;
168}
169
170/* Free a priority queue. */
171void pq_free(pqueue *pq) {
172 if (!pq) return;
173 hfree(pq->items);
174 hfree(pq);
175}
176
177/* Insert maintaining distance order (higher distances first). */
178void pq_push(pqueue *pq, hnswNode *node, float distance) {
179 if (pq->count < pq->cap) {
180 /* Queue not full: shift right from high distances to make room. */
181 uint32_t i = pq->count;
182 while (i > 0 && pq->items[i-1].distance < distance) {
183 pq->items[i] = pq->items[i-1];
184 i--;
185 }
186 pq->items[i].node = node;
187 pq->items[i].distance = distance;
188 pq->count++;
189 } else {
190 /* Queue full: if new item is worse than worst, ignore it. */
191 if (distance >= pq->items[0].distance) return;
192
193 /* Otherwise shift left from low distances to drop worst. */
194 uint32_t i = 0;
195 while (i < pq->cap-1 && pq->items[i+1].distance > distance) {
196 pq->items[i] = pq->items[i+1];
197 i++;
198 }
199 pq->items[i].node = node;
200 pq->items[i].distance = distance;
201 }
202}
203
204/* Remove and return the top (closest) element, which is at count-1
205 * since we store elements with higher distances first.
206 * Runs in constant time. */
207hnswNode *pq_pop(pqueue *pq, float *distance) {
208 if (pq->count == 0) return NULL;
209 pq->count--;
210 *distance = pq->items[pq->count].distance;
211 return pq->items[pq->count].node;
212}
213
214/* Get distance of the furthest element.
215 * An empty priority queue has infinite distance as its furthest element,
216 * note that this behavior is needed by the algorithms below. */
217float pq_max_distance(pqueue *pq) {
218 if (pq->count == 0) return INFINITY;
219 return pq->items[0].distance;
220}
221
222/* ============================ HNSW algorithm ============================== */
223
224#if defined(HAVE_AVX512)
225/* AVX512 optimized dot product for float vectors */
226ATTRIBUTE_TARGET_AVX512
227float vectors_distance_float_avx512(const float *x, const float *y, uint32_t dim) {
228 __m512 sum = _mm512_setzero_ps();
229 uint32_t i;
230
231 /* Process 16 floats at a time with AVX512 */
232 for (i = 0; i + 15 < dim; i += 16) {
233 __m512 vx = _mm512_loadu_ps(&x[i]);
234 __m512 vy = _mm512_loadu_ps(&y[i]);
235 sum = _mm512_fmadd_ps(vx, vy, sum);
236 }
237
238 /* Horizontal sum of the 16 elements in sum */
239 float dot = _mm512_reduce_add_ps(sum);
240
241 /* Handle remaining elements */
242 for (; i < dim; i++) {
243 dot += x[i] * y[i];
244 }
245
246 return 1.0f - dot;
247}
248#endif /* HAVE_AVX512 */
249
250#if defined(HAVE_AVX2)
251/* AVX2 optimized dot product for float vectors */
252ATTRIBUTE_TARGET_AVX2
253float vectors_distance_float_avx2(const float *x, const float *y, uint32_t dim) {
254 __m256 sum1 = _mm256_setzero_ps();
255 __m256 sum2 = _mm256_setzero_ps();
256 uint32_t i;
257
258 /* Process 16 floats at a time with two AVX2 registers */
259 for (i = 0; i + 15 < dim; i += 16) {
260 __m256 vx1 = _mm256_loadu_ps(&x[i]);
261 __m256 vy1 = _mm256_loadu_ps(&y[i]);
262 __m256 vx2 = _mm256_loadu_ps(&x[i + 8]);
263 __m256 vy2 = _mm256_loadu_ps(&y[i + 8]);
264
265 sum1 = _mm256_fmadd_ps(vx1, vy1, sum1);
266 sum2 = _mm256_fmadd_ps(vx2, vy2, sum2);
267 }
268
269 /* Combine the two sums */
270 __m256 combined = _mm256_add_ps(sum1, sum2);
271
272 /* Horizontal sum of the 8 elements */
273 __m128 sum_high = _mm256_extractf128_ps(combined, 1);
274 __m128 sum_low = _mm256_castps256_ps128(combined);
275 __m128 sum_128 = _mm_add_ps(sum_high, sum_low);
276
277 sum_128 = _mm_hadd_ps(sum_128, sum_128);
278 sum_128 = _mm_hadd_ps(sum_128, sum_128);
279
280 float dot = _mm_cvtss_f32(sum_128);
281
282 /* Handle remaining elements */
283 for (; i < dim; i++) {
284 dot += x[i] * y[i];
285 }
286
287 return 1.0f - dot;
288}
289#endif /* HAVE_AVX2 */
290
291/* Optimized dot product: automatically selects best available implementation
292 * Dot product: our vectors are already normalized.
293 * Version for not quantized vectors of floats. */
294float vectors_distance_float(const float *x, const float *y, uint32_t dim) {
295#if defined(HAVE_AVX512)
296 if (dim >= 16 && VSET_USE_AVX512) {
297 return vectors_distance_float_avx512(x, y, dim);
298 }
299#endif
300
301#if defined(HAVE_AVX2)
302 if (VSET_USE_AVX2 && dim >= 16) {
303 return vectors_distance_float_avx2(x, y, dim);
304 }
305#endif
306
307 /* Fallback to original scalar implementation */
308 float dot0 = 0.0f, dot1 = 0.0f;
309 uint32_t i;
310
311 /* Use two accumulators to reduce dependencies among multiplications.
312 * This provides a clear speed boost in Apple silicon, but should be
313 * help in general. */
314 for (i = 0; i + 7 < dim; i += 8) {
315 dot0 += x[i] * y[i] +
316 x[i+1] * y[i+1] +
317 x[i+2] * y[i+2] +
318 x[i+3] * y[i+3];
319
320 dot1 += x[i+4] * y[i+4] +
321 x[i+5] * y[i+5] +
322 x[i+6] * y[i+6] +
323 x[i+7] * y[i+7];
324 }
325
326 /* Handle the remaining elements. These are a minority in the case
327 * of a small vector, don't optimize this part. */
328 for (; i < dim; i++) dot0 += x[i] * y[i];
329
330 /* The following line may be counter intuitive. The dot product of
331 * normalized vectors is equivalent to their cosine similarity. The
332 * cosine will be from -1 (vectors facing opposite directions in the
333 * N-dim space) to 1 (vectors are facing in the same direction).
334 *
335 * We kinda want a "score" of distance from 0 to 2 (this is a distance
336 * function and we want minimize the distance for K-NN searches), so we
337 * can't just add 1: that would return a number in the 0-2 range, with
338 * 0 meaning opposite vectors and 2 identical vectors, so this is
339 * similarity, not distance.
340 *
341 * Returning instead (1 - dotprod) inverts the meaning: 0 is identical
342 * and 2 is opposite, hence it is their distance.
343 *
344 * Why don't normalize the similarity right now, and return from 0 to
345 * 1? Because division is costly. */
346 return 1.0f - (dot0 + dot1);
347}
348
349/* Q8 quants dotproduct. We do integer math and later fix it by range. */
350float vectors_distance_q8(const int8_t *x, const int8_t *y, uint32_t dim,
351 float range_a, float range_b) {
352 // Handle zero vectors special case.
353 if (range_a == 0 || range_b == 0) {
354 /* Zero vector distance from anything is 1.0
355 * (since 1.0 - dot_product where dot_product = 0). */
356 return 1.0f;
357 }
358
359 /* Each vector is quantized from [-max_abs, +max_abs] to [-127, 127]
360 * where range = 2*max_abs. */
361 const float scale_product = (range_a/127) * (range_b/127);
362
363 int32_t dot0 = 0, dot1 = 0;
364 uint32_t i;
365
366 // Process 8 elements at a time for better pipeline utilization.
367 for (i = 0; i + 7 < dim; i += 8) {
368 dot0 += ((int32_t)x[i]) * ((int32_t)y[i]) +
369 ((int32_t)x[i+1]) * ((int32_t)y[i+1]) +
370 ((int32_t)x[i+2]) * ((int32_t)y[i+2]) +
371 ((int32_t)x[i+3]) * ((int32_t)y[i+3]);
372
373 dot1 += ((int32_t)x[i+4]) * ((int32_t)y[i+4]) +
374 ((int32_t)x[i+5]) * ((int32_t)y[i+5]) +
375 ((int32_t)x[i+6]) * ((int32_t)y[i+6]) +
376 ((int32_t)x[i+7]) * ((int32_t)y[i+7]);
377 }
378
379 // Handle remaining elements.
380 for (; i < dim; i++) dot0 += ((int32_t)x[i]) * ((int32_t)y[i]);
381
382 // Convert to original range.
383 float dotf = (dot0 + dot1) * scale_product;
384 float distance = 1.0f - dotf;
385
386 // Clamp distance to [0, 2].
387 if (distance < 0) distance = 0;
388 else if (distance > 2) distance = 2;
389 return distance;
390}
391
392static inline int popcount64(uint64_t x) {
393 x = (x & 0x5555555555555555) + ((x >> 1) & 0x5555555555555555);
394 x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333);
395 x = (x & 0x0F0F0F0F0F0F0F0F) + ((x >> 4) & 0x0F0F0F0F0F0F0F0F);
396 x = (x & 0x00FF00FF00FF00FF) + ((x >> 8) & 0x00FF00FF00FF00FF);
397 x = (x & 0x0000FFFF0000FFFF) + ((x >> 16) & 0x0000FFFF0000FFFF);
398 x = (x & 0x00000000FFFFFFFF) + ((x >> 32) & 0x00000000FFFFFFFF);
399 return x;
400}
401
402/* Binary vectors distance. */
403float vectors_distance_bin(const uint64_t *x, const uint64_t *y, uint32_t dim) {
404 uint32_t len = (dim+63)/64;
405 uint32_t opposite = 0;
406 for (uint32_t j = 0; j < len; j++) {
407 uint64_t xor = x[j]^y[j];
408 opposite += popcount64(xor);
409 }
410 return (float)opposite*2/dim;
411}
412
413/* Dot product between nodes. Will call the right version depending on the
414 * quantization used. */
415float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b) {
416 switch(index->quant_type) {
417 case HNSW_QUANT_NONE:
418 return vectors_distance_float(a->vector,b->vector,index->vector_dim);
419 case HNSW_QUANT_Q8:
420 return vectors_distance_q8(a->vector,b->vector,index->vector_dim,a->quants_range,b->quants_range);
421 case HNSW_QUANT_BIN:
422 return vectors_distance_bin(a->vector,b->vector,index->vector_dim);
423 default:
424 assert(1 != 1);
425 return 0;
426 }
427}
428
429/* This do Q8 'range' quantization.
430 * For people looking at this code thinking: Oh, I could use min/max
431 * quants instead! Well: I tried with min/max normalization but the dot
432 * product needs to accumulate the sum for later correction, and it's slower. */
433void quantize_to_q8(float *src, int8_t *dst, uint32_t dim, float *rangeptr) {
434 float max_abs = 0;
435 for (uint32_t j = 0; j < dim; j++) {
436 if (src[j] > max_abs) max_abs = src[j];
437 if (-src[j] > max_abs) max_abs = -src[j];
438 }
439
440 if (max_abs == 0) {
441 if (rangeptr) *rangeptr = 0;
442 memset(dst, 0, dim);
443 return;
444 }
445
446 const float scale = 127.0f / max_abs; // Scale to map to [-127, 127].
447
448 for (uint32_t j = 0; j < dim; j++) {
449 dst[j] = (int8_t)roundf(src[j] * scale);
450 }
451 if (rangeptr) *rangeptr = max_abs; // Return max_abs instead of 2*max_abs.
452}
453
454/* Binary quantization of vector 'src' to 'dst'. We use full words of
455 * 64 bit as smallest unit, we will just set all the unused bits to 0
456 * so that they'll be the same in all the vectors, and when xor+popcount
457 * is used to compute the distance, such bits are not considered. This
458 * allows to go faster. */
459void quantize_to_bin(float *src, uint64_t *dst, uint32_t dim) {
460 memset(dst,0,(dim+63)/64*sizeof(uint64_t));
461 for (uint32_t j = 0; j < dim; j++) {
462 uint32_t word = j/64;
463 uint32_t bit = j&63;
464 /* Since cosine similarity checks the vector direction and
465 * not magnitudo, we do likewise in the binary quantization and
466 * just remember if the component is positive or negative. */
467 if (src[j] > 0) dst[word] |= 1ULL<<bit;
468 }
469}
470
471/* L2 normalization of the float vector.
472 *
473 * Store the L2 value on 'l2ptr' if not NULL. This way the process
474 * can be reversed even if some precision will be lost. */
475void hnsw_normalize_vector(float *x, float *l2ptr, uint32_t dim) {
476 float l2 = 0;
477 uint32_t i;
478 for (i = 0; i + 3 < dim; i += 4) {
479 l2 += x[i]*x[i] +
480 x[i+1]*x[i+1] +
481 x[i+2]*x[i+2] +
482 x[i+3]*x[i+3];
483 }
484 for (; i < dim; i++) l2 += x[i]*x[i];
485 if (l2 == 0) return; // All zero vector, can't normalize.
486
487 l2 = sqrtf(l2);
488 if (l2ptr) *l2ptr = l2;
489 for (i = 0; i < dim; i++) x[i] /= l2;
490}
491
492/* Helper function to generate random level. */
493uint32_t random_level(void) {
494 static const int threshold = HNSW_P * RAND_MAX;
495 uint32_t level = 0;
496
497 while (rand() < threshold && level < HNSW_MAX_LEVEL)
498 level += 1;
499 return level;
500}
501
502/* Create new HNSW index, quantized or not. */
503HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type, uint32_t m) {
504 HNSW *index = hmalloc(sizeof(HNSW));
505 if (!index) return NULL;
506
507 /* M parameter sanity check. */
508 if (m == 0) m = HNSW_DEFAULT_M;
509 else if (m > HNSW_MAX_M) m = HNSW_MAX_M;
510
511 index->M = m;
512 index->quant_type = quant_type;
513 index->enter_point = NULL;
514 index->max_level = 0;
515 index->vector_dim = vector_dim;
516 index->node_count = 0;
517 index->last_id = 0;
518 index->head = NULL;
519 index->cursors = NULL;
520
521 /* Initialize epochs array. */
522 for (int i = 0; i < HNSW_MAX_THREADS; i++)
523 index->current_epoch[i] = 0;
524
525 /* Initialize locks. */
526 if (pthread_rwlock_init(&index->global_lock, NULL) != 0) {
527 hfree(index);
528 return NULL;
529 }
530
531 for (int i = 0; i < HNSW_MAX_THREADS; i++) {
532 if (pthread_mutex_init(&index->slot_locks[i], NULL) != 0) {
533 /* Clean up previously initialized mutexes. */
534 for (int j = 0; j < i; j++)
535 pthread_mutex_destroy(&index->slot_locks[j]);
536 pthread_rwlock_destroy(&index->global_lock);
537 hfree(index);
538 return NULL;
539 }
540 }
541
542 /* Initialize atomic variables. */
543 index->next_slot = 0;
544 index->version = 0;
545 return index;
546}
547
548/* Fill 'vec' with the node vector, de-normalizing and de-quantizing it
549 * as needed. Note that this function will return an approximated version
550 * of the original vector. */
551void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec) {
552 if (index->quant_type == HNSW_QUANT_NONE) {
553 memcpy(vec,node->vector,index->vector_dim*sizeof(float));
554 } else if (index->quant_type == HNSW_QUANT_Q8) {
555 int8_t *quants = node->vector;
556 for (uint32_t j = 0; j < index->vector_dim; j++)
557 vec[j] = (quants[j]*node->quants_range)/127;
558 } else if (index->quant_type == HNSW_QUANT_BIN) {
559 uint64_t *bits = node->vector;
560 for (uint32_t j = 0; j < index->vector_dim; j++) {
561 uint32_t word = j/64;
562 uint32_t bit = j&63;
563 vec[j] = (bits[word] & (1ULL<<bit)) ? 1.0f : -1.0f;
564 }
565 }
566
567 // De-normalize.
568 if (index->quant_type != HNSW_QUANT_BIN) {
569 for (uint32_t j = 0; j < index->vector_dim; j++)
570 vec[j] *= node->l2;
571 }
572}
573
574/* Return the number of bytes needed to represent a vector in the index,
575 * that is function of the dimension of the vectors and the quantization
576 * type used. */
577uint32_t hnsw_quants_bytes(HNSW *index) {
578 switch(index->quant_type) {
579 case HNSW_QUANT_NONE: return index->vector_dim * sizeof(float);
580 case HNSW_QUANT_Q8: return index->vector_dim;
581 case HNSW_QUANT_BIN: return (index->vector_dim+63)/64*8;
582 default: assert(0 && "Quantization type not supported.");
583 }
584}
585
586/* Create new node. Returns NULL on out of memory.
587 * It is possible to pass the vector as floats or, in case this index
588 * was already stored on disk and is being loaded, or serialized and
589 * transmitted in any form, the already quantized version in
590 * 'qvector'.
591 *
592 * Only vector or qvector should be non-NULL. The reason why passing
593 * a quantized vector is useful, is that because re-normalizing and
594 * re-quantizing several times the same vector may accumulate rounding
595 * errors. So if you work with quantized indexes, you should save
596 * the quantized indexes.
597 *
598 * Note that, together with qvector, the quantization range is needed,
599 * since this library uses per-vector quantization. In case of quantized
600 * vectors the l2 is considered to be '1', so if you want to restore
601 * the right l2 (to use the API that returns an approximation of the
602 * original vector) make sure to save the l2 on disk and set it back
603 * after the node creation (see later for the serialization API that
604 * handles this and more). */
605hnswNode *hnsw_node_new(HNSW *index, uint64_t id, const float *vector, const int8_t *qvector, float qrange, uint32_t level, int normalize) {
606 hnswNode *node = hmalloc(sizeof(hnswNode)+(sizeof(hnswNodeLayer)*(level+1)));
607 if (!node) return NULL;
608
609 if (id == 0) id = ++index->last_id;
610 node->level = level;
611 node->id = id;
612 node->next = NULL;
613 node->vector = NULL;
614 node->l2 = 1; // Default in case of already quantized vectors. It is
615 // up to the caller to fill this later, if needed.
616
617 /* Initialize visited epoch array. */
618 for (int i = 0; i < HNSW_MAX_THREADS; i++)
619 node->visited_epoch[i] = 0;
620
621 if (qvector == NULL) {
622 /* Copy input vector. */
623 node->vector = hmalloc(sizeof(float) * index->vector_dim);
624 if (!node->vector) {
625 hfree(node);
626 return NULL;
627 }
628 memcpy(node->vector, vector, sizeof(float) * index->vector_dim);
629 if (normalize)
630 hnsw_normalize_vector(node->vector,&node->l2,index->vector_dim);
631
632 /* Handle quantization. */
633 if (index->quant_type != HNSW_QUANT_NONE) {
634 void *quants = hmalloc(hnsw_quants_bytes(index));
635 if (quants == NULL) {
636 hfree(node->vector);
637 hfree(node);
638 return NULL;
639 }
640
641 // Quantize.
642 switch(index->quant_type) {
643 case HNSW_QUANT_Q8:
644 quantize_to_q8(node->vector,quants,index->vector_dim,&node->quants_range);
645 break;
646 case HNSW_QUANT_BIN:
647 quantize_to_bin(node->vector,quants,index->vector_dim);
648 break;
649 default:
650 assert(0 && "Quantization type not handled.");
651 break;
652 }
653
654 // Discard the full precision vector.
655 hfree(node->vector);
656 node->vector = quants;
657 }
658 } else {
659 // We got the already quantized vector. Just copy it.
660 assert(index->quant_type != HNSW_QUANT_NONE);
661 uint32_t vector_bytes = hnsw_quants_bytes(index);
662 node->vector = hmalloc(vector_bytes);
663 node->quants_range = qrange;
664 if (node->vector == NULL) {
665 hfree(node);
666 return NULL;
667 }
668 memcpy(node->vector,qvector,vector_bytes);
669 }
670
671 /* Initialize each layer. */
672 for (uint32_t i = 0; i <= level; i++) {
673 uint32_t max_links = (i == 0) ? index->M*2 : index->M;
674 node->layers[i].max_links = max_links;
675 node->layers[i].num_links = 0;
676 node->layers[i].worst_distance = 0;
677 node->layers[i].worst_idx = 0;
678 node->layers[i].links = hmalloc(sizeof(hnswNode*) * max_links);
679 if (!node->layers[i].links) {
680 for (uint32_t j = 0; j < i; j++) hfree(node->layers[j].links);
681 hfree(node->vector);
682 hfree(node);
683 return NULL;
684 }
685 }
686
687 return node;
688}
689
690/* Free a node. */
691void hnsw_node_free(hnswNode *node) {
692 if (!node) return;
693
694 for (uint32_t i = 0; i <= node->level; i++)
695 hfree(node->layers[i].links);
696
697 hfree(node->vector);
698 hfree(node);
699}
700
701/* Free the entire index. */
702void hnsw_free(HNSW *index,void(*free_value)(void*value)) {
703 if (!index) return;
704
705 hnswNode *current = index->head;
706 while (current) {
707 hnswNode *next = current->next;
708 if (free_value) free_value(current->value);
709 hnsw_node_free(current);
710 current = next;
711 }
712
713 /* Destroy locks */
714 pthread_rwlock_destroy(&index->global_lock);
715 for (int i = 0; i < HNSW_MAX_THREADS; i++) {
716 pthread_mutex_destroy(&index->slot_locks[i]);
717 }
718
719 hfree(index);
720}
721
722/* Add node to linked list of nodes. We may need to scan the whole
723 * HNSW graph for several reasons. The list is doubly linked since we
724 * also need the ability to remove a node without scanning the whole thing. */
725void hnsw_add_node(HNSW *index, hnswNode *node) {
726 node->next = index->head;
727 node->prev = NULL;
728 if (index->head)
729 index->head->prev = node;
730 index->head = node;
731 index->node_count++;
732}
733
734/* Search the specified layer starting from the specified entry point
735 * to collect 'ef' nodes that are near to 'query'.
736 *
737 * This function implements optional hybrid search, so that each node
738 * can be accepted or not based on its associated value. In this case
739 * a callback 'filter_callback' should be passed, together with a maximum
740 * effort for the search (number of candidates to evaluate), since even
741 * with a a low "EF" value we risk that there are too few nodes that satisfy
742 * the provided filter, and we could trigger a full scan. */
743pqueue *search_layer_with_filter(
744 HNSW *index, hnswNode *query, hnswNode *entry_point,
745 uint32_t ef, uint32_t layer, uint32_t slot,
746 int (*filter_callback)(void *value, void *privdata),
747 void *filter_privdata, uint32_t max_candidates)
748{
749 // Mark visited nodes with a never seen epoch.
750 index->current_epoch[slot]++;
751
752 pqueue *candidates = pq_new(HNSW_MAX_CANDIDATES);
753 pqueue *results = pq_new(ef);
754 if (!candidates || !results) {
755 if (candidates) pq_free(candidates);
756 if (results) pq_free(results);
757 return NULL;
758 }
759
760 // Take track of the total effort: only used when filtering via
761 // a callback to have a bound effort.
762 uint32_t evaluated_candidates = 1;
763
764 // Add entry point.
765 float dist = hnsw_distance(index, query, entry_point);
766 pq_push(candidates, entry_point, dist);
767 if (filter_callback == NULL ||
768 filter_callback(entry_point->value, filter_privdata))
769 {
770 pq_push(results, entry_point, dist);
771 }
772 entry_point->visited_epoch[slot] = index->current_epoch[slot];
773
774 // Process candidates.
775 while (candidates->count > 0) {
776 // Max effort. If zero, we keep scanning.
777 if (filter_callback &&
778 max_candidates &&
779 evaluated_candidates >= max_candidates) break;
780
781 float cur_dist;
782 hnswNode *current = pq_pop(candidates, &cur_dist);
783 evaluated_candidates++;
784
785 float furthest = pq_max_distance(results);
786 if (results->count >= ef && cur_dist > furthest) break;
787
788 /* Check neighbors. */
789 for (uint32_t i = 0; i < current->layers[layer].num_links; i++) {
790 hnswNode *neighbor = current->layers[layer].links[i];
791
792 if (neighbor->visited_epoch[slot] == index->current_epoch[slot])
793 continue; // Already visited during this scan.
794
795 neighbor->visited_epoch[slot] = index->current_epoch[slot];
796 float neighbor_dist = hnsw_distance(index, query, neighbor);
797
798 furthest = pq_max_distance(results);
799 if (filter_callback == NULL) {
800 /* Original HNSW logic when no filtering:
801 * Add to results if better than current max or
802 * results not full. */
803 if (neighbor_dist < furthest || results->count < ef) {
804 pq_push(candidates, neighbor, neighbor_dist);
805 pq_push(results, neighbor, neighbor_dist);
806 }
807 } else {
808 /* With filtering: we add candidates even if doesn't match
809 * the filter, in order to continue to explore the graph. */
810 if (neighbor_dist < furthest || candidates->count < ef) {
811 pq_push(candidates, neighbor, neighbor_dist);
812 }
813
814 /* Add results only if passes filter. */
815 if (filter_callback(neighbor->value, filter_privdata)) {
816 if (neighbor_dist < furthest || results->count < ef) {
817 pq_push(results, neighbor, neighbor_dist);
818 }
819 }
820 }
821 }
822 }
823
824 pq_free(candidates);
825 return results;
826}
827
828/* Just a wrapper without hybrid search callback. */
829pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
830 uint32_t ef, uint32_t layer, uint32_t slot)
831{
832 return search_layer_with_filter(index, query, entry_point, ef, layer, slot,
833 NULL, NULL, 0);
834}
835
836/* This function is used in order to initialize a node allocated in the
837 * function stack with the specified vector. The idea is that we can
838 * easily use hnsw_distance() from a vector and the HNSW nodes this way:
839 *
840 * hnswNode myQuery;
841 * hnsw_init_tmp_node(myIndex,&myQuery,0,some_vector);
842 * hnsw_distance(&myQuery, some_hnsw_node);
843 *
844 * Make sure to later free the node with:
845 *
846 * hnsw_free_tmp_node(&myQuery,some_vector);
847 * You have to pass the vector to the free function, because sometimes
848 * hnsw_init_tmp_node() may just avoid allocating a vector at all,
849 * just reusing 'some_vector' pointer.
850 *
851 * Return 0 on out of memory, 1 on success.
852 */
853int hnsw_init_tmp_node(HNSW *index, hnswNode *node, int is_normalized, const float *vector) {
854 node->vector = NULL;
855
856 /* Work on a normalized query vector if the input vector is
857 * not normalized. */
858 if (!is_normalized) {
859 node->vector = hmalloc(sizeof(float)*index->vector_dim);
860 if (node->vector == NULL) return 0;
861 memcpy(node->vector,vector,sizeof(float)*index->vector_dim);
862 hnsw_normalize_vector(node->vector,NULL,index->vector_dim);
863 } else {
864 node->vector = (float*)vector;
865 }
866
867 /* If quantization is enabled, our query fake node should be
868 * quantized as well. */
869 if (index->quant_type != HNSW_QUANT_NONE) {
870 void *quants = hmalloc(hnsw_quants_bytes(index));
871 if (quants == NULL) {
872 if (node->vector != vector) hfree(node->vector);
873 return 0;
874 }
875 switch(index->quant_type) {
876 case HNSW_QUANT_Q8:
877 quantize_to_q8(node->vector, quants, index->vector_dim, &node->quants_range);
878 break;
879 case HNSW_QUANT_BIN:
880 quantize_to_bin(node->vector, quants, index->vector_dim);
881 }
882 if (node->vector != vector) hfree(node->vector);
883 node->vector = quants;
884 }
885 return 1;
886}
887
888/* Free the stack allocated node initialized by hnsw_init_tmp_node(). */
889void hnsw_free_tmp_node(hnswNode *node, const float *vector) {
890 if (node->vector != vector) hfree(node->vector);
891}
892
893/* Return approximated K-NN items. Note that neighbors and distances
894 * arrays must have space for at least 'k' items.
895 * norm_query should be set to 1 if the query vector is already
896 * normalized, otherwise, if 0, the function will copy the vector,
897 * L2-normalize the copy and search using the normalized version.
898 *
899 * If the filter_privdata callback is passed, only elements passing the
900 * specified filter (invoked with privdata and the value associated
901 * to the node as arguments) are returned. In such case, if max_candidates
902 * is not NULL, it represents the maximum number of nodes to explore, since
903 * the search may be otherwise unbound if few or no elements pass the
904 * filter. */
905int hnsw_search_with_filter
906 (HNSW *index, const float *query_vector, uint32_t k,
907 hnswNode **neighbors, float *distances, uint32_t slot,
908 int query_vector_is_normalized,
909 int (*filter_callback)(void *value, void *privdata),
910 void *filter_privdata, uint32_t max_candidates)
911
912{
913 if (!index || !query_vector || !neighbors || k == 0) return -1;
914 if (!index->enter_point) return 0; // Empty index.
915
916 /* Use a fake node that holds the query vector, this way we can
917 * use our normal node to node distance functions when checking
918 * the distance between query and graph nodes. */
919 hnswNode query;
920 if (hnsw_init_tmp_node(index,&query,query_vector_is_normalized,query_vector) == 0) return -1;
921
922 // Start searching from the entry point.
923 hnswNode *curr_ep = index->enter_point;
924
925 /* Start from higher layer to layer 1 (layer 0 is handled later)
926 * in the next section. Descend to the most similar node found
927 * so far. */
928 for (int lc = index->max_level; lc > 0; lc--) {
929 pqueue *results = search_layer(index, &query, curr_ep, 1, lc, slot);
930 if (!results) continue;
931
932 if (results->count > 0) {
933 curr_ep = pq_get_node(results,0);
934 }
935 pq_free(results);
936 }
937
938 /* Search bottom layer (the most densely populated) with ef = k */
939 pqueue *results = search_layer_with_filter(
940 index, &query, curr_ep, k, 0, slot, filter_callback,
941 filter_privdata, max_candidates);
942 if (!results) {
943 hnsw_free_tmp_node(&query, query_vector);
944 return -1;
945 }
946
947 /* Copy results. */
948 uint32_t found = MIN(k, results->count);
949 for (uint32_t i = 0; i < found; i++) {
950 neighbors[i] = pq_get_node(results,i);
951 if (distances) {
952 distances[i] = pq_get_distance(results,i);
953 }
954 }
955
956 pq_free(results);
957 hnsw_free_tmp_node(&query, query_vector);
958 return found;
959}
960
961/* Wrapper to hnsw_search_with_filter() when no filter is needed. */
962int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
963 hnswNode **neighbors, float *distances, uint32_t slot,
964 int query_vector_is_normalized)
965{
966 return hnsw_search_with_filter(index,query_vector,k,neighbors,
967 distances,slot,query_vector_is_normalized,
968 NULL,NULL,0);
969}
970
971/* Rescan a node and update the wortst neighbor index.
972 * The followinng two functions are variants of this function to be used
973 * when links are added or removed: they may do less work than a full scan. */
974void hnsw_update_worst_neighbor(HNSW *index, hnswNode *node, uint32_t layer) {
975 float worst_dist = 0;
976 uint32_t worst_idx = 0;
977 for (uint32_t i = 0; i < node->layers[layer].num_links; i++) {
978 float dist = hnsw_distance(index, node, node->layers[layer].links[i]);
979 if (dist > worst_dist) {
980 worst_dist = dist;
981 worst_idx = i;
982 }
983 }
984 node->layers[layer].worst_distance = worst_dist;
985 node->layers[layer].worst_idx = worst_idx;
986}
987
988/* Update node worst neighbor distance information when a new neighbor
989 * is added. */
990void hnsw_update_worst_neighbor_on_add(HNSW *index, hnswNode *node, uint32_t layer, uint32_t added_index, float distance) {
991 (void) index; // Unused but here for API symmetry.
992 if (node->layers[layer].num_links == 1 || // First neighbor?
993 distance > node->layers[layer].worst_distance) // New worst?
994 {
995 node->layers[layer].worst_distance = distance;
996 node->layers[layer].worst_idx = added_index;
997 }
998}
999
1000/* Update node worst neighbor distance information when a linked neighbor
1001 * is removed. */
1002void hnsw_update_worst_neighbor_on_remove(HNSW *index, hnswNode *node, uint32_t layer, uint32_t removed_idx)
1003{
1004 if (node->layers[layer].num_links == 0) {
1005 node->layers[layer].worst_distance = 0;
1006 node->layers[layer].worst_idx = 0;
1007 } else if (removed_idx == node->layers[layer].worst_idx) {
1008 hnsw_update_worst_neighbor(index,node,layer);
1009 } else if (removed_idx < node->layers[layer].worst_idx) {
1010 // Just update index if we removed element before worst.
1011 node->layers[layer].worst_idx--;
1012 }
1013}
1014
1015/* We have a list of candidate nodes to link to the new node, when inserting
1016 * one. This function selects which nodes to link and performs the linking.
1017 *
1018 * Parameters:
1019 *
1020 * - 'candidates' is the priority queue of potential good nodes to link to the
1021 * new node 'new_node'.
1022 * - 'required_links' is as many links we would like our new_node to get
1023 * at the specified layer.
1024 * - 'aggressive' changes the strategy used to find good neighbors as follows:
1025 *
1026 * This function is called with aggressive=0 for all the layers, including
1027 * layer 0. When called like that, it will use the diversity of links and
1028 * quality of links checks before linking our new node with some candidate.
1029 *
1030 * However if the insert function finds that at layer 0, with aggressive=0,
1031 * few connections were made, it calls this function again with aggressiveness
1032 * levels greater up to 2.
1033 *
1034 * At aggressive=1, the diversity checks are disabled, and the candidate
1035 * node for linking is accepted even if it is nearest to an already accepted
1036 * neighbor than it is to the new node.
1037 *
1038 * When we link our new node by replacing the link of a candidate neighbor
1039 * that already has the max number of links, inevitably some other node loses
1040 * a connection (to make space for our new node link). In this case:
1041 *
1042 * 1. If such "dropped" node would remain with too little links, we try with
1043 * some different neighbor instead, however as the 'aggressive' parameter
1044 * has incremental values (0, 1, 2) we are more and more willing to leave
1045 * the dropped node with fever connections.
1046 * 2. If aggressive=2, we will scan the candidate neighbor node links to
1047 * find a different linked-node to replace, one better connected even if
1048 * its distance is not the worse.
1049 *
1050 * Note: this function is also called during deletion of nodes in order to
1051 * provide certain nodes with additional links.
1052 */
1053void select_neighbors(HNSW *index, pqueue *candidates, hnswNode *new_node,
1054 uint32_t layer, uint32_t required_links, int aggressive)
1055{
1056 for (uint32_t i = 0; i < candidates->count; i++) {
1057 hnswNode *neighbor = pq_get_node(candidates,i);
1058 if (neighbor == new_node) continue; // Don't link node with itself.
1059
1060 /* Use our cached distance among the new node and the candidate. */
1061 float dist = pq_get_distance(candidates,i);
1062
1063 /* First of all, since our links are all bidirectional, if the
1064 * new node for any reason has no longer room, or if it accumulated
1065 * the required number of links, return ASAP. */
1066 if (new_node->layers[layer].num_links >= new_node->layers[layer].max_links ||
1067 new_node->layers[layer].num_links >= required_links) return;
1068
1069 /* If aggressive is true, it is possible that the new node
1070 * already got some link among the candidates (see the top comment,
1071 * this function gets re-called in case of too few links).
1072 * So we need to check if this candidate is already linked to
1073 * the new node. */
1074 if (aggressive) {
1075 int duplicated = 0;
1076 for (uint32_t j = 0; j < new_node->layers[layer].num_links; j++) {
1077 if (new_node->layers[layer].links[j] == neighbor) {
1078 duplicated = 1;
1079 break;
1080 }
1081 }
1082 if (duplicated) continue;
1083 }
1084
1085 /* Diversity check. We accept new candidates
1086 * only if there is no element already accepted that is nearest
1087 * to the candidate than the new element itself.
1088 * However this check is disabled if we have pressure to find
1089 * new links (aggressive != 0) */
1090 if (!aggressive) {
1091 int diversity_failed = 0;
1092 for (uint32_t j = 0; j < new_node->layers[layer].num_links; j++) {
1093 float link_dist = hnsw_distance(index, neighbor,
1094 new_node->layers[layer].links[j]);
1095 if (link_dist < dist) {
1096 diversity_failed = 1;
1097 break;
1098 }
1099 }
1100 if (diversity_failed) continue;
1101 }
1102
1103 /* If potential neighbor node has space, simply add the new link.
1104 * We will have space as well. */
1105 uint32_t n = neighbor->layers[layer].num_links;
1106 if (n < neighbor->layers[layer].max_links) {
1107 /* Link candidate to new node. */
1108 neighbor->layers[layer].links[n] = new_node;
1109 neighbor->layers[layer].num_links++;
1110
1111 /* Update candidate worst link info. */
1112 hnsw_update_worst_neighbor_on_add(index,neighbor,layer,n,dist);
1113
1114 /* Link new node to candidate. */
1115 uint32_t new_links = new_node->layers[layer].num_links;
1116 new_node->layers[layer].links[new_links] = neighbor;
1117 new_node->layers[layer].num_links++;
1118
1119 /* Update new node worst link info. */
1120 hnsw_update_worst_neighbor_on_add(index,new_node,layer,new_links,dist);
1121 continue;
1122 }
1123
1124 /* ====================================================================
1125 * Replacing existing candidate neighbor link step.
1126 * ================================================================== */
1127
1128 /* If we are here, our accepted candidate for linking is full.
1129 *
1130 * If new node is more distant to candidate than its current worst link
1131 * then we skip it: we would not be able to establish a bidirectional
1132 * connection without compromising link quality of candidate.
1133 *
1134 * At aggressiveness > 0 we don't care about this check. */
1135 if (!aggressive && dist >= neighbor->layers[layer].worst_distance)
1136 continue;
1137
1138 /* We can add it: we are ready to replace the candidate neighbor worst
1139 * link with the new node, assuming certain conditions are met. */
1140 hnswNode *worst_node = neighbor->layers[layer].links[neighbor->layers[layer].worst_idx];
1141
1142 /* The worst node linked to our candidate may remain too disconnected
1143 * if we remove the candidate node as its link. Let's check if
1144 * this is the case: */
1145 if (aggressive == 0 &&
1146 worst_node->layers[layer].num_links <= index->M/2)
1147 continue;
1148
1149 /* Aggressive level = 1. It's ok if the node remains with just
1150 * HNSW_M/4 links. */
1151 else if (aggressive == 1 &&
1152 worst_node->layers[layer].num_links <= index->M/4)
1153 continue;
1154
1155 /* If aggressive is set to 2, then the new node we are adding failed
1156 * to find enough neighbors. We can't insert an almost orphaned new
1157 * node, so let's see if the target node has some other link
1158 * that is well connected in the graph: we could drop it instead
1159 * of the worst link. */
1160 if (aggressive == 2 && worst_node->layers[layer].num_links <=
1161 index->M/4)
1162 {
1163 /* Let's see if we can find at least a candidate link that
1164 * would remain with a few connections. Track the one
1165 * that is the farthest away (worst distance) from our candidate
1166 * neighbor (in order to remove the less interesting link). */
1167 worst_node = NULL;
1168 uint32_t worst_idx = 0;
1169 float max_dist = 0;
1170 for (uint32_t j = 0; j < neighbor->layers[layer].num_links; j++) {
1171 hnswNode *to_drop = neighbor->layers[layer].links[j];
1172
1173 /* Skip this if it would remain too disconnected as well.
1174 *
1175 * NOTE about index->M/4 min connections requirement:
1176 *
1177 * It is not too strict, since leaving a node with just a
1178 * single link does not just leave it too weakly connected, but
1179 * also sometimes creates cycles with few disconnected
1180 * nodes linked among them. */
1181 if (to_drop->layers[layer].num_links <= index->M/4) continue;
1182
1183 float link_dist = hnsw_distance(index, neighbor, to_drop);
1184 if (worst_node == NULL || link_dist > max_dist) {
1185 worst_node = to_drop;
1186 max_dist = link_dist;
1187 worst_idx = j;
1188 }
1189 }
1190
1191 if (worst_node != NULL) {
1192 /* We found a node that we can drop. Let's pretend this is
1193 * the worst node of the candidate to unify the following
1194 * code path. Later we will fix the worst node info anyway. */
1195 neighbor->layers[layer].worst_distance = max_dist;
1196 neighbor->layers[layer].worst_idx = worst_idx;
1197 } else {
1198 /* Otherwise we have no other option than reallocating
1199 * the max number of links for this target node, and
1200 * ensure at least a few connections for our new node. */
1201 uint32_t reallocation_limit = layer == 0 ?
1202 index->M * 3 : index->M *2;
1203 if (neighbor->layers[layer].max_links >= reallocation_limit)
1204 continue;
1205
1206 uint32_t new_max_links = neighbor->layers[layer].max_links+1;
1207 hnswNode **new_links = hrealloc(neighbor->layers[layer].links,
1208 sizeof(hnswNode*) * new_max_links);
1209 if (new_links == NULL) continue; // Non critical.
1210
1211 /* Update neighbor's link capacity. */
1212 neighbor->layers[layer].links = new_links;
1213 neighbor->layers[layer].max_links = new_max_links;
1214
1215 /* Establish bidirectional link. */
1216 uint32_t n = neighbor->layers[layer].num_links;
1217 neighbor->layers[layer].links[n] = new_node;
1218 neighbor->layers[layer].num_links++;
1219 hnsw_update_worst_neighbor_on_add(index, neighbor, layer,
1220 n, dist);
1221
1222 n = new_node->layers[layer].num_links;
1223 new_node->layers[layer].links[n] = neighbor;
1224 new_node->layers[layer].num_links++;
1225 hnsw_update_worst_neighbor_on_add(index, new_node, layer,
1226 n, dist);
1227 continue;
1228 }
1229 }
1230
1231 // Remove backlink from the worst node of our candidate.
1232 for (uint64_t j = 0; j < worst_node->layers[layer].num_links; j++) {
1233 if (worst_node->layers[layer].links[j] == neighbor) {
1234 memmove(&worst_node->layers[layer].links[j],
1235 &worst_node->layers[layer].links[j+1],
1236 (worst_node->layers[layer].num_links - j - 1) * sizeof(hnswNode*));
1237 worst_node->layers[layer].num_links--;
1238 hnsw_update_worst_neighbor_on_remove(index,worst_node,layer,j);
1239 break;
1240 }
1241 }
1242
1243 /* Replace worst link with the new node. */
1244 neighbor->layers[layer].links[neighbor->layers[layer].worst_idx] = new_node;
1245
1246 /* Update the worst link in the target node, at this point
1247 * the link that we replaced may no longer be the worst. */
1248 hnsw_update_worst_neighbor(index,neighbor,layer);
1249
1250 // Add new node -> candidate link.
1251 uint32_t new_links = new_node->layers[layer].num_links;
1252 new_node->layers[layer].links[new_links] = neighbor;
1253 new_node->layers[layer].num_links++;
1254
1255 // Update new node worst link.
1256 hnsw_update_worst_neighbor_on_add(index,new_node,layer,new_links,dist);
1257 }
1258}
1259
1260/* This function implements node reconnection after a node deletion in HNSW.
1261 * When a node is deleted, other nodes at the specified layer lose one
1262 * connection (all the neighbors of the deleted node). This function attempts
1263 * to pair such nodes together in a way that maximizes connection quality
1264 * among the M nodes that were former neighbors of our deleted node.
1265 *
1266 * The algorithm works by first building a distance matrix among the nodes:
1267 *
1268 * N0 N1 N2 N3
1269 * N0 0 1.2 0.4 0.9
1270 * N1 1.2 0 0.8 0.5
1271 * N2 0.4 0.8 0 1.1
1272 * N3 0.9 0.5 1.1 0
1273 *
1274 * For each potential pairing (i,j) we compute a score that combines:
1275 * 1. The direct cosine distance between the two nodes
1276 * 2. The average distance to other nodes that would no longer be
1277 * available for pairing if we select this pair
1278 *
1279 * We want to balance local node-to-node requirements and global requirements.
1280 * For instance sometimes connecting A with B, while optimal, would leave
1281 * C and D to be connected without other choices, and this could be a very
1282 * bad connection. Maybe instead A and C and B and D are both relatively high
1283 * quality connections.
1284 *
1285 * The formula used to calculate the score of each connection is:
1286 *
1287 * score[i,j] = W1*(2-distance[i,j]) + W2*((new_avg_i + new_avg_j)/2)
1288 * where new_avg_x is the average of distances in row x excluding distance[i,j]
1289 *
1290 * So the score is directly proportional to the SIMILARITY of the two nodes
1291 * and also directly proportional to the DISTANCE of the potential other
1292 * connections that we lost by pairign i,j. So we have a cost for missed
1293 * opportunities, or better, in this case, a reward if the missing
1294 * opportunities are not so good (big average distance).
1295 *
1296 * W1 and W2 are weights (defaults: 0.7 and 0.3) that determine the relative
1297 * importance of immediate connection quality vs future pairing potential.
1298 *
1299 * After the initial pairing phase, any nodes that couldn't be paired
1300 * (due to odd count or existing connections) are handled by searching
1301 * the broader graph using the standard HNSW neighbor selection logic.
1302 */
1303void hnsw_reconnect_nodes(HNSW *index, hnswNode **nodes, int count, uint32_t layer) {
1304 if (count <= 0) return;
1305 debugmsg("Reconnecting %d nodes\n", count);
1306
1307 /* Step 1: Build the distance matrix between all nodes.
1308 * Since distance(i,j) = distance(j,i), we only compute the upper triangle
1309 * and mirror it to the lower triangle. */
1310 float *distances = hmalloc((unsigned long) count * count * sizeof(float));
1311 if (!distances) return;
1312
1313 for (int i = 0; i < count; i++) {
1314 distances[i*count + i] = 0; // Distance to self is 0
1315 for (int j = i+1; j < count; j++) {
1316 float dist = hnsw_distance(index, nodes[i], nodes[j]);
1317 distances[i*count + j] = dist; // Upper triangle.
1318 distances[j*count + i] = dist; // Lower triangle.
1319 }
1320 }
1321
1322 /* Step 2: Calculate row averages (will be used in scoring):
1323 * please note that we just calculate row averages and not
1324 * columns averages since the matrix is symmetrical, so those
1325 * are the same: check the image in the top comment if you have any
1326 * doubt about this. */
1327 float *row_avgs = hmalloc(count * sizeof(float));
1328 if (!row_avgs) {
1329 hfree(distances);
1330 return;
1331 }
1332
1333 for (int i = 0; i < count; i++) {
1334 float sum = 0;
1335 int valid_count = 0;
1336 for (int j = 0; j < count; j++) {
1337 if (i != j) {
1338 sum += distances[i*count + j];
1339 valid_count++;
1340 }
1341 }
1342 row_avgs[i] = valid_count ? sum / valid_count : 0;
1343 }
1344
1345 /* Step 3: Build scoring matrix. What we do here is to combine how
1346 * good is a given i,j nodes connection, with how badly connecting
1347 * i,j will affect the remaining quality of connections left to
1348 * pair the other nodes. */
1349 float *scores = hmalloc((unsigned long) count * count * sizeof(float));
1350 if (!scores) {
1351 hfree(distances);
1352 hfree(row_avgs);
1353 return;
1354 }
1355
1356 /* Those weights were obtained manually... No guarantee that they
1357 * are optimal. However with these values the algorithm is certain
1358 * better than its greedy version that just attempts to pick the
1359 * best pair each time (verified experimentally). */
1360 const float W1 = 0.7; // Weight for immediate distance.
1361 const float W2 = 0.3; // Weight for future potential.
1362
1363 for (int i = 0; i < count; i++) {
1364 for (int j = 0; j < count; j++) {
1365 if (i == j) {
1366 scores[i*count + j] = -1; // Invalid pairing.
1367 continue;
1368 }
1369
1370 // Check for existing connection between i and j.
1371 int already_linked = 0;
1372 for (uint32_t k = 0; k < nodes[i]->layers[layer].num_links; k++)
1373 {
1374 if (nodes[i]->layers[layer].links[k] == nodes[j]) {
1375 scores[i*count + j] = -1; // Already linked.
1376 already_linked = 1;
1377 break;
1378 }
1379 }
1380 if (already_linked) continue;
1381
1382 float dist = distances[i*count + j];
1383
1384 /* Calculate new averages excluding this pair.
1385 * Handle edge case where we might have too few elements.
1386 * Note that it would be not very smart to recompute the average
1387 * each time scanning the row, we can remove the element
1388 * and adjust the average without it. */
1389 float new_avg_i = 0, new_avg_j = 0;
1390 if (count > 2) {
1391 new_avg_i = (row_avgs[i] * (count-1) - dist) / (count-2);
1392 new_avg_j = (row_avgs[j] * (count-1) - dist) / (count-2);
1393 }
1394
1395 /* Final weighted score: the more similar i,j, the better
1396 * the score. The more distant are the pairs we lose by
1397 * connecting i,j, the better the score. */
1398 scores[i*count + j] = W1*(2-dist) + W2*((new_avg_i + new_avg_j)/2);
1399 }
1400 }
1401
1402 // Step 5: Pair nodes greedily based on scores.
1403 int *used = hmalloc(count*sizeof(int));
1404 memset(used,0,count*sizeof(int));
1405 if (!used) {
1406 hfree(distances);
1407 hfree(row_avgs);
1408 hfree(scores);
1409 return;
1410 }
1411
1412 /* Scan the matrix looking each time for the potential
1413 * link with the best score. */
1414 while(1) {
1415 float max_score = -1;
1416 int best_j = -1, best_i = -1;
1417
1418 // Seek best score i,j values.
1419 for (int i = 0; i < count; i++) {
1420 if (used[i]) continue; // Already connected.
1421
1422 /* No space left? Not possible after a node deletion but makes
1423 * this function more future-proof. */
1424 if (nodes[i]->layers[layer].num_links >=
1425 nodes[i]->layers[layer].max_links) continue;
1426
1427 for (int j = 0; j < count; j++) {
1428 if (i == j) continue; // Same node, skip.
1429 if (used[j]) continue; // Already connected.
1430 float score = scores[i*count + j];
1431 if (score < 0) continue; // Invalid link.
1432
1433 /* If the target node has space, and its score is better
1434 * than any other seen so far... remember it is the best. */
1435 if (score > max_score &&
1436 nodes[j]->layers[layer].num_links <
1437 nodes[j]->layers[layer].max_links)
1438 {
1439 // Track the best connection found so far.
1440 max_score = score;
1441 best_j = j;
1442 best_i = i;
1443 }
1444 }
1445 }
1446
1447 // Possible link found? Connect i and j.
1448 if (best_j != -1) {
1449 debugmsg("[%d] linking %d with %d: %f\n", layer, (int)best_i, (int)best_j, max_score);
1450 // Link i -> j.
1451 int link_idx = nodes[best_i]->layers[layer].num_links;
1452 nodes[best_i]->layers[layer].links[link_idx] = nodes[best_j];
1453 nodes[best_i]->layers[layer].num_links++;
1454
1455 // Update worst distance if needed.
1456 float dist = distances[best_i*count + best_j];
1457 hnsw_update_worst_neighbor_on_add(index,nodes[best_i],layer,link_idx,dist);
1458
1459 // Link j -> i.
1460 link_idx = nodes[best_j]->layers[layer].num_links;
1461 nodes[best_j]->layers[layer].links[link_idx] = nodes[best_i];
1462 nodes[best_j]->layers[layer].num_links++;
1463
1464 // Update worst distance if needed.
1465 hnsw_update_worst_neighbor_on_add(index,nodes[best_j],layer,link_idx,dist);
1466
1467 // Mark connection as used.
1468 used[best_i] = used[best_j] = 1;
1469 } else {
1470 break; // No more valid connections available.
1471 }
1472 }
1473
1474 /* Step 6: Handle remaining unpaired nodes using the standard HNSW
1475 * neighbor selection. */
1476 for (int i = 0; i < count; i++) {
1477 if (used[i]) continue;
1478
1479 // Skip if node is already at max connections.
1480 if (nodes[i]->layers[layer].num_links >=
1481 nodes[i]->layers[layer].max_links)
1482 continue;
1483
1484 debugmsg("[%d] Force linking %d\n", layer, i);
1485
1486 /* First, try with local nodes as candidates.
1487 * Some candidate may have space. */
1488 pqueue *candidates = pq_new(count);
1489 if (!candidates) continue;
1490
1491 /* Add all the local nodes having some space as candidates
1492 * to be linked with this node. */
1493 for (int j = 0; j < count; j++) {
1494 if (i != j && // Must not be itself.
1495 nodes[j]->layers[layer].num_links < // Must not be full.
1496 nodes[j]->layers[layer].max_links)
1497 {
1498 float dist = distances[i*count + j];
1499 pq_push(candidates, nodes[j], dist);
1500 }
1501 }
1502
1503 /* Try local candidates first with aggressive = 1.
1504 * So we will link only if there is space.
1505 * We want one link more than the links we already have. */
1506 uint32_t wanted_links = nodes[i]->layers[layer].num_links+1;
1507 if (candidates->count > 0) {
1508 select_neighbors(index, candidates, nodes[i], layer,
1509 wanted_links, 1);
1510 debugmsg("Final links after attempt with local nodes: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links);
1511 }
1512
1513 // If still no connection, search the broader graph.
1514 if (nodes[i]->layers[layer].num_links != wanted_links) {
1515 debugmsg("No force linking possible with local candidates\n");
1516 pq_free(candidates);
1517
1518 // Find entry point for target layer by descending through levels.
1519 hnswNode *curr_ep = index->enter_point;
1520 for (uint32_t lc = index->max_level; lc > layer; lc--) {
1521 pqueue *results = search_layer(index, nodes[i], curr_ep, 1, lc, 0);
1522 if (results) {
1523 if (results->count > 0) {
1524 curr_ep = pq_get_node(results,0);
1525 }
1526 pq_free(results);
1527 }
1528 }
1529
1530 if (curr_ep) {
1531 /* Search this layer for candidates.
1532 * Use the default EF_C in this case, since it's not an
1533 * "insert" operation, and we don't know the user
1534 * specified "EF". */
1535 candidates = search_layer(index, nodes[i], curr_ep, HNSW_EF_C, layer, 0);
1536 if (candidates) {
1537 /* Try to connect with aggressiveness proportional to the
1538 * node linking condition. */
1539 int aggressiveness =
1540 (nodes[i]->layers[layer].num_links > index->M / 2)
1541 ? 1 : 2;
1542 select_neighbors(index, candidates, nodes[i], layer,
1543 wanted_links, aggressiveness);
1544 debugmsg("Final links with broader search: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links);
1545 pq_free(candidates);
1546 }
1547 }
1548 } else {
1549 pq_free(candidates);
1550 }
1551 }
1552
1553 // Cleanup.
1554 hfree(distances);
1555 hfree(row_avgs);
1556 hfree(scores);
1557 hfree(used);
1558}
1559
1560/* This is an helper function in order to support node deletion.
1561 * It's goal is just to:
1562 *
1563 * 1. Remove the node from the bidirectional links of neighbors in the graph.
1564 * 2. Remove the node from the linked list of nodes.
1565 * 3. Fix the entry point in the graph. We just select one of the neighbors
1566 * of the deleted node at a lower level. If none is found, we do
1567 * a full scan.
1568 * 4. The node itself amd its aux value field are NOT freed. It's up to the
1569 * caller to do it, by using hnsw_node_free().
1570 * 5. The node associated value (node->value) is NOT freed.
1571 *
1572 * Why this function will not free the node? Because in node updates it
1573 * could be a good idea to reuse the node allocation for different reasons
1574 * (currently not implemented).
1575 * In general it is more future-proof to be able to reuse the node if
1576 * needed. Right now this library reuses the node only when links are
1577 * not touched (see hnsw_update() for more information). */
1578void hnsw_unlink_node(HNSW *index, hnswNode *node) {
1579 if (!index || !node) return;
1580
1581 index->version++; // This node may be missing in an already compiled list
1582 // of neighbors. Make optimistic concurrent inserts fail.
1583
1584 /* Remove all bidirectional links at each level.
1585 * Note that in this implementation all the
1586 * links are guaranteed to be bedirectional. */
1587
1588 /* For each level of the deleted node... */
1589 for (uint32_t level = 0; level <= node->level; level++) {
1590 /* For each linked node of the deleted node... */
1591 for (uint32_t i = 0; i < node->layers[level].num_links; i++) {
1592 hnswNode *linked = node->layers[level].links[i];
1593 /* Find and remove the backlink in the linked node */
1594 for (uint32_t j = 0; j < linked->layers[level].num_links; j++) {
1595 if (linked->layers[level].links[j] == node) {
1596 /* Remove by shifting remaining links left */
1597 memmove(&linked->layers[level].links[j],
1598 &linked->layers[level].links[j + 1],
1599 (linked->layers[level].num_links - j - 1) * sizeof(hnswNode*));
1600 linked->layers[level].num_links--;
1601 hnsw_update_worst_neighbor_on_remove(index,linked,level,j);
1602 break;
1603 }
1604 }
1605 }
1606 }
1607
1608 /* Update cursors pointing at this element. */
1609 if (index->cursors) hnsw_cursor_element_deleted(index,node);
1610
1611 /* Update the previous node's next pointer. */
1612 if (node->prev) {
1613 node->prev->next = node->next;
1614 } else {
1615 /* If there's no previous node, this is the head. */
1616 index->head = node->next;
1617 }
1618
1619 /* Update the next node's prev pointer. */
1620 if (node->next) node->next->prev = node->prev;
1621
1622 /* Update node count. */
1623 index->node_count--;
1624
1625 /* If this node was the enter_point, we need to update it. */
1626 if (node == index->enter_point) {
1627 /* Reset entry point - we'll find a new one (unless the HNSW is
1628 * now empty) */
1629 index->enter_point = NULL;
1630 index->max_level = 0;
1631
1632 /* Step 1: Try to find a replacement by scanning levels
1633 * from top to bottom. Under normal conditions, if there is
1634 * any other node at the same level, we have a link. Anyway
1635 * we descend levels to find any neighbor at the higher level
1636 * possible. */
1637 for (int level = node->level; level >= 0; level--) {
1638 if (node->layers[level].num_links > 0) {
1639 index->enter_point = node->layers[level].links[0];
1640 break;
1641 }
1642 }
1643
1644 /* Step 2: If no links were found at any level, do a full scan.
1645 * This should never happen in practice if the HNSW is not
1646 * empty. */
1647 if (!index->enter_point) {
1648 uint32_t new_max_level = 0;
1649 hnswNode *current = index->head;
1650
1651 while (current) {
1652 if (current != node && current->level >= new_max_level) {
1653 new_max_level = current->level;
1654 index->enter_point = current;
1655 }
1656 current = current->next;
1657 }
1658 }
1659
1660 /* Update max_level. */
1661 if (index->enter_point)
1662 index->max_level = index->enter_point->level;
1663 }
1664
1665 /* Clear the node's links but don't free the node itself */
1666 node->prev = node->next = NULL;
1667}
1668
1669/* Higher level API for hnsw_unlink_node() + hnsw_reconnect_nodes() actual work.
1670 * This will get the write lock, will delete the node, free it,
1671 * reconnect the node neighbors among themselves, and unlock again.
1672 * If free_value function pointer is not NULL, then the function provided is
1673 * used to free node->value.
1674 *
1675 * The function returns 0 on error (inability to acquire the lock), otherwise
1676 * 1 is returned. */
1677int hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)) {
1678 if (pthread_rwlock_wrlock(&index->global_lock) != 0) return 0;
1679 hnsw_unlink_node(index,node);
1680 if (free_value && node->value) free_value(node->value);
1681
1682 /* Relink all the nodes orphaned of this node link.
1683 * Do it for all the levels. */
1684 for (unsigned int j = 0; j <= node->level; j++) {
1685 hnsw_reconnect_nodes(index, node->layers[j].links,
1686 node->layers[j].num_links, j);
1687 }
1688 hnsw_node_free(node);
1689 pthread_rwlock_unlock(&index->global_lock);
1690 return 1;
1691}
1692
1693/* ============================ Threaded API ================================
1694 * Concurrent readers should use the following API to get a slot assigned
1695 * (and a lock, too), do their read-only call, and unlock the slot.
1696 *
1697 * There is a reason why read operations don't implement opaque transparent
1698 * locking directly on behalf of the user: when we return a result set
1699 * with hnsw_search(), we report a set of nodes. The caller will do something
1700 * with the nodes and the associated values, so the unlocking of the
1701 * slot should happen AFTER the result was already used, otherwise we may
1702 * have changes to the HNSW nodes as the result is being accessed. */
1703
1704/* Try to acquire a read slot. Returns the slot number (0 to HNSW_MAX_THREADS-1)
1705 * on success, -1 on error (pthread mutex errors). */
1706int hnsw_acquire_read_slot(HNSW *index) {
1707 /* First try a non-blocking approach on all slots. */
1708 for (uint32_t i = 0; i < HNSW_MAX_THREADS; i++) {
1709 if (pthread_mutex_trylock(&index->slot_locks[i]) == 0) {
1710 if (pthread_rwlock_rdlock(&index->global_lock) != 0) {
1711 pthread_mutex_unlock(&index->slot_locks[i]);
1712 return -1;
1713 }
1714 return i;
1715 }
1716 }
1717
1718 /* All trylock attempts failed, use atomic increment to select slot. */
1719 uint32_t slot = index->next_slot++ % HNSW_MAX_THREADS;
1720
1721 /* Try to lock the selected slot. */
1722 if (pthread_mutex_lock(&index->slot_locks[slot]) != 0) return -1;
1723
1724 /* Get read lock. */
1725 if (pthread_rwlock_rdlock(&index->global_lock) != 0) {
1726 pthread_mutex_unlock(&index->slot_locks[slot]);
1727 return -1;
1728 }
1729
1730 return slot;
1731}
1732
1733/* Release a previously acquired read slot: note that it is important that
1734 * nodes returned by hnsw_search() are accessed while the read lock is
1735 * still active, to be sure that nodes are not freed. */
1736void hnsw_release_read_slot(HNSW *index, int slot) {
1737 if (slot < 0 || slot >= HNSW_MAX_THREADS) return;
1738 pthread_rwlock_unlock(&index->global_lock);
1739 pthread_mutex_unlock(&index->slot_locks[slot]);
1740}
1741
1742/* ============================ Nodes insertion =============================
1743 * We have an optimistic API separating the read-only candidates search
1744 * and the write side (actual node insertion). We internally also use
1745 * this API to provide the plain hnsw_insert() function for code unification. */
1746
1747struct InsertContext {
1748 pqueue *level_queues[HNSW_MAX_LEVEL]; /* Candidates for each level. */
1749 hnswNode *node; /* Pre-allocated node ready for insertion */
1750 uint64_t version; /* Index version at preparation time. This is used
1751 * for CAS-like locking during change commit. */
1752};
1753
1754/* Optimistic insertion API.
1755 *
1756 * WARNING: Note that this is an internal function: users should call
1757 * hnsw_prepare_insert() instead.
1758 *
1759 * This is how it works: you use hnsw_prepare_insert() and it will return
1760 * a context where good candidate neighbors are already pre-selected.
1761 * This step only uses read locks.
1762 *
1763 * Then finally you try to actually commit the new node with
1764 * hnsw_try_commit_insert(): this time we will require a write lock, but
1765 * for less time than it would be otherwise needed if using directly
1766 * hnsw_insert(). When you try to commit the write, if no node was deleted in
1767 * the meantime, your operation will succeed, otherwise it will fail, and
1768 * you should try to just use the hnsw_insert() API, since there is
1769 * contention.
1770 *
1771 * See hnsw_node_new() for information about 'vector' and 'qvector'
1772 * arguments, and which one to pass. */
1773InsertContext *hnsw_prepare_insert_nolock(HNSW *index, const float *vector,
1774 const int8_t *qvector, float qrange, uint64_t id,
1775 int slot, int ef)
1776{
1777 InsertContext *ctx = hmalloc(sizeof(*ctx));
1778 if (!ctx) return NULL;
1779
1780 memset(ctx, 0, sizeof(*ctx));
1781 ctx->version = index->version;
1782
1783 /* Crete a new node that we may be able to insert into the
1784 * graph later, when calling the commit function. */
1785 uint32_t level = random_level();
1786 ctx->node = hnsw_node_new(index, id, vector, qvector, qrange, level, 1);
1787 if (!ctx->node) {
1788 hfree(ctx);
1789 return NULL;
1790 }
1791
1792 hnswNode *curr_ep = index->enter_point;
1793
1794 /* Empty graph, no need to collect candidates. */
1795 if (curr_ep == NULL) return ctx;
1796
1797 /* Phase 1: Find good entry point on the highest level of the new
1798 * node we are going to insert. */
1799 for (unsigned int lc = index->max_level; lc > level; lc--) {
1800 pqueue *results = search_layer(index, ctx->node, curr_ep, 1, lc, slot);
1801
1802 if (results) {
1803 if (results->count > 0) curr_ep = pq_get_node(results,0);
1804 pq_free(results);
1805 }
1806 }
1807
1808 /* Phase 2: Collect a set of potential connections for each layer of
1809 * the new node. */
1810 for (int lc = MIN(level, index->max_level); lc >= 0; lc--) {
1811 pqueue *candidates =
1812 search_layer(index, ctx->node, curr_ep, ef, lc, slot);
1813
1814 if (!candidates) continue;
1815 curr_ep = (candidates->count > 0) ? pq_get_node(candidates,0) : curr_ep;
1816 ctx->level_queues[lc] = candidates;
1817 }
1818
1819 return ctx;
1820}
1821
1822/* External API for hnsw_prepare_insert_nolock(), handling locking. */
1823InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector,
1824 const int8_t *qvector, float qrange, uint64_t id,
1825 int ef)
1826{
1827 InsertContext *ctx;
1828 int slot = hnsw_acquire_read_slot(index);
1829 ctx = hnsw_prepare_insert_nolock(index,vector,qvector,qrange,id,slot,ef);
1830 hnsw_release_read_slot(index,slot);
1831 return ctx;
1832}
1833
1834/* Free an insert context and all its resources. */
1835void hnsw_free_insert_context(InsertContext *ctx) {
1836 if (!ctx) return;
1837 for (uint32_t i = 0; i < HNSW_MAX_LEVEL; i++) {
1838 if (ctx->level_queues[i]) pq_free(ctx->level_queues[i]);
1839 }
1840 if (ctx->node) hnsw_node_free(ctx->node);
1841 hfree(ctx);
1842}
1843
1844/* Commit a prepared insert operation. This function is a low level API that
1845 * should not be called by the user. See instead hnsw_try_commit_insert(), that
1846 * will perform the CAS check and acquire the write lock.
1847 *
1848 * See the top comment in hnsw_prepare_insert() for more information
1849 * on the optimistic insertion API.
1850 *
1851 * This function can't fail and always returns the pointer to the
1852 * just inserted node. Out of memory is not possible since no critical
1853 * allocation is never performed in this code path: we populate links
1854 * on already allocated nodes. */
1855hnswNode *hnsw_commit_insert_nolock(HNSW *index, InsertContext *ctx, void *value) {
1856 hnswNode *node = ctx->node;
1857 node->value = value;
1858
1859 /* Handle first node case. */
1860 if (index->enter_point == NULL) {
1861 index->version++; // First node, make concurrent inserts fail.
1862 index->enter_point = node;
1863 index->max_level = node->level;
1864 hnsw_add_node(index, node);
1865 ctx->node = NULL; // So hnsw_free_insert_context() will not free it.
1866 hnsw_free_insert_context(ctx);
1867 return node;
1868 }
1869
1870 /* Connect the node with near neighbors at each level. */
1871 for (int lc = MIN(node->level,index->max_level); lc >= 0; lc--) {
1872 if (ctx->level_queues[lc] == NULL) continue;
1873
1874 /* Try to provide index->M connections to our node. The call
1875 * is not guaranteed to be able to provide all the links we would
1876 * like to have for the new node: they must be bi-directional, obey
1877 * certain quality checks, and so forth, so later there are further
1878 * calls to force the hand a bit if needed.
1879 *
1880 * Let's start with aggressiveness = 0. */
1881 select_neighbors(index, ctx->level_queues[lc], node, lc, index->M, 0);
1882
1883 /* Layer 0 and too few connections? Let's be more aggressive. */
1884 if (lc == 0 && node->layers[0].num_links < index->M/2) {
1885 select_neighbors(index, ctx->level_queues[lc], node, lc,
1886 index->M, 1);
1887
1888 /* Still too few connections? Let's go to
1889 * aggressiveness level '2' in linking strategy. */
1890 if (node->layers[0].num_links < index->M/4) {
1891 select_neighbors(index, ctx->level_queues[lc], node, lc,
1892 index->M/4, 2);
1893 }
1894 }
1895 }
1896
1897 /* If new node level is higher than current max, update entry point. */
1898 if (node->level > index->max_level) {
1899 index->version++; // Entry point changed, make concurrent inserts fail.
1900 index->enter_point = node;
1901 index->max_level = node->level;
1902 }
1903
1904 /* Add node to the linked list. */
1905 hnsw_add_node(index, node);
1906 ctx->node = NULL; // So hnsw_free_insert_context() will not free the node.
1907 hnsw_free_insert_context(ctx);
1908 return node;
1909}
1910
1911/* If the context obtained with hnsw_prepare_insert() is still valid
1912 * (nodes not deleted in the meantime) then add the new node to the HNSW
1913 * index and return its pointer. Otherwise NULL is returned and the operation
1914 * should be either performed with the blocking API hnsw_insert() or attempted
1915 * again. */
1916hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx, void *value) {
1917 /* Check if the version changed since preparation. Note that we
1918 * should access index->version under the write lock in order to
1919 * be sure we can safely commit the write: this is just a fast-path
1920 * in order to return ASAP without acquiring the write lock in case
1921 * the version changed. */
1922 if (ctx->version != index->version) {
1923 hnsw_free_insert_context(ctx);
1924 return NULL;
1925 }
1926
1927 /* Try to acquire write lock. */
1928 if (pthread_rwlock_wrlock(&index->global_lock) != 0) {
1929 hnsw_free_insert_context(ctx);
1930 return NULL;
1931 }
1932
1933 /* Check version again under write lock. */
1934 if (ctx->version != index->version) {
1935 pthread_rwlock_unlock(&index->global_lock);
1936 hnsw_free_insert_context(ctx);
1937 return NULL;
1938 }
1939
1940 /* Commit the change: note that it's up to hnsw_commit_insert_nolock()
1941 * to free the insertion context. */
1942 hnswNode *node = hnsw_commit_insert_nolock(index, ctx, value);
1943
1944 /* Release the write lock. */
1945 pthread_rwlock_unlock(&index->global_lock);
1946 return node;
1947}
1948
1949/* Insert a new element into the graph.
1950 * See hnsw_node_new() for information about 'vector' and 'qvector'
1951 * arguments, and which one to pass.
1952 *
1953 * Return NULL on out of memory during insert. Otherwise the newly
1954 * inserted node pointer is returned. */
1955hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, void *value, int ef) {
1956 /* Write lock. We acquire the write lock even for the prepare()
1957 * operation (that is a read-only operation) since we want this function
1958 * to don't fail in the check-and-set stage of commit().
1959 *
1960 * Basically here we are using the optimistic API in a non-optimistinc
1961 * way in order to have a single insertion code in the implementation. */
1962 if (pthread_rwlock_wrlock(&index->global_lock) != 0) return NULL;
1963
1964 // Prepare the insertion - note we pass slot 0 since we're single threaded.
1965 InsertContext *ctx = hnsw_prepare_insert_nolock(index, vector, qvector,
1966 qrange, id, 0, ef);
1967 if (!ctx) {
1968 pthread_rwlock_unlock(&index->global_lock);
1969 return NULL;
1970 }
1971
1972 // Commit the prepared insertion without version checking.
1973 hnswNode *node = hnsw_commit_insert_nolock(index, ctx, value);
1974
1975 // Release write lock and return our node pointer.
1976 pthread_rwlock_unlock(&index->global_lock);
1977 return node;
1978}
1979
1980/* Helper function for qsort call in hnsw_should_reuse_node(). */
1981static int compare_floats(const float *a, const float *b) {
1982 if (*a < *b) return 1;
1983 if (*a > *b) return -1;
1984 return 0;
1985}
1986
1987/* This function determines if a node can be reused with a new vector by:
1988 *
1989 * 1. Computing average of worst 25% of current distances.
1990 * 2. Checking if at least 50% of new distances stay below this threshold.
1991 * 3. Requiring a minimum number of links for the check to be meaningful.
1992 *
1993 * This check is useful when we want to just update a node that already
1994 * exists in the graph. Often the new vector is a learned embedding generated
1995 * by some model, and the embedding represents some document that perhaps
1996 * changed just slightly compared to the past, so the new embedding will
1997 * be very nearby. We need to find a way do determine if the current node
1998 * neighbors (practically speaking its location in the grapb) are good
1999 * enough even with the new vector.
2000 *
2001 * XXX: this function needs improvements: successive updates to the same
2002 * node with more and more distant vectors will make the node drift away
2003 * from its neighbors. One of the additional metrics used could be
2004 * neighbor-to-neighbor distance, that represents a more absolute check
2005 * of fit for the new vector. */
2006int hnsw_should_reuse_node(HNSW *index, hnswNode *node, int is_normalized, const float *new_vector) {
2007 /* Step 1: Not enough links? Advice to avoid reuse. */
2008 const uint32_t min_links_for_reuse = 4;
2009 uint32_t layer0_connections = node->layers[0].num_links;
2010 if (layer0_connections < min_links_for_reuse) return 0;
2011
2012 /* Step2: get all current distances and run our heuristic. */
2013 float *old_distances = hmalloc(sizeof(float) * layer0_connections);
2014 if (!old_distances) return 0;
2015
2016 // Temporary node with the new vector, to simplify the next logic.
2017 hnswNode tmp_node;
2018 if (hnsw_init_tmp_node(index,&tmp_node,is_normalized,new_vector) == 0) {
2019 hfree(old_distances);
2020 return 0;
2021 }
2022
2023 /* Get old dinstances and sort them to access the 25% worst
2024 * (bigger) ones. */
2025 for (uint32_t i = 0; i < layer0_connections; i++) {
2026 old_distances[i] = hnsw_distance(index, node, node->layers[0].links[i]);
2027 }
2028 qsort(old_distances, layer0_connections, sizeof(float),
2029 (int (*)(const void*, const void*))(&compare_floats));
2030
2031 uint32_t count = (layer0_connections+3)/4; // 25% approx to larger int.
2032 if (count > layer0_connections) count = layer0_connections; // Futureproof.
2033 float worst_avg = 0;
2034
2035 // Compute average of 25% worst dinstances.
2036 for (uint32_t i = 0; i < count; i++) worst_avg += old_distances[i];
2037 worst_avg /= count;
2038 hfree(old_distances);
2039
2040 // Count how many new distances stay below the threshold.
2041 uint32_t good_distances = 0;
2042 for (uint32_t i = 0; i < layer0_connections; i++) {
2043 float new_dist = hnsw_distance(index, &tmp_node, node->layers[0].links[i]);
2044 if (new_dist <= worst_avg) good_distances++;
2045 }
2046 hnsw_free_tmp_node(&tmp_node,new_vector);
2047
2048 /* At least 50% of the nodes should pass our quality test, for the
2049 * node to be reused. */
2050 return good_distances >= layer0_connections/2;
2051}
2052
2053/**
2054 * Return a random node from the HNSW graph.
2055 *
2056 * This function performs a random walk starting from the entry point,
2057 * using only level 0 connections for navigation. It uses log^2(N) steps
2058 * to ensure proper mixing time.
2059 */
2060
2061hnswNode *hnsw_random_node(HNSW *index, int slot) {
2062 if (index->node_count == 0 || index->enter_point == NULL)
2063 return NULL;
2064
2065 (void)slot; // Unused, but we need the caller to acquire the lock.
2066
2067 /* First phase: descend from max level to level 0 taking random paths.
2068 * Note that we don't need a more conservative log^2(N) steps for
2069 * proper mixing, since we already descend to a random cluster here. */
2070 hnswNode *current = index->enter_point;
2071 for (uint32_t level = index->max_level; level > 0; level--) {
2072 /* If current node doesn't have this level or no links, continue
2073 * to lower level. */
2074 if (current->level < level || current->layers[level].num_links == 0)
2075 continue;
2076
2077 /* Choose random neighbor at this level. */
2078 uint32_t rand_neighbor = rand() % current->layers[level].num_links;
2079 current = current->layers[level].links[rand_neighbor];
2080 }
2081
2082 /* Second phase: at level 0, take log(N) * c random steps. */
2083 const int c = 3; // Multiplier for more thorough exploration.
2084 double logN = log2(index->node_count + 1);
2085 uint32_t num_walks = (uint32_t)(logN * c);
2086
2087 /* Avoid the ping-pong effect: imagine there are just two nodes and
2088 * the number of walks selected is even. We will select always the
2089 * first element of the graph; conversely, if it is odd, we will always
2090 * select the other element. One way to add more selection randomness is
2091 * to randomly add '1' or '0' to the number of walks to perform. */
2092 num_walks += rand() & 1;
2093
2094 // Perform random walk at level 0.
2095 for (uint32_t i = 0; i < num_walks; i++) {
2096 if (current->layers[0].num_links == 0) return current;
2097
2098 // Choose random neighbor.
2099 uint32_t rand_neighbor = rand() % current->layers[0].num_links;
2100 current = current->layers[0].links[rand_neighbor];
2101 }
2102 return current;
2103}
2104
2105/* ============================= Serialization ==============================
2106 *
2107 * TO SERIALIZE
2108 * ============
2109 *
2110 * To serialize on disk, you need to persist the vector dimension, number
2111 * of elements, and the quantization type index->quant_type. These are
2112 * global values for the whole index.
2113 *
2114 * Then, to serialize each node:
2115 *
2116 * call hnsw_serialize_node() with each node you find in the linked list
2117 * of nodes, starting at index->head (each node has a next pointer).
2118 * The function will return an hnswSerNode structure, you will need
2119 * to store the following on disk (for each node):
2120 *
2121 * - The sernode->vector data, that is sernode->vector_size bytes.
2122 * - The sernode->params array, that points to an array of uint64_t
2123 * integers. There are sernode->params_count total items. These
2124 * parameters contain everything there is to need about your node: how
2125 * many levels it has, its ID, the list of neighbors for each level (as node
2126 * IDs), and so forth.
2127 *
2128 * You need to to save your own node->value in some way as well, but it already
2129 * belongs to the user of the API, since, for this library, it's just a pointer,
2130 * so the user should know how to serialized its private data.
2131 *
2132 * RELOADING FROM DISK / NET
2133 * =========================
2134 *
2135 * When reloading nodes, you first load the index vector dimension and
2136 * quantization type, and create the index with:
2137 *
2138 * HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type);
2139 *
2140 * Then you load back, for each node (you stored how many nodes you had)
2141 * the vector and the params array / count.
2142 * You also load the value associated with your node.
2143 *
2144 * At this point you add back the loaded elements into the index with:
2145 *
2146 * hnsw_insert_serialized(HNSW *index, void *vector, uint64_t params,
2147 * uint32_t params_len, void *value);
2148 *
2149 * Once you added all the nodes back, you need to resolve the pointers
2150 * (since so far they are added just with the node IDs as reference), so
2151 * you call:
2152 *
2153 * hnsw_deserialize_index(index);
2154 *
2155 * The index is now ready to be used like if it has been always in memory.
2156 *
2157 * DESIGN NOTES
2158 * ============
2159 *
2160 * Why this API does not just give you a binary blob to save? Because in
2161 * many systems (and in Redis itself) to save integers / floats can have
2162 * more interesting encodings that just storing a 64 bit value. Many vector
2163 * indexes will be small, and their IDs will be small numbers, so the storage
2164 * system can exploit that and use less disk space, less network bandwidth
2165 * and so forth.
2166 *
2167 * How is the data stored in these arrays of numbers? Oh well, we have
2168 * things that are obviously numbers like node ID, number of levels for the
2169 * node and so forth. Also each of our nodes have an unique incremental ID,
2170 * so we can store a node set of links in terms of linked node IDs. This
2171 * data is put directly in the loaded node pointer space! We just cast the
2172 * integer to the pointer (so THIS IS NOT SAFE for 32 bit systems). Then
2173 * we want to translate such IDs into pointers. To do that, we build an
2174 * hash table, then scan all the nodes again and fix all the links converting
2175 * the ID to the pointer. */
2176
2177/* History of serialization versions:
2178 * version 0: the first implementation, lacking worst node id/info.
2179 * version 1: includes worst link id/info. */
2180#define HNSW_SERIALIZATION_VERSION 1
2181
2182/* This is a special worst link index that is set when loading a serialized
2183 * node with version 0 (this version of the serialization lacked explicit
2184 * information about the worst link index/distance). This way, later, the
2185 * function that fixes a deserialized index will know to compute the worst
2186 * index info at runtime. */
2187#define HNSW_SER_WORSTLINK_MISSING UINT32_MAX
2188
2189/* Return the serialized node information as specified in the top comment
2190 * above. Note that the returned information is true as long as the node
2191 * provided is not deleted or modified, so this function should be called
2192 * when there are no concurrent writes.
2193 *
2194 * The function hnsw_serialize_node() should be called in order to
2195 * free the result of this function. */
2196hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node) {
2197 /* The first step is calculating the number of uint64_t parameters
2198 * that we need in order to serialize the node. */
2199 uint32_t num_params = 0;
2200 num_params += 2; // node ID, number of layers.
2201 for (uint32_t i = 0; i <= node->level; i++) {
2202 num_params += 2; // max_links and num_links info for this layer.
2203 num_params += node->layers[i].num_links; // The IDs of linked nodes.
2204 num_params += 1; // worst link id/distance parameter.
2205 }
2206
2207 /* We use another 64bit value to store two floats that are about
2208 * the vector: l2 and quantization range (that is only used if the
2209 * vector is quantized). */
2210 num_params++;
2211
2212 /* Allocate the return object and the parameters array. */
2213 hnswSerNode *sn = hmalloc(sizeof(hnswSerNode));
2214 if (sn == NULL) return NULL;
2215 sn->params = hmalloc(sizeof(uint64_t)*num_params);
2216 if (sn->params == NULL) {
2217 hfree(sn);
2218 return NULL;
2219 }
2220
2221 /* Fill data. */
2222 sn->params_count = num_params;
2223 sn->vector = node->vector;
2224 sn->vector_size = hnsw_quants_bytes(index);
2225
2226 uint32_t param_idx = 0;
2227 sn->params[param_idx++] = node->id;
2228 /* The second parameter contains information about the serialization
2229 * version of this node, the node level and some unused field:
2230 *
2231 * +--------+--------+--------+--------+
2232 * |VVVVVVVV|........|........|LLLLLLLL|
2233 * +--------+--------+--------+--------+
2234 *
2235 * V is the version, 8 bits.
2236 * L is the node level, 8 bits (but actually 16 is the max so far).
2237 * The middle two bytes are reserved for future uses. */
2238 sn->params[param_idx] = node->level & 0xff;
2239 sn->params[param_idx] |= HNSW_SERIALIZATION_VERSION << 24;
2240 param_idx++;
2241 for (uint32_t i = 0; i <= node->level; i++) {
2242 sn->params[param_idx++] = node->layers[i].num_links;
2243 sn->params[param_idx++] = node->layers[i].max_links;
2244 for (uint32_t j = 0; j < node->layers[i].num_links; j++) {
2245 sn->params[param_idx++] = node->layers[i].links[j]->id;
2246 }
2247 /* Since version 1: pack and store worst_idx and worst_distance. */
2248 uint32_t worst_distance_bits;
2249 memcpy(&worst_distance_bits, &node->layers[i].worst_distance,
2250 sizeof(float));
2251 uint64_t wi =
2252 (((uint64_t)worst_distance_bits) << 32) | node->layers[i].worst_idx;
2253 sn->params[param_idx++] = wi;
2254 }
2255
2256 /* Store l2 and range as uint32_t, in a way that is endian-safe.
2257 * Note that in big endian archs both are reversed: integers and
2258 * also the bytes of floats, so they will match. */
2259 uint64_t l2_and_range;
2260 uint32_t l2_bits, range_bits;
2261 memcpy(&l2_bits,&node->l2,sizeof(float));
2262 memcpy(&range_bits,&node->quants_range,sizeof(float));
2263 l2_and_range = ((uint64_t)range_bits<<32) | l2_bits;
2264
2265 sn->params[param_idx++] = l2_and_range;
2266
2267 /* Better safe than sorry: */
2268 assert(param_idx == num_params);
2269 return sn;
2270}
2271
2272/* This is needed in order to free hnsw_serialize_node() returned
2273 * structure. */
2274void hnsw_free_serialized_node(hnswSerNode *sn) {
2275 hfree(sn->params);
2276 hfree(sn);
2277}
2278
2279/* Load a serialized node. See the top comment in this section of code
2280 * for the documentation about how to use this.
2281 *
2282 * The function returns NULL both on out of memory and if the remaining
2283 * parameters length does not match the number of links or other items
2284 * to load. */
2285hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value)
2286{
2287 if (params_len < 2) return NULL;
2288
2289 uint64_t id = params[0];
2290 /* Check the node serialization function for the specific layout
2291 * of param[1] fields. */
2292 uint32_t level = params[1] & 0xff; // Node level.
2293 uint32_t version = (params[1] & 0xff000000) >> 24; // Format version.
2294
2295 if (version > HNSW_SERIALIZATION_VERSION) return NULL;
2296 int has_worst_link_info = version > 0;
2297
2298 /* Keep track of maximum ID seen while loading. */
2299 if (id >= index->last_id) index->last_id = id;
2300
2301 /* Create node, passing vector data directly based on quantization type. */
2302 hnswNode *node;
2303 if (index->quant_type != HNSW_QUANT_NONE) {
2304 node = hnsw_node_new(index, id, NULL, vector, 0, level, 0);
2305 } else {
2306 node = hnsw_node_new(index, id, vector, NULL, 0, level, 0);
2307 }
2308 if (!node) return NULL;
2309
2310 /* Load params array into the node. */
2311 uint32_t param_idx = 2;
2312 for (uint32_t i = 0; i <= level; i++) {
2313 /* Sanity check. */
2314 if (param_idx + 2 + has_worst_link_info > params_len) {
2315 hnsw_node_free(node);
2316 return NULL;
2317 }
2318
2319 uint32_t num_links = params[param_idx++];
2320 uint32_t max_links = params[param_idx++];
2321
2322 /* Sanity check: links should be less than max links and
2323 * in general a reasonable amount. */
2324 if (num_links > max_links || max_links > HNSW_MAX_M*4) {
2325 hnsw_node_free(node);
2326 return NULL;
2327 }
2328
2329 /* If max_links is larger than current allocation, reallocate.
2330 * It could happen in select_neighbors() that we over-allocate the
2331 * node under very unlikely to happen conditions. */
2332 if (max_links > node->layers[i].max_links) {
2333 hnswNode **new_links = hrealloc(node->layers[i].links,
2334 sizeof(hnswNode*) * max_links);
2335 if (!new_links) {
2336 hnsw_node_free(node);
2337 return NULL;
2338 }
2339 node->layers[i].links = new_links;
2340 node->layers[i].max_links = max_links;
2341 }
2342 node->layers[i].num_links = num_links;
2343
2344 /* Sanity check. */
2345 if (param_idx + num_links + has_worst_link_info > params_len) {
2346 hnsw_node_free(node);
2347 return NULL;
2348 }
2349
2350 /* Fill links for this layer with the IDs. Note that this
2351 * is going to not work in 32 bit systems. Deleting / adding-back
2352 * nodes can produce IDs larger than 2^32-1 even if we can't never
2353 * fit more than 2^32 nodes in a 32 bit system. */
2354 for (uint32_t j = 0; j < num_links; j++)
2355 node->layers[i].links[j] = (hnswNode*)params[param_idx++];
2356
2357 if (has_worst_link_info) {
2358 uint64_t wi = params[param_idx++];
2359 uint32_t worst_idx = wi & 0xffffffff;
2360 uint32_t worst_distance_bits = wi >> 32;
2361 float worst_distance;
2362 memcpy(&worst_distance,&worst_distance_bits,sizeof(float));
2363 node->layers[i].worst_idx = worst_idx;
2364 node->layers[i].worst_distance = worst_distance;
2365
2366 // Sanity check the worst ID range.
2367 if (node->layers[i].num_links > 0 &&
2368 node->layers[i].worst_idx >= node->layers[i].num_links)
2369 {
2370 hnsw_node_free(node);
2371 return NULL;
2372 }
2373 } else {
2374 node->layers[i].worst_idx = HNSW_SER_WORSTLINK_MISSING;
2375 node->layers[i].worst_distance = 0;
2376 }
2377 }
2378
2379 /* Get l2 and quantization range. */
2380 if (param_idx >= params_len) {
2381 hnsw_node_free(node);
2382 return NULL;
2383 }
2384
2385 /* Load l2 and range packed into an uint64_t in an endian safe way. */
2386 uint64_t l2_and_range = params[param_idx];
2387 uint32_t l2_bits, range_bits;
2388 l2_bits = l2_and_range & 0xffffffff;
2389 range_bits = l2_and_range >> 32;
2390 memcpy(&node->l2, &l2_bits, sizeof(float));
2391 memcpy(&node->quants_range, &range_bits, sizeof(float));
2392
2393 node->value = value;
2394 hnsw_add_node(index, node);
2395
2396 /* Keep track of higher node level and set the entry point to the
2397 * greatest level node seen so far: thanks to this check we don't
2398 * need to remember what our entry point was during serialization. */
2399 if (index->enter_point == NULL || level > index->max_level) {
2400 index->max_level = level;
2401 index->enter_point = node;
2402 }
2403 return node;
2404}
2405
2406/* Integer hashing, used by hnsw_deserialize_index().
2407 * MurmurHash3's 64-bit finalizer function. */
2408uint64_t hnsw_hash_node_id(uint64_t id) {
2409 id ^= id >> 33;
2410 id *= 0xff51afd7ed558ccd;
2411 id ^= id >> 33;
2412 id *= 0xc4ceb9fe1a85ec53;
2413 id ^= id >> 33;
2414 return id;
2415}
2416
2417/* Helper for duplicated link detection in hnsw_deserialize_index(). */
2418static int qsort_compare_pointers(const void *aptr, const void *bptr) {
2419 uintptr_t a = *((uintptr_t*)aptr);
2420 uintptr_t b = *((uintptr_t*)bptr);
2421 if (a > b) return 1;
2422 if (a < b) return -1;
2423 return 0;
2424}
2425
2426/* Fix pointers of neighbors nodes: after loading the serialized nodes, the
2427 * neighbors links are just IDs (casted to pointers), instead of the actual
2428 * pointers. We need to resolve IDs into pointers.
2429 *
2430 * The two integers salt0 and salt1 are used to make the internal state
2431 * of the function unguessable to an external attacker, in order to protect
2432 * from corruptions. Show be two random numbers from /dev/urandom if possible
2433 * otherwise can be just 0,0 if the application is not security critical and
2434 * never processes untrusted inputs.
2435 *
2436 * Return 0 on error (out of memory or some ID that can't be resolved), 1 on
2437 * success. */
2438int hnsw_deserialize_index(HNSW *index, uint64_t salt0, uint64_t salt1) {
2439 /* We will use simple linear probing, so over-allocating is a good
2440 * idea: anyway this flat array of pointers will consume a fraction
2441 * of the memory of the loaded index. */
2442 uint64_t min_size = index->node_count*2;
2443 uint64_t table_size = 1;
2444 while(table_size < min_size) table_size <<= 1;
2445
2446 hnswNode **table = hmalloc(sizeof(hnswNode*) * table_size);
2447 if (table == NULL) return 0;
2448 memset(table,0,sizeof(hnswNode*) * table_size);
2449
2450 /* First pass: populate the ID -> pointer hash table. */
2451 hnswNode *node = index->head;
2452 while(node) {
2453 uint64_t bucket = hnsw_hash_node_id(node->id) & (table_size-1);
2454 for (uint64_t j = 0; j < table_size; j++) {
2455 if (table[bucket] == NULL) {
2456 table[bucket] = node;
2457 break;
2458 }
2459 bucket = (bucket+1) & (table_size-1);
2460 }
2461 node = node->next;
2462 }
2463
2464 /* Second pass: fix pointers of all the neighbors links.
2465 * As we scan and fix the links, we also compute the accumulator
2466 * register "reciprocal", that is used in order to guarantee that all
2467 * the links are reciprocal.
2468 *
2469 * This is how it works, we hash (using a strong hash function) the
2470 * following key for each link that we see from A to B (or vice versa):
2471 *
2472 * hash(salt || A || B || link-level)
2473 *
2474 * We always sort A and B, so the same link from A to B and from B to A
2475 * will hash the same. The we xor the result into the 128 bit accumulator.
2476 * If each link has its own backlink, the accumulator is guaranteed to
2477 * be zero at the end.
2478 *
2479 * Collisions are extremely unlikely to happen, and an external attacker
2480 * can't easily control the hash function output, since the salt is
2481 * unknown, and also there would be to control the pointers.
2482 *
2483 * This algorithm is O(1) for each node so it is basically free for
2484 * us, as we scan the list of nodes, and runs on constant and very
2485 * small memory. */
2486 uint64_t accumulator[2] = {0,0};
2487
2488 node = index->head; // Rewind.
2489 while(node) {
2490 uint64_t this_node_id = node->id;
2491 for (uint32_t i = 0; i <= node->level; i++) {
2492 // Check if there are duplicated links: those are
2493 // also corruptions of the on-disk serialization format.
2494 if (node->layers[i].num_links > 0) {
2495 qsort(node->layers[i].links, node->layers[i].num_links,
2496 sizeof(void*), qsort_compare_pointers);
2497 for (uint32_t j = 0; j < node->layers[i].num_links-1; j++) {
2498 if (node->layers[i].links[j] == node->layers[i].links[j+1])
2499 goto corrupted;
2500 }
2501 }
2502
2503 // Resolve pointers.
2504 for (uint32_t j = 0; j < node->layers[i].num_links; j++) {
2505 uint64_t linked_id = (uint64_t) node->layers[i].links[j];
2506
2507 // We can't link to our own node.
2508 if (linked_id == this_node_id) goto corrupted;
2509
2510 // Compute accumulator for reciprocal links check.
2511 uint64_t mixed_h1, mixed_h2;
2512 secure_pair_mixer_128(salt0, salt1, this_node_id, linked_id, (uint64_t)i, &mixed_h1, &mixed_h2);
2513
2514 accumulator[0] ^= mixed_h1;
2515 accumulator[1] ^= mixed_h2;
2516
2517 // Fix links.
2518 uint64_t bucket = hnsw_hash_node_id(linked_id) & (table_size-1);
2519 hnswNode *neighbor = NULL;
2520 for (uint64_t k = 0; k < table_size; k++) {
2521 if (table[bucket] && table[bucket]->id == linked_id) {
2522 neighbor = table[bucket];
2523 break;
2524 }
2525 bucket = (bucket+1) & (table_size-1);
2526 }
2527
2528 /* The neighbor must exist and also exist at the right
2529 * level. */
2530 if (neighbor == NULL || neighbor->level < i) {
2531 /* Unresolved link! Either a bug in this code
2532 * or broken serialization data. */
2533 goto corrupted;
2534 }
2535 node->layers[i].links[j] = neighbor;
2536 }
2537
2538 /* The worst link information was missing from older
2539 * serialization formats. Compute it on the fly if needed. */
2540 if (node->layers[i].worst_idx == HNSW_SER_WORSTLINK_MISSING) {
2541 hnsw_update_worst_neighbor(index,node,i);
2542 }
2543 }
2544 node = node->next;
2545 }
2546
2547 /* Check that links are reciprocal, otherwise fail. */
2548 if (accumulator[0] || accumulator[1]) goto corrupted;
2549
2550 /* Everything fine. Return success. */
2551 hfree(table);
2552 return 1;
2553
2554corrupted:
2555 /* Some corruption error detected. */
2556 hfree(table);
2557 return 0;
2558}
2559
2560/* ================================ Iterator ================================ */
2561
2562/* Get a cursor that can be used as argument of hnsw_cursor_next() to iterate
2563 * all the elements that remain there from the start to the end of the
2564 * iteration, excluding newly added elements.
2565 *
2566 * The function returns NULL on out of memory. */
2567hnswCursor *hnsw_cursor_init(HNSW *index) {
2568 if (pthread_rwlock_wrlock(&index->global_lock) != 0) return NULL;
2569 hnswCursor *cursor = hmalloc(sizeof(*cursor));
2570 if (cursor == NULL) {
2571 pthread_rwlock_unlock(&index->global_lock);
2572 return NULL;
2573 }
2574 cursor->index = index;
2575 cursor->next = index->cursors;
2576 cursor->current = index->head;
2577 index->cursors = cursor;
2578 pthread_rwlock_unlock(&index->global_lock);
2579 return cursor;
2580}
2581
2582/* Free the cursor. Can be called both at the end of the iteration, when
2583 * hnsw_cursor_next() returned NULL, or before. */
2584void hnsw_cursor_free(hnswCursor *cursor) {
2585 HNSW *index = cursor->index;
2586 if (pthread_rwlock_wrlock(&index->global_lock) != 0) {
2587 // No easy way to recover from that. We will leak memory.
2588 return;
2589 }
2590
2591 hnswCursor *x = index->cursors;
2592 hnswCursor *prev = NULL;
2593 while(x) {
2594 if (x == cursor) {
2595 if (prev)
2596 prev->next = cursor->next;
2597 else
2598 index->cursors = cursor->next;
2599 hfree(cursor);
2600 break;
2601 }
2602 prev = x;
2603 x = x->next;
2604 }
2605 pthread_rwlock_unlock(&index->global_lock);
2606}
2607
2608/* Acquire a lock to use the cursor. Returns 1 if the lock was acquired
2609 * with success, otherwise zero is returned. The returned element is
2610 * protected after calling hnsw_cursor_next() for all the time required to
2611 * access it, then hnsw_cursor_release_lock() should be called in order
2612 * to unlock the HNSW index. */
2613int hnsw_cursor_acquire_lock(hnswCursor *cursor) {
2614 return pthread_rwlock_rdlock(&cursor->index->global_lock) == 0;
2615}
2616
2617/* Release the cursor lock, see hnsw_cursor_acquire_lock() top comment
2618 * for more information. */
2619void hnsw_cursor_release_lock(hnswCursor *cursor) {
2620 pthread_rwlock_unlock(&cursor->index->global_lock);
2621}
2622
2623/* Return the next element of the HNSW. See hnsw_cursor_init() for
2624 * the guarantees of the function. */
2625hnswNode *hnsw_cursor_next(hnswCursor *cursor) {
2626 hnswNode *ret = cursor->current;
2627 if (ret) cursor->current = ret->next;
2628 return ret;
2629}
2630
2631/* Called by hnsw_unlink_node() if there is at least an active cursor.
2632 * Will scan the cursors to see if any cursor is going to yield this
2633 * one, and in this case, updates the current element to the next. */
2634void hnsw_cursor_element_deleted(HNSW *index, hnswNode *deleted) {
2635 hnswCursor *x = index->cursors;
2636 while(x) {
2637 if (x->current == deleted) x->current = deleted->next;
2638 x = x->next;
2639 }
2640}
2641
2642/* ============================ Debugging stuff ============================= */
2643
2644/* Show stats about nodes connections. */
2645void hnsw_print_stats(HNSW *index) {
2646 if (!index || !index->head) {
2647 printf("Empty index or NULL pointer passed\n");
2648 return;
2649 }
2650
2651 long long total_links = 0;
2652 int min_links = -1; // We'll set this to first node's count.
2653 int isolated_nodes = 0;
2654 uint32_t node_count = 0;
2655
2656 // Iterate through all nodes using the linked list.
2657 hnswNode *current = index->head;
2658 while (current) {
2659 // Count total links for this node across all layers.
2660 int node_total_links = 0;
2661 for (uint32_t layer = 0; layer <= current->level; layer++)
2662 node_total_links += current->layers[layer].num_links;
2663
2664 // Update statistics.
2665 total_links += node_total_links;
2666
2667 // Initialize or update minimum links.
2668 if (min_links == -1 || node_total_links < min_links) {
2669 min_links = node_total_links;
2670 }
2671
2672 // Check if node is isolated (no links at all).
2673 if (node_total_links == 0) isolated_nodes++;
2674
2675 node_count++;
2676 current = current->next;
2677 }
2678
2679 // Print statistics
2680 printf("HNSW Graph Statistics:\n");
2681 printf("----------------------\n");
2682 printf("Total nodes: %u\n", node_count);
2683 if (node_count > 0) {
2684 printf("Average links per node: %.2f\n",
2685 (float)total_links / node_count);
2686 printf("Minimum links in a single node: %d\n", min_links);
2687 printf("Number of isolated nodes: %d (%.1f%%)\n",
2688 isolated_nodes,
2689 (float)isolated_nodes * 100 / node_count);
2690 }
2691}
2692
2693/* Validate graph connectivity and link reciprocity. Takes pointers to store results:
2694 * - connected_nodes: will contain number of reachable nodes from entry point.
2695 * - reciprocal_links: will contain 1 if all links are reciprocal, 0 otherwise.
2696 * Returns 0 on success, -1 on error (NULL parameters and such).
2697 */
2698int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links) {
2699 if (!index || !connected_nodes || !reciprocal_links) return -1;
2700 if (!index->enter_point) {
2701 *connected_nodes = 0;
2702 *reciprocal_links = 1; // Empty graph is valid.
2703 return 0;
2704 }
2705
2706 // Initialize connectivity check.
2707 index->current_epoch[0]++;
2708 *connected_nodes = 0;
2709 *reciprocal_links = 1;
2710
2711 // Initialize node stack.
2712 uint64_t stack_size = index->node_count;
2713 hnswNode **stack = hmalloc(sizeof(hnswNode*) * stack_size);
2714 if (!stack) return -1;
2715 uint64_t stack_top = 0;
2716
2717 // Start from entry point.
2718 index->enter_point->visited_epoch[0] = index->current_epoch[0];
2719 (*connected_nodes)++;
2720 stack[stack_top++] = index->enter_point;
2721
2722 // Process all reachable nodes.
2723 while (stack_top > 0) {
2724 hnswNode *current = stack[--stack_top];
2725
2726 // Explore all neighbors at each level.
2727 for (uint32_t level = 0; level <= current->level; level++) {
2728 for (uint64_t i = 0; i < current->layers[level].num_links; i++) {
2729 hnswNode *neighbor = current->layers[level].links[i];
2730
2731 // Check reciprocity.
2732 int found_backlink = 0;
2733 for (uint64_t j = 0; j < neighbor->layers[level].num_links; j++) {
2734 if (neighbor->layers[level].links[j] == current) {
2735 found_backlink = 1;
2736 break;
2737 }
2738 }
2739 if (!found_backlink) {
2740 *reciprocal_links = 0;
2741 }
2742
2743 // If we haven't visited this neighbor yet.
2744 if (neighbor->visited_epoch[0] != index->current_epoch[0]) {
2745 neighbor->visited_epoch[0] = index->current_epoch[0];
2746 (*connected_nodes)++;
2747 if (stack_top < stack_size) {
2748 stack[stack_top++] = neighbor;
2749 } else {
2750 // This should never happen in a valid graph.
2751 hfree(stack);
2752 return -1;
2753 }
2754 }
2755 }
2756 }
2757 }
2758
2759 hfree(stack);
2760
2761 // Now scan for unreachable nodes and print debug info.
2762 printf("\nUnreachable nodes debug information:\n");
2763 printf("=====================================\n");
2764
2765 hnswNode *current = index->head;
2766 while (current) {
2767 if (current->visited_epoch[0] != index->current_epoch[0]) {
2768 printf("\nUnreachable node found:\n");
2769 printf("- Node pointer: %p\n", (void*)current);
2770 printf("- Node ID: %llu\n", (unsigned long long)current->id);
2771 printf("- Node level: %u\n", current->level);
2772
2773 // Print info about all its links at each level.
2774 for (uint32_t level = 0; level <= current->level; level++) {
2775 printf(" Level %u links (%u):\n", level,
2776 current->layers[level].num_links);
2777 for (uint64_t i = 0; i < current->layers[level].num_links; i++) {
2778 hnswNode *neighbor = current->layers[level].links[i];
2779 // Check reciprocity for this specific link
2780 int found_backlink = 0;
2781 for (uint64_t j = 0; j < neighbor->layers[level].num_links; j++) {
2782 if (neighbor->layers[level].links[j] == current) {
2783 found_backlink = 1;
2784 break;
2785 }
2786 }
2787 printf(" - Link %llu: pointer=%p, id=%llu, visited=%s,recpr=%s\n",
2788 (unsigned long long)i, (void*)neighbor,
2789 (unsigned long long)neighbor->id,
2790 neighbor->visited_epoch[0] == index->current_epoch[0] ?
2791 "yes" : "no",
2792 found_backlink ? "yes" : "no");
2793 }
2794 }
2795 }
2796 current = current->next;
2797 }
2798
2799 printf("Total connected nodes: %llu\n", (unsigned long long)*connected_nodes);
2800 printf("All links are bi-directiona? %s\n", (*reciprocal_links)?"yes":"no");
2801 return 0;
2802}
2803
2804/* Test graph recall ability by verifying each node can be found searching
2805 * for its own vector. This helps validate that the majority of nodes are
2806 * properly connected and easily reachable in the graph structure. Every
2807 * unreachable node is reported.
2808 *
2809 * Normally only a small percentage of nodes will be not reachable when
2810 * visited. This is expected and part of the statistical properties
2811 * of HNSW. This happens especially with entries that have an ambiguous
2812 * meaning in the represented space, and are across two or multiple clusters
2813 * of items.
2814 *
2815 * The function works by:
2816 * 1. Iterating through all nodes in the linked list
2817 * 2. Using each node's vector to perform a search with specified EF
2818 * 3. Verifying the node can find itself as nearest neighbor
2819 * 4. Collecting and reporting statistics about reachability
2820 *
2821 * This is just a debugging function that reports stuff in the standard
2822 * output, part of the implementation because this kind of functions
2823 * provide some visibility on what happens inside the HNSW.
2824 */
2825void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose) {
2826 // Stats
2827 uint32_t total_nodes = 0;
2828 uint32_t unreachable_nodes = 0;
2829 uint32_t perfectly_reachable = 0; // Node finds itself as first result
2830
2831 // For storing search results
2832 hnswNode **neighbors = hmalloc(sizeof(hnswNode*) * test_ef);
2833 float *distances = hmalloc(sizeof(float) * test_ef);
2834 float *test_vector = hmalloc(sizeof(float) * index->vector_dim);
2835 if (!neighbors || !distances || !test_vector) {
2836 hfree(neighbors);
2837 hfree(distances);
2838 hfree(test_vector);
2839 return;
2840 }
2841
2842 // Get a read slot for searching (even if it's highly unlikely that
2843 // this test will be run threaded...).
2844 int slot = hnsw_acquire_read_slot(index);
2845 if (slot < 0) {
2846 hfree(neighbors);
2847 hfree(distances);
2848 return;
2849 }
2850
2851 printf("\nTesting graph recall\n");
2852 printf("====================\n");
2853
2854 // Process one node at a time using the linked list
2855 hnswNode *current = index->head;
2856 while (current) {
2857 total_nodes++;
2858
2859 // If using quantization, we need to reconstruct the normalized vector
2860 if (index->quant_type == HNSW_QUANT_Q8) {
2861 int8_t *quants = current->vector;
2862 // Reconstruct normalized vector from quantized data
2863 for (uint32_t j = 0; j < index->vector_dim; j++) {
2864 test_vector[j] = (quants[j] * current->quants_range) / 127;
2865 }
2866 } else if (index->quant_type == HNSW_QUANT_NONE) {
2867 memcpy(test_vector,current->vector,sizeof(float)*index->vector_dim);
2868 } else {
2869 assert(0 && "Quantization type not supported.");
2870 }
2871
2872 // Search using the node's own vector with high ef
2873 int found = hnsw_search(index, test_vector, test_ef, neighbors,
2874 distances, slot, 1);
2875
2876 if (found == 0) continue; // Empty HNSW?
2877
2878 // Look for the node itself in the results
2879 int found_self = 0;
2880 int self_position = -1;
2881 for (int i = 0; i < found; i++) {
2882 if (neighbors[i] == current) {
2883 found_self = 1;
2884 self_position = i;
2885 break;
2886 }
2887 }
2888
2889 if (!found_self || self_position != 0) {
2890 unreachable_nodes++;
2891 if (verbose) {
2892 if (!found_self)
2893 printf("\nNode %s cannot find itself:\n", (char*)current->value);
2894 else
2895 printf("\nNode %s is not top result:\n", (char*)current->value);
2896 printf("- Node ID: %llu\n", (unsigned long long)current->id);
2897 printf("- Node level: %u\n", current->level);
2898 printf("- Found %d neighbors but self not among them\n", found);
2899 printf("- Closest neighbor distance: %f\n", distances[0]);
2900 printf("- Neighbors: ");
2901 for (uint32_t i = 0; i < current->layers[0].num_links; i++) {
2902 printf("%s ", (char*)current->layers[0].links[i]->value);
2903 }
2904 printf("\n");
2905 printf("\nFound instead: ");
2906 for (int j = 0; j < found && j < 10; j++) {
2907 printf("%s ", (char*)neighbors[j]->value);
2908 }
2909 printf("\n");
2910 }
2911 } else {
2912 perfectly_reachable++;
2913 }
2914 current = current->next;
2915 }
2916
2917 // Release read slot
2918 hnsw_release_read_slot(index, slot);
2919
2920 // Free resources
2921 hfree(neighbors);
2922 hfree(distances);
2923 hfree(test_vector);
2924
2925 // Print final statistics
2926 printf("Total nodes tested: %u\n", total_nodes);
2927 printf("Perfectly reachable nodes: %u (%.1f%%)\n",
2928 perfectly_reachable,
2929 total_nodes ? (float)perfectly_reachable * 100 / total_nodes : 0);
2930 printf("Unreachable/suboptimal nodes: %u (%.1f%%)\n",
2931 unreachable_nodes,
2932 total_nodes ? (float)unreachable_nodes * 100 / total_nodes : 0);
2933}
2934
2935/* Return exact K-NN items by performing a linear scan of all nodes.
2936 * This function has the same signature as hnsw_search_with_filter() but
2937 * instead of using the graph structure, it scans all nodes to find the
2938 * true nearest neighbors.
2939 *
2940 * Note that neighbors and distances arrays must have space for at least 'k' items.
2941 * norm_query should be set to 1 if the query vector is already normalized.
2942 *
2943 * If the filter_callback is passed, only elements passing the specified filter
2944 * are returned. The slot parameter is ignored but kept for API consistency. */
2945int hnsw_ground_truth_with_filter
2946 (HNSW *index, const float *query_vector, uint32_t k,
2947 hnswNode **neighbors, float *distances, uint32_t slot,
2948 int query_vector_is_normalized,
2949 int (*filter_callback)(void *value, void *privdata),
2950 void *filter_privdata)
2951{
2952 /* Note that we don't really use the slot here: it's a linear scan.
2953 * Yet we want the user to acquire the slot as this will hold the
2954 * global lock in read only mode. */
2955 (void) slot;
2956
2957 /* Take our query vector into a temporary node. */
2958 hnswNode query;
2959 if (hnsw_init_tmp_node(index, &query, query_vector_is_normalized, query_vector) == 0) return -1;
2960
2961 /* Accumulate best results into a priority queue. */
2962 pqueue *results = pq_new(k);
2963 if (!results) {
2964 hnsw_free_tmp_node(&query, query_vector);
2965 return -1;
2966 }
2967
2968 /* Scan all nodes linearly. */
2969 hnswNode *current = index->head;
2970 while (current) {
2971 /* Apply filter if needed. */
2972 if (filter_callback &&
2973 !filter_callback(current->value, filter_privdata))
2974 {
2975 current = current->next;
2976 continue;
2977 }
2978
2979 /* Calculate distance to query. */
2980 float dist = hnsw_distance(index, &query, current);
2981
2982 /* Add to results to pqueue. Will be accepted only if better than
2983 * the current worse or pqueue not full. */
2984 pq_push(results, current, dist);
2985 current = current->next;
2986 }
2987
2988 /* Copy results to output arrays. */
2989 uint32_t found = MIN(k, results->count);
2990 for (uint32_t i = 0; i < found; i++) {
2991 neighbors[i] = pq_get_node(results, i);
2992 if (distances) distances[i] = pq_get_distance(results, i);
2993 }
2994
2995 /* Clean up. */
2996 pq_free(results);
2997 hnsw_free_tmp_node(&query, query_vector);
2998 return found;
2999}
diff --git a/examples/redis-unstable/modules/vector-sets/hnsw.h b/examples/redis-unstable/modules/vector-sets/hnsw.h
new file mode 100644
index 0000000..8935521
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/hnsw.h
@@ -0,0 +1,189 @@
1/*
2 * HNSW (Hierarchical Navigable Small World) Implementation
3 * Based on the paper by Yu. A. Malkov, D. A. Yashunin
4 *
5 * Copyright (c) 2009-Present, Redis Ltd.
6 * All rights reserved.
7 *
8 * Licensed under your choice of (a) the Redis Source Available License 2.0
9 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
10 * GNU Affero General Public License v3 (AGPLv3).
11 * Originally authored by: Salvatore Sanfilippo.
12 */
13
14#ifndef HNSW_H
15#define HNSW_H
16
17#include <pthread.h>
18#include <stdatomic.h>
19
20#define HNSW_DEFAULT_M 16 /* Used when 0 is given at creation time. */
21#define HNSW_MIN_M 4 /* Probably even too low already. */
22#define HNSW_MAX_M 4096 /* Safeguard sanity limit. */
23#define HNSW_MAX_THREADS 32 /* Maximum number of concurrent threads */
24
25/* Quantization types you can enable at creation time in hnsw_new() */
26#define HNSW_QUANT_NONE 0 // No quantization.
27#define HNSW_QUANT_Q8 1 // Q8 quantization.
28#define HNSW_QUANT_BIN 2 // Binary quantization.
29
30/* Layer structure for HNSW nodes. Each node will have from one to a few
31 * of this depending on its level. */
32typedef struct {
33 struct hnswNode **links; /* Array of neighbors for this layer */
34 uint32_t num_links; /* Number of used links */
35 uint32_t max_links; /* Maximum links for this layer. We may
36 * reallocate the node in very particular
37 * conditions in order to allow linking of
38 * new inserted nodes, so this may change
39 * dynamically and be > M*2 for a small set of
40 * nodes. */
41 float worst_distance; /* Distance to the worst neighbor */
42 uint32_t worst_idx; /* Index of the worst neighbor */
43} hnswNodeLayer;
44
45/* Node structure for HNSW graph */
46typedef struct hnswNode {
47 uint32_t level; /* Node's maximum level */
48 uint64_t id; /* Unique identifier, may be useful in order to
49 * have a bitmap of visited notes to use as
50 * alternative to epoch / visited_epoch.
51 * Also used in serialization in order to retain
52 * links specifying IDs. */
53 void *vector; /* The vector, quantized or not. */
54 float quants_range; /* Quantization range for this vector:
55 * min/max values will be in the range
56 * -quants_range, +quants_range */
57 float l2; /* L2 before normalization. */
58
59 /* Last time (epoch) this node was visited. We need one per thread.
60 * This avoids having a different data structure where we track
61 * visited nodes, but costs memory per node. */
62 uint64_t visited_epoch[HNSW_MAX_THREADS];
63
64 void *value; /* Associated value */
65 struct hnswNode *prev, *next; /* Prev/Next node in the list starting at
66 * HNSW->head. */
67
68 /* Links (and links info) per each layer. Note that this is part
69 * of the node allocation to be more cache friendly: reliable 3% speedup
70 * on Apple silicon, and does not make anything more complex. */
71 hnswNodeLayer layers[];
72} hnswNode;
73
74struct HNSW;
75
76/* It is possible to navigate an HNSW with a cursor that guarantees
77 * visiting all the elements that remain in the HNSW from the start to the
78 * end of the process (but not the new ones, so that the process will
79 * eventually finish). Check hnsw_cursor_init(), hnsw_cursor_next() and
80 * hnsw_cursor_free(). */
81typedef struct hnswCursor {
82 struct HNSW *index; // Reference to the index of this cursor.
83 hnswNode *current; // Element to report when hnsw_cursor_next() is called.
84 struct hnswCursor *next; // Next cursor active.
85} hnswCursor;
86
87/* Main HNSW index structure */
88typedef struct HNSW {
89 hnswNode *enter_point; /* Entry point for the graph */
90 uint32_t M; /* M as in the paper: layer 0 has M*2 max
91 neighbors (M populated at insertion time)
92 while all the other layers have M neighbors. */
93 uint32_t max_level; /* Current maximum level in the graph */
94 uint32_t vector_dim; /* Dimensionality of stored vectors */
95 uint64_t node_count; /* Total number of nodes */
96 _Atomic uint64_t last_id; /* Last node ID used */
97 uint64_t current_epoch[HNSW_MAX_THREADS]; /* Current epoch for visit tracking */
98 hnswNode *head; /* Linked list of nodes. Last first */
99
100 /* We have two locks here:
101 * 1. A global_lock that is used to perform write operations blocking all
102 * the readers.
103 * 2. One mutex per epoch slot, in order for read operations to acquire
104 * a lock on a specific slot to use epochs tracking of visited nodes. */
105 pthread_rwlock_t global_lock; /* Global read-write lock */
106 pthread_mutex_t slot_locks[HNSW_MAX_THREADS]; /* Per-slot locks */
107
108 _Atomic uint32_t next_slot; /* Next thread slot to try */
109 _Atomic uint64_t version; /* Version for optimistic concurrency, this is
110 * incremented on deletions and entry point
111 * updates. */
112 uint32_t quant_type; /* Quantization used. HNSW_QUANT_... */
113 hnswCursor *cursors;
114} HNSW;
115
116/* Serialized node. This structure is used as return value of
117 * hnsw_serialize_node(). */
118typedef struct hnswSerNode {
119 void *vector;
120 uint32_t vector_size;
121 uint64_t *params;
122 uint32_t params_count;
123} hnswSerNode;
124
125/* Insert preparation context */
126typedef struct InsertContext InsertContext;
127
128/* Core HNSW functions */
129HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type, uint32_t m);
130void hnsw_free(HNSW *index,void(*free_value)(void*value));
131void hnsw_node_free(hnswNode *node);
132void hnsw_print_stats(HNSW *index);
133hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector,
134 float qrange, uint64_t id, void *value, int ef);
135int hnsw_search(HNSW *index, const float *query, uint32_t k,
136 hnswNode **neighbors, float *distances, uint32_t slot,
137 int query_vector_is_normalized);
138int hnsw_search_with_filter
139 (HNSW *index, const float *query_vector, uint32_t k,
140 hnswNode **neighbors, float *distances, uint32_t slot,
141 int query_vector_is_normalized,
142 int (*filter_callback)(void *value, void *privdata),
143 void *filter_privdata, uint32_t max_candidates);
144void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec);
145int hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value));
146hnswNode *hnsw_random_node(HNSW *index, int slot);
147
148/* Thread safety functions. */
149int hnsw_acquire_read_slot(HNSW *index);
150void hnsw_release_read_slot(HNSW *index, int slot);
151
152/* Optimistic insertion API. */
153InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, int ef);
154hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx, void *value);
155void hnsw_free_insert_context(InsertContext *ctx);
156
157/* Serialization. */
158hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node);
159void hnsw_free_serialized_node(hnswSerNode *sn);
160hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value);
161int hnsw_deserialize_index(HNSW *index, uint64_t salt0, uint64_t salt1);
162
163// Helper function in case the user wants to directly copy
164// the vector bytes.
165uint32_t hnsw_quants_bytes(HNSW *index);
166
167/* Cursors. */
168hnswCursor *hnsw_cursor_init(HNSW *index);
169void hnsw_cursor_free(hnswCursor *cursor);
170hnswNode *hnsw_cursor_next(hnswCursor *cursor);
171int hnsw_cursor_acquire_lock(hnswCursor *cursor);
172void hnsw_cursor_release_lock(hnswCursor *cursor);
173
174/* Allocator selection. */
175void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t),
176 void *(*realloc_ptr)(void*, size_t));
177
178/* Testing. */
179int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links);
180void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose);
181float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b);
182int hnsw_ground_truth_with_filter
183 (HNSW *index, const float *query_vector, uint32_t k,
184 hnswNode **neighbors, float *distances, uint32_t slot,
185 int query_vector_is_normalized,
186 int (*filter_callback)(void *value, void *privdata),
187 void *filter_privdata);
188
189#endif /* HNSW_H */
diff --git a/examples/redis-unstable/modules/vector-sets/mixer.h b/examples/redis-unstable/modules/vector-sets/mixer.h
new file mode 100644
index 0000000..d75e193
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/mixer.h
@@ -0,0 +1,106 @@
1/* Redis implementation for vector sets. The data structure itself
2 * is implemented in hnsw.c.
3 *
4 * Copyright (c) 2009-Present, Redis Ltd.
5 * All rights reserved.
6 *
7 * Licensed under your choice of (a) the Redis Source Available License 2.0
8 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
9 * GNU Affero General Public License v3 (AGPLv3).
10 * Originally authored by: Salvatore Sanfilippo.
11 *
12 * =============================================================================
13 *
14 * Mixing function for HNSW link integrity verification
15 * Designed to resist collision attacks when salts are unknown.
16 */
17
18#include <stdint.h>
19#include <string.h>
20
21static inline uint64_t ROTL64(uint64_t x, int r) {
22 return (x << r) | (x >> (64 - r));
23}
24
25// Use more rounds and stronger constants
26#define MIX_PRIME_1 0xFF51AFD7ED558CCDULL
27#define MIX_PRIME_2 0xC4CEB9FE1A85EC53ULL
28#define MIX_PRIME_3 0x9E3779B97F4A7C15ULL
29#define MIX_PRIME_4 0xBF58476D1CE4E5B9ULL
30#define MIX_PRIME_5 0x94D049BB133111EBULL
31#define MIX_PRIME_6 0x2B7E151628AED2A7ULL
32
33/* Mixer design goals:
34 * 1. Thorough mixing of the level parameter.
35 * 2. Enough rounds of mixing.
36 * 3. Cross-influence between h1 and h2.
37 * 4. Domain separation to prevent related-key attacks.
38 */
39void secure_pair_mixer_128(uint64_t salt0, uint64_t salt1,
40 uint64_t id1_in, uint64_t id2_in, uint64_t level,
41 uint64_t* out_h1, uint64_t* out_h2) {
42 // Order independence (A -> B links should hash as B -> A links).
43 uint64_t id_a = (id1_in < id2_in) ? id1_in : id2_in;
44 uint64_t id_b = (id1_in < id2_in) ? id2_in : id1_in;
45
46 // Domain separation: mix salts with a constant to prevent
47 // related-key attacks.
48 uint64_t h1 = salt0 ^ 0xDEADBEEFDEADBEEFULL;
49 uint64_t h2 = salt1 ^ 0xCAFEBABECAFEBABEULL;
50
51 // First, thoroughly mix the level into both accumulators
52 // This prevents predictable level values from being a weakness
53 uint64_t level_mix = level;
54 level_mix *= MIX_PRIME_5;
55 level_mix ^= level_mix >> 32;
56 level_mix *= MIX_PRIME_6;
57
58 h1 ^= level_mix;
59 h2 ^= ROTL64(level_mix, 31);
60
61 // Mix in id_a with strong diffusion.
62 h1 ^= id_a;
63 h1 *= MIX_PRIME_1;
64 h1 = ROTL64(h1, 23);
65 h1 *= MIX_PRIME_2;
66
67 // Mix in id_b.
68 h2 ^= id_b;
69 h2 *= MIX_PRIME_3;
70 h2 = ROTL64(h2, 29);
71 h2 *= MIX_PRIME_4;
72
73 // Three rounds of cross-mixing for better security.
74 for (int i = 0; i < 3; i++) {
75 // Cross-influence.
76 uint64_t tmp = h1;
77 h1 += h2;
78 h2 += tmp;
79
80 // Mix h1.
81 h1 ^= ROTL64(h1, 31);
82 h1 *= MIX_PRIME_1;
83 h1 ^= salt0;
84
85 // Mix h2.
86 h2 ^= ROTL64(h2, 37);
87 h2 *= MIX_PRIME_2;
88 h2 ^= salt1;
89 }
90
91 // Finalization with avalanche rounds.
92 h1 ^= h1 >> 33;
93 h1 *= MIX_PRIME_3;
94 h1 ^= h1 >> 29;
95 h1 *= MIX_PRIME_4;
96 h1 ^= h1 >> 32;
97
98 h2 ^= h2 >> 33;
99 h2 *= MIX_PRIME_5;
100 h2 ^= h2 >> 29;
101 h2 *= MIX_PRIME_6;
102 h2 ^= h2 >> 32;
103
104 *out_h1 = h1;
105 *out_h2 = h2;
106}
diff --git a/examples/redis-unstable/modules/vector-sets/test.py b/examples/redis-unstable/modules/vector-sets/test.py
new file mode 100755
index 0000000..8d56d58
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/test.py
@@ -0,0 +1,294 @@
1#!/usr/bin/env python3
2#
3# Vector set tests.
4# A Redis instance should be running in the default port.
5#
6# Copyright (c) 2009-Present, Redis Ltd.
7# All rights reserved.
8#
9# Licensed under your choice of (a) the Redis Source Available License 2.0
10# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
11# GNU Affero General Public License v3 (AGPLv3).
12#
13
14import redis
15import random
16import struct
17import math
18import time
19import sys
20import os
21import importlib
22import inspect
23import argparse
24from typing import List, Tuple, Optional
25from dataclasses import dataclass
26
27def colored(text: str, color: str) -> str:
28 colors = {
29 'red': '\033[91m',
30 'green': '\033[92m',
31 'yellow': '\033[93m',
32 'blue': '\033[94m',
33 'magenta': '\033[95m',
34 'cyan': '\033[96m',
35 }
36 reset = '\033[0m'
37 return f"{colors.get(color, '')}{text}{reset}"
38
39@dataclass
40class VectorData:
41 vectors: List[List[float]]
42 names: List[str]
43
44 def find_k_nearest(self, query_vector: List[float], k: int) -> List[Tuple[str, float]]:
45 """Find k-nearest neighbors using the same scoring as Redis VSIM WITHSCORES."""
46 similarities = []
47 query_norm = math.sqrt(sum(x*x for x in query_vector))
48 if query_norm == 0:
49 return []
50
51 for i, vec in enumerate(self.vectors):
52 vec_norm = math.sqrt(sum(x*x for x in vec))
53 if vec_norm == 0:
54 continue
55
56 dot_product = sum(a*b for a,b in zip(query_vector, vec))
57 cosine_sim = dot_product / (query_norm * vec_norm)
58 distance = 1.0 - cosine_sim
59 redis_similarity = 1.0 - (distance/2.0)
60 similarities.append((self.names[i], redis_similarity))
61
62 similarities.sort(key=lambda x: x[1], reverse=True)
63 return similarities[:k]
64
65def generate_random_vector(dim: int) -> List[float]:
66 """Generate a random normalized vector."""
67 vec = [random.gauss(0, 1) for _ in range(dim)]
68 norm = math.sqrt(sum(x*x for x in vec))
69 return [x/norm for x in vec]
70
71def fill_redis_with_vectors(r: redis.Redis, key: str, count: int, dim: int,
72 with_reduce: Optional[int] = None) -> VectorData:
73 """Fill Redis with random vectors and return a VectorData object for verification."""
74 vectors = []
75 names = []
76
77 r.delete(key)
78 for i in range(count):
79 vec = generate_random_vector(dim)
80 name = f"{key}:item:{i}"
81 vectors.append(vec)
82 names.append(name)
83
84 vec_bytes = struct.pack(f'{dim}f', *vec)
85 args = [key]
86 if with_reduce:
87 args.extend(['REDUCE', with_reduce])
88 args.extend(['FP32', vec_bytes, name])
89 r.execute_command('VADD', *args)
90
91 return VectorData(vectors=vectors, names=names)
92
93class TestCase:
94 def __init__(self, primary_port=6379, replica_port=6380):
95 self.error_msg = None
96 self.error_details = None
97 self.test_key = f"test:{self.__class__.__name__.lower()}"
98 # Primary Redis instance
99 self.redis = redis.Redis(port=primary_port,db=9)
100 self.redis3 = redis.Redis(port=primary_port,protocol=3,db=9)
101 # Replica Redis instance
102 self.replica = redis.Redis(port=replica_port,db=9)
103 # Replication status
104 self.replication_setup = False
105 # Ports
106 self.primary_port = primary_port
107 self.replica_port = replica_port
108
109 def setup(self):
110 self.redis.delete(self.test_key)
111
112 def teardown(self):
113 self.redis.delete(self.test_key)
114
115 def setup_replication(self) -> bool:
116 """
117 Setup replication between primary and replica Redis instances.
118 Returns True if replication is successfully established, False otherwise.
119 """
120 # Configure replica to replicate from primary
121 self.replica.execute_command('REPLICAOF', '127.0.0.1', self.primary_port)
122
123 # Wait for replication to be established
124 max_attempts = 50
125 for attempt in range(max_attempts):
126 # Check replication info
127 repl_info = self.replica.info('replication')
128
129 # Check if replication is established
130 if (repl_info.get('role') == 'slave' and
131 repl_info.get('master_host') == '127.0.0.1' and
132 repl_info.get('master_port') == self.primary_port and
133 repl_info.get('master_link_status') == 'up'):
134
135 self.replication_setup = True
136 return True
137
138 # Wait before next attempt
139 print(colored(".",'cyan'),end="",flush=True)
140 time.sleep(0.5)
141
142 # If we get here, replication wasn't established
143 self.error_msg = "Failed to establish replication between primary and replica"
144 return False
145
146 def test(self):
147 raise NotImplementedError("Subclasses must implement test method")
148
149 def run(self):
150 try:
151 self.setup()
152 self.test()
153 return True
154 except AssertionError as e:
155 self.error_msg = str(e)
156 import traceback
157 self.error_details = traceback.format_exc()
158 return False
159 except Exception as e:
160 self.error_msg = f"Unexpected error: {str(e)}"
161 import traceback
162 self.error_details = traceback.format_exc()
163 return False
164 finally:
165 self.teardown()
166
167 def getname(self):
168 """Each test class should override this to provide its name"""
169 return self.__class__.__name__
170
171 def estimated_runtime(self):
172 """"Each test class should override this if it takes a significant amount of time to run. Default is 100ms"""
173 return 0.1
174
175def find_test_classes(primary_port, replica_port):
176 test_classes = []
177 script_dir = os.path.dirname(os.path.abspath(__file__))
178 tests_dir = os.path.join(script_dir, 'tests')
179
180 if not os.path.exists(tests_dir):
181 return []
182
183 for file in os.listdir(tests_dir):
184 if file.endswith('.py'):
185 module_name = f"tests.{file[:-3]}"
186 try:
187 module = importlib.import_module(module_name)
188 for name, obj in inspect.getmembers(module):
189 if inspect.isclass(obj) and obj.__name__ != 'TestCase' and hasattr(obj, 'test'):
190 # Create test instance with specified ports
191 test_instance = obj(primary_port,replica_port)
192 test_classes.append(test_instance)
193 except Exception as e:
194 print(f"Error loading {file}: {e}")
195
196 return test_classes
197
198def check_redis_empty(r, instance_name):
199 """Check if Redis instance is empty"""
200 try:
201 dbsize = r.dbsize()
202 if dbsize > 0:
203 print(colored(f"ERROR: {instance_name} Redis instance DB 9 is not empty (dbsize: {dbsize}).", "red"))
204 print(colored("Make sure you're not using a production instance and that all data is safe to delete.", "red"))
205 sys.exit(1)
206 except redis.exceptions.ConnectionError:
207 print(colored(f"ERROR: Cannot connect to {instance_name} Redis instance.", "red"))
208 sys.exit(1)
209
210def check_replica_running(replica_port):
211 """Check if replica Redis instance is running"""
212 r = redis.Redis(port=replica_port)
213 try:
214 r.ping()
215 return True
216 except redis.exceptions.ConnectionError:
217 print(colored(f"WARNING: Replica Redis instance (port {replica_port}) is not running.", "yellow"))
218 print(colored("Replication tests will be skipped. Make sure to start the replica instance.", "yellow"))
219 return False
220
221def run_tests():
222 # Parse command line arguments
223 parser = argparse.ArgumentParser(description='Run Redis vector tests.')
224 parser.add_argument('--primary-port', type=int, default=6379, help='Primary Redis instance port (default: 6379)')
225 parser.add_argument('--replica-port', type=int, default=6380, help='Replica Redis instance port (default: 6380)')
226 args = parser.parse_args()
227
228 print("================================================")
229 print(f"Make sure to have Redis running on localhost")
230 print(f"Primary port: {args.primary_port}")
231 print(f"Replica port: {args.replica_port}")
232 print("with --enable-debug-command yes")
233 print("================================================\n")
234
235 # Check if Redis instances are empty
236 primary = redis.Redis(port=args.primary_port,db=9)
237 replica = redis.Redis(port=args.replica_port,db=9)
238
239 check_redis_empty(primary, "Primary")
240
241 # Check if replica is running
242 replica_running = check_replica_running(args.replica_port)
243 if replica_running:
244 check_redis_empty(replica, "Replica")
245
246 tests = find_test_classes(args.primary_port, args.replica_port)
247 if not tests:
248 print("No tests found!")
249 return
250
251 # Sort tests by estimated runtime
252 tests.sort(key=lambda t: t.estimated_runtime())
253
254 passed = 0
255 skipped = 0
256 total = len(tests)
257
258 for test in tests:
259 print(f"{test.getname()}: ", end="")
260 sys.stdout.flush()
261
262 if not replica_running and test.getname().lower().find("replication") != -1:
263 print(colored("SKIPPING","yellow"))
264 skipped += 1
265 continue
266
267 start_time = time.time()
268 success = test.run()
269 duration = time.time() - start_time
270
271 if success:
272 print(colored("OK", "green"), f"({duration:.2f}s)")
273 passed += 1
274 else:
275 print(colored("ERR", "red"), f"({duration:.2f}s)")
276 print(f"Error: {test.error_msg}")
277 if test.error_details:
278 print("\nTraceback:")
279 print(test.error_details)
280
281 print("\n" + "="*50)
282 print(f"\nTest Summary: {passed}/{total} tests passed")
283
284 if passed == total:
285 print(colored("ALL TESTS PASSED!", "green"))
286 else:
287 if total-skipped-passed > 0:
288 print(colored(f"{total-skipped-passed} TESTS FAILED!", "red"))
289 sys.exit(1)
290 if skipped > 0:
291 print(colored(f"{skipped} TESTS SKIPPED!", "yellow"))
292
293if __name__ == "__main__":
294 run_tests()
diff --git a/examples/redis-unstable/modules/vector-sets/tests/basic_commands.py b/examples/redis-unstable/modules/vector-sets/tests/basic_commands.py
new file mode 100644
index 0000000..8481a36
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/basic_commands.py
@@ -0,0 +1,21 @@
1from test import TestCase, generate_random_vector
2import struct
3
4class BasicCommands(TestCase):
5 def getname(self):
6 return "VADD, VDIM, VCARD basic usage"
7
8 def test(self):
9 # Test VADD
10 vec = generate_random_vector(4)
11 vec_bytes = struct.pack('4f', *vec)
12 result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1')
13 assert result == 1, "VADD should return 1 for first item"
14
15 # Test VDIM
16 dim = self.redis.execute_command('VDIM', self.test_key)
17 assert dim == 4, f"VDIM should return 4, got {dim}"
18
19 # Test VCARD
20 card = self.redis.execute_command('VCARD', self.test_key)
21 assert card == 1, f"VCARD should return 1, got {card}"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/basic_similarity.py b/examples/redis-unstable/modules/vector-sets/tests/basic_similarity.py
new file mode 100644
index 0000000..11c3c9b
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/basic_similarity.py
@@ -0,0 +1,35 @@
1from test import TestCase
2
3class BasicSimilarity(TestCase):
4 def getname(self):
5 return "VSIM reported distance makes sense with 4D vectors"
6
7 def test(self):
8 # Add two very similar vectors, one different
9 vec1 = [1, 0, 0, 0]
10 vec2 = [0.99, 0.01, 0, 0]
11 vec3 = [0.1, 1, -1, 0.5]
12
13 # Add vectors using VALUES format
14 self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
15 *[str(x) for x in vec1], f'{self.test_key}:item:1')
16 self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
17 *[str(x) for x in vec2], f'{self.test_key}:item:2')
18 self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
19 *[str(x) for x in vec3], f'{self.test_key}:item:3')
20
21 # Query similarity with vec1
22 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
23 *[str(x) for x in vec1], 'WITHSCORES')
24
25 # Convert results to dictionary
26 results_dict = {}
27 for i in range(0, len(result), 2):
28 key = result[i].decode()
29 score = float(result[i+1])
30 results_dict[key] = score
31
32 # Verify results
33 assert results_dict[f'{self.test_key}:item:1'] > 0.99, "Self-similarity should be very high"
34 assert results_dict[f'{self.test_key}:item:2'] > 0.99, "Similar vector should have high similarity"
35 assert results_dict[f'{self.test_key}:item:3'] < 0.8, "Not very similar vector should have low similarity"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py b/examples/redis-unstable/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py
new file mode 100644
index 0000000..f4b3a12
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py
@@ -0,0 +1,156 @@
1from test import TestCase, generate_random_vector
2import threading
3import time
4import struct
5
6class ThreadingStressTest(TestCase):
7 def getname(self):
8 return "Concurrent VADD/DEL/VSIM operations stress test"
9
10 def estimated_runtime(self):
11 return 10 # Test runs for 10 seconds
12
13 def test(self):
14 # Constants - easy to modify if needed
15 NUM_VADD_THREADS = 10
16 NUM_VSIM_THREADS = 1
17 NUM_DEL_THREADS = 1
18 TEST_DURATION = 10 # seconds
19 VECTOR_DIM = 100
20 DEL_INTERVAL = 1 # seconds
21
22 # Shared flags and state
23 stop_event = threading.Event()
24 error_list = []
25 error_lock = threading.Lock()
26
27 def log_error(thread_name, error):
28 with error_lock:
29 error_list.append(f"{thread_name}: {error}")
30
31 def vadd_worker(thread_id):
32 """Thread function to perform VADD operations"""
33 thread_name = f"VADD-{thread_id}"
34 try:
35 vector_count = 0
36 while not stop_event.is_set():
37 try:
38 # Generate random vector
39 vec = generate_random_vector(VECTOR_DIM)
40 vec_bytes = struct.pack(f'{VECTOR_DIM}f', *vec)
41
42 # Add vector with CAS option
43 self.redis.execute_command(
44 'VADD',
45 self.test_key,
46 'FP32',
47 vec_bytes,
48 f'{self.test_key}:item:{thread_id}:{vector_count}',
49 'CAS'
50 )
51
52 vector_count += 1
53
54 # Small sleep to reduce CPU pressure
55 if vector_count % 10 == 0:
56 time.sleep(0.001)
57 except Exception as e:
58 log_error(thread_name, f"Error: {str(e)}")
59 time.sleep(0.1) # Slight backoff on error
60 except Exception as e:
61 log_error(thread_name, f"Thread error: {str(e)}")
62
63 def del_worker():
64 """Thread function that deletes the key periodically"""
65 thread_name = "DEL"
66 try:
67 del_count = 0
68 while not stop_event.is_set():
69 try:
70 # Sleep first, then delete
71 time.sleep(DEL_INTERVAL)
72 if stop_event.is_set():
73 break
74
75 self.redis.delete(self.test_key)
76 del_count += 1
77 except Exception as e:
78 log_error(thread_name, f"Error: {str(e)}")
79 except Exception as e:
80 log_error(thread_name, f"Thread error: {str(e)}")
81
82 def vsim_worker(thread_id):
83 """Thread function to perform VSIM operations"""
84 thread_name = f"VSIM-{thread_id}"
85 try:
86 search_count = 0
87 while not stop_event.is_set():
88 try:
89 # Generate query vector
90 query_vec = generate_random_vector(VECTOR_DIM)
91 query_str = [str(x) for x in query_vec]
92
93 # Perform similarity search
94 args = ['VSIM', self.test_key, 'VALUES', VECTOR_DIM]
95 args.extend(query_str)
96 args.extend(['COUNT', 10])
97 self.redis.execute_command(*args)
98
99 search_count += 1
100
101 # Small sleep to reduce CPU pressure
102 if search_count % 10 == 0:
103 time.sleep(0.005)
104 except Exception as e:
105 # Don't log empty array errors, as they're expected when key doesn't exist
106 if "empty array" not in str(e).lower():
107 log_error(thread_name, f"Error: {str(e)}")
108 time.sleep(0.1) # Slight backoff on error
109 except Exception as e:
110 log_error(thread_name, f"Thread error: {str(e)}")
111
112 # Start all threads
113 threads = []
114
115 # VADD threads
116 for i in range(NUM_VADD_THREADS):
117 thread = threading.Thread(target=vadd_worker, args=(i,))
118 thread.start()
119 threads.append(thread)
120
121 # DEL threads
122 for _ in range(NUM_DEL_THREADS):
123 thread = threading.Thread(target=del_worker)
124 thread.start()
125 threads.append(thread)
126
127 # VSIM threads
128 for i in range(NUM_VSIM_THREADS):
129 thread = threading.Thread(target=vsim_worker, args=(i,))
130 thread.start()
131 threads.append(thread)
132
133 # Let the test run for the specified duration
134 time.sleep(TEST_DURATION)
135
136 # Signal all threads to stop
137 stop_event.set()
138
139 # Wait for threads to finish
140 for thread in threads:
141 thread.join(timeout=2.0)
142
143 # Check if Redis is still responsive
144 try:
145 ping_result = self.redis.ping()
146 assert ping_result, "Redis did not respond to PING after stress test"
147 except Exception as e:
148 assert False, f"Redis connection failed after stress test: {str(e)}"
149
150 # Report any errors for diagnosis, but don't fail the test unless PING fails
151 if error_list:
152 error_count = len(error_list)
153 print(f"\nEncountered {error_count} errors during stress test.")
154 print("First 5 errors:")
155 for error in error_list[:5]:
156 print(f"- {error}")
diff --git a/examples/redis-unstable/modules/vector-sets/tests/concurrent_vsim_and_del.py b/examples/redis-unstable/modules/vector-sets/tests/concurrent_vsim_and_del.py
new file mode 100644
index 0000000..9bbf011
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/concurrent_vsim_and_del.py
@@ -0,0 +1,48 @@
1from test import TestCase, fill_redis_with_vectors, generate_random_vector
2import threading, time
3
4class ConcurrentVSIMAndDEL(TestCase):
5 def getname(self):
6 return "Concurrent VSIM and DEL operations"
7
8 def estimated_runtime(self):
9 return 2
10
11 def test(self):
12 # Fill the key with 5000 random vectors
13 dim = 128
14 count = 5000
15 fill_redis_with_vectors(self.redis, self.test_key, count, dim)
16
17 # List to store results from threads
18 thread_results = []
19
20 def vsim_thread():
21 """Thread function to perform VSIM operations until the key is deleted"""
22 while True:
23 query_vec = generate_random_vector(dim)
24 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
25 *[str(x) for x in query_vec], 'COUNT', 10)
26 if not result:
27 # Empty array detected, key is deleted
28 thread_results.append(True)
29 break
30
31 # Start multiple threads to perform VSIM operations
32 threads = []
33 for _ in range(4): # Start 4 threads
34 t = threading.Thread(target=vsim_thread)
35 t.start()
36 threads.append(t)
37
38 # Delete the key while threads are still running
39 time.sleep(1)
40 self.redis.delete(self.test_key)
41
42 # Wait for all threads to finish (they will exit once they detect the key is deleted)
43 for t in threads:
44 t.join()
45
46 # Verify that all threads detected an empty array or error
47 assert len(thread_results) == len(threads), "Not all threads detected the key deletion"
48 assert all(thread_results), "Some threads did not detect an empty array or error after DEL"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/debug_digest.py b/examples/redis-unstable/modules/vector-sets/tests/debug_digest.py
new file mode 100644
index 0000000..78f06d8
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/debug_digest.py
@@ -0,0 +1,39 @@
1from test import TestCase, generate_random_vector
2import struct
3
4class DebugDigestTest(TestCase):
5 def getname(self):
6 return "[regression] DEBUG DIGEST-VALUE with attributes"
7
8 def test(self):
9 # Generate random vectors
10 vec1 = generate_random_vector(4)
11 vec2 = generate_random_vector(4)
12 vec_bytes1 = struct.pack('4f', *vec1)
13 vec_bytes2 = struct.pack('4f', *vec2)
14
15 # Add vectors to the key, one with attribute, one without
16 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes1, f'{self.test_key}:item:1')
17 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes2, f'{self.test_key}:item:2', 'SETATTR', '{"color":"red"}')
18
19 # Call DEBUG DIGEST-VALUE on the key
20 try:
21 digest1 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key)
22 assert digest1 is not None, "DEBUG DIGEST-VALUE should return a value"
23
24 # Change attribute and verify digest changes
25 self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', '{"color":"blue"}')
26
27 digest2 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key)
28 assert digest2 is not None, "DEBUG DIGEST-VALUE should return a value after attribute change"
29 assert digest1 != digest2, "Digest should change when an attribute is modified"
30
31 # Remove attribute and verify digest changes again
32 self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', '')
33
34 digest3 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key)
35 assert digest3 is not None, "DEBUG DIGEST-VALUE should return a value after attribute removal"
36 assert digest2 != digest3, "Digest should change when an attribute is removed"
37
38 except Exception as e:
39 raise AssertionError(f"DEBUG DIGEST-VALUE command failed: {str(e)}")
diff --git a/examples/redis-unstable/modules/vector-sets/tests/deletion.py b/examples/redis-unstable/modules/vector-sets/tests/deletion.py
new file mode 100644
index 0000000..cb91959
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/deletion.py
@@ -0,0 +1,173 @@
1from test import TestCase, fill_redis_with_vectors, generate_random_vector
2import random
3
4"""
5A note about this test:
6It was experimentally tried to modify hnsw.c in order to
7avoid calling hnsw_reconnect_nodes(). In this case, the test
8fails very often with EF set to 250, while it hardly
9fails at all with the same parameters if hnsw_reconnect_nodes()
10is called.
11
12Note that for the nature of the test (it is very strict) it can
13still fail from time to time, without this signaling any
14actual bug.
15"""
16
17class VREM(TestCase):
18 def getname(self):
19 return "Deletion and graph state after deletion"
20
21 def estimated_runtime(self):
22 return 2.0
23
24 def format_neighbors_with_scores(self, links_result, old_links=None, items_to_remove=None):
25 """Format neighbors with their similarity scores and status indicators"""
26 if not links_result:
27 return "No neighbors"
28
29 output = []
30 for level, neighbors in enumerate(links_result):
31 level_num = len(links_result) - level - 1
32 output.append(f"Level {level_num}:")
33
34 # Get neighbors and scores
35 neighbors_with_scores = []
36 for i in range(0, len(neighbors), 2):
37 neighbor = neighbors[i].decode() if isinstance(neighbors[i], bytes) else neighbors[i]
38 score = float(neighbors[i+1]) if i+1 < len(neighbors) else None
39 status = ""
40
41 # For old links, mark deleted ones
42 if items_to_remove and neighbor in items_to_remove:
43 status = " [lost]"
44 # For new links, mark newly added ones
45 elif old_links is not None:
46 # Check if this neighbor was in the old links at this level
47 was_present = False
48 if old_links and level < len(old_links):
49 old_neighbors = [n.decode() if isinstance(n, bytes) else n
50 for n in old_links[level]]
51 was_present = neighbor in old_neighbors
52 if not was_present:
53 status = " [gained]"
54
55 if score is not None:
56 neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor} ({score:.6f}){status}")
57 else:
58 neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor}{status}")
59
60 output.extend([" " + n for n in neighbors_with_scores])
61 return "\n".join(output)
62
63 def test(self):
64 # 1. Fill server with random elements
65 dim = 128
66 count = 5000
67 data = fill_redis_with_vectors(self.redis, self.test_key, count, dim)
68
69 # 2. Do VSIM to get 200 items
70 query_vec = generate_random_vector(dim)
71 results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
72 *[str(x) for x in query_vec],
73 'COUNT', 200, 'WITHSCORES')
74
75 # Convert results to list of (item, score) pairs, sorted by score
76 items = []
77 for i in range(0, len(results), 2):
78 item = results[i].decode()
79 score = float(results[i+1])
80 items.append((item, score))
81 items.sort(key=lambda x: x[1], reverse=True) # Sort by similarity
82
83 # Store the graph structure for all items before deletion
84 neighbors_before = {}
85 for item, _ in items:
86 links = self.redis.execute_command('VLINKS', self.test_key, item, 'WITHSCORES')
87 if links: # Some items might not have links
88 neighbors_before[item] = links
89
90 # 3. Remove 100 random items
91 items_to_remove = set(item for item, _ in random.sample(items, 100))
92 # Keep track of top 10 non-removed items
93 top_remaining = []
94 for item, score in items:
95 if item not in items_to_remove:
96 top_remaining.append((item, score))
97 if len(top_remaining) == 10:
98 break
99
100 # Remove the items
101 for item in items_to_remove:
102 result = self.redis.execute_command('VREM', self.test_key, item)
103 assert result == 1, f"VREM failed to remove {item}"
104
105 # 4. Do VSIM again with same vector
106 new_results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
107 *[str(x) for x in query_vec],
108 'COUNT', 200, 'WITHSCORES',
109 'EF', 500)
110
111 # Convert new results to dict of item -> score
112 new_scores = {}
113 for i in range(0, len(new_results), 2):
114 item = new_results[i].decode()
115 score = float(new_results[i+1])
116 new_scores[item] = score
117
118 failure = False
119 failed_item = None
120 failed_reason = None
121 # 5. Verify all top 10 non-removed items are still found with similar scores
122 for item, old_score in top_remaining:
123 if item not in new_scores:
124 failure = True
125 failed_item = item
126 failed_reason = "missing"
127 break
128 new_score = new_scores[item]
129 if abs(new_score - old_score) >= 0.01:
130 failure = True
131 failed_item = item
132 failed_reason = f"score changed: {old_score:.6f} -> {new_score:.6f}"
133 break
134
135 if failure:
136 print("\nTest failed!")
137 print(f"Problem with item: {failed_item} ({failed_reason})")
138
139 print("\nOriginal neighbors (with similarity scores):")
140 if failed_item in neighbors_before:
141 print(self.format_neighbors_with_scores(
142 neighbors_before[failed_item],
143 items_to_remove=items_to_remove))
144 else:
145 print("No neighbors found in original graph")
146
147 print("\nCurrent neighbors (with similarity scores):")
148 current_links = self.redis.execute_command('VLINKS', self.test_key,
149 failed_item, 'WITHSCORES')
150 if current_links:
151 print(self.format_neighbors_with_scores(
152 current_links,
153 old_links=neighbors_before.get(failed_item)))
154 else:
155 print("No neighbors in current graph")
156
157 print("\nOriginal results (top 20):")
158 for item, score in items[:20]:
159 deleted = "[deleted]" if item in items_to_remove else ""
160 print(f"{item}: {score:.6f} {deleted}")
161
162 print("\nNew results after removal (top 20):")
163 new_items = []
164 for i in range(0, len(new_results), 2):
165 item = new_results[i].decode()
166 score = float(new_results[i+1])
167 new_items.append((item, score))
168 new_items.sort(key=lambda x: x[1], reverse=True)
169 for item, score in new_items[:20]:
170 print(f"{item}: {score:.6f}")
171
172 raise AssertionError(f"Test failed: Problem with item {failed_item} ({failed_reason}). *** IMPORTANT *** This test may fail from time to time without indicating that there is a bug. However normally it should pass. The fact is that it's a quite extreme test where we destroy 50% of nodes of top results and still expect perfect recall, with vectors that are very hostile because of the distribution used.")
173
diff --git a/examples/redis-unstable/modules/vector-sets/tests/dimension_validation.py b/examples/redis-unstable/modules/vector-sets/tests/dimension_validation.py
new file mode 100644
index 0000000..f081152
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/dimension_validation.py
@@ -0,0 +1,67 @@
1from test import TestCase, generate_random_vector
2import struct
3import redis.exceptions
4
5class DimensionValidation(TestCase):
6 def getname(self):
7 return "[regression] Dimension Validation with Projection"
8
9 def estimated_runtime(self):
10 return 0.5
11
12 def test(self):
13 # Test scenario 1: Create a set with projection
14 original_dim = 100
15 reduced_dim = 50
16
17 # Create the initial vector and set with projection
18 vec1 = generate_random_vector(original_dim)
19 vec1_bytes = struct.pack(f'{original_dim}f', *vec1)
20
21 # Add first vector with projection
22 result = self.redis.execute_command('VADD', self.test_key,
23 'REDUCE', reduced_dim,
24 'FP32', vec1_bytes, f'{self.test_key}:item:1')
25 assert result == 1, "First VADD with REDUCE should return 1"
26
27 # Check VINFO returns the correct projection information
28 info = self.redis.execute_command('VINFO', self.test_key)
29 info_map = {k.decode('utf-8'): v for k, v in zip(info[::2], info[1::2])}
30 assert 'vector-dim' in info_map, "VINFO should contain vector-dim"
31 assert info_map['vector-dim'] == reduced_dim, f"Expected reduced dimension {reduced_dim}, got {info['vector-dim']}"
32 assert 'projection-input-dim' in info_map, "VINFO should contain projection-input-dim"
33 assert info_map['projection-input-dim'] == original_dim, f"Expected original dimension {original_dim}, got {info['projection-input-dim']}"
34
35 # Test scenario 2: Try adding a mismatched vector - should fail
36 wrong_dim = 80
37 wrong_vec = generate_random_vector(wrong_dim)
38 wrong_vec_bytes = struct.pack(f'{wrong_dim}f', *wrong_vec)
39
40 # This should fail with dimension mismatch error
41 try:
42 self.redis.execute_command('VADD', self.test_key,
43 'REDUCE', reduced_dim,
44 'FP32', wrong_vec_bytes, f'{self.test_key}:item:2')
45 assert False, "VADD with wrong dimension should fail"
46 except redis.exceptions.ResponseError as e:
47 assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error, got: {e}"
48
49 # Test scenario 3: Add a correctly-sized vector
50 vec2 = generate_random_vector(original_dim)
51 vec2_bytes = struct.pack(f'{original_dim}f', *vec2)
52
53 # This should succeed
54 result = self.redis.execute_command('VADD', self.test_key,
55 'REDUCE', reduced_dim,
56 'FP32', vec2_bytes, f'{self.test_key}:item:3')
57 assert result == 1, "VADD with correct dimensions should succeed"
58
59 # Check VSIM also validates input dimensions
60 wrong_query = generate_random_vector(wrong_dim)
61 try:
62 self.redis.execute_command('VSIM', self.test_key,
63 'VALUES', wrong_dim, *[str(x) for x in wrong_query],
64 'COUNT', 10)
65 assert False, "VSIM with wrong dimension should fail"
66 except redis.exceptions.ResponseError as e:
67 assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error in VSIM, got: {e}"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/epsilon.py b/examples/redis-unstable/modules/vector-sets/tests/epsilon.py
new file mode 100644
index 0000000..97e11c0
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/epsilon.py
@@ -0,0 +1,77 @@
1from test import TestCase
2
3class EpsilonOption(TestCase):
4 def getname(self):
5 return "VSIM EPSILON option filtering"
6
7 def estimated_runtime(self):
8 return 0.1
9
10 def test(self):
11 # Add vectors as shown in the example
12 # Vector 'a' at (1, 1) - normalized to (0.707, 0.707)
13 result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '1', '1', 'a')
14 assert result == 1, "VADD should return 1 for item 'a'"
15
16 # Vector 'b' at (0, 1) - normalized to (0, 1)
17 result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '0', '1', 'b')
18 assert result == 1, "VADD should return 1 for item 'b'"
19
20 # Vector 'c' at (0, 0) - this will be a zero vector, might be handled specially
21 result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '0', '0', 'c')
22 assert result == 1, "VADD should return 1 for item 'c'"
23
24 # Vector 'd' at (0, -1) - normalized to (0, -1)
25 result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '0', '-1', 'd')
26 assert result == 1, "VADD should return 1 for item 'd'"
27
28 # Vector 'e' at (-1, -1) - normalized to (-0.707, -0.707)
29 result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '-1', '-1', 'e')
30 assert result == 1, "VADD should return 1 for item 'e'"
31
32 # Test without EPSILON - should return all items
33 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES')
34 # Result is a flat list: [elem1, score1, elem2, score2, ...]
35 elements_all = [result[i].decode() for i in range(0, len(result), 2)]
36 scores_all = [float(result[i]) for i in range(1, len(result), 2)]
37
38 assert len(elements_all) == 5, f"Should return 5 elements without EPSILON, got {len(elements_all)}"
39 assert elements_all[0] == 'a', "First element should be 'a' (most similar)"
40 assert scores_all[0] == 1.0, "Score for 'a' should be 1.0 (identical)"
41
42 # Test with EPSILON 0.5 - should return only elements with similarity >= 0.5 (distance < 0.5)
43 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES', 'EPSILON', '0.5')
44 elements_epsilon_0_5 = [result[i].decode() for i in range(0, len(result), 2)]
45 scores_epsilon_0_5 = [float(result[i]) for i in range(1, len(result), 2)]
46
47 assert len(elements_epsilon_0_5) == 3, f"With EPSILON 0.5, should return 3 elements, got {len(elements_epsilon_0_5)}"
48 assert set(elements_epsilon_0_5) == {'a', 'b', 'c'}, f"With EPSILON 0.5, should get a, b, c, got {elements_epsilon_0_5}"
49
50 # Verify all returned scores are >= 0.5
51 for i, score in enumerate(scores_epsilon_0_5):
52 assert score >= 0.5, f"Element {elements_epsilon_0_5[i]} has score {score} which is < 0.5"
53
54 # Test with EPSILON 0.2 - should return only elements with similarity >= 0.8 (distance < 0.2)
55 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES', 'EPSILON', '0.2')
56 elements_epsilon_0_2 = [result[i].decode() for i in range(0, len(result), 2)]
57 scores_epsilon_0_2 = [float(result[i]) for i in range(1, len(result), 2)]
58
59 assert len(elements_epsilon_0_2) == 2, f"With EPSILON 0.2, should return 2 elements, got {len(elements_epsilon_0_2)}"
60 assert set(elements_epsilon_0_2) == {'a', 'b'}, f"With EPSILON 0.2, should get a, b, got {elements_epsilon_0_2}"
61
62 # Verify all returned scores are >= 0.8 (since distance < 0.2 means similarity > 0.8)
63 for i, score in enumerate(scores_epsilon_0_2):
64 assert score >= 0.8, f"Element {elements_epsilon_0_2[i]} has score {score} which is < 0.8"
65
66 # Test with very small EPSILON - should return only the exact match
67 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES', 'EPSILON', '0.001')
68 elements_epsilon_small = [result[i].decode() for i in range(0, len(result), 2)]
69
70 assert len(elements_epsilon_small) == 1, f"With EPSILON 0.001, should return only 1 element, got {len(elements_epsilon_small)}"
71 assert elements_epsilon_small[0] == 'a', "With very small EPSILON, should only get 'a'"
72
73 # Test with EPSILON 1.0 - should return all elements (since all similarities are between 0 and 1)
74 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES', 'EPSILON', '1.0')
75 elements_epsilon_1 = [result[i].decode() for i in range(0, len(result), 2)]
76
77 assert len(elements_epsilon_1) == 5, f"With EPSILON 1.0, should return all 5 elements, got {len(elements_epsilon_1)}"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/evict_empty.py b/examples/redis-unstable/modules/vector-sets/tests/evict_empty.py
new file mode 100644
index 0000000..6c78c82
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/evict_empty.py
@@ -0,0 +1,27 @@
1from test import TestCase, generate_random_vector
2import struct
3
4class VREM_LastItemDeletesKey(TestCase):
5 def getname(self):
6 return "VREM last item deletes key"
7
8 def test(self):
9 # Generate a random vector
10 vec = generate_random_vector(4)
11 vec_bytes = struct.pack('4f', *vec)
12
13 # Add the vector to the key
14 result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1')
15 assert result == 1, "VADD should return 1 for first item"
16
17 # Verify the key exists
18 exists = self.redis.exists(self.test_key)
19 assert exists == 1, "Key should exist after VADD"
20
21 # Remove the item
22 result = self.redis.execute_command('VREM', self.test_key, f'{self.test_key}:item:1')
23 assert result == 1, "VREM should return 1 for successful removal"
24
25 # Verify the key no longer exists
26 exists = self.redis.exists(self.test_key)
27 assert exists == 0, "Key should no longer exist after VREM of last item"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/filter_expr.py b/examples/redis-unstable/modules/vector-sets/tests/filter_expr.py
new file mode 100644
index 0000000..364915d
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/filter_expr.py
@@ -0,0 +1,242 @@
1from test import TestCase
2
3class VSIMFilterExpressions(TestCase):
4 def getname(self):
5 return "VSIM FILTER expressions basic functionality"
6
7 def test(self):
8 # Create a small set of vectors with different attributes
9
10 # Basic vectors for testing - all orthogonal for clear results
11 vec1 = [1, 0, 0, 0]
12 vec2 = [0, 1, 0, 0]
13 vec3 = [0, 0, 1, 0]
14 vec4 = [0, 0, 0, 1]
15 vec5 = [0.5, 0.5, 0, 0]
16
17 # Add vectors with various attributes
18 self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
19 *[str(x) for x in vec1], f'{self.test_key}:item:1')
20 self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:1',
21 '{"age": 25, "name": "Alice", "active": true, "scores": [85, 90, 95], "city": "New York"}')
22
23 self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
24 *[str(x) for x in vec2], f'{self.test_key}:item:2')
25 self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2',
26 '{"age": 30, "name": "Bob", "active": false, "scores": [70, 75, 80], "city": "Boston"}')
27
28 self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
29 *[str(x) for x in vec3], f'{self.test_key}:item:3')
30 self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:3',
31 '{"age": 35, "name": "Charlie", "scores": [60, 65, 70], "city": "Seattle"}')
32
33 self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
34 *[str(x) for x in vec4], f'{self.test_key}:item:4')
35 # Item 4 has no attribute at all
36
37 self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
38 *[str(x) for x in vec5], f'{self.test_key}:item:5')
39 self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:5',
40 'invalid json') # Intentionally malformed JSON
41
42 # Basic equality with numbers
43 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
44 *[str(x) for x in vec1],
45 'FILTER', '.age == 25')
46 assert len(result) == 1, "Expected 1 result for age == 25"
47 assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for age == 25"
48
49 # Greater than
50 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
51 *[str(x) for x in vec1],
52 'FILTER', '.age > 25')
53 assert len(result) == 2, "Expected 2 results for age > 25"
54
55 # Less than or equal
56 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
57 *[str(x) for x in vec1],
58 'FILTER', '.age <= 30')
59 assert len(result) == 2, "Expected 2 results for age <= 30"
60
61 # String equality
62 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
63 *[str(x) for x in vec1],
64 'FILTER', '.name == "Alice"')
65 assert len(result) == 1, "Expected 1 result for name == Alice"
66 assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for name == Alice"
67
68 # String inequality
69 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
70 *[str(x) for x in vec1],
71 'FILTER', '.name != "Alice"')
72 assert len(result) == 2, "Expected 2 results for name != Alice"
73
74 # Boolean value
75 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
76 *[str(x) for x in vec1],
77 'FILTER', '.active')
78 assert len(result) == 1, "Expected 1 result for .active being true"
79
80 # Logical AND
81 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
82 *[str(x) for x in vec1],
83 'FILTER', '.age > 20 and .age < 30')
84 assert len(result) == 1, "Expected 1 result for 20 < age < 30"
85 assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for 20 < age < 30"
86
87 # Logical OR
88 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
89 *[str(x) for x in vec1],
90 'FILTER', '.age < 30 or .age > 35')
91 assert len(result) == 1, "Expected 1 result for age < 30 or age > 35"
92
93 # Logical NOT
94 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
95 *[str(x) for x in vec1],
96 'FILTER', '!(.age == 25)')
97 assert len(result) == 2, "Expected 2 results for NOT(age == 25)"
98
99 # The "in" operator with array
100 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
101 *[str(x) for x in vec1],
102 'FILTER', '.age in [25, 35]')
103 assert len(result) == 2, "Expected 2 results for age in [25, 35]"
104
105 # The "in" operator with strings in array
106 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
107 *[str(x) for x in vec1],
108 'FILTER', '.name in ["Alice", "David"]')
109 assert len(result) == 1, "Expected 1 result for name in [Alice, David]"
110
111 # The "in" operator for substring matching
112 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
113 *[str(x) for x in vec1],
114 'FILTER', '"lic" in .name')
115 assert len(result) == 1, "Expected 1 result for 'lic' in name"
116 assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 (Alice)"
117
118 # The "in" operator with city substring
119 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
120 *[str(x) for x in vec1],
121 'FILTER', '"ork" in .city')
122 assert len(result) == 1, "Expected 1 result for 'ork' in city"
123 assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 (New York)"
124
125 # The "in" operator with no matches
126 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
127 *[str(x) for x in vec1],
128 'FILTER', '"xyz" in .name')
129 assert len(result) == 0, "Expected 0 results for 'xyz' in name"
130
131 # Off-by-one tests - substring at the beginning
132 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
133 *[str(x) for x in vec1],
134 'FILTER', '"Ali" in .name')
135 assert len(result) == 1, "Expected 1 result for 'Ali' at beginning of 'Alice'"
136 assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1"
137
138 # Off-by-one tests - substring at the end
139 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
140 *[str(x) for x in vec1],
141 'FILTER', '"ice" in .name')
142 assert len(result) == 1, "Expected 1 result for 'ice' at end of 'Alice'"
143 assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1"
144
145 # Off-by-one tests - exact match (entire string)
146 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
147 *[str(x) for x in vec1],
148 'FILTER', '"Alice" in .name')
149 assert len(result) == 1, "Expected 1 result for exact match 'Alice' in 'Alice'"
150 assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1"
151
152 # Off-by-one tests - single character
153 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
154 *[str(x) for x in vec1],
155 'FILTER', '"A" in .name')
156 assert len(result) == 1, "Expected 1 result for single char 'A' in 'Alice'"
157
158 # Off-by-one tests - empty string (should match all strings)
159 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
160 *[str(x) for x in vec1],
161 'FILTER', '"" in .name')
162 assert len(result) == 3, "Expected 3 results for empty string (matches all strings)"
163
164 # Off-by-one tests - non-empty strings are never substrings of ""
165 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
166 *[str(x) for x in vec1],
167 'FILTER', '.name in ""')
168 assert len(result) == 0, "Expected 0 results for empty string on the right of IN operator"
169
170 # Off-by-one tests - empty string match empty string.
171 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
172 *[str(x) for x in vec1],
173 'FILTER', '"" in .name && "" in ""')
174 assert len(result) == 3, "Expected empty string matching empty string"
175
176 # Arithmetic operations - addition
177 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
178 *[str(x) for x in vec1],
179 'FILTER', '.age + 10 > 40')
180 assert len(result) == 1, "Expected 1 result for age + 10 > 40"
181
182 # Arithmetic operations - multiplication
183 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
184 *[str(x) for x in vec1],
185 'FILTER', '.age * 2 > 60')
186 assert len(result) == 1, "Expected 1 result for age * 2 > 60"
187
188 # Arithmetic operations - division
189 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
190 *[str(x) for x in vec1],
191 'FILTER', '.age / 5 == 5')
192 assert len(result) == 1, "Expected 1 result for age / 5 == 5"
193
194 # Arithmetic operations - modulo
195 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
196 *[str(x) for x in vec1],
197 'FILTER', '.age % 2 == 0')
198 assert len(result) == 1, "Expected 1 result for age % 2 == 0"
199
200 # Power operator
201 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
202 *[str(x) for x in vec1],
203 'FILTER', '.age ** 2 > 900')
204 assert len(result) == 1, "Expected 1 result for age^2 > 900"
205
206 # Missing attribute (should exclude items missing that attribute)
207 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
208 *[str(x) for x in vec1],
209 'FILTER', '.missing_field == "value"')
210 assert len(result) == 0, "Expected 0 results for missing_field == value"
211
212 # No attribute set at all
213 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
214 *[str(x) for x in vec1],
215 'FILTER', '.any_field')
216 assert f'{self.test_key}:item:4' not in [item.decode() for item in result], "Item with no attribute should be excluded"
217
218 # Malformed JSON
219 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
220 *[str(x) for x in vec1],
221 'FILTER', '.any_field')
222 assert f'{self.test_key}:item:5' not in [item.decode() for item in result], "Item with malformed JSON should be excluded"
223
224 # Complex expression combining multiple operators
225 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
226 *[str(x) for x in vec1],
227 'FILTER', '(.age > 20 and .age < 40) and (.city == "Boston" or .city == "New York")')
228 assert len(result) == 2, "Expected 2 results for the complex expression"
229 expected_items = [f'{self.test_key}:item:1', f'{self.test_key}:item:2']
230 assert set([item.decode() for item in result]) == set(expected_items), "Expected item:1 and item:2 for the complex expression"
231
232 # Parentheses to control operator precedence
233 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
234 *[str(x) for x in vec1],
235 'FILTER', '.age > (20 + 10)')
236 assert len(result) == 1, "Expected 1 result for age > (20 + 10)"
237
238 # Array access (arrays evaluate to true)
239 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
240 *[str(x) for x in vec1],
241 'FILTER', '.scores')
242 assert len(result) == 3, "Expected 3 results for .scores (arrays evaluate to true)"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/filter_int.py b/examples/redis-unstable/modules/vector-sets/tests/filter_int.py
new file mode 100644
index 0000000..0fd1dc1
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/filter_int.py
@@ -0,0 +1,668 @@
1from test import TestCase, generate_random_vector
2import struct
3import random
4import math
5import json
6import time
7
8class VSIMFilterAdvanced(TestCase):
9 def getname(self):
10 return "VSIM FILTER comprehensive functionality testing"
11
12 def estimated_runtime(self):
13 return 15 # This test might take up to 15 seconds for the large dataset
14
15 def setup(self):
16 super().setup()
17 self.dim = 32 # Vector dimension
18 self.count = 5000 # Number of vectors for large tests
19 self.small_count = 50 # Number of vectors for small/quick tests
20
21 # Categories for attributes
22 self.categories = ["electronics", "furniture", "clothing", "books", "food"]
23 self.cities = ["New York", "London", "Tokyo", "Paris", "Berlin", "Sydney", "Toronto", "Singapore"]
24 self.price_ranges = [(10, 50), (50, 200), (200, 1000), (1000, 5000)]
25 self.years = list(range(2000, 2025))
26
27 def create_attributes(self, index):
28 """Create realistic attributes for a vector"""
29 category = random.choice(self.categories)
30 city = random.choice(self.cities)
31 min_price, max_price = random.choice(self.price_ranges)
32 price = round(random.uniform(min_price, max_price), 2)
33 year = random.choice(self.years)
34 in_stock = random.random() > 0.3 # 70% chance of being in stock
35 rating = round(random.uniform(1, 5), 1)
36 views = int(random.expovariate(1/1000)) # Exponential distribution for page views
37 tags = random.sample(["popular", "sale", "new", "limited", "exclusive", "clearance"],
38 k=random.randint(0, 3))
39
40 # Add some specific patterns for testing
41 # Every 10th item has a specific property combination for testing
42 is_premium = (index % 10 == 0)
43
44 # Create attributes dictionary
45 attrs = {
46 "id": index,
47 "category": category,
48 "location": city,
49 "price": price,
50 "year": year,
51 "in_stock": in_stock,
52 "rating": rating,
53 "views": views,
54 "tags": tags
55 }
56
57 if is_premium:
58 attrs["is_premium"] = True
59 attrs["special_features"] = ["premium", "warranty", "support"]
60
61 # Add sub-categories for more complex filters
62 if category == "electronics":
63 attrs["subcategory"] = random.choice(["phones", "computers", "cameras", "audio"])
64 elif category == "furniture":
65 attrs["subcategory"] = random.choice(["chairs", "tables", "sofas", "beds"])
66 elif category == "clothing":
67 attrs["subcategory"] = random.choice(["shirts", "pants", "dresses", "shoes"])
68
69 # Add some intentionally missing fields for testing
70 if random.random() > 0.9: # 10% chance of missing price
71 del attrs["price"]
72
73 # Some items have promotion field
74 if random.random() > 0.7: # 30% chance of having a promotion
75 attrs["promotion"] = random.choice(["discount", "bundle", "gift"])
76
77 # Create invalid JSON for a small percentage of vectors
78 if random.random() > 0.98: # 2% chance of having invalid JSON
79 return "{{invalid json}}"
80
81 return json.dumps(attrs)
82
83 def create_vectors_with_attributes(self, key, count):
84 """Create vectors and add attributes to them"""
85 vectors = []
86 names = []
87 attribute_map = {} # To store attributes for verification
88
89 # Create vectors
90 for i in range(count):
91 vec = generate_random_vector(self.dim)
92 vectors.append(vec)
93 name = f"{key}:item:{i}"
94 names.append(name)
95
96 # Add to Redis
97 vec_bytes = struct.pack(f'{self.dim}f', *vec)
98 self.redis.execute_command('VADD', key, 'FP32', vec_bytes, name)
99
100 # Create and add attributes
101 attrs = self.create_attributes(i)
102 self.redis.execute_command('VSETATTR', key, name, attrs)
103
104 # Store attributes for later verification
105 try:
106 attribute_map[name] = json.loads(attrs) if '{' in attrs else None
107 except json.JSONDecodeError:
108 attribute_map[name] = None
109
110 return vectors, names, attribute_map
111
112 def filter_linear_search(self, vectors, names, query_vector, filter_expr, attribute_map, k=10):
113 """Perform a linear search with filtering for verification"""
114 similarities = []
115 query_norm = math.sqrt(sum(x*x for x in query_vector))
116
117 if query_norm == 0:
118 return []
119
120 for i, vec in enumerate(vectors):
121 name = names[i]
122 attributes = attribute_map.get(name)
123
124 # Skip if doesn't match filter
125 if not self.matches_filter(attributes, filter_expr):
126 continue
127
128 vec_norm = math.sqrt(sum(x*x for x in vec))
129 if vec_norm == 0:
130 continue
131
132 dot_product = sum(a*b for a,b in zip(query_vector, vec))
133 cosine_sim = dot_product / (query_norm * vec_norm)
134 distance = 1.0 - cosine_sim
135 redis_similarity = 1.0 - (distance/2.0)
136 similarities.append((name, redis_similarity))
137
138 similarities.sort(key=lambda x: x[1], reverse=True)
139 return similarities[:k]
140
141 def matches_filter(self, attributes, filter_expr):
142 """Filter matching for verification - uses Python eval to handle complex expressions"""
143 if attributes is None:
144 return False # No attributes or invalid JSON
145
146 # Replace JSON path selectors with Python dictionary access
147 py_expr = filter_expr
148
149 # Handle `.field` notation (replace with attributes['field'])
150 i = 0
151 while i < len(py_expr):
152 if py_expr[i] == '.' and (i == 0 or not py_expr[i-1].isalnum()):
153 # Find the end of the selector (stops at operators or whitespace)
154 j = i + 1
155 while j < len(py_expr) and (py_expr[j].isalnum() or py_expr[j] == '_'):
156 j += 1
157
158 if j > i + 1: # Found a valid selector
159 field = py_expr[i+1:j]
160 # Use a safe access pattern that returns a default value based on context
161 py_expr = py_expr[:i] + f"attributes.get('{field}')" + py_expr[j:]
162 i = i + len(f"attributes.get('{field}')")
163 else:
164 i += 1
165 else:
166 i += 1
167
168 # Convert not operator if needed
169 py_expr = py_expr.replace('!', ' not ')
170
171 try:
172 # Custom evaluation that handles exceptions for missing fields
173 # by returning False for the entire expression
174
175 # Split the expression on logical operators
176 parts = []
177 for op in [' and ', ' or ']:
178 if op in py_expr:
179 parts = py_expr.split(op)
180 break
181
182 if not parts: # No logical operators found
183 parts = [py_expr]
184
185 # Try to evaluate each part - if any part fails,
186 # the whole expression should fail
187 try:
188 result = eval(py_expr, {"attributes": attributes})
189 return bool(result)
190 except (TypeError, AttributeError):
191 # This typically happens when trying to compare None with
192 # numbers or other types, or when an attribute doesn't exist
193 return False
194 except Exception as e:
195 print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}")
196 return False
197
198 except Exception as e:
199 print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}")
200 return False
201
202 def safe_decode(self,item):
203 return item.decode() if isinstance(item, bytes) else item
204
205 def calculate_recall(self, redis_results, linear_results, k=10):
206 """Calculate recall (percentage of correct results retrieved)"""
207 redis_set = set(self.safe_decode(item) for item in redis_results)
208 linear_set = set(item[0] for item in linear_results[:k])
209
210 if not linear_set:
211 return 1.0 # If no linear results, consider it perfect recall
212
213 intersection = redis_set.intersection(linear_set)
214 return len(intersection) / len(linear_set)
215
216 def test_recall_with_filter(self, filter_expr, ef=500, filter_ef=None):
217 """Test recall for a given filter expression"""
218 # Create query vector
219 query_vec = generate_random_vector(self.dim)
220
221 # First, get ground truth using linear scan
222 linear_results = self.filter_linear_search(
223 self.vectors, self.names, query_vec, filter_expr, self.attribute_map, k=50)
224
225 # Calculate true selectivity from ground truth
226 true_selectivity = len(linear_results) / len(self.names) if self.names else 0
227
228 # Perform Redis search with filter
229 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
230 cmd_args.extend([str(x) for x in query_vec])
231 cmd_args.extend(['COUNT', 50, 'WITHSCORES', 'EF', ef, 'FILTER', filter_expr])
232 if filter_ef:
233 cmd_args.extend(['FILTER-EF', filter_ef])
234
235 start_time = time.time()
236 redis_results = self.redis.execute_command(*cmd_args)
237 query_time = time.time() - start_time
238
239 # Convert Redis results to dict
240 redis_items = {}
241 for i in range(0, len(redis_results), 2):
242 key = redis_results[i].decode() if isinstance(redis_results[i], bytes) else redis_results[i]
243 score = float(redis_results[i+1])
244 redis_items[key] = score
245
246 # Calculate metrics
247 recall = self.calculate_recall(redis_items.keys(), linear_results)
248 selectivity = len(redis_items) / len(self.names) if redis_items else 0
249
250 # Compare against the true selectivity from linear scan
251 assert abs(selectivity - true_selectivity) < 0.1, \
252 f"Redis selectivity {selectivity:.3f} differs significantly from ground truth {true_selectivity:.3f}"
253
254 # We expect high recall for standard parameters
255 if ef >= 500 and (filter_ef is None or filter_ef >= 1000):
256 try:
257 assert recall >= 0.7, \
258 f"Low recall {recall:.2f} for filter '{filter_expr}'"
259 except AssertionError as e:
260 # Get items found in each set
261 redis_items_set = set(redis_items.keys())
262 linear_items_set = set(item[0] for item in linear_results)
263
264 # Find items in each set
265 only_in_redis = redis_items_set - linear_items_set
266 only_in_linear = linear_items_set - redis_items_set
267 in_both = redis_items_set & linear_items_set
268
269 # Build comprehensive debug message
270 debug = f"\nGround Truth: {len(linear_results)} matching items (total vectors: {len(self.vectors)})"
271 debug += f"\nRedis Found: {len(redis_items)} items with FILTER-EF: {filter_ef or 'default'}"
272 debug += f"\nItems in both sets: {len(in_both)} (recall: {recall:.4f})"
273 debug += f"\nItems only in Redis: {len(only_in_redis)}"
274 debug += f"\nItems only in Ground Truth: {len(only_in_linear)}"
275
276 # Show some example items from each set with their scores
277 if only_in_redis:
278 debug += "\n\nTOP 5 ITEMS ONLY IN REDIS:"
279 sorted_redis = sorted([(k, v) for k, v in redis_items.items()], key=lambda x: x[1], reverse=True)
280 for i, (item, score) in enumerate(sorted_redis[:5]):
281 if item in only_in_redis:
282 debug += f"\n {i+1}. {item} (Score: {score:.4f})"
283
284 # Show attribute that should match filter
285 attr = self.attribute_map.get(item)
286 if attr:
287 debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}"
288
289 if only_in_linear:
290 debug += "\n\nTOP 5 ITEMS ONLY IN GROUND TRUTH:"
291 for i, (item, score) in enumerate(linear_results[:5]):
292 if item in only_in_linear:
293 debug += f"\n {i+1}. {item} (Score: {score:.4f})"
294
295 # Show attribute that should match filter
296 attr = self.attribute_map.get(item)
297 if attr:
298 debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}"
299
300 # Help identify parsing issues
301 debug += "\n\nPARSING CHECK:"
302 debug += f"\nRedis command: VSIM {self.test_key} VALUES {self.dim} [...] FILTER '{filter_expr}'"
303
304 # Check for WITHSCORES handling issues
305 if len(redis_results) > 0 and len(redis_results) % 2 == 0:
306 debug += f"\nRedis returned {len(redis_results)} items (looks like item,score pairs)"
307 debug += f"\nFirst few results: {redis_results[:4]}"
308
309 # Check the filter implementation
310 debug += "\n\nFILTER IMPLEMENTATION CHECK:"
311 debug += f"\nFilter expression: '{filter_expr}'"
312 debug += "\nSample attribute matches from attribute_map:"
313 count_matching = 0
314 for i, (name, attrs) in enumerate(self.attribute_map.items()):
315 if attrs and self.matches_filter(attrs, filter_expr):
316 count_matching += 1
317 if i < 3: # Show first 3 matches
318 debug += f"\n - {name}: {attrs}"
319 debug += f"\nTotal items matching filter in attribute_map: {count_matching}"
320
321 # Check if results array handling could be wrong
322 debug += "\n\nRESULT ARRAYS CHECK:"
323 if len(linear_results) >= 1:
324 debug += f"\nlinear_results[0]: {linear_results[0]}"
325 if isinstance(linear_results[0], tuple) and len(linear_results[0]) == 2:
326 debug += " (correct tuple format: (name, score))"
327 else:
328 debug += " (UNEXPECTED FORMAT!)"
329
330 # Debug sort order
331 debug += "\n\nSORTING CHECK:"
332 if len(linear_results) >= 2:
333 debug += f"\nGround truth first item score: {linear_results[0][1]}"
334 debug += f"\nGround truth second item score: {linear_results[1][1]}"
335 debug += f"\nCorrectly sorted by similarity? {linear_results[0][1] >= linear_results[1][1]}"
336
337 # Re-raise with detailed information
338 raise AssertionError(str(e) + debug)
339
340 return recall, selectivity, query_time, len(redis_items)
341
342 def test(self):
343 print(f"\nRunning comprehensive VSIM FILTER tests...")
344
345 # Create a larger dataset for testing
346 print(f"Creating dataset with {self.count} vectors and attributes...")
347 self.vectors, self.names, self.attribute_map = self.create_vectors_with_attributes(
348 self.test_key, self.count)
349
350 # ==== 1. Recall and Precision Testing ====
351 print("Testing recall for various filters...")
352
353 # Test basic filters with different selectivity
354 results = {}
355 results["category"] = self.test_recall_with_filter('.category == "electronics"')
356 results["price_high"] = self.test_recall_with_filter('.price > 1000')
357 results["in_stock"] = self.test_recall_with_filter('.in_stock')
358 results["rating"] = self.test_recall_with_filter('.rating >= 4')
359 results["complex1"] = self.test_recall_with_filter('.category == "electronics" and .price < 500')
360
361 print("Filter | Recall | Selectivity | Time (ms) | Results")
362 print("----------------------------------------------------")
363 for name, (recall, selectivity, time_ms, count) in results.items():
364 print(f"{name:7} | {recall:.3f} | {selectivity:.3f} | {time_ms*1000:.1f} | {count}")
365
366 # ==== 2. Filter Selectivity Performance ====
367 print("\nTesting filter selectivity performance...")
368
369 # High selectivity (very few matches)
370 high_sel_recall, _, high_sel_time, _ = self.test_recall_with_filter('.is_premium')
371
372 # Medium selectivity
373 med_sel_recall, _, med_sel_time, _ = self.test_recall_with_filter('.price > 100 and .price < 1000')
374
375 # Low selectivity (many matches)
376 low_sel_recall, _, low_sel_time, _ = self.test_recall_with_filter('.year > 2000')
377
378 print(f"High selectivity recall: {high_sel_recall:.3f}, time: {high_sel_time*1000:.1f}ms")
379 print(f"Med selectivity recall: {med_sel_recall:.3f}, time: {med_sel_time*1000:.1f}ms")
380 print(f"Low selectivity recall: {low_sel_recall:.3f}, time: {low_sel_time*1000:.1f}ms")
381
382 # ==== 3. FILTER-EF Parameter Testing ====
383 print("\nTesting FILTER-EF parameter...")
384
385 # Test with different FILTER-EF values
386 filter_expr = '.category == "electronics" and .price > 200'
387 ef_values = [100, 500, 2000, 5000]
388
389 print("FILTER-EF | Recall | Time (ms)")
390 print("-----------------------------")
391 for filter_ef in ef_values:
392 recall, _, query_time, _ = self.test_recall_with_filter(
393 filter_expr, ef=500, filter_ef=filter_ef)
394 print(f"{filter_ef:9} | {recall:.3f} | {query_time*1000:.1f}")
395
396 # Assert that higher FILTER-EF generally gives better recall
397 low_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=100)
398 high_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=5000)
399
400 # This might not always be true due to randomness, but generally holds
401 # We use a softer assertion to avoid flaky tests
402 assert high_ef_recall >= low_ef_recall * 0.8, \
403 f"Higher FILTER-EF should generally give better recall: {high_ef_recall:.3f} vs {low_ef_recall:.3f}"
404
405 # ==== 4. Complex Filter Expressions ====
406 print("\nTesting complex filter expressions...")
407
408 # Test a variety of complex expressions
409 complex_filters = [
410 '.price > 100 and (.category == "electronics" or .category == "furniture")',
411 '(.rating > 4 and .in_stock) or (.price < 50 and .views > 1000)',
412 '.category in ["electronics", "clothing"] and .price > 200 and .rating >= 3',
413 '(.category == "electronics" and .subcategory == "phones") or (.category == "furniture" and .price > 1000)',
414 '.year > 2010 and !(.price < 100) and .in_stock'
415 ]
416
417 print("Expression | Results | Time (ms)")
418 print("-----------------------------")
419 for i, expr in enumerate(complex_filters):
420 try:
421 _, _, query_time, result_count = self.test_recall_with_filter(expr)
422 print(f"Complex {i+1} | {result_count:7} | {query_time*1000:.1f}")
423 except Exception as e:
424 print(f"Complex {i+1} | Error: {str(e)}")
425
426 # ==== 5. Attribute Type Testing ====
427 print("\nTesting different attribute types...")
428
429 type_filters = [
430 ('.price > 500', "Numeric"),
431 ('.category == "books"', "String equality"),
432 ('.in_stock', "Boolean"),
433 ('.tags in ["sale", "new"]', "Array membership"),
434 ('.rating * 2 > 8', "Arithmetic")
435 ]
436
437 for expr, type_name in type_filters:
438 try:
439 _, _, query_time, result_count = self.test_recall_with_filter(expr)
440 print(f"{type_name:16} | {expr:30} | {result_count:5} results | {query_time*1000:.1f}ms")
441 except Exception as e:
442 print(f"{type_name:16} | {expr:30} | Error: {str(e)}")
443
444 # ==== 6. Filter + Count Interaction ====
445 print("\nTesting COUNT parameter with filters...")
446
447 filter_expr = '.category == "electronics"'
448 counts = [5, 20, 100]
449
450 for count in counts:
451 query_vec = generate_random_vector(self.dim)
452 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
453 cmd_args.extend([str(x) for x in query_vec])
454 cmd_args.extend(['COUNT', count, 'WITHSCORES', 'FILTER', filter_expr])
455
456 results = self.redis.execute_command(*cmd_args)
457 result_count = len(results) // 2 # Divide by 2 because WITHSCORES returns pairs
458
459 # We expect result count to be at most the requested count
460 assert result_count <= count, f"Got {result_count} results with COUNT {count}"
461 print(f"COUNT {count:3} | Got {result_count:3} results")
462
463 # ==== 7. Edge Cases ====
464 print("\nTesting edge cases...")
465
466 # Test with no matching items
467 no_match_expr = '.category == "nonexistent_category"'
468 results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim,
469 *[str(x) for x in generate_random_vector(self.dim)],
470 'FILTER', no_match_expr)
471 assert len(results) == 0, f"Expected 0 results for non-matching filter, got {len(results)}"
472 print(f"No matching items: {len(results)} results (expected 0)")
473
474 # Test with invalid filter syntax
475 try:
476 self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim,
477 *[str(x) for x in generate_random_vector(self.dim)],
478 'FILTER', '.category === "books"') # Triple equals is invalid
479 assert False, "Expected error for invalid filter syntax"
480 except:
481 print("Invalid filter syntax correctly raised an error")
482
483 # Test with extremely long complex expression
484 long_expr = ' and '.join([f'.rating > {i/10}' for i in range(10)])
485 try:
486 results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim,
487 *[str(x) for x in generate_random_vector(self.dim)],
488 'FILTER', long_expr)
489 print(f"Long expression: {len(results)} results")
490 except Exception as e:
491 print(f"Long expression error: {str(e)}")
492
493 print("\nComprehensive VSIM FILTER tests completed successfully")
494
495
496class VSIMFilterSelectivityTest(TestCase):
497 def getname(self):
498 return "VSIM FILTER selectivity performance benchmark"
499
500 def estimated_runtime(self):
501 return 8 # This test might take up to 8 seconds
502
503 def setup(self):
504 super().setup()
505 self.dim = 32
506 self.count = 10000
507 self.test_key = f"{self.test_key}:selectivity" # Use a different key
508
509 def create_vector_with_age_attribute(self, name, age):
510 """Create a vector with a specific age attribute"""
511 vec = generate_random_vector(self.dim)
512 vec_bytes = struct.pack(f'{self.dim}f', *vec)
513 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name)
514 self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps({"age": age}))
515
516 def test(self):
517 print("\nRunning VSIM FILTER selectivity benchmark...")
518
519 # Create a dataset where we control the exact selectivity
520 print(f"Creating controlled dataset with {self.count} vectors...")
521
522 # Create vectors with age attributes from 1 to 100
523 for i in range(self.count):
524 age = (i % 100) + 1 # Ages from 1 to 100
525 name = f"{self.test_key}:item:{i}"
526 self.create_vector_with_age_attribute(name, age)
527
528 # Create a query vector
529 query_vec = generate_random_vector(self.dim)
530
531 # Test filters with different selectivities
532 selectivities = [0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.99]
533 results = []
534
535 print("\nSelectivity | Filter | Results | Time (ms)")
536 print("--------------------------------------------------")
537
538 for target_selectivity in selectivities:
539 # Calculate age threshold for desired selectivity
540 # For example, age <= 10 gives 10% selectivity
541 age_threshold = int(target_selectivity * 100)
542 filter_expr = f'.age <= {age_threshold}'
543
544 # Run query and measure time
545 start_time = time.time()
546 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
547 cmd_args.extend([str(x) for x in query_vec])
548 cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr])
549
550 results = self.redis.execute_command(*cmd_args)
551 query_time = time.time() - start_time
552
553 actual_selectivity = len(results) / min(100, int(target_selectivity * self.count))
554 print(f"{target_selectivity:.2f} | {filter_expr:15} | {len(results):7} | {query_time*1000:.1f}")
555
556 # Add assertion to ensure reasonable performance for different selectivities
557 # For very selective queries (1%), we might need more exploration
558 if target_selectivity <= 0.05:
559 # For very selective queries, ensure we can find some results
560 assert len(results) > 0, f"No results found for {filter_expr}"
561 else:
562 # For less selective queries, performance should be reasonable
563 assert query_time < 1.0, f"Query too slow: {query_time:.3f}s for {filter_expr}"
564
565 print("\nSelectivity benchmark completed successfully")
566
567
568class VSIMFilterComparisonTest(TestCase):
569 def getname(self):
570 return "VSIM FILTER EF parameter comparison"
571
572 def estimated_runtime(self):
573 return 8 # This test might take up to 8 seconds
574
575 def setup(self):
576 super().setup()
577 self.dim = 32
578 self.count = 5000
579 self.test_key = f"{self.test_key}:efparams" # Use a different key
580
581 def create_dataset(self):
582 """Create a dataset with specific attribute patterns for testing FILTER-EF"""
583 vectors = []
584 names = []
585
586 # Create vectors with category and quality score attributes
587 for i in range(self.count):
588 vec = generate_random_vector(self.dim)
589 name = f"{self.test_key}:item:{i}"
590
591 # Add vector to Redis
592 vec_bytes = struct.pack(f'{self.dim}f', *vec)
593 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name)
594
595 # Create attributes - we want a very selective filter
596 # Only 2% of items have category=premium AND quality>90
597 category = "premium" if random.random() < 0.1 else random.choice(["standard", "economy", "basic"])
598 quality = random.randint(1, 100)
599
600 attrs = {
601 "id": i,
602 "category": category,
603 "quality": quality
604 }
605
606 self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps(attrs))
607 vectors.append(vec)
608 names.append(name)
609
610 return vectors, names
611
612 def test(self):
613 print("\nRunning VSIM FILTER-EF parameter comparison...")
614
615 # Create dataset
616 vectors, names = self.create_dataset()
617
618 # Create a selective filter that matches ~2% of items
619 filter_expr = '.category == "premium" and .quality > 90'
620
621 # Create query vector
622 query_vec = generate_random_vector(self.dim)
623
624 # Test different FILTER-EF values
625 ef_values = [50, 100, 500, 1000, 5000]
626 results = []
627
628 print("\nFILTER-EF | Results | Time (ms) | Notes")
629 print("---------------------------------------")
630
631 baseline_count = None
632
633 for ef in ef_values:
634 # Run query and measure time
635 start_time = time.time()
636 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
637 cmd_args.extend([str(x) for x in query_vec])
638 cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr, 'FILTER-EF', ef])
639
640 query_results = self.redis.execute_command(*cmd_args)
641 query_time = time.time() - start_time
642
643 # Set baseline for comparison
644 if baseline_count is None:
645 baseline_count = len(query_results)
646
647 recall_rate = len(query_results) / max(1, baseline_count) if baseline_count > 0 else 1.0
648
649 notes = ""
650 if ef == 5000:
651 notes = "Baseline"
652 elif recall_rate < 0.5:
653 notes = "Low recall!"
654
655 print(f"{ef:9} | {len(query_results):7} | {query_time*1000:.1f} | {notes}")
656 results.append((ef, len(query_results), query_time))
657
658 # If we have enough results at highest EF, check that recall improves with higher EF
659 if results[-1][1] >= 5: # At least 5 results for highest EF
660 # Extract result counts
661 result_counts = [r[1] for r in results]
662
663 # The last result (highest EF) should typically find more results than the first (lowest EF)
664 # but we use a soft assertion to avoid flaky tests
665 assert result_counts[-1] >= result_counts[0], \
666 f"Higher FILTER-EF should find at least as many results: {result_counts[-1]} vs {result_counts[0]}"
667
668 print("\nFILTER-EF parameter comparison completed successfully")
diff --git a/examples/redis-unstable/modules/vector-sets/tests/large_scale.py b/examples/redis-unstable/modules/vector-sets/tests/large_scale.py
new file mode 100644
index 0000000..eac5dca
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/large_scale.py
@@ -0,0 +1,56 @@
1from test import TestCase, fill_redis_with_vectors, generate_random_vector
2import random
3
4class LargeScale(TestCase):
5 def getname(self):
6 return "Large Scale Comparison"
7
8 def estimated_runtime(self):
9 return 10
10
11 def test(self):
12 dim = 300
13 count = 20000
14 k = 50
15
16 # Fill Redis and get reference data for comparison
17 random.seed(42) # Make test deterministic
18 data = fill_redis_with_vectors(self.redis, self.test_key, count, dim)
19
20 # Generate query vector
21 query_vec = generate_random_vector(dim)
22
23 # Get results from Redis with good exploration factor
24 redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
25 *[str(x) for x in query_vec],
26 'COUNT', k, 'WITHSCORES', 'EF', 500)
27
28 # Convert Redis results to dict
29 redis_results = {}
30 for i in range(0, len(redis_raw), 2):
31 key = redis_raw[i].decode()
32 score = float(redis_raw[i+1])
33 redis_results[key] = score
34
35 # Get results from linear scan
36 linear_results = data.find_k_nearest(query_vec, k)
37 linear_items = {name: score for name, score in linear_results}
38
39 # Compare overlap
40 redis_set = set(redis_results.keys())
41 linear_set = set(linear_items.keys())
42 overlap = len(redis_set & linear_set)
43
44 # If test fails, print comparison for debugging
45 if overlap < k * 0.7:
46 data.print_comparison({'items': redis_results, 'query_vector': query_vec}, k)
47
48 assert overlap >= k * 0.7, \
49 f"Expected at least 70% overlap in top {k} results, got {overlap/k*100:.1f}%"
50
51 # Verify scores for common items
52 for item in redis_set & linear_set:
53 redis_score = redis_results[item]
54 linear_score = linear_items[item]
55 assert abs(redis_score - linear_score) < 0.01, \
56 f"Score mismatch for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/memory_usage.py b/examples/redis-unstable/modules/vector-sets/tests/memory_usage.py
new file mode 100644
index 0000000..d0f3f09
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/memory_usage.py
@@ -0,0 +1,36 @@
1from test import TestCase, generate_random_vector
2import struct
3
4class MemoryUsageTest(TestCase):
5 def getname(self):
6 return "[regression] MEMORY USAGE with attributes"
7
8 def test(self):
9 # Generate random vectors
10 vec1 = generate_random_vector(4)
11 vec2 = generate_random_vector(4)
12 vec_bytes1 = struct.pack('4f', *vec1)
13 vec_bytes2 = struct.pack('4f', *vec2)
14
15 # Add vectors to the key, one with attribute, one without
16 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes1, f'{self.test_key}:item:1')
17 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes2, f'{self.test_key}:item:2', 'SETATTR', '{"color":"red"}')
18
19 # Get memory usage for the key
20 try:
21 memory_usage = self.redis.execute_command('MEMORY', 'USAGE', self.test_key)
22 # If we got here without exception, the command worked
23 assert memory_usage > 0, "MEMORY USAGE should return a positive value"
24
25 # Add more attributes to increase complexity
26 self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:1', '{"color":"blue","size":10}')
27
28 # Check memory usage again
29 new_memory_usage = self.redis.execute_command('MEMORY', 'USAGE', self.test_key)
30 assert new_memory_usage > 0, "MEMORY USAGE should still return a positive value after setting attributes"
31
32 # Memory usage should be higher after adding attributes
33 assert new_memory_usage > memory_usage, "Memory usage increase after adding attributes"
34
35 except Exception as e:
36 raise AssertionError(f"MEMORY USAGE command failed: {str(e)}")
diff --git a/examples/redis-unstable/modules/vector-sets/tests/node_update.py b/examples/redis-unstable/modules/vector-sets/tests/node_update.py
new file mode 100644
index 0000000..53aa2dd
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/node_update.py
@@ -0,0 +1,85 @@
1from test import TestCase, generate_random_vector
2import struct
3import math
4import random
5
6class VectorUpdateAndClusters(TestCase):
7 def getname(self):
8 return "VADD vector update with cluster relocation"
9
10 def estimated_runtime(self):
11 return 2.0 # Should take around 2 seconds
12
13 def generate_cluster_vector(self, base_vec, noise=0.1):
14 """Generate a vector that's similar to base_vec with some noise."""
15 vec = [x + random.gauss(0, noise) for x in base_vec]
16 # Normalize
17 norm = math.sqrt(sum(x*x for x in vec))
18 return [x/norm for x in vec]
19
20 def test(self):
21 dim = 128
22 vectors_per_cluster = 5000
23
24 # Create two very different base vectors for our clusters
25 cluster1_base = generate_random_vector(dim)
26 cluster2_base = [-x for x in cluster1_base] # Opposite direction
27
28 # Add vectors from first cluster
29 for i in range(vectors_per_cluster):
30 vec = self.generate_cluster_vector(cluster1_base)
31 vec_bytes = struct.pack(f'{dim}f', *vec)
32 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes,
33 f'{self.test_key}:cluster1:{i}')
34
35 # Add vectors from second cluster
36 for i in range(vectors_per_cluster):
37 vec = self.generate_cluster_vector(cluster2_base)
38 vec_bytes = struct.pack(f'{dim}f', *vec)
39 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes,
40 f'{self.test_key}:cluster2:{i}')
41
42 # Pick a test vector from cluster1
43 test_key = f'{self.test_key}:cluster1:0'
44
45 # Verify it's in cluster1 using VSIM
46 initial_vec = self.generate_cluster_vector(cluster1_base)
47 results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
48 *[str(x) for x in initial_vec],
49 'COUNT', 100, 'WITHSCORES')
50
51 # Count how many cluster1 items are in top results
52 cluster1_count = sum(1 for i in range(0, len(results), 2)
53 if b'cluster1' in results[i])
54 assert cluster1_count > 80, "Initial clustering check failed"
55
56 # Now update the test vector to be in cluster2
57 new_vec = self.generate_cluster_vector(cluster2_base, noise=0.05)
58 vec_bytes = struct.pack(f'{dim}f', *new_vec)
59 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, test_key)
60
61 # Verify the embedding was actually updated using VEMB
62 emb_result = self.redis.execute_command('VEMB', self.test_key, test_key)
63 updated_vec = [float(x) for x in emb_result]
64
65 # Verify updated vector matches what we inserted
66 dot_product = sum(a*b for a,b in zip(updated_vec, new_vec))
67 similarity = dot_product / (math.sqrt(sum(x*x for x in updated_vec)) *
68 math.sqrt(sum(x*x for x in new_vec)))
69 assert similarity > 0.9, "Vector was not properly updated"
70
71 # Verify it's now in cluster2 using VSIM
72 results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
73 *[str(x) for x in cluster2_base],
74 'COUNT', 100, 'WITHSCORES')
75
76 # Verify our updated vector is among top results
77 found = False
78 for i in range(0, len(results), 2):
79 if results[i].decode() == test_key:
80 found = True
81 similarity = float(results[i+1])
82 assert similarity > 0.80, f"Updated vector has low similarity: {similarity}"
83 break
84
85 assert found, "Updated vector not found in cluster2 proximity"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/persistence.py b/examples/redis-unstable/modules/vector-sets/tests/persistence.py
new file mode 100644
index 0000000..79730f4
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/persistence.py
@@ -0,0 +1,86 @@
1from test import TestCase, fill_redis_with_vectors, generate_random_vector
2import random
3
4class HNSWPersistence(TestCase):
5 def getname(self):
6 return "HNSW Persistence"
7
8 def estimated_runtime(self):
9 return 30
10
11 def _verify_results(self, key, dim, query_vec, reduced_dim=None):
12 """Run a query and return results dict"""
13 k = 10
14 args = ['VSIM', key]
15
16 if reduced_dim:
17 args.extend(['VALUES', dim])
18 args.extend([str(x) for x in query_vec])
19 else:
20 args.extend(['VALUES', dim])
21 args.extend([str(x) for x in query_vec])
22
23 args.extend(['COUNT', k, 'WITHSCORES'])
24 results = self.redis.execute_command(*args)
25
26 results_dict = {}
27 for i in range(0, len(results), 2):
28 key = results[i].decode()
29 score = float(results[i+1])
30 results_dict[key] = score
31 return results_dict
32
33 def test(self):
34 # Setup dimensions
35 dim = 128
36 reduced_dim = 32
37 count = 5000
38 random.seed(42)
39
40 # Create two datasets - one normal and one with dimension reduction
41 normal_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:normal", count, dim)
42 projected_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:projected",
43 count, dim, reduced_dim)
44
45 # Generate query vectors we'll use before and after reload
46 query_vec_normal = generate_random_vector(dim)
47 query_vec_projected = generate_random_vector(dim)
48
49 # Get initial results for both sets
50 initial_normal = self._verify_results(f"{self.test_key}:normal",
51 dim, query_vec_normal)
52 initial_projected = self._verify_results(f"{self.test_key}:projected",
53 dim, query_vec_projected, reduced_dim)
54
55 # Force Redis to save and reload the dataset
56 self.redis.execute_command('DEBUG', 'RELOAD')
57
58 # Verify results after reload
59 reloaded_normal = self._verify_results(f"{self.test_key}:normal",
60 dim, query_vec_normal)
61 reloaded_projected = self._verify_results(f"{self.test_key}:projected",
62 dim, query_vec_projected, reduced_dim)
63
64 # Verify normal vectors results
65 assert len(initial_normal) == len(reloaded_normal), \
66 "Normal vectors: Result count mismatch before/after reload"
67
68 for key in initial_normal:
69 assert key in reloaded_normal, f"Normal vectors: Missing item after reload: {key}"
70 assert abs(initial_normal[key] - reloaded_normal[key]) < 0.0001, \
71 f"Normal vectors: Score mismatch for {key}: " + \
72 f"before={initial_normal[key]:.6f}, after={reloaded_normal[key]:.6f}"
73
74 # Verify projected vectors results
75 assert len(initial_projected) == len(reloaded_projected), \
76 "Projected vectors: Result count mismatch before/after reload"
77
78 for key in initial_projected:
79 assert key in reloaded_projected, \
80 f"Projected vectors: Missing item after reload: {key}"
81 assert abs(initial_projected[key] - reloaded_projected[key]) < 0.0001, \
82 f"Projected vectors: Score mismatch for {key}: " + \
83 f"before={initial_projected[key]:.6f}, after={reloaded_projected[key]:.6f}"
84
85 self.redis.delete(f"{self.test_key}:normal")
86 self.redis.delete(f"{self.test_key}:projected")
diff --git a/examples/redis-unstable/modules/vector-sets/tests/reduce.py b/examples/redis-unstable/modules/vector-sets/tests/reduce.py
new file mode 100644
index 0000000..e39164f
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/reduce.py
@@ -0,0 +1,71 @@
1from test import TestCase, fill_redis_with_vectors, generate_random_vector
2
3class Reduce(TestCase):
4 def getname(self):
5 return "Dimension Reduction"
6
7 def estimated_runtime(self):
8 return 0.2
9
10 def test(self):
11 original_dim = 100
12 reduced_dim = 80
13 count = 1000
14 k = 50 # Number of nearest neighbors to check
15
16 # Fill Redis with vectors using REDUCE and get reference data
17 data = fill_redis_with_vectors(self.redis, self.test_key, count, original_dim, reduced_dim)
18
19 # Verify dimension is reduced
20 dim = self.redis.execute_command('VDIM', self.test_key)
21 assert dim == reduced_dim, f"Expected dimension {reduced_dim}, got {dim}"
22
23 # Generate query vector and get nearest neighbors using Redis
24 query_vec = generate_random_vector(original_dim)
25 redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES',
26 original_dim, *[str(x) for x in query_vec],
27 'COUNT', k, 'WITHSCORES')
28
29 # Convert Redis results to dict
30 redis_results = {}
31 for i in range(0, len(redis_raw), 2):
32 key = redis_raw[i].decode()
33 score = float(redis_raw[i+1])
34 redis_results[key] = score
35
36 # Get results from linear scan with original vectors
37 linear_results = data.find_k_nearest(query_vec, k)
38 linear_items = {name: score for name, score in linear_results}
39
40 # Compare overlap between reduced and non-reduced results
41 redis_set = set(redis_results.keys())
42 linear_set = set(linear_items.keys())
43 overlap = len(redis_set & linear_set)
44 overlap_ratio = overlap / k
45
46 # With random projection, we expect some loss of accuracy but should
47 # maintain at least some similarity structure.
48 # Note that gaussian distribution is the worse with this test, so
49 # in real world practice, things will be better.
50 min_expected_overlap = 0.1 # At least 10% overlap in top-k
51 assert overlap_ratio >= min_expected_overlap, \
52 f"Dimension reduction lost too much structure. Only {overlap_ratio*100:.1f}% overlap in top {k}"
53
54 # For items that appear in both results, scores should be reasonably correlated
55 common_items = redis_set & linear_set
56 for item in common_items:
57 redis_score = redis_results[item]
58 linear_score = linear_items[item]
59 # Allow for some deviation due to dimensionality reduction
60 assert abs(redis_score - linear_score) < 0.2, \
61 f"Score mismatch too high for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}"
62
63 # If test fails, print comparison for debugging
64 if overlap_ratio < min_expected_overlap:
65 print("\nLow overlap in results. Details:")
66 print("\nTop results from linear scan (original vectors):")
67 for name, score in linear_results:
68 print(f"{name}: {score:.3f}")
69 print("\nTop results from Redis (reduced vectors):")
70 for item, score in sorted(redis_results.items(), key=lambda x: x[1], reverse=True):
71 print(f"{item}: {score:.3f}")
diff --git a/examples/redis-unstable/modules/vector-sets/tests/replication.py b/examples/redis-unstable/modules/vector-sets/tests/replication.py
new file mode 100644
index 0000000..91dfdf7
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/replication.py
@@ -0,0 +1,92 @@
1from test import TestCase, generate_random_vector
2import struct
3import random
4import time
5
6class ComprehensiveReplicationTest(TestCase):
7 def getname(self):
8 return "Comprehensive Replication Test with mixed operations"
9
10 def estimated_runtime(self):
11 # This test will take longer than the default 100ms
12 return 20.0 # 20 seconds estimate
13
14 def test(self):
15 # Setup replication between primary and replica
16 assert self.setup_replication(), "Failed to setup replication"
17
18 # Test parameters
19 num_vectors = 5000
20 vector_dim = 8
21 delete_probability = 0.1
22 cas_probability = 0.3
23
24 # Keep track of added items for potential deletion
25 added_items = []
26
27 # Add vectors and occasionally delete
28 for i in range(num_vectors):
29 # Generate a random vector
30 vec = generate_random_vector(vector_dim)
31 vec_bytes = struct.pack(f'{vector_dim}f', *vec)
32 item_name = f"{self.test_key}:item:{i}"
33
34 # Decide whether to use CAS or not
35 use_cas = random.random() < cas_probability
36
37 if use_cas and added_items:
38 # Get an existing item for CAS reference (if available)
39 cas_item = random.choice(added_items)
40 try:
41 # Add with CAS
42 result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes,
43 item_name, 'CAS')
44 # Only add to our list if actually added (CAS might fail)
45 if result == 1:
46 added_items.append(item_name)
47 except Exception as e:
48 print(f" CAS VADD failed: {e}")
49 else:
50 try:
51 # Add without CAS
52 result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, item_name)
53 # Only add to our list if actually added
54 if result == 1:
55 added_items.append(item_name)
56 except Exception as e:
57 print(f" VADD failed: {e}")
58
59 # Randomly delete items (with 10% probability)
60 if random.random() < delete_probability and added_items:
61 try:
62 # Select a random item to delete
63 item_to_delete = random.choice(added_items)
64 # Delete the item using VREM (not VDEL)
65 self.redis.execute_command('VREM', self.test_key, item_to_delete)
66 # Remove from our list
67 added_items.remove(item_to_delete)
68 except Exception as e:
69 print(f" VREM failed: {e}")
70
71 # Allow time for replication to complete
72 time.sleep(2.0)
73
74 # Verify final VCARD matches
75 primary_card = self.redis.execute_command('VCARD', self.test_key)
76 replica_card = self.replica.execute_command('VCARD', self.test_key)
77 assert primary_card == replica_card, f"Final VCARD mismatch: primary={primary_card}, replica={replica_card}"
78
79 # Verify VDIM matches
80 primary_dim = self.redis.execute_command('VDIM', self.test_key)
81 replica_dim = self.replica.execute_command('VDIM', self.test_key)
82 assert primary_dim == replica_dim, f"VDIM mismatch: primary={primary_dim}, replica={replica_dim}"
83
84 # Verify digests match using DEBUG DIGEST
85 primary_digest = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key)
86 replica_digest = self.replica.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key)
87 assert primary_digest == replica_digest, f"Digest mismatch: primary={primary_digest}, replica={replica_digest}"
88
89 # Print summary
90 print(f"\n Added and maintained {len(added_items)} vectors with dimension {vector_dim}")
91 print(f" Final vector count: {primary_card}")
92 print(f" Final digest: {primary_digest[0].decode()}")
diff --git a/examples/redis-unstable/modules/vector-sets/tests/threading_config.py b/examples/redis-unstable/modules/vector-sets/tests/threading_config.py
new file mode 100644
index 0000000..dfc931a
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/threading_config.py
@@ -0,0 +1,249 @@
1from test import TestCase, generate_random_vector
2import struct
3
4
5class ThreadingConfigTest(TestCase):
6 """
7 Test suite for vset-force-single-threaded-execution configuration.
8
9 This test validates the behavior of VADD and VSIM commands under different
10 threading configurations. The new configuration is MUTABLE and BINARY:
11 - false (0): Multi-threaded execution enabled (default)
12 - true (1): Force single-threaded execution
13
14 Key behaviors tested:
15 - VADD with and without CAS option under both threading modes
16 - VSIM with and without NOTHREAD option under both threading modes
17 - Configuration reading, validation, and runtime modification
18 - Thread behavior switching (multi-threaded vs forced single-threaded)
19 """
20
21 def getname(self):
22 return "vset-force-single-threaded-execution configuration testing"
23
24 def estimated_runtime(self):
25 return 0.5 # Updated for mutable config testing with mode switching
26
27 def get_config_value(self):
28 """Get current vset-force-single-threaded-execution config value"""
29 try:
30 result = self.redis.execute_command('CONFIG', 'GET', 'vset-force-single-threaded-execution')
31 if len(result) >= 2:
32 # Redis returns 'yes'/'no' for boolean configs
33 return result[1].decode() if isinstance(result[1], bytes) else result[1]
34 return None
35 except Exception:
36 return None
37
38 def set_config_value(self, value):
39 """Set vset-force-single-threaded-execution config value"""
40 try:
41 # Convert boolean to yes/no string
42 str_value = 'yes' if value else 'no'
43 result = self.redis.execute_command('CONFIG', 'SET', 'vset-force-single-threaded-execution', str_value)
44 return result == b'OK' or result == 'OK'
45 except Exception as e:
46 print(f"Failed to set config: {e}")
47 return False
48
49 def test_config_access_and_mutability(self):
50 """Test 1: Configuration access and mutability"""
51 # Get initial value
52 initial_value = self.get_config_value()
53 assert initial_value is not None, "Should be able to read vset-force-single-threaded-execution config"
54 assert initial_value in ['yes', 'no'], f"Config value should be yes/no, got {initial_value}"
55
56 # Test mutability by toggling the value
57 new_value = 'no' if initial_value == 'yes' else 'yes'
58 assert self.set_config_value(new_value == 'yes'), "Should be able to change config value"
59
60 # Verify the change
61 current_value = self.get_config_value()
62 assert current_value == new_value, f"Config should be {new_value}, got {current_value}"
63
64 # Restore original value
65 assert self.set_config_value(initial_value == 'yes'), "Should be able to restore original value"
66
67 return initial_value == 'yes'
68
69 def test_vadd_without_cas(self, force_single_threaded=False):
70 """Test 2: VADD command without CAS option"""
71 # Set threading mode
72 self.set_config_value(force_single_threaded)
73
74 # Clear test data to avoid dimension conflicts
75 self.redis.delete(self.test_key)
76
77 dim = 64
78 vec = generate_random_vector(dim)
79 vec_bytes = struct.pack(f'{dim}f', *vec)
80
81 result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1')
82 assert result == 1, f"VADD should return 1 for new item, got {result}"
83
84 # Verify the vector was added
85 card = self.redis.execute_command('VCARD', self.test_key)
86 assert card == 1, f"VCARD should return 1, got {card}"
87
88 def test_vadd_with_cas(self, force_single_threaded=False):
89 """Test 3: VADD command with CAS option"""
90 # Set threading mode
91 self.set_config_value(force_single_threaded)
92
93 # Clear test data to avoid dimension conflicts
94 self.redis.delete(self.test_key)
95
96 dim = 64
97 vec = generate_random_vector(dim)
98 vec_bytes = struct.pack(f'{dim}f', *vec)
99
100 # First insertion with CAS should succeed
101 result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:cas', 'CAS')
102 assert result == 1, f"First VADD with CAS should return 1, got {result}"
103
104 # Second insertion of same item with CAS should return 0
105 result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:cas', 'CAS')
106 assert result == 0, f"Duplicate VADD with CAS should return 0, got {result}"
107
108 def test_vsim_without_nothread(self, force_single_threaded=False):
109 """Test 4: VSIM command without NOTHREAD"""
110 # Set threading mode
111 self.set_config_value(force_single_threaded)
112
113 # Clear test data to avoid dimension conflicts
114 self.redis.delete(self.test_key)
115
116 dim = 64
117
118 # Add test vectors
119 for i in range(5):
120 vec = generate_random_vector(dim)
121 vec_bytes = struct.pack(f'{dim}f', *vec)
122 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:{i}')
123
124 # Test VSIM without NOTHREAD
125 query_vec = generate_random_vector(dim)
126 args = ['VSIM', self.test_key, 'VALUES', dim] + [str(x) for x in query_vec] + ['COUNT', 3]
127 result = self.redis.execute_command(*args)
128
129 assert isinstance(result, list), f"VSIM should return a list, got {type(result)}"
130 assert len(result) <= 3, f"VSIM should return at most 3 results, got {len(result)}"
131
132 def test_vsim_with_nothread(self, force_single_threaded=False):
133 """Test 5: VSIM command with NOTHREAD"""
134 # Set threading mode
135 self.set_config_value(force_single_threaded)
136
137 dim = 64
138
139 # Ensure we have vectors to search (use existing vectors from previous test)
140 card = self.redis.execute_command('VCARD', self.test_key)
141 if card == 0:
142 # Add test vectors if none exist
143 for i in range(5):
144 vec = generate_random_vector(dim)
145 vec_bytes = struct.pack(f'{dim}f', *vec)
146 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:{i}')
147
148 # Test VSIM with NOTHREAD
149 query_vec = generate_random_vector(dim)
150 args = ['VSIM', self.test_key, 'VALUES', dim] + [str(x) for x in query_vec] + ['COUNT', 3, 'NOTHREAD']
151 result = self.redis.execute_command(*args)
152
153 assert isinstance(result, list), f"VSIM with NOTHREAD should return a list, got {type(result)}"
154 assert len(result) <= 3, f"VSIM with NOTHREAD should return at most 3 results, got {len(result)}"
155
156 def test_threading_mode_comparison(self):
157 """Test 6: Compare behavior between threading modes"""
158 dim = 64
159
160 # Clear test data
161 self.redis.delete(self.test_key)
162
163 # Test multi-threaded mode (default)
164 self.set_config_value(False) # Multi-threaded
165 self.test_vadd_without_cas(False)
166 self.test_vadd_with_cas(False)
167 multi_threaded_card = self.redis.execute_command('VCARD', self.test_key)
168
169 # Clear and test single-threaded mode
170 self.redis.delete(self.test_key)
171 self.set_config_value(True) # Single-threaded
172 self.test_vadd_without_cas(True)
173 self.test_vadd_with_cas(True)
174 single_threaded_card = self.redis.execute_command('VCARD', self.test_key)
175
176 # Both modes should produce same results
177 assert multi_threaded_card == single_threaded_card, \
178 f"Both modes should produce same results: multi={multi_threaded_card}, single={single_threaded_card}"
179
180 def test_nothread_override_behavior(self):
181 """Test 7: NOTHREAD option should work regardless of config"""
182 dim = 64
183
184 # Test with both config modes
185 for force_single in [False, True]:
186 self.set_config_value(force_single)
187 self.redis.delete(self.test_key)
188
189 # Add test vectors
190 for i in range(3):
191 vec = generate_random_vector(dim)
192 vec_bytes = struct.pack(f'{dim}f', *vec)
193 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:{i}')
194
195 # NOTHREAD should work regardless of config
196 query_vec = generate_random_vector(dim)
197 args = ['VSIM', self.test_key, 'VALUES', dim] + [str(x) for x in query_vec] + ['COUNT', 2, 'NOTHREAD']
198 result = self.redis.execute_command(*args)
199
200 assert isinstance(result, list), f"NOTHREAD should work with force_single={force_single}"
201 assert len(result) <= 2, f"NOTHREAD should return ≤2 results with force_single={force_single}"
202
203 def test(self):
204 """Main test method - runs all threading configuration tests"""
205 # Get initial configuration
206 initial_force_single = self.test_config_access_and_mutability()
207 print(f"Initial vset-force-single-threaded-execution: {'yes' if initial_force_single else 'no'}")
208
209 # Clear test data
210 self.redis.delete(self.test_key)
211
212 # Test both threading modes
213 print("Testing multi-threaded mode...")
214 self.set_config_value(False)
215 self.test_vadd_without_cas(False)
216 self.test_vadd_with_cas(False)
217 self.test_vsim_without_nothread(False)
218 self.test_vsim_with_nothread(False)
219
220 print("Testing single-threaded mode...")
221 self.set_config_value(True)
222 self.test_vadd_without_cas(True)
223 self.test_vadd_with_cas(True)
224 self.test_vsim_without_nothread(True)
225 self.test_vsim_with_nothread(True)
226
227 # Test mode comparison and NOTHREAD override
228 self.test_threading_mode_comparison()
229 self.test_nothread_override_behavior()
230
231 # Restore initial configuration
232 self.set_config_value(initial_force_single)
233
234 # Print summary
235 self._print_test_summary(initial_force_single)
236
237 def _print_test_summary(self, initial_force_single):
238 """Print a summary of what was tested"""
239 print(f"\nThreading Configuration Test Summary:")
240 print(f" Configuration: vset-force-single-threaded-execution")
241 print(f" Type: Boolean, Mutable")
242 print(f" Initial value: {'yes' if initial_force_single else 'no'}")
243 print(f" Tested modes: Both multi-threaded (no) and single-threaded (yes)")
244 print(f" VADD: Works correctly in both modes")
245 print(f" VADD with CAS: Works correctly in both modes")
246 print(f" VSIM: Works correctly in both modes")
247 print(f" NOTHREAD option: Overrides config in both modes")
248 print(f" Configuration mutability: ✅ Successfully changed at runtime")
249 print(f" All tests passed successfully!")
diff --git a/examples/redis-unstable/modules/vector-sets/tests/vadd_cas.py b/examples/redis-unstable/modules/vector-sets/tests/vadd_cas.py
new file mode 100644
index 0000000..3cb3508
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/vadd_cas.py
@@ -0,0 +1,98 @@
1from test import TestCase, generate_random_vector
2import threading
3import struct
4import math
5import time
6import random
7from typing import List, Dict
8
9class ConcurrentCASTest(TestCase):
10 def getname(self):
11 return "Concurrent VADD with CAS"
12
13 def estimated_runtime(self):
14 return 1.5
15
16 def worker(self, vectors: List[List[float]], start_idx: int, end_idx: int,
17 dim: int, results: Dict[str, bool]):
18 """Worker thread that adds a subset of vectors using VADD CAS"""
19 for i in range(start_idx, end_idx):
20 vec = vectors[i]
21 name = f"{self.test_key}:item:{i}"
22 vec_bytes = struct.pack(f'{dim}f', *vec)
23
24 # Try to add the vector with CAS
25 try:
26 result = self.redis.execute_command('VADD', self.test_key, 'FP32',
27 vec_bytes, name, 'CAS')
28 results[name] = (result == 1) # Store if it was actually added
29 except Exception as e:
30 results[name] = False
31 print(f"Error adding {name}: {e}")
32
33 def verify_vector_similarity(self, vec1: List[float], vec2: List[float]) -> float:
34 """Calculate cosine similarity between two vectors"""
35 dot_product = sum(a*b for a,b in zip(vec1, vec2))
36 norm1 = math.sqrt(sum(x*x for x in vec1))
37 norm2 = math.sqrt(sum(x*x for x in vec2))
38 return dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0
39
40 def test(self):
41 # Test parameters
42 dim = 128
43 total_vectors = 5000
44 num_threads = 8
45 vectors_per_thread = total_vectors // num_threads
46
47 # Generate all vectors upfront
48 random.seed(42) # For reproducibility
49 vectors = [generate_random_vector(dim) for _ in range(total_vectors)]
50
51 # Prepare threads and results dictionary
52 threads = []
53 results = {} # Will store success/failure for each vector
54
55 # Launch threads
56 for i in range(num_threads):
57 start_idx = i * vectors_per_thread
58 end_idx = start_idx + vectors_per_thread if i < num_threads-1 else total_vectors
59 thread = threading.Thread(target=self.worker,
60 args=(vectors, start_idx, end_idx, dim, results))
61 threads.append(thread)
62 thread.start()
63
64 # Wait for all threads to complete
65 for thread in threads:
66 thread.join()
67
68 # Verify cardinality
69 card = self.redis.execute_command('VCARD', self.test_key)
70 assert card == total_vectors, \
71 f"Expected {total_vectors} elements, but found {card}"
72
73 # Verify each vector
74 num_verified = 0
75 for i in range(total_vectors):
76 name = f"{self.test_key}:item:{i}"
77
78 # Verify the item was successfully added
79 assert results[name], f"Vector {name} was not successfully added"
80
81 # Get the stored vector
82 stored_vec_raw = self.redis.execute_command('VEMB', self.test_key, name)
83 stored_vec = [float(x) for x in stored_vec_raw]
84
85 # Verify vector dimensions
86 assert len(stored_vec) == dim, \
87 f"Stored vector dimension mismatch for {name}: {len(stored_vec)} != {dim}"
88
89 # Calculate similarity with original vector
90 similarity = self.verify_vector_similarity(vectors[i], stored_vec)
91 assert similarity > 0.99, \
92 f"Low similarity ({similarity}) for {name}"
93
94 num_verified += 1
95
96 # Final verification
97 assert num_verified == total_vectors, \
98 f"Only verified {num_verified} out of {total_vectors} vectors"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/vemb.py b/examples/redis-unstable/modules/vector-sets/tests/vemb.py
new file mode 100644
index 0000000..0f4cf77
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/vemb.py
@@ -0,0 +1,41 @@
1from test import TestCase
2import struct
3import math
4
5class VEMB(TestCase):
6 def getname(self):
7 return "VEMB Command"
8
9 def test(self):
10 dim = 4
11
12 # Add same vector in both formats
13 vec = [1, 0, 0, 0]
14 norm = math.sqrt(sum(x*x for x in vec))
15 vec = [x/norm for x in vec] # Normalize the vector
16
17 # Add using FP32
18 vec_bytes = struct.pack(f'{dim}f', *vec)
19 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1')
20
21 # Add using VALUES
22 self.redis.execute_command('VADD', self.test_key, 'VALUES', dim,
23 *[str(x) for x in vec], f'{self.test_key}:item:2')
24
25 # Get both back with VEMB
26 result1 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:1')
27 result2 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:2')
28
29 retrieved_vec1 = [float(x) for x in result1]
30 retrieved_vec2 = [float(x) for x in result2]
31
32 # Compare both vectors with original (allow for small quantization errors)
33 for i in range(dim):
34 assert abs(vec[i] - retrieved_vec1[i]) < 0.01, \
35 f"FP32 vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec1[i]}"
36 assert abs(vec[i] - retrieved_vec2[i]) < 0.01, \
37 f"VALUES vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec2[i]}"
38
39 # Test non-existent item
40 result = self.redis.execute_command('VEMB', self.test_key, 'nonexistent')
41 assert result is None, "Non-existent item should return nil"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/vismember.py b/examples/redis-unstable/modules/vector-sets/tests/vismember.py
new file mode 100644
index 0000000..eabebca
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/vismember.py
@@ -0,0 +1,47 @@
1from test import TestCase, generate_random_vector
2import struct
3
4class BasicVISMEMBER(TestCase):
5 def getname(self):
6 return "VISMEMBER basic functionality"
7
8 def test(self):
9 # Add multiple vectors to the vector set
10 vec1 = generate_random_vector(4)
11 vec2 = generate_random_vector(4)
12 vec_bytes1 = struct.pack('4f', *vec1)
13 vec_bytes2 = struct.pack('4f', *vec2)
14
15 # Create item keys
16 item1 = f'{self.test_key}:item:1'
17 item2 = f'{self.test_key}:item:2'
18 nonexistent_item = f'{self.test_key}:item:nonexistent'
19
20 # Add the vectors
21 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes1, item1)
22 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes2, item2)
23
24 # Test VISMEMBER with existing elements
25 result1 = self.redis.execute_command('VISMEMBER', self.test_key, item1)
26 assert result1 == 1, f"VISMEMBER should return 1 for existing item, got {result1}"
27
28 result2 = self.redis.execute_command('VISMEMBER', self.test_key, item2)
29 assert result2 == 1, f"VISMEMBER should return 1 for existing item, got {result2}"
30
31 # Test VISMEMBER with non-existent element
32 result3 = self.redis.execute_command('VISMEMBER', self.test_key, nonexistent_item)
33 assert result3 == 0, f"VISMEMBER should return 0 for non-existent item, got {result3}"
34
35 # Test VISMEMBER with non-existent key
36 nonexistent_key = f'{self.test_key}_nonexistent'
37 result4 = self.redis.execute_command('VISMEMBER', nonexistent_key, item1)
38 assert result4 == 0, f"VISMEMBER should return 0 for non-existent key, got {result4}"
39
40 # Test VISMEMBER after removing an element
41 self.redis.execute_command('VREM', self.test_key, item1)
42 result5 = self.redis.execute_command('VISMEMBER', self.test_key, item1)
43 assert result5 == 0, f"VISMEMBER should return 0 after element removal, got {result5}"
44
45 # Verify item2 still exists
46 result6 = self.redis.execute_command('VISMEMBER', self.test_key, item2)
47 assert result6 == 1, f"VISMEMBER should still return 1 for remaining item, got {result6}"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/vrand-ping-pong.py b/examples/redis-unstable/modules/vector-sets/tests/vrand-ping-pong.py
new file mode 100644
index 0000000..99d2e9a
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/vrand-ping-pong.py
@@ -0,0 +1,35 @@
1from test import TestCase, generate_random_vector
2import struct
3
4class VRANDMEMBERPingPongRegressionTest(TestCase):
5 def getname(self):
6 return "[regression] VRANDMEMBER ping-pong"
7
8 def test(self):
9 """
10 This test ensures that when only two vectors exist, VRANDMEMBER
11 does not get stuck returning only one of them due to the "ping-pong" issue.
12 """
13 self.redis.delete(self.test_key) # Clean up before test
14 dim = 4
15
16 # Add exactly two vectors
17 vec1_name = "vec1"
18 vec1_data = generate_random_vector(dim)
19 self.redis.execute_command('VADD', self.test_key, 'VALUES', dim, *vec1_data, vec1_name)
20
21 vec2_name = "vec2"
22 vec2_data = generate_random_vector(dim)
23 self.redis.execute_command('VADD', self.test_key, 'VALUES', dim, *vec2_data, vec2_name)
24
25 # Call VRANDMEMBER many times and check for distribution
26 iterations = 100
27 results = []
28 for _ in range(iterations):
29 member = self.redis.execute_command('VRANDMEMBER', self.test_key)
30 results.append(member.decode())
31
32 # Verify that both members were returned, proving it's not stuck
33 unique_results = set(results)
34
35 assert len(unique_results) == 2, f"Ping-pong test failed: should have returned 2 unique members, but got {len(unique_results)}."
diff --git a/examples/redis-unstable/modules/vector-sets/tests/vrandmember.py b/examples/redis-unstable/modules/vector-sets/tests/vrandmember.py
new file mode 100644
index 0000000..ca9e006
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/vrandmember.py
@@ -0,0 +1,55 @@
1from test import TestCase, generate_random_vector, fill_redis_with_vectors
2import struct
3
4class VRANDMEMBERTest(TestCase):
5 def getname(self):
6 return "VRANDMEMBER basic functionality"
7
8 def test(self):
9 # Test with empty key
10 result = self.redis.execute_command('VRANDMEMBER', self.test_key)
11 assert result is None, "VRANDMEMBER on non-existent key should return NULL"
12
13 result = self.redis.execute_command('VRANDMEMBER', self.test_key, 5)
14 assert isinstance(result, list) and len(result) == 0, "VRANDMEMBER with count on non-existent key should return empty array"
15
16 # Fill with vectors
17 dim = 4
18 count = 100
19 data = fill_redis_with_vectors(self.redis, self.test_key, count, dim)
20
21 # Test single random member
22 result = self.redis.execute_command('VRANDMEMBER', self.test_key)
23 assert result is not None, "VRANDMEMBER should return a random member"
24 assert result.decode() in data.names, "Random member should be in the set"
25
26 # Test multiple unique members (positive count)
27 positive_count = 10
28 result = self.redis.execute_command('VRANDMEMBER', self.test_key, positive_count)
29 assert isinstance(result, list), "VRANDMEMBER with positive count should return an array"
30 assert len(result) == positive_count, f"Should return {positive_count} members"
31
32 # Check for uniqueness
33 decoded_results = [r.decode() for r in result]
34 assert len(decoded_results) == len(set(decoded_results)), "Results should be unique with positive count"
35 for item in decoded_results:
36 assert item in data.names, "All returned items should be in the set"
37
38 # Test more members than in the set
39 result = self.redis.execute_command('VRANDMEMBER', self.test_key, count + 10)
40 assert len(result) == count, "Should return only the available members when asking for more than exist"
41
42 # Test with duplicates (negative count)
43 negative_count = -20
44 result = self.redis.execute_command('VRANDMEMBER', self.test_key, negative_count)
45 assert isinstance(result, list), "VRANDMEMBER with negative count should return an array"
46 assert len(result) == abs(negative_count), f"Should return {abs(negative_count)} members"
47
48 # Check that all returned elements are valid
49 decoded_results = [r.decode() for r in result]
50 for item in decoded_results:
51 assert item in data.names, "All returned items should be in the set"
52
53 # Test with count = 0 (edge case)
54 result = self.redis.execute_command('VRANDMEMBER', self.test_key, 0)
55 assert isinstance(result, list) and len(result) == 0, "VRANDMEMBER with count=0 should return empty array"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/vrange.py b/examples/redis-unstable/modules/vector-sets/tests/vrange.py
new file mode 100644
index 0000000..7e57588
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/vrange.py
@@ -0,0 +1,113 @@
1from test import TestCase, generate_random_vector
2import struct
3
4class BasicVRANGE(TestCase):
5 def getname(self):
6 return "VRANGE basic functionality and iteration"
7
8 def test(self):
9 # Add multiple elements with different names for lexicographical ordering
10 elements = [
11 "apple", "apricot", "banana", "cherry", "date",
12 "elderberry", "fig", "grape", "honeydew", "kiwi",
13 "lemon", "mango", "nectarine", "orange", "papaya",
14 "quince", "raspberry", "strawberry", "tangerine", "watermelon"
15 ]
16
17 # Add all elements to the vector set
18 for elem in elements:
19 vec = generate_random_vector(4)
20 vec_bytes = struct.pack('4f', *vec)
21 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, elem)
22
23 # Test 1: Basic range with inclusive boundaries
24 result = self.redis.execute_command('VRANGE', self.test_key, '[apple', '[grape', '5')
25 result = [r.decode() for r in result]
26 assert result == ['apple', 'apricot', 'banana', 'cherry', 'date'], f"Expected first 5 elements from apple, got {result}"
27
28 # Test 2: Exclusive start boundary
29 result = self.redis.execute_command('VRANGE', self.test_key, '(apple', '[cherry', '10')
30 result = [r.decode() for r in result]
31 assert result == ['apricot', 'banana', 'cherry'], f"Expected elements after apple up to cherry inclusive, got {result}"
32
33 # Test 3: Exclusive end boundary
34 result = self.redis.execute_command('VRANGE', self.test_key, '[banana', '(cherry', '10')
35 result = [r.decode() for r in result]
36 assert result == ['banana'], f"Expected only banana (cherry excluded), got {result}"
37
38 # Test 4: Using '-' for minimum element
39 result = self.redis.execute_command('VRANGE', self.test_key, '-', '[banana', '10')
40 result = [r.decode() for r in result]
41 assert result[0] == 'apple', "Should start from the first element"
42 assert result[-1] == 'banana', "Should end at banana"
43
44 # Test 5: Using '+' for maximum element
45 result = self.redis.execute_command('VRANGE', self.test_key, '[raspberry', '+', '10')
46 result = [r.decode() for r in result]
47 assert 'raspberry' in result and 'strawberry' in result and 'tangerine' in result and 'watermelon' in result, "Should include all elements from raspberry onwards"
48
49 # Test 6: Full range with '-' and '+'
50 result = self.redis.execute_command('VRANGE', self.test_key, '-', '+', '100')
51 result = [r.decode() for r in result]
52 assert len(result) == len(elements), f"Should return all {len(elements)} elements"
53 assert result == sorted(elements), "Elements should be in lexicographical order"
54
55 # Test 7: Iterator pattern - verify each element appears exactly once
56 seen = set()
57 batch_size = 3
58 current = '-'
59
60 while True:
61 if current == '-':
62 # First iteration
63 result = self.redis.execute_command('VRANGE', self.test_key, '-', '+', str(batch_size))
64 else:
65 # Subsequent iterations - exclusive start from last element
66 result = self.redis.execute_command('VRANGE', self.test_key, f'({current}', '+', str(batch_size))
67
68 result = [r.decode() for r in result]
69
70 if not result:
71 break
72
73 # Check no duplicates in this batch
74 for elem in result:
75 assert elem not in seen, f"Element {elem} appeared more than once"
76 seen.add(elem)
77
78 # Update current to last element
79 current = result[-1]
80
81 # Break if we got less than requested (end of set)
82 if len(result) < batch_size:
83 break
84
85 # Verify we saw all elements exactly once
86 assert seen == set(elements), f"Iterator should visit all elements exactly once. Missing: {set(elements) - seen}, Extra: {seen - set(elements)}"
87
88 # Test 8: Count of 0 returns empty array
89 result = self.redis.execute_command('VRANGE', self.test_key, '-', '+', '0')
90 assert result == [], f"Count of 0 should return empty array, got {result}"
91
92 # Test 9: Range with no matching elements
93 result = self.redis.execute_command('VRANGE', self.test_key, '[zebra', '+', '10')
94 assert result == [], f"Range beyond all elements should return empty array, got {result}"
95
96 # Test 10: Non-existent key
97 result = self.redis.execute_command('VRANGE', 'nonexistent_key', '-', '+', '10')
98 assert result == [], f"Non-existent key should return empty array, got {result}"
99
100 # Test 11: Partial word boundaries
101 result = self.redis.execute_command('VRANGE', self.test_key, '[app', '[apr', '10')
102 result = [r.decode() for r in result]
103 assert 'apple' in result, "Should include 'apple' which starts with 'app'"
104 assert 'apricot' not in result, "Should not include 'apricot' as it's >= 'apr'"
105
106 # Test 12: Single element range
107 result = self.redis.execute_command('VRANGE', self.test_key, '[cherry', '[cherry', '10')
108 result = [r.decode() for r in result]
109 assert result == ['cherry'], f"Inclusive single element range should return that element, got {result}"
110
111 # Test 13: Empty range (start > end)
112 result = self.redis.execute_command('VRANGE', self.test_key, '[grape', '[apple', '10')
113 assert result == [], f"Range where start > end should return empty array, got {result}"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/vsim_limit_efsearch.py b/examples/redis-unstable/modules/vector-sets/tests/vsim_limit_efsearch.py
new file mode 100644
index 0000000..25b9689
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/vsim_limit_efsearch.py
@@ -0,0 +1,32 @@
1from test import TestCase, generate_random_vector
2import struct
3
4class VSIMLimitEFSearch(TestCase):
5 def getname(self):
6 return "VSIM Limit EF Search"
7
8 def estimated_runtime(self):
9 return 0.2
10
11 def test(self):
12 dim = 32
13 vec = generate_random_vector(dim)
14 vec_bytes = struct.pack(f'{dim}f', *vec)
15
16 # Add test vector
17 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1')
18
19 query_vec = generate_random_vector(dim)
20
21 # Test EF upper bound (should accept 1000000)
22 result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
23 *[str(x) for x in query_vec], 'EF', 1000000)
24 assert isinstance(result, list), "EF=1000000 should be accepted"
25
26 # Test EF over limit (should reject > 1000000)
27 try:
28 self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
29 *[str(x) for x in query_vec], 'EF', 1000001)
30 assert False, "EF=1000001 should be rejected"
31 except Exception as e:
32 assert "invalid EF" in str(e), f"Expected EF validation error, got: {e}"
diff --git a/examples/redis-unstable/modules/vector-sets/tests/with.py b/examples/redis-unstable/modules/vector-sets/tests/with.py
new file mode 100644
index 0000000..d14a23f
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/with.py
@@ -0,0 +1,214 @@
1from test import TestCase, generate_random_vector
2import struct
3import json
4import random
5
6class VSIMWithAttribs(TestCase):
7 def getname(self):
8 return "VSIM WITHATTRIBS/WITHSCORES functionality testing"
9
10 def setup(self):
11 super().setup()
12 self.dim = 8
13 self.count = 20
14
15 # Create vectors with attributes
16 for i in range(self.count):
17 vec = generate_random_vector(self.dim)
18 vec_bytes = struct.pack(f'{self.dim}f', *vec)
19
20 # Item name
21 name = f"{self.test_key}:item:{i}"
22
23 # Add to Redis
24 self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name)
25
26 # Create and add attribute
27 if i % 5 == 0:
28 # Every 5th item has no attribute (for testing NULL responses)
29 continue
30
31 category = random.choice(["electronics", "furniture", "clothing"])
32 price = random.randint(50, 1000)
33 attrs = {"category": category, "price": price, "id": i}
34
35 self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps(attrs))
36
37 def is_numeric(self, value):
38 """Check if a value can be converted to float"""
39 try:
40 if isinstance(value, (int, float)):
41 return True
42 if isinstance(value, bytes):
43 float(value.decode('utf-8'))
44 return True
45 if isinstance(value, str):
46 float(value)
47 return True
48 return False
49 except (ValueError, TypeError):
50 return False
51
52 def test(self):
53 # Create query vector
54 query_vec = generate_random_vector(self.dim)
55
56 # Test 1: VSIM with no additional options (should be same for RESP2 and RESP3)
57 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
58 cmd_args.extend([str(x) for x in query_vec])
59 cmd_args.extend(['COUNT', 5])
60
61 results_resp2 = self.redis.execute_command(*cmd_args)
62 results_resp3 = self.redis3.execute_command(*cmd_args)
63
64 # Both should return simple arrays of item names
65 assert len(results_resp2) == 5, f"RESP2: Expected 5 results, got {len(results_resp2)}"
66 assert len(results_resp3) == 5, f"RESP3: Expected 5 results, got {len(results_resp3)}"
67 assert all(isinstance(item, bytes) for item in results_resp2), "RESP2: Results should be byte strings"
68 assert all(isinstance(item, bytes) for item in results_resp3), "RESP3: Results should be byte strings"
69
70 # Test 2: VSIM with WITHSCORES only
71 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
72 cmd_args.extend([str(x) for x in query_vec])
73 cmd_args.extend(['COUNT', 5, 'WITHSCORES'])
74
75 results_resp2 = self.redis.execute_command(*cmd_args)
76 results_resp3 = self.redis3.execute_command(*cmd_args)
77
78 # RESP2: Should be a flat array alternating item, score
79 assert len(results_resp2) == 10, f"RESP2: Expected 10 elements (5 items × 2), got {len(results_resp2)}"
80 for i in range(0, len(results_resp2), 2):
81 assert isinstance(results_resp2[i], bytes), f"RESP2: Item at {i} should be bytes"
82 assert self.is_numeric(results_resp2[i+1]), f"RESP2: Score at {i+1} should be numeric"
83 score = float(results_resp2[i+1]) if isinstance(results_resp2[i+1], bytes) else results_resp2[i+1]
84 assert 0 <= score <= 1, f"RESP2: Score {score} should be between 0 and 1"
85
86 # RESP3: Should be a dict/map with items as keys and scores as DIRECT values (not arrays)
87 assert isinstance(results_resp3, dict), f"RESP3: Expected dict, got {type(results_resp3)}"
88 assert len(results_resp3) == 5, f"RESP3: Expected 5 entries, got {len(results_resp3)}"
89 for item, score in results_resp3.items():
90 assert isinstance(item, bytes), f"RESP3: Key should be bytes"
91 # Score should be a direct value, NOT an array
92 assert not isinstance(score, list), f"RESP3: With single WITH option, value should not be array"
93 assert self.is_numeric(score), f"RESP3: Score should be numeric, got {type(score)}"
94 score_val = float(score) if isinstance(score, bytes) else score
95 assert 0 <= score_val <= 1, f"RESP3: Score {score_val} should be between 0 and 1"
96
97 # Test 3: VSIM with WITHATTRIBS only
98 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
99 cmd_args.extend([str(x) for x in query_vec])
100 cmd_args.extend(['COUNT', 5, 'WITHATTRIBS'])
101
102 results_resp2 = self.redis.execute_command(*cmd_args)
103 results_resp3 = self.redis3.execute_command(*cmd_args)
104
105 # RESP2: Should be a flat array alternating item, attribute
106 assert len(results_resp2) == 10, f"RESP2: Expected 10 elements (5 items × 2), got {len(results_resp2)}"
107 for i in range(0, len(results_resp2), 2):
108 assert isinstance(results_resp2[i], bytes), f"RESP2: Item at {i} should be bytes"
109 attr = results_resp2[i+1]
110 assert attr is None or isinstance(attr, bytes), f"RESP2: Attribute at {i+1} should be None or bytes"
111 if attr is not None:
112 # Verify it's valid JSON
113 json.loads(attr)
114
115 # RESP3: Should be a dict/map with items as keys and attributes as DIRECT values (not arrays)
116 assert isinstance(results_resp3, dict), f"RESP3: Expected dict, got {type(results_resp3)}"
117 assert len(results_resp3) == 5, f"RESP3: Expected 5 entries, got {len(results_resp3)}"
118 for item, attr in results_resp3.items():
119 assert isinstance(item, bytes), f"RESP3: Key should be bytes"
120 # Attribute should be a direct value, NOT an array
121 assert not isinstance(attr, list), f"RESP3: With single WITH option, value should not be array"
122 assert attr is None or isinstance(attr, bytes), f"RESP3: Attribute should be None or bytes"
123 if attr is not None:
124 # Verify it's valid JSON
125 json.loads(attr)
126
127 # Test 4: VSIM with both WITHSCORES and WITHATTRIBS
128 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
129 cmd_args.extend([str(x) for x in query_vec])
130 cmd_args.extend(['COUNT', 5, 'WITHSCORES', 'WITHATTRIBS'])
131
132 results_resp2 = self.redis.execute_command(*cmd_args)
133 results_resp3 = self.redis3.execute_command(*cmd_args)
134
135 # RESP2: Should be a flat array with pattern: item, score, attribute
136 assert len(results_resp2) == 15, f"RESP2: Expected 15 elements (5 items × 3), got {len(results_resp2)}"
137 for i in range(0, len(results_resp2), 3):
138 assert isinstance(results_resp2[i], bytes), f"RESP2: Item at {i} should be bytes"
139 assert self.is_numeric(results_resp2[i+1]), f"RESP2: Score at {i+1} should be numeric"
140 score = float(results_resp2[i+1]) if isinstance(results_resp2[i+1], bytes) else results_resp2[i+1]
141 assert 0 <= score <= 1, f"RESP2: Score {score} should be between 0 and 1"
142 attr = results_resp2[i+2]
143 assert attr is None or isinstance(attr, bytes), f"RESP2: Attribute at {i+2} should be None or bytes"
144
145 # RESP3: Should be a dict where each value is a 2-element array [score, attribute]
146 assert isinstance(results_resp3, dict), f"RESP3: Expected dict, got {type(results_resp3)}"
147 assert len(results_resp3) == 5, f"RESP3: Expected 5 entries, got {len(results_resp3)}"
148 for item, value in results_resp3.items():
149 assert isinstance(item, bytes), f"RESP3: Key should be bytes"
150 # With BOTH options, value MUST be an array
151 assert isinstance(value, list), f"RESP3: With both WITH options, value should be a list, got {type(value)}"
152 assert len(value) == 2, f"RESP3: Value should have 2 elements [score, attr], got {len(value)}"
153
154 score, attr = value
155 assert self.is_numeric(score), f"RESP3: Score should be numeric"
156 score_val = float(score) if isinstance(score, bytes) else score
157 assert 0 <= score_val <= 1, f"RESP3: Score {score_val} should be between 0 and 1"
158 assert attr is None or isinstance(attr, bytes), f"RESP3: Attribute should be None or bytes"
159
160 # Test 5: Verify consistency - same items returned in same order
161 cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
162 cmd_args.extend([str(x) for x in query_vec])
163 cmd_args.extend(['COUNT', 5, 'WITHSCORES', 'WITHATTRIBS'])
164
165 results_resp2 = self.redis.execute_command(*cmd_args)
166 results_resp3 = self.redis3.execute_command(*cmd_args)
167
168 # Extract items from RESP2 (every 3rd element starting from 0)
169 items_resp2 = [results_resp2[i] for i in range(0, len(results_resp2), 3)]
170
171 # Extract items from RESP3 (keys of the dict)
172 items_resp3 = list(results_resp3.keys())
173
174 # Verify same items returned
175 assert set(items_resp2) == set(items_resp3), "RESP2 and RESP3 should return the same items"
176
177 # Build a mapping from items to scores and attributes for comparison
178 data_resp2 = {}
179 for i in range(0, len(results_resp2), 3):
180 item = results_resp2[i]
181 score = float(results_resp2[i+1]) if isinstance(results_resp2[i+1], bytes) else results_resp2[i+1]
182 attr = results_resp2[i+2]
183 data_resp2[item] = (score, attr)
184
185 data_resp3 = {}
186 for item, value in results_resp3.items():
187 score = float(value[0]) if isinstance(value[0], bytes) else value[0]
188 attr = value[1]
189 data_resp3[item] = (score, attr)
190
191 # Verify scores and attributes match for each item
192 for item in data_resp2:
193 score_resp2, attr_resp2 = data_resp2[item]
194 score_resp3, attr_resp3 = data_resp3[item]
195
196 assert abs(score_resp2 - score_resp3) < 0.0001, \
197 f"Scores for {item} don't match: RESP2={score_resp2}, RESP3={score_resp3}"
198 assert attr_resp2 == attr_resp3, \
199 f"Attributes for {item} don't match: RESP2={attr_resp2}, RESP3={attr_resp3}"
200
201 # Test 6: Test ordering of WITHSCORES and WITHATTRIBS doesn't matter
202 cmd_args1 = ['VSIM', self.test_key, 'VALUES', self.dim]
203 cmd_args1.extend([str(x) for x in query_vec])
204 cmd_args1.extend(['COUNT', 3, 'WITHSCORES', 'WITHATTRIBS'])
205
206 cmd_args2 = ['VSIM', self.test_key, 'VALUES', self.dim]
207 cmd_args2.extend([str(x) for x in query_vec])
208 cmd_args2.extend(['COUNT', 3, 'WITHATTRIBS', 'WITHSCORES']) # Reversed order
209
210 results1_resp3 = self.redis3.execute_command(*cmd_args1)
211 results2_resp3 = self.redis3.execute_command(*cmd_args2)
212
213 # Both should return the same structure
214 assert results1_resp3 == results2_resp3, "Order of WITH options shouldn't matter"
diff --git a/examples/redis-unstable/modules/vector-sets/vset.c b/examples/redis-unstable/modules/vector-sets/vset.c
new file mode 100644
index 0000000..500f8e9
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/vset.c
@@ -0,0 +1,2587 @@
1/* Redis implementation for vector sets. The data structure itself
2 * is implemented in hnsw.c.
3 *
4 * Copyright (c) 2009-Present, Redis Ltd.
5 * All rights reserved.
6 *
7 * Licensed under your choice of (a) the Redis Source Available License 2.0
8 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
9 * GNU Affero General Public License v3 (AGPLv3).
10 * Originally authored by: Salvatore Sanfilippo.
11 *
12 * ======================== Understand threading model =========================
13 * This code implements threaded operarations for two of the commands:
14 *
15 * 1. VSIM, by default.
16 * 2. VADD, if the CAS option is specified.
17 *
18 * Note that even if the second operation, VADD, is a write operation, only
19 * the neighbors collection for the new node is performed in a thread: then,
20 * the actual insert is performed in the reply callback VADD_CASReply(),
21 * which is executed in the main thread.
22 *
23 * Threaded operations need us to protect various operations with mutexes,
24 * even if a certain degree of protection is already provided by the HNSW
25 * library. Here are a few very important things about this implementation
26 * and the way locking is performed.
27 *
28 * 1. All the write operations are performed in the main Redis thread:
29 * this also include VADD_CASReply() callback, that is called by Redis
30 * internals only in the context of the main thread. However the HNSW
31 * library allows background threads in hnsw_search() (VSIM) to modify
32 * nodes metadata to speedup search (to understand if a node was already
33 * visited), but this only happens after acquiring a specific lock
34 * for a given "read slot".
35 *
36 * 2. We use a global lock for each Vector Set object, called "in_use". This
37 * lock is a read-write lock, and is acquired in read mode by all the
38 * threads that perform reads in the background. It is only acquired in
39 * write mode by vectorSetWaitAllBackgroundClients(): the function acquires
40 * the lock and immediately releases it, with the effect of waiting all the
41 * background threads still running from ending their execution.
42 *
43 * Note that no thread can be spawned, since we only call
44 * vectorSetWaitAllBackgroundClients() from the main Redis thread, that
45 * is also the only thread spawning other threads.
46 *
47 * vectorSetWaitAllBackgroundClients() is used in two ways:
48 * A) When we need to delete a vector set because of (DEL) or other
49 * operations destroying the object, we need to wait that all the
50 * background threads working with this object finished their work.
51 * B) When we modify the HNSW nodes bypassing the normal locking
52 * provided by the HNSW library. This only happens when we update
53 * an existing node attribute so far, in VSETATTR and when we call
54 * VADD to update a node with the SETATTR option.
55 *
56 * 3. Often during read operations performed by Redis commands in the
57 * main thread (VCARD, VEMB, VRANDMEMBER, ...) we don't acquire any
58 * lock at all. The commands run in the main Redis thread, we can only
59 * have, at the same time, background reads against the same data
60 * structure. Note that VSIM_thread() and VADD_thread() still modify the
61 * read slot metadata, that is node->visited_epoch[slot], but as long as
62 * our read commands running in the main thread don't need to use
63 * hnsw_search() or other HNSW functions using the visited epochs slots
64 * we are safe.
65 *
66 * 4. There is a race from the moment we create a thread, passing the
67 * vector set object, to the moment the thread can actually lock the
68 * result win the in_use_lock mutex: as the thread starts, in the meanwhile
69 * a DEL/expire could trigger and remove the object. For this reason
70 * we use an atomic counter that protects our object for this small
71 * time in vectorSetWaitAllBackgroundClients(). This prevents removal
72 * of objects that are about to be taken by threads.
73 *
74 * Note that other competing solutions could be used to fix the problem
75 * but have their set of issues, however they are worth documenting here
76 * and evaluating in the future:
77 *
78 * A. Using a conditional variable we could "wait" for the thread to
79 * acquire the lock. However this means waiting before returning
80 * to the event loop, and would make the command execution slower.
81 * B. We could use again an atomic variable, like we did, but this time
82 * as a refcount for the object, with a vsetAcquire() vsetRelease().
83 * In this case, the command could retain the object in the main thread
84 * before starting the thread, and the thread, after the work is done,
85 * could release it. This way sometimes the object would be freed by
86 * the thread, and it's while now can be safe to do the kind of resource
87 * deallocation that vectorSetReleaseObject() does, given that the
88 * Redis Modules API is not always thread safe this solution may not
89 * be future-proof. However there is to evaluate it better in the
90 * future.
91 * C. We could use the "B" solution but instead of freeing the object
92 * in the thread, in this specific case we could just put it into a
93 * list and defer it for later freeing (for instance in the reply
94 * callback), so that the object is always freed in the main thread.
95 * This would require a list of objects to free.
96 *
97 * However the current solution only disadvantage is the potential busy
98 * loop, but this busy loop in practical terms will almost never do
99 * much: to trigger it, a number of circumnstances must happen: deleting
100 * Vector Set keys while using them, hitting the small window needed to
101 * start the thread and read-lock the mutex.
102 */
103
104#define _DEFAULT_SOURCE
105#define _USE_MATH_DEFINES
106#define _POSIX_C_SOURCE 200809L
107
108#include "../../src/redismodule.h"
109#include <stdio.h>
110#include <stdlib.h>
111#include <ctype.h>
112#include <string.h>
113#include <strings.h>
114#include <stdint.h>
115#include <math.h>
116#include <pthread.h>
117#include <stdatomic.h>
118#include "hnsw.h"
119#include "vset_config.h"
120
121// We inline directly the expression implementation here so that building
122// the module is trivial.
123#include "expr.c"
124
125static RedisModuleType *VectorSetType;
126static uint64_t VectorSetTypeNextId = 0;
127
128// Default EF value if not specified during creation.
129#define VSET_DEFAULT_C_EF 200
130
131// Default EF value if not specified during search.
132#define VSET_DEFAULT_SEARCH_EF 100
133
134// Default num elements returned by VSIM.
135#define VSET_DEFAULT_COUNT 10
136
137/* ========================== Internal data structure ====================== */
138
139/* Our abstract data type needs a dual representation similar to Redis
140 * sorted set: the proximity graph, and also a element -> graph-node map
141 * that will allow us to perform deletions and other operations that have
142 * as input the element itself. */
143struct vsetObject {
144 HNSW *hnsw; // Proximity graph.
145 RedisModuleDict *dict; // Element -> node mapping.
146 float *proj_matrix; // Random projection matrix, NULL if no projection
147 uint32_t proj_input_size; // Input dimension after projection.
148 // Output dimension is implicit in
149 // hnsw->vector_dim.
150 pthread_rwlock_t in_use_lock; // Lock needed to destroy the object safely.
151 uint64_t id; // Unique ID used by threaded VADD to know the
152 // object is still the same.
153 uint64_t numattribs; // Number of nodes associated with an attribute.
154 atomic_int thread_creation_pending; // Number of threads that are currently
155 // pending to lock the object.
156};
157
158/* Each node has two associated values: the associated string (the item
159 * in the set) and potentially a JSON string, that is, the attributes, used
160 * for hybrid search with the VSIM FILTER option. */
161struct vsetNodeVal {
162 RedisModuleString *item;
163 RedisModuleString *attrib;
164};
165
166/* Count the number of set bits in an integer (population count/Hamming weight).
167 * This is a portable implementation that doesn't rely on compiler
168 * extensions. */
169static inline uint32_t bit_count(uint32_t n) {
170 uint32_t count = 0;
171 while (n) {
172 count += n & 1;
173 n >>= 1;
174 }
175 return count;
176}
177
178/* Create a Hadamard-based projection matrix for dimensionality reduction.
179 * Uses {-1, +1} entries with a pattern based on bit operations.
180 * The pattern is matrix[i][j] = (i & j) % 2 == 0 ? 1 : -1
181 * Matrix is scaled by 1/sqrt(input_dim) for normalization.
182 * Returns NULL on allocation failure.
183 *
184 * Note that compared to other approaches (random gaussian weights), what
185 * we have here is deterministic, it means that our replicas will have
186 * the same set of weights. Also this approach seems to work much better
187 * in practice, and the distances between elements are better guaranteed.
188 *
189 * Note that we still save the projection matrix in the RDB file, because
190 * in the future we may change the weights generation, and we want everything
191 * to be backward compatible. */
192float *createProjectionMatrix(uint32_t input_dim, uint32_t output_dim) {
193 float *matrix = RedisModule_Alloc(sizeof(float) * input_dim * output_dim);
194
195 /* Scale factor to normalize the projection. */
196 const float scale = 1.0f / sqrt(input_dim);
197
198 /* Fill the matrix using Hadamard pattern. */
199 for (uint32_t i = 0; i < output_dim; i++) {
200 for (uint32_t j = 0; j < input_dim; j++) {
201 /* Calculate position in the flattened matrix. */
202 uint32_t pos = i * input_dim + j;
203
204 /* Hadamard pattern: use bit operations to determine sign
205 * If the count of 1-bits in the bitwise AND of i and j is even,
206 * the value is 1, otherwise -1. */
207 int value = (bit_count(i & j) % 2 == 0) ? 1 : -1;
208
209 /* Store the scaled value. */
210 matrix[pos] = value * scale;
211 }
212 }
213 return matrix;
214}
215
216/* Apply random projection to input vector. Returns new allocated vector. */
217float *applyProjection(const float *input, const float *proj_matrix,
218 uint32_t input_dim, uint32_t output_dim)
219{
220 float *output = RedisModule_Alloc(sizeof(float) * output_dim);
221
222 for (uint32_t i = 0; i < output_dim; i++) {
223 const float *row = &proj_matrix[i * input_dim];
224 float sum = 0.0f;
225 for (uint32_t j = 0; j < input_dim; j++) {
226 sum += row[j] * input[j];
227 }
228 output[i] = sum;
229 }
230 return output;
231}
232
233/* Create the vector as HNSW+Dictionary combined data structure. */
234struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type, uint32_t hnsw_M) {
235 struct vsetObject *o;
236 o = RedisModule_Alloc(sizeof(*o));
237
238 o->id = VectorSetTypeNextId++;
239 o->hnsw = hnsw_new(dim,quant_type,hnsw_M);
240 if (!o->hnsw) { // May fail because of mutex creation.
241 RedisModule_Free(o);
242 return NULL;
243 }
244
245 o->dict = RedisModule_CreateDict(NULL);
246 o->proj_matrix = NULL;
247 o->proj_input_size = 0;
248 o->numattribs = 0;
249 o->thread_creation_pending = 0;
250 RedisModule_Assert(pthread_rwlock_init(&o->in_use_lock,NULL) == 0);
251 return o;
252}
253
254void vectorSetReleaseNodeValue(void *v) {
255 struct vsetNodeVal *nv = v;
256 RedisModule_FreeString(NULL,nv->item);
257 if (nv->attrib) RedisModule_FreeString(NULL,nv->attrib);
258 RedisModule_Free(nv);
259}
260
261/* Free the vector set object. */
262void vectorSetReleaseObject(struct vsetObject *o) {
263 if (!o) return;
264 if (o->hnsw) hnsw_free(o->hnsw,vectorSetReleaseNodeValue);
265 if (o->dict) RedisModule_FreeDict(NULL,o->dict);
266 if (o->proj_matrix) RedisModule_Free(o->proj_matrix);
267 pthread_rwlock_destroy(&o->in_use_lock);
268 RedisModule_Free(o);
269}
270
271/* Wait for all the threads performing operations on this
272 * index to terminate their work (locking for write will
273 * wait for all the other threads).
274 *
275 * if 'for_del' is set to 1, we also wait for all the pending threads
276 * that still didn't acquire the lock to finish their work. This
277 * is useful only if we are going to call this function to delete
278 * the object, and not if we want to just to modify it. */
279void vectorSetWaitAllBackgroundClients(struct vsetObject *vset, int for_del) {
280 if (for_del) {
281 // If we are going to destroy the object, after this call, let's
282 // wait for threads that are being created and still didn't had
283 // a chance to acquire the lock.
284 while (vset->thread_creation_pending > 0);
285 }
286 RedisModule_Assert(pthread_rwlock_wrlock(&vset->in_use_lock) == 0);
287 pthread_rwlock_unlock(&vset->in_use_lock);
288}
289
290/* Return a string representing the quantization type name of a vector set. */
291const char *vectorSetGetQuantName(struct vsetObject *o) {
292 switch(o->hnsw->quant_type) {
293 case HNSW_QUANT_NONE: return "f32";
294 case HNSW_QUANT_Q8: return "int8";
295 case HNSW_QUANT_BIN: return "bin";
296 default: return "unknown";
297 }
298}
299
300/* Insert the specified element into the Vector Set.
301 * If update is '1', the existing node will be updated.
302 *
303 * Returns 1 if the element was added, or 0 if the element was already there
304 * and was just updated. */
305int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, RedisModuleString *attrib, int update, int ef)
306{
307 hnswNode *node = RedisModule_DictGet(o->dict,val,NULL);
308 if (node != NULL) {
309 if (update) {
310 /* Wait for clients in the background: background VSIM
311 * operations touch the nodes attributes we are going
312 * to touch. */
313 vectorSetWaitAllBackgroundClients(o,0);
314
315 struct vsetNodeVal *nv = node->value;
316 /* Pass NULL as value-free function. We want to reuse
317 * the old value. */
318 hnsw_delete_node(o->hnsw, node, NULL);
319 node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef);
320 RedisModule_Assert(node != NULL);
321 RedisModule_DictReplace(o->dict,val,node);
322
323 /* If attrib != NULL, the user wants that in case of an update we
324 * update the attribute as well (otherwise it remains as it was).
325 * Note that the order of operations is conceinved so that it
326 * works in case the old attrib and the new attrib pointer is the
327 * same. */
328 if (attrib) {
329 // Empty attribute string means: unset the attribute during
330 // the update.
331 size_t attrlen;
332 RedisModule_StringPtrLen(attrib,&attrlen);
333 if (attrlen != 0) {
334 RedisModule_RetainString(NULL,attrib);
335 o->numattribs++;
336 } else {
337 attrib = NULL;
338 }
339
340 if (nv->attrib) {
341 o->numattribs--;
342 RedisModule_FreeString(NULL,nv->attrib);
343 }
344 nv->attrib = attrib;
345 }
346 }
347 return 0;
348 }
349
350 struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
351 nv->item = val;
352 nv->attrib = attrib;
353 node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef);
354 if (node == NULL) {
355 // XXX Technically in Redis-land we don't have out of memory, as we
356 // crash on OOM. However the HNSW library may fail for error in the
357 // locking libc call. Probably impossible in practical terms.
358 RedisModule_Free(nv);
359 return 0;
360 }
361 if (attrib != NULL) o->numattribs++;
362 RedisModule_DictSet(o->dict,val,node);
363 RedisModule_RetainString(NULL,val);
364 if (attrib) RedisModule_RetainString(NULL,attrib);
365 return 1;
366}
367
368/* Parse vector from FP32 blob or VALUES format, with optional REDUCE.
369 * Format: [REDUCE dim] FP32|VALUES ...
370 * Returns allocated vector and sets dimension in *dim.
371 * If reduce_dim is not NULL, sets it to the requested reduction dimension.
372 * Returns NULL on parsing error.
373 *
374 * The function sets as a reference *consumed_args, so that the caller
375 * knows how many arguments we consumed in order to parse the input
376 * vector. Remaining arguments are often command options. */
377float *parseVector(RedisModuleString **argv, int argc, int start_idx,
378 size_t *dim, uint32_t *reduce_dim, int *consumed_args)
379{
380 int consumed = 0; // Arguments consumed
381
382 /* Check for REDUCE option first. */
383 if (reduce_dim) *reduce_dim = 0;
384 if (reduce_dim && argc > start_idx + 2 &&
385 !strcasecmp(RedisModule_StringPtrLen(argv[start_idx],NULL),"REDUCE"))
386 {
387 long long rdim;
388 if (RedisModule_StringToLongLong(argv[start_idx+1],&rdim)
389 != REDISMODULE_OK || rdim <= 0)
390 {
391 return NULL;
392 }
393 if (reduce_dim) *reduce_dim = rdim;
394 start_idx += 2; // Skip REDUCE and its argument.
395 consumed += 2;
396 }
397
398 /* Now parse the vector format as before. */
399 float *vec = NULL;
400 const char *vec_format = RedisModule_StringPtrLen(argv[start_idx],NULL);
401
402 if (!strcasecmp(vec_format,"FP32")) {
403 if (argc < start_idx + 2) return NULL; // Need FP32 + vector + value.
404 size_t vec_raw_len;
405 const char *blob =
406 RedisModule_StringPtrLen(argv[start_idx+1],&vec_raw_len);
407
408 // Must be 4 bytes per component.
409 if (vec_raw_len % 4 || vec_raw_len < 4) return NULL;
410 *dim = vec_raw_len/4;
411
412 vec = RedisModule_Alloc(vec_raw_len);
413 if (!vec) return NULL;
414 memcpy(vec,blob,vec_raw_len);
415 consumed += 2;
416 } else if (!strcasecmp(vec_format,"VALUES")) {
417 if (argc < start_idx + 2) return NULL; // Need at least the dimension.
418 long long vdim; // Vector dimension passed by the user.
419 if (RedisModule_StringToLongLong(argv[start_idx+1],&vdim)
420 != REDISMODULE_OK || vdim < 1) return NULL;
421
422 // Check that all the arguments are available.
423 if (argc < start_idx + 2 + vdim) return NULL;
424
425 *dim = vdim;
426 vec = RedisModule_Alloc(sizeof(float) * vdim);
427 if (!vec) return NULL;
428
429 for (int j = 0; j < vdim; j++) {
430 double val;
431 if (RedisModule_StringToDouble(argv[start_idx+2+j],&val)
432 != REDISMODULE_OK)
433 {
434 RedisModule_Free(vec);
435 return NULL;
436 }
437 vec[j] = val;
438 }
439 consumed += vdim + 2;
440 } else {
441 return NULL; // Unknown format.
442 }
443
444 if (consumed_args) *consumed_args = consumed;
445 return vec;
446}
447
448/* ========================== Commands implementation ======================= */
449
450/* VADD thread handling the "CAS" version of the command, that is
451 * performed blocking the client, accumulating here, in the thread, the
452 * set of potential candidates, and later inserting the element in the
453 * key (if it still exists, and if it is still the *same* vector set)
454 * in the Reply callback. */
455void *VADD_thread(void *arg) {
456 pthread_detach(pthread_self());
457
458 void **targ = (void**)arg;
459 RedisModuleBlockedClient *bc = targ[0];
460 struct vsetObject *vset = targ[1];
461 float *vec = targ[3];
462 int ef = (uint64_t)targ[6];
463
464 /* Lock the object and signal that we are no longer pending
465 * the lock acquisition. */
466 RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0);
467 vset->thread_creation_pending--;
468
469 /* Look for candidates... */
470 InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, ef);
471 targ[5] = ic; // Pass the context to the reply callback.
472
473 /* Unblock the client so that our read reply will be invoked. */
474 pthread_rwlock_unlock(&vset->in_use_lock);
475 RedisModule_BlockedClientMeasureTimeEnd(bc);
476 RedisModule_UnblockClient(bc,targ); // Use targ as privdata.
477 return NULL;
478}
479
480/* Reply callback for CAS variant of VADD.
481 * Note: this is called in the main thread, in the background thread
482 * we just do the read operation of gathering the neighbors. */
483int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
484 (void)argc;
485 RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
486
487 int retval = REDISMODULE_OK;
488 void **targ = (void**)RedisModule_GetBlockedClientPrivateData(ctx);
489 uint64_t vset_id = (unsigned long) targ[2];
490 float *vec = targ[3];
491 RedisModuleString *val = targ[4];
492 InsertContext *ic = targ[5];
493 int ef = (uint64_t)targ[6];
494 RedisModuleString *attrib = targ[7];
495 RedisModule_Free(targ);
496
497 /* Open the key: there are no guarantees it still exists, or contains
498 * a vector set, or even the SAME vector set. */
499 RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1],
500 REDISMODULE_READ|REDISMODULE_WRITE);
501 int type = RedisModule_KeyType(key);
502 struct vsetObject *vset = NULL;
503
504 if (type != REDISMODULE_KEYTYPE_EMPTY &&
505 RedisModule_ModuleTypeGetType(key) == VectorSetType)
506 {
507 vset = RedisModule_ModuleTypeGetValue(key);
508 // Same vector set?
509 if (vset->id != vset_id) vset = NULL;
510
511 /* Also, if the element was already inserted, we just pretend
512 * the other insert won. We don't even start a threaded VADD
513 * if this was an update, since the deletion of the element itself
514 * in order to perform the update would invalidate the CAS state. */
515 if (vset && RedisModule_DictGet(vset->dict,val,NULL) != NULL)
516 vset = NULL;
517 }
518
519 if (vset == NULL) {
520 /* If the object does not match the start of the operation, we
521 * just pretend the VADD was performed BEFORE the key was deleted
522 * or replaced. We return success but don't do anything. */
523 hnsw_free_insert_context(ic);
524 } else {
525 /* Otherwise try to insert the new element with the neighbors
526 * collected in background. If we fail, do it synchronously again
527 * from scratch. */
528
529 // First: allocate the dual-ported value for the node.
530 struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
531 nv->item = val;
532 nv->attrib = attrib;
533
534 /* Then: insert the node in the HNSW data structure. Note that
535 * 'ic' could be NULL in case hnsw_prepare_insert() failed because of
536 * locking failure (likely impossible in practical terms). */
537 hnswNode *newnode;
538 if (ic == NULL ||
539 (newnode = hnsw_try_commit_insert(vset->hnsw, ic, nv)) == NULL)
540 {
541 /* If we are here, the CAS insert failed. We need to insert
542 * again with full locking for neighbors selection and
543 * actual insertion. This time we can't fail: */
544 newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, nv, ef);
545 RedisModule_Assert(newnode != NULL);
546 }
547 RedisModule_DictSet(vset->dict,val,newnode);
548 val = NULL; // Don't free it later.
549 attrib = NULL; // Don't free it later.
550
551 RedisModule_ReplicateVerbatim(ctx);
552 }
553
554 // Whatever happens is a success... :D
555 RedisModule_ReplyWithBool(ctx,1);
556 if (val) RedisModule_FreeString(ctx,val); // Not added? Free it.
557 if (attrib) RedisModule_FreeString(ctx,attrib); // Not added? Free it.
558 RedisModule_Free(vec);
559 return retval;
560}
561
562/* VADD key [REDUCE dim] FP32|VALUES vector value [CAS] [NOQUANT] [BIN] [Q8]
563 * [M count] */
564int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
565 RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
566
567 if (argc < 5) return RedisModule_WrongArity(ctx);
568
569 /* Parse vector with optional REDUCE */
570 size_t dim = 0;
571 uint32_t reduce_dim = 0;
572 int consumed_args;
573 int cas = 0; // Threaded check-and-set style insert.
574 long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes.
575 long long hnsw_create_M = HNSW_DEFAULT_M; // HNSW creation default M value.
576 float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args);
577 RedisModuleString *attrib = NULL; // Attributes if passed via ATTRIB.
578 if (!vec)
579 return RedisModule_ReplyWithError(ctx,"ERR invalid vector specification");
580
581 /* Missing element string at the end? */
582 if (argc-2-consumed_args < 1) {
583 RedisModule_Free(vec);
584 return RedisModule_WrongArity(ctx);
585 }
586
587 /* Parse options after the element string. */
588 uint32_t quant_type = HNSW_QUANT_Q8; // Default quantization type.
589
590 for (int j = 2 + consumed_args + 1; j < argc; j++) {
591 const char *opt = RedisModule_StringPtrLen(argv[j], NULL);
592 if (!strcasecmp(opt, "CAS")) {
593 cas = 1;
594 } else if (!strcasecmp(opt, "EF") && j+1 < argc) {
595 if (RedisModule_StringToLongLong(argv[j+1], &ef)
596 != REDISMODULE_OK || ef <= 0 || ef > 1000000)
597 {
598 RedisModule_Free(vec);
599 return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
600 }
601 j++; // skip argument.
602 } else if (!strcasecmp(opt, "M") && j+1 < argc) {
603 if (RedisModule_StringToLongLong(argv[j+1], &hnsw_create_M)
604 != REDISMODULE_OK || hnsw_create_M < HNSW_MIN_M ||
605 hnsw_create_M > HNSW_MAX_M)
606 {
607 RedisModule_Free(vec);
608 return RedisModule_ReplyWithError(ctx, "ERR invalid M");
609 }
610 j++; // skip argument.
611 } else if (!strcasecmp(opt, "SETATTR") && j+1 < argc) {
612 attrib = argv[j+1];
613 j++; // skip argument.
614 } else if (!strcasecmp(opt, "NOQUANT")) {
615 quant_type = HNSW_QUANT_NONE;
616 } else if (!strcasecmp(opt, "BIN")) {
617 quant_type = HNSW_QUANT_BIN;
618 } else if (!strcasecmp(opt, "Q8")) {
619 quant_type = HNSW_QUANT_Q8;
620 } else {
621 RedisModule_Free(vec);
622 return RedisModule_ReplyWithError(ctx,"ERR invalid option after element");
623 }
624 }
625
626 /* Drop CAS if this is a replica and we are getting the command from the
627 * replication link: we want to add/delete items in the same order as
628 * the master, while with CAS the timing would be different.
629 *
630 * Also for Lua scripts and MULTI/EXEC, we want to run the command
631 * on the main thread. */
632 if (RedisModule_GetContextFlags(ctx) &
633 (REDISMODULE_CTX_FLAGS_REPLICATED|
634 REDISMODULE_CTX_FLAGS_LUA|
635 REDISMODULE_CTX_FLAGS_MULTI))
636 {
637 cas = 0;
638 }
639
640 if (VSGlobalConfig.forceSingleThreadExec) {
641 cas = 0;
642 }
643
644 /* Open/create key */
645 RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1],
646 REDISMODULE_READ|REDISMODULE_WRITE);
647 int type = RedisModule_KeyType(key);
648 if (type != REDISMODULE_KEYTYPE_EMPTY &&
649 RedisModule_ModuleTypeGetType(key) != VectorSetType)
650 {
651 RedisModule_Free(vec);
652 return RedisModule_ReplyWithError(ctx,REDISMODULE_ERRORMSG_WRONGTYPE);
653 }
654
655 /* Get the correct value argument based on format and REDUCE */
656 RedisModuleString *val = argv[2 + consumed_args];
657
658 /* Create or get existing vector set */
659 struct vsetObject *vset;
660 if (type == REDISMODULE_KEYTYPE_EMPTY) {
661 cas = 0; /* Do synchronous insert at creation, otherwise the
662 * key would be left empty until the threaded part
663 * does not return. It's also pointless to try try
664 * doing threaded first element insertion. */
665 vset = createVectorSetObject(reduce_dim ? reduce_dim : dim, quant_type, hnsw_create_M);
666 if (vset == NULL) {
667 // We can't fail for OOM in Redis, but the mutex initialization
668 // at least theoretically COULD fail. Likely this code path
669 // is not reachable in practical terms.
670 RedisModule_Free(vec);
671 return RedisModule_ReplyWithError(ctx,
672 "ERR unable to create a Vector Set: system resources issue?");
673 }
674
675 /* Initialize projection if requested */
676 if (reduce_dim) {
677 vset->proj_matrix = createProjectionMatrix(dim, reduce_dim);
678 vset->proj_input_size = dim;
679
680 /* Project the vector */
681 float *projected = applyProjection(vec, vset->proj_matrix,
682 dim, reduce_dim);
683 RedisModule_Free(vec);
684 vec = projected;
685 }
686 RedisModule_ModuleTypeSetValue(key,VectorSetType,vset);
687 } else {
688 vset = RedisModule_ModuleTypeGetValue(key);
689
690 if (vset->hnsw->quant_type != quant_type) {
691 RedisModule_Free(vec);
692 return RedisModule_ReplyWithError(ctx,
693 "ERR asked quantization mismatch with existing vector set");
694 }
695
696 if (vset->hnsw->M != hnsw_create_M) {
697 RedisModule_Free(vec);
698 return RedisModule_ReplyWithError(ctx,
699 "ERR asked M value mismatch with existing vector set");
700 }
701
702 if ((vset->proj_matrix == NULL && vset->hnsw->vector_dim != dim) ||
703 (vset->proj_matrix && vset->hnsw->vector_dim != reduce_dim))
704 {
705 RedisModule_Free(vec);
706 return RedisModule_ReplyWithErrorFormat(ctx,
707 "ERR Vector dimension mismatch - got %d but set has %d",
708 (int)dim, (int)vset->hnsw->vector_dim);
709 }
710
711 /* Check REDUCE compatibility */
712 if (reduce_dim) {
713 if (!vset->proj_matrix) {
714 RedisModule_Free(vec);
715 return RedisModule_ReplyWithError(ctx,
716 "ERR cannot add projection to existing set without projection");
717 }
718 if (reduce_dim != vset->hnsw->vector_dim) {
719 RedisModule_Free(vec);
720 return RedisModule_ReplyWithError(ctx,
721 "ERR projection dimension mismatch with existing set");
722 }
723 }
724
725 /* Apply projection if needed */
726 if (vset->proj_matrix) {
727 /* Ensure input dimension matches the projection matrix's expected input dimension */
728 if (dim != vset->proj_input_size) {
729 RedisModule_Free(vec);
730 return RedisModule_ReplyWithErrorFormat(ctx,
731 "ERR Input dimension mismatch for projection - got %d but projection expects %d",
732 (int)dim, (int)vset->proj_input_size);
733 }
734
735 float *projected = applyProjection(vec, vset->proj_matrix,
736 vset->proj_input_size,
737 vset->hnsw->vector_dim);
738 RedisModule_Free(vec);
739 vec = projected;
740 dim = vset->hnsw->vector_dim;
741 }
742 }
743
744 /* For existing keys don't do CAS updates. For how things work now, the
745 * CAS state would be invalidated by the deletion before adding back. */
746 if (cas && RedisModule_DictGet(vset->dict,val,NULL) != NULL)
747 cas = 0;
748
749 /* Here depending on the CAS option we directly insert in a blocking
750 * way, or use a thread to do candidate neighbors selection and only
751 * later, in the reply callback, actually add the element. */
752 if (cas) {
753 RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,VADD_CASReply,NULL,NULL,0);
754 pthread_t tid;
755 void **targ = RedisModule_Alloc(sizeof(void*)*8);
756 targ[0] = bc;
757 targ[1] = vset;
758 targ[2] = (void*)(unsigned long)vset->id;
759 targ[3] = vec;
760 targ[4] = val;
761 targ[5] = NULL; // Used later for insertion context.
762 targ[6] = (void*)(unsigned long)ef;
763 targ[7] = attrib;
764 RedisModule_RetainString(ctx,val);
765 if (attrib) RedisModule_RetainString(ctx,attrib);
766 RedisModule_BlockedClientMeasureTimeStart(bc);
767 vset->thread_creation_pending++;
768 if (pthread_create(&tid,NULL,VADD_thread,targ) != 0) {
769 vset->thread_creation_pending--;
770 RedisModule_AbortBlock(bc);
771 RedisModule_Free(targ);
772 RedisModule_FreeString(ctx,val);
773 if (attrib) RedisModule_FreeString(ctx,attrib);
774
775 // Fall back to synchronous insert, see later in the code.
776 } else {
777 return REDISMODULE_OK;
778 }
779 }
780
781 /* Insert vector synchronously: we reach this place even
782 * if cas was true but thread creation failed. */
783 int added = vectorSetInsert(vset,vec,NULL,0,val,attrib,1,ef);
784 RedisModule_Free(vec);
785
786 RedisModule_ReplyWithBool(ctx,added);
787 if (added) RedisModule_ReplicateVerbatim(ctx);
788 return REDISMODULE_OK;
789}
790
791/* HNSW callback to filter items according to a predicate function
792 * (our FILTER expression in this case). */
793int vectorSetFilterCallback(void *value, void *privdata) {
794 exprstate *expr = privdata;
795 struct vsetNodeVal *nv = value;
796 if (nv->attrib == NULL) return 0; // No attributes? No match.
797 size_t json_len;
798 char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len);
799 return exprRun(expr,json,json_len);
800}
801
802/* Common path for the execution of the VSIM command both threaded and
803 * not threaded. Note that 'ctx' may be normal context of a thread safe
804 * context obtained from a blocked client. The locking that is specific
805 * to the vset object is handled by the caller, however the function
806 * handles the HNSW locking explicitly. */
807void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
808 float *vec, unsigned long count, float epsilon, unsigned long withscores,
809 unsigned long withattribs, unsigned long ef, exprstate *filter_expr,
810 unsigned long filter_ef, int ground_truth)
811{
812 /* In our scan, we can't just collect 'count' elements as
813 * if count is small we would explore the graph in an insufficient
814 * way to provide enough recall.
815 *
816 * If the user didn't asked for a specific exploration, we use
817 * VSET_DEFAULT_SEARCH_EF as minimum, or we match count if count
818 * is greater than that. Otherwise the minumim will be the specified
819 * EF argument. */
820 if (ef == 0) ef = VSET_DEFAULT_SEARCH_EF;
821 if (count > ef) ef = count;
822
823 int slot = hnsw_acquire_read_slot(vset->hnsw);
824 if (ef > vset->hnsw->node_count) ef = vset->hnsw->node_count;
825
826 /* Perform search */
827 hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef);
828 float *distances = RedisModule_Alloc(sizeof(float)*ef);
829 unsigned int found;
830 if (ground_truth) {
831 found = hnsw_ground_truth_with_filter(vset->hnsw, vec, ef, neighbors,
832 distances, slot, 0,
833 filter_expr ? vectorSetFilterCallback : NULL,
834 filter_expr);
835 } else {
836 if (filter_expr == NULL) {
837 found = hnsw_search(vset->hnsw, vec, ef, neighbors,
838 distances, slot, 0);
839 } else {
840 found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors,
841 distances, slot, 0, vectorSetFilterCallback,
842 filter_expr, filter_ef);
843 }
844 }
845
846 /* Return results */
847 int resp3 = RedisModule_GetContextFlags(ctx) & REDISMODULE_CTX_FLAGS_RESP3;
848 int reply_with_map = resp3 && (withscores || withattribs);
849
850 if (reply_with_map)
851 RedisModule_ReplyWithMap(ctx, REDISMODULE_POSTPONED_LEN);
852 else
853 RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN);
854
855 long long arraylen = 0;
856 for (unsigned int i = 0; i < found && i < count; i++) {
857 if (distances[i]/2 > epsilon) break;
858 struct vsetNodeVal *nv = neighbors[i]->value;
859 RedisModule_ReplyWithString(ctx, nv->item);
860 arraylen++;
861
862 /* If the user asked for multiple properties at the same time using
863 * the RESP3 protocol, we wrap the value of the map into an N-items
864 * array. Two for now, since we have just two properties that can be
865 * requested.
866 *
867 * So in the case of RESP2 we will just have the flat reply:
868 * item, score, attribute. For RESP3 instead item -> [score, attribute]
869 */
870 if (resp3 && withscores && withattribs)
871 RedisModule_ReplyWithArray(ctx,2);
872
873 if (withscores) {
874 /* The similarity score is provided in a 0-1 range. */
875 RedisModule_ReplyWithDouble(ctx, 1.0 - distances[i]/2.0);
876 }
877 if (withattribs) {
878 /* Return the attributes as well, if any. */
879 if (nv->attrib)
880 RedisModule_ReplyWithString(ctx, nv->attrib);
881 else
882 RedisModule_ReplyWithNull(ctx);
883 }
884 }
885 hnsw_release_read_slot(vset->hnsw,slot);
886
887 if (reply_with_map) {
888 RedisModule_ReplySetMapLength(ctx, arraylen);
889 } else {
890 int items_per_ele = 1+withattribs+withscores;
891 RedisModule_ReplySetArrayLength(ctx, arraylen * items_per_ele);
892 }
893
894 RedisModule_Free(vec);
895 RedisModule_Free(neighbors);
896 RedisModule_Free(distances);
897 if (filter_expr) exprFree(filter_expr);
898}
899
900/* VSIM thread handling the blocked client request. */
901void *VSIM_thread(void *arg) {
902 pthread_detach(pthread_self());
903
904 // Extract arguments.
905 void **targ = (void**)arg;
906 RedisModuleBlockedClient *bc = targ[0];
907 struct vsetObject *vset = targ[1];
908 float *vec = targ[2];
909 unsigned long count = (unsigned long)targ[3];
910 float epsilon = *((float*)targ[4]);
911 unsigned long withscores = (unsigned long)targ[5];
912 unsigned long withattribs = (unsigned long)targ[6];
913 unsigned long ef = (unsigned long)targ[7];
914 exprstate *filter_expr = targ[8];
915 unsigned long filter_ef = (unsigned long)targ[9];
916 unsigned long ground_truth = (unsigned long)targ[10];
917 RedisModule_Free(targ[4]);
918 RedisModule_Free(targ);
919
920 /* Lock the object and signal that we are no longer pending
921 * the lock acquisition. */
922 RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0);
923 vset->thread_creation_pending--;
924
925 // Accumulate reply in a thread safe context: no contention.
926 RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc);
927
928 // Run the query.
929 VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth);
930 pthread_rwlock_unlock(&vset->in_use_lock);
931
932 // Cleanup.
933 RedisModule_FreeThreadSafeContext(ctx);
934 RedisModule_BlockedClientMeasureTimeEnd(bc);
935 RedisModule_UnblockClient(bc,NULL);
936 return NULL;
937}
938
939/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */
940int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
941 RedisModule_AutoMemory(ctx);
942
943 /* Basic argument check: need at least key and vector specification
944 * method. */
945 if (argc < 4) return RedisModule_WrongArity(ctx);
946
947 /* Defaults */
948 int withscores = 0;
949 int withattribs = 0;
950 long long count = VSET_DEFAULT_COUNT; /* New default value */
951 long long ef = 0; /* Exploration factor (see HNSW paper) */
952 double epsilon = 2.0; /* Max cosine distance */
953 long long ground_truth = 0; /* Linear scan instead of HNSW search? */
954 int no_thread = 0; /* NOTHREAD option: exec on main thread. */
955
956 /* Things computed later. */
957 long long filter_ef = 0;
958 exprstate *filter_expr = NULL;
959
960 /* Get key and vector type */
961 RedisModuleString *key = argv[1];
962 const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL);
963
964 /* Get vector set */
965 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ);
966 int type = RedisModule_KeyType(keyptr);
967 if (type == REDISMODULE_KEYTYPE_EMPTY)
968 return RedisModule_ReplyWithEmptyArray(ctx);
969
970 if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType)
971 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
972
973 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
974
975 /* Vector parsing stage */
976 float *vec = NULL;
977 size_t dim = 0;
978 int vector_args = 0; /* Number of args consumed by vector specification */
979
980 if (!strcasecmp(vectorType, "ELE")) {
981 /* Get vector from existing element */
982 RedisModuleString *ele = argv[3];
983 hnswNode *node = RedisModule_DictGet(vset->dict, ele, NULL);
984 if (!node) {
985 return RedisModule_ReplyWithError(ctx, "ERR element not found in set");
986 }
987 vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim);
988 hnsw_get_node_vector(vset->hnsw,node,vec);
989 dim = vset->hnsw->vector_dim;
990 vector_args = 2; /* ELE + element name */
991 } else {
992 /* Parse vector. */
993 int consumed_args;
994
995 vec = parseVector(argv, argc, 2, &dim, NULL, &consumed_args);
996 if (!vec) {
997 return RedisModule_ReplyWithError(ctx,
998 "ERR invalid vector specification");
999 }
1000 vector_args = consumed_args;
1001
1002 /* Apply projection if the set uses it, with the exception
1003 * of ELE type, that will already have the right dimension. */
1004 if (vset->proj_matrix && dim != vset->hnsw->vector_dim) {
1005 /* Ensure input dimension matches the projection matrix's expected input dimension */
1006 if (dim != vset->proj_input_size) {
1007 RedisModule_Free(vec);
1008 return RedisModule_ReplyWithErrorFormat(ctx,
1009 "ERR Input dimension mismatch for projection - got %d but projection expects %d",
1010 (int)dim, (int)vset->proj_input_size);
1011 }
1012
1013 float *projected = applyProjection(vec, vset->proj_matrix,
1014 vset->proj_input_size,
1015 vset->hnsw->vector_dim);
1016 RedisModule_Free(vec);
1017 vec = projected;
1018 dim = vset->hnsw->vector_dim;
1019 }
1020
1021 /* Count consumed arguments */
1022 if (!strcasecmp(vectorType, "FP32")) {
1023 vector_args = 2; /* FP32 + vector blob */
1024 } else if (!strcasecmp(vectorType, "VALUES")) {
1025 long long vdim;
1026 if (RedisModule_StringToLongLong(argv[3], &vdim) != REDISMODULE_OK) {
1027 RedisModule_Free(vec);
1028 return RedisModule_ReplyWithError(ctx, "ERR invalid vector dimension");
1029 }
1030 vector_args = 2 + vdim; /* VALUES + dim + values */
1031 } else {
1032 RedisModule_Free(vec);
1033 return RedisModule_ReplyWithError(ctx,
1034 "ERR vector type must be ELE, FP32 or VALUES");
1035 }
1036 }
1037
1038 /* Check vector dimension matches set */
1039 if (dim != vset->hnsw->vector_dim) {
1040 RedisModule_Free(vec);
1041 return RedisModule_ReplyWithErrorFormat(ctx,
1042 "ERR Vector dimension mismatch - got %d but set has %d",
1043 (int)dim, (int)vset->hnsw->vector_dim);
1044 }
1045
1046 /* Parse optional arguments - start after vector specification */
1047 int j = 2 + vector_args;
1048 while (j < argc) {
1049 const char *opt = RedisModule_StringPtrLen(argv[j], NULL);
1050 if (!strcasecmp(opt, "WITHSCORES")) {
1051 withscores = 1;
1052 j++;
1053 } else if (!strcasecmp(opt, "WITHATTRIBS")) {
1054 withattribs = 1;
1055 j++;
1056 } else if (!strcasecmp(opt, "TRUTH")) {
1057 ground_truth = 1;
1058 j++;
1059 } else if (!strcasecmp(opt, "NOTHREAD")) {
1060 no_thread = 1;
1061 j++;
1062 } else if (!strcasecmp(opt, "COUNT") && j+1 < argc) {
1063 if (RedisModule_StringToLongLong(argv[j+1], &count)
1064 != REDISMODULE_OK || count <= 0)
1065 {
1066 RedisModule_Free(vec);
1067 return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT");
1068 }
1069 j += 2;
1070 } else if (!strcasecmp(opt, "EPSILON") && j+1 < argc) {
1071 if (RedisModule_StringToDouble(argv[j+1], &epsilon) !=
1072 REDISMODULE_OK || epsilon <= 0)
1073 {
1074 RedisModule_Free(vec);
1075 return RedisModule_ReplyWithError(ctx, "ERR invalid EPSILON");
1076 }
1077 j += 2;
1078 } else if (!strcasecmp(opt, "EF") && j+1 < argc) {
1079 if (RedisModule_StringToLongLong(argv[j+1], &ef) !=
1080 REDISMODULE_OK || ef <= 0 || ef > 1000000)
1081 {
1082 RedisModule_Free(vec);
1083 return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
1084 }
1085 j += 2;
1086 } else if (!strcasecmp(opt, "FILTER-EF") && j+1 < argc) {
1087 if (RedisModule_StringToLongLong(argv[j+1], &filter_ef) !=
1088 REDISMODULE_OK || filter_ef <= 0)
1089 {
1090 RedisModule_Free(vec);
1091 return RedisModule_ReplyWithError(ctx, "ERR invalid FILTER-EF");
1092 }
1093 j += 2;
1094 } else if (!strcasecmp(opt, "FILTER") && j+1 < argc) {
1095 RedisModuleString *exprarg = argv[j+1];
1096 size_t exprlen;
1097 char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen);
1098 int errpos;
1099 filter_expr = exprCompile(exprstr,&errpos);
1100 if (filter_expr == NULL) {
1101 if ((size_t)errpos >= exprlen) errpos = 0;
1102 RedisModule_Free(vec);
1103 return RedisModule_ReplyWithErrorFormat(ctx,
1104 "ERR syntax error in FILTER expression near: %s",
1105 exprstr+errpos);
1106 }
1107 j += 2;
1108 } else {
1109 RedisModule_Free(vec);
1110 return RedisModule_ReplyWithError(ctx,
1111 "ERR syntax error in VSIM command");
1112 }
1113 }
1114
1115 int threaded_request = 1; // Run on a thread, by default.
1116 if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes.
1117
1118 /* Disable threaded for MULTI/EXEC and Lua, or if explicitly
1119 * requested by the user via the NOTHREAD option. */
1120 if (no_thread || VSGlobalConfig.forceSingleThreadExec ||
1121 (RedisModule_GetContextFlags(ctx) &
1122 (REDISMODULE_CTX_FLAGS_LUA | REDISMODULE_CTX_FLAGS_MULTI)))
1123 {
1124 threaded_request = 0;
1125 }
1126
1127 if (threaded_request) {
1128 /* Note: even if we create one thread per request, the underlying
1129 * HNSW library has a fixed number of slots for the threads, as it's
1130 * defined in HNSW_MAX_THREADS (beware that if you increase it,
1131 * every node will use more memory). This means that while this request
1132 * is threaded, and will NOT block Redis, it may end waiting for a
1133 * free slot if all the HNSW_MAX_THREADS slots are used. */
1134 RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0);
1135 pthread_t tid;
1136 void **targ = RedisModule_Alloc(sizeof(void*)*11);
1137 targ[0] = bc;
1138 targ[1] = vset;
1139 targ[2] = vec;
1140 targ[3] = (void*)count;
1141 targ[4] = RedisModule_Alloc(sizeof(float));
1142 *((float*)targ[4]) = epsilon;
1143 targ[5] = (void*)(unsigned long)withscores;
1144 targ[6] = (void*)(unsigned long)withattribs;
1145 targ[7] = (void*)(unsigned long)ef;
1146 targ[8] = (void*)filter_expr;
1147 targ[9] = (void*)(unsigned long)filter_ef;
1148 targ[10] = (void*)(unsigned long)ground_truth;
1149 RedisModule_BlockedClientMeasureTimeStart(bc);
1150 vset->thread_creation_pending++;
1151 if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) {
1152 vset->thread_creation_pending--;
1153 RedisModule_AbortBlock(bc);
1154 RedisModule_Free(targ[4]);
1155 RedisModule_Free(targ);
1156 VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth);
1157 }
1158 } else {
1159 VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth);
1160 }
1161
1162 return REDISMODULE_OK;
1163}
1164
1165/* VDIM <key>: return the dimension of vectors in the vector set. */
1166int VDIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1167 RedisModule_AutoMemory(ctx);
1168
1169 if (argc != 2) return RedisModule_WrongArity(ctx);
1170
1171 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1172 int type = RedisModule_KeyType(key);
1173
1174 if (type == REDISMODULE_KEYTYPE_EMPTY)
1175 return RedisModule_ReplyWithError(ctx, "ERR key does not exist");
1176
1177 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1178 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1179
1180 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1181 return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim);
1182}
1183
1184/* VCARD <key>: return cardinality (num of elements) of the vector set. */
1185int VCARD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1186 RedisModule_AutoMemory(ctx);
1187
1188 if (argc != 2) return RedisModule_WrongArity(ctx);
1189
1190 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1191 int type = RedisModule_KeyType(key);
1192
1193 if (type == REDISMODULE_KEYTYPE_EMPTY)
1194 return RedisModule_ReplyWithLongLong(ctx, 0);
1195
1196 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1197 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1198
1199 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1200 return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count);
1201}
1202
1203/* VREM key element
1204 * Remove an element from a vector set.
1205 * Returns 1 if the element was found and removed, 0 if not found. */
1206int VREM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1207 RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
1208
1209 if (argc != 3) return RedisModule_WrongArity(ctx);
1210
1211 /* Get key and value */
1212 RedisModuleString *key = argv[1];
1213 RedisModuleString *element = argv[2];
1214
1215 /* Open key */
1216 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key,
1217 REDISMODULE_READ|REDISMODULE_WRITE);
1218 int type = RedisModule_KeyType(keyptr);
1219
1220 /* Handle non-existing key or wrong type */
1221 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1222 return RedisModule_ReplyWithBool(ctx, 0);
1223 }
1224 if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) {
1225 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1226 }
1227
1228 /* Get vector set from key */
1229 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
1230
1231 /* Find the node for this element */
1232 hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL);
1233 if (!node) {
1234 return RedisModule_ReplyWithBool(ctx, 0);
1235 }
1236
1237 /* Remove from dictionary */
1238 RedisModule_DictDel(vset->dict, element, NULL);
1239
1240 /* Remove from HNSW graph using the high-level API that handles
1241 * locking and cleanup. We pass RedisModule_FreeString as the value
1242 * free function since the strings were retained at insertion time. */
1243 struct vsetNodeVal *nv = node->value;
1244 if (nv->attrib != NULL) vset->numattribs--;
1245 RedisModule_Assert(hnsw_delete_node(vset->hnsw, node, vectorSetReleaseNodeValue) == 1);
1246
1247 /* Destroy empty vector set. */
1248 if (RedisModule_DictSize(vset->dict) == 0) {
1249 RedisModule_DeleteKey(keyptr);
1250 }
1251
1252 /* Reply and propagate the command */
1253 RedisModule_ReplyWithBool(ctx, 1);
1254 RedisModule_ReplicateVerbatim(ctx);
1255 return REDISMODULE_OK;
1256}
1257
1258/* VEMB key element
1259 * Returns the embedding vector associated with an element, or NIL if not
1260 * found. The vector is returned in the same format it was added, but the
1261 * return value will have some lack of precision due to quantization and
1262 * normalization of vectors. Also, if items were added using REDUCE, the
1263 * reduced vector is returned instead. */
1264int VEMB_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1265 RedisModule_AutoMemory(ctx);
1266 int raw_output = 0; // RAW option.
1267
1268 if (argc < 3) return RedisModule_WrongArity(ctx);
1269
1270 /* Parse arguments. */
1271 for (int j = 3; j < argc; j++) {
1272 const char *opt = RedisModule_StringPtrLen(argv[j], NULL);
1273 if (!strcasecmp(opt,"raw")) {
1274 raw_output = 1;
1275 } else {
1276 return RedisModule_ReplyWithError(ctx,"ERR invalid option");
1277 }
1278 }
1279
1280 /* Get key and element. */
1281 RedisModuleString *key = argv[1];
1282 RedisModuleString *element = argv[2];
1283
1284 /* Open key. */
1285 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ);
1286 int type = RedisModule_KeyType(keyptr);
1287
1288 /* Handle non-existing key and key of wrong type. */
1289 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1290 return RedisModule_ReplyWithNull(ctx);
1291 } else if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) {
1292 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1293 }
1294
1295 /* Lookup the node about the specified element. */
1296 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
1297 hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL);
1298 if (!node) {
1299 return RedisModule_ReplyWithNull(ctx);
1300 }
1301
1302 if (raw_output) {
1303 int output_qrange = vset->hnsw->quant_type == HNSW_QUANT_Q8;
1304 RedisModule_ReplyWithArray(ctx, 3+output_qrange);
1305 RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset));
1306 RedisModule_ReplyWithStringBuffer(ctx, node->vector, hnsw_quants_bytes(vset->hnsw));
1307 RedisModule_ReplyWithDouble(ctx, node->l2);
1308 if (output_qrange) RedisModule_ReplyWithDouble(ctx, node->quants_range);
1309 } else {
1310 /* Get the vector associated with the node. */
1311 float *vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim);
1312 hnsw_get_node_vector(vset->hnsw, node, vec); // May dequantize/denorm.
1313
1314 /* Return as array of doubles. */
1315 RedisModule_ReplyWithArray(ctx, vset->hnsw->vector_dim);
1316 for (uint32_t i = 0; i < vset->hnsw->vector_dim; i++)
1317 RedisModule_ReplyWithDouble(ctx, vec[i]);
1318 RedisModule_Free(vec);
1319 }
1320 return REDISMODULE_OK;
1321}
1322
1323/* VSETATTR key element json
1324 * Set or remove the JSON attribute associated with an element.
1325 * Setting an empty string removes the attribute.
1326 * The command returns one if the attribute was actually updated or
1327 * zero if there is no key or element. */
1328int VSETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1329 RedisModule_AutoMemory(ctx);
1330
1331 if (argc != 4) return RedisModule_WrongArity(ctx);
1332
1333 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1],
1334 REDISMODULE_READ|REDISMODULE_WRITE);
1335 int type = RedisModule_KeyType(key);
1336
1337 if (type == REDISMODULE_KEYTYPE_EMPTY)
1338 return RedisModule_ReplyWithBool(ctx, 0);
1339
1340 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1341 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1342
1343 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1344 hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL);
1345 if (!node)
1346 return RedisModule_ReplyWithBool(ctx, 0);
1347
1348 struct vsetNodeVal *nv = node->value;
1349 RedisModuleString *new_attr = argv[3];
1350
1351 /* Background VSIM operations use the node attributes, so
1352 * wait for background operations before messing with them. */
1353 vectorSetWaitAllBackgroundClients(vset,0);
1354
1355 /* Set or delete the attribute based on the fact it's an empty
1356 * string or not. */
1357 size_t attrlen;
1358 RedisModule_StringPtrLen(new_attr, &attrlen);
1359 if (attrlen == 0) {
1360 // If we had an attribute before, decrease the count and free it.
1361 if (nv->attrib) {
1362 vset->numattribs--;
1363 RedisModule_FreeString(NULL, nv->attrib);
1364 nv->attrib = NULL;
1365 }
1366 } else {
1367 // If we didn't have an attribute before, increase the count.
1368 // Otherwise free the old one.
1369 if (nv->attrib) {
1370 RedisModule_FreeString(NULL, nv->attrib);
1371 } else {
1372 vset->numattribs++;
1373 }
1374 // Set new attribute.
1375 RedisModule_RetainString(NULL, new_attr);
1376 nv->attrib = new_attr;
1377 }
1378
1379 RedisModule_ReplyWithBool(ctx, 1);
1380 RedisModule_ReplicateVerbatim(ctx);
1381 return REDISMODULE_OK;
1382}
1383
1384/* VGETATTR key element
1385 * Get the JSON attribute associated with an element.
1386 * Returns NIL if the element has no attribute or doesn't exist. */
1387int VGETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1388 RedisModule_AutoMemory(ctx);
1389
1390 if (argc != 3) return RedisModule_WrongArity(ctx);
1391
1392 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1393 int type = RedisModule_KeyType(key);
1394
1395 if (type == REDISMODULE_KEYTYPE_EMPTY)
1396 return RedisModule_ReplyWithNull(ctx);
1397
1398 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1399 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1400
1401 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1402 hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL);
1403 if (!node)
1404 return RedisModule_ReplyWithNull(ctx);
1405
1406 struct vsetNodeVal *nv = node->value;
1407 if (!nv->attrib)
1408 return RedisModule_ReplyWithNull(ctx);
1409
1410 return RedisModule_ReplyWithString(ctx, nv->attrib);
1411}
1412
1413/* ============================== Reflection ================================ */
1414
1415/* VLINKS key element [WITHSCORES]
1416 * Returns the neighbors of an element at each layer in the HNSW graph.
1417 * Reply is an array of arrays, where each nested array represents one level
1418 * of neighbors, from highest level to level 0. If WITHSCORES is specified,
1419 * each neighbor is followed by its distance from the element. */
1420int VLINKS_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1421 RedisModule_AutoMemory(ctx);
1422
1423 if (argc < 3 || argc > 4) return RedisModule_WrongArity(ctx);
1424
1425 RedisModuleString *key = argv[1];
1426 RedisModuleString *element = argv[2];
1427
1428 /* Parse WITHSCORES option. */
1429 int withscores = 0;
1430 if (argc == 4) {
1431 const char *opt = RedisModule_StringPtrLen(argv[3], NULL);
1432 if (strcasecmp(opt, "WITHSCORES") != 0) {
1433 return RedisModule_WrongArity(ctx);
1434 }
1435 withscores = 1;
1436 }
1437
1438 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ);
1439 int type = RedisModule_KeyType(keyptr);
1440
1441 /* Handle non-existing key or wrong type. */
1442 if (type == REDISMODULE_KEYTYPE_EMPTY)
1443 return RedisModule_ReplyWithNull(ctx);
1444
1445 if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType)
1446 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1447
1448 /* Find the node for this element. */
1449 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
1450 hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL);
1451 if (!node)
1452 return RedisModule_ReplyWithNull(ctx);
1453
1454 /* Reply with array of arrays, one per level. */
1455 RedisModule_ReplyWithArray(ctx, node->level + 1);
1456
1457 /* For each level, from highest to lowest: */
1458 for (int i = node->level; i >= 0; i--) {
1459 /* Reply with array of neighbors at this level. */
1460 if (withscores)
1461 RedisModule_ReplyWithMap(ctx,node->layers[i].num_links);
1462 else
1463 RedisModule_ReplyWithArray(ctx,node->layers[i].num_links);
1464
1465 /* Add each neighbor's element value to the array. */
1466 for (uint32_t j = 0; j < node->layers[i].num_links; j++) {
1467 struct vsetNodeVal *nv = node->layers[i].links[j]->value;
1468 RedisModule_ReplyWithString(ctx, nv->item);
1469 if (withscores) {
1470 float distance = hnsw_distance(vset->hnsw, node, node->layers[i].links[j]);
1471 /* Convert distance to similarity score to match
1472 * VSIM behavior.*/
1473 float similarity = 1.0 - distance/2.0;
1474 RedisModule_ReplyWithDouble(ctx, similarity);
1475 }
1476 }
1477 }
1478 return REDISMODULE_OK;
1479}
1480
1481/* VINFO key
1482 * Returns information about a vector set, both visible and hidden
1483 * features of the HNSW data structure. */
1484int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1485 RedisModule_AutoMemory(ctx);
1486
1487 if (argc != 2) return RedisModule_WrongArity(ctx);
1488
1489 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1490 int type = RedisModule_KeyType(key);
1491
1492 if (type == REDISMODULE_KEYTYPE_EMPTY)
1493 return RedisModule_ReplyWithNullArray(ctx);
1494
1495 if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
1496 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1497
1498 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1499
1500 /* Reply with hash */
1501 RedisModule_ReplyWithMap(ctx, 9);
1502
1503 /* Quantization type */
1504 RedisModule_ReplyWithSimpleString(ctx, "quant-type");
1505 RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset));
1506
1507 /* HNSW M value */
1508 RedisModule_ReplyWithSimpleString(ctx, "hnsw-m");
1509 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->M);
1510
1511 /* Vector dimensionality. */
1512 RedisModule_ReplyWithSimpleString(ctx, "vector-dim");
1513 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim);
1514
1515 /* Original input dimension before projection.
1516 * This is zero for vector sets without a random projection matrix. */
1517 RedisModule_ReplyWithSimpleString(ctx, "projection-input-dim");
1518 RedisModule_ReplyWithLongLong(ctx, vset->proj_input_size);
1519
1520 /* Number of elements. */
1521 RedisModule_ReplyWithSimpleString(ctx, "size");
1522 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count);
1523
1524 /* Max level of HNSW. */
1525 RedisModule_ReplyWithSimpleString(ctx, "max-level");
1526 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->max_level);
1527
1528 /* Number of nodes with attributes. */
1529 RedisModule_ReplyWithSimpleString(ctx, "attributes-count");
1530 RedisModule_ReplyWithLongLong(ctx, vset->numattribs);
1531
1532 /* Vector set ID. */
1533 RedisModule_ReplyWithSimpleString(ctx, "vset-uid");
1534 RedisModule_ReplyWithLongLong(ctx, vset->id);
1535
1536 /* HNSW max node ID. */
1537 RedisModule_ReplyWithSimpleString(ctx, "hnsw-max-node-uid");
1538 RedisModule_ReplyWithLongLong(ctx, vset->hnsw->last_id);
1539
1540 return REDISMODULE_OK;
1541}
1542
1543/* VRANDMEMBER key [count]
1544 * Return random members from a vector set.
1545 *
1546 * Without count: returns a single random member.
1547 * With positive count: N unique random members (no duplicates).
1548 * With negative count: N random members (with possible duplicates).
1549 *
1550 * If the key doesn't exist, returns NULL if count is not given, or
1551 * an empty array if a count was given. */
1552int VRANDMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1553 RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
1554
1555 /* Check arguments. */
1556 if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx);
1557
1558 /* Parse optional count argument. */
1559 long long count = 1; /* Default is to return a single element. */
1560 int with_count = (argc == 3);
1561
1562 if (with_count) {
1563 if (RedisModule_StringToLongLong(argv[2], &count) != REDISMODULE_OK) {
1564 return RedisModule_ReplyWithError(ctx,
1565 "ERR COUNT value is not an integer");
1566 }
1567 /* Count = 0 is a special case, return empty array */
1568 if (count == 0) {
1569 return RedisModule_ReplyWithEmptyArray(ctx);
1570 }
1571 }
1572
1573 /* Open key. */
1574 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1575 int type = RedisModule_KeyType(key);
1576
1577 /* Handle non-existing key. */
1578 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1579 if (!with_count) {
1580 return RedisModule_ReplyWithNull(ctx);
1581 } else {
1582 return RedisModule_ReplyWithEmptyArray(ctx);
1583 }
1584 }
1585
1586 /* Check key type. */
1587 if (RedisModule_ModuleTypeGetType(key) != VectorSetType) {
1588 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1589 }
1590
1591 /* Get vector set from key. */
1592 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1593 uint64_t set_size = vset->hnsw->node_count;
1594
1595 /* No elements in the set? */
1596 if (set_size == 0) {
1597 if (!with_count) {
1598 return RedisModule_ReplyWithNull(ctx);
1599 } else {
1600 return RedisModule_ReplyWithEmptyArray(ctx);
1601 }
1602 }
1603
1604 /* Case 1: No count specified: return a single element. */
1605 if (!with_count) {
1606 hnswNode *random_node = hnsw_random_node(vset->hnsw, 0);
1607 if (random_node) {
1608 struct vsetNodeVal *nv = random_node->value;
1609 return RedisModule_ReplyWithString(ctx, nv->item);
1610 } else {
1611 return RedisModule_ReplyWithNull(ctx);
1612 }
1613 }
1614
1615 /* Case 2: COUNT option given, return an array of elements. */
1616 int allow_duplicates = (count < 0);
1617 long long abs_count = (count < 0) ? -count : count;
1618
1619 /* Cap the count to the set size if we are not allowing duplicates. */
1620 if (!allow_duplicates && abs_count > (long long)set_size)
1621 abs_count = set_size;
1622
1623 /* Prepare reply. */
1624 RedisModule_ReplyWithArray(ctx, abs_count);
1625
1626 if (allow_duplicates) {
1627 /* Simple case: With duplicates, just pick random nodes
1628 * abs_count times. */
1629 for (long long i = 0; i < abs_count; i++) {
1630 hnswNode *random_node = hnsw_random_node(vset->hnsw,0);
1631 struct vsetNodeVal *nv = random_node->value;
1632 RedisModule_ReplyWithString(ctx, nv->item);
1633 }
1634 } else {
1635 /* Case where count is positive: we need unique elements.
1636 * But, if the user asked for many elements, selecting so
1637 * many (> 20%) random nodes may be too expansive: we just start
1638 * from a random element and follow the next link.
1639 *
1640 * Otherwisem for the <= 20% case, a dictionary is used to
1641 * reject duplicates. */
1642 int use_dict = (abs_count <= set_size * 0.2);
1643
1644 if (use_dict) {
1645 RedisModuleDict *returned = RedisModule_CreateDict(ctx);
1646
1647 long long returned_count = 0;
1648 while (returned_count < abs_count) {
1649 hnswNode *random_node = hnsw_random_node(vset->hnsw, 0);
1650 struct vsetNodeVal *nv = random_node->value;
1651
1652 /* Check if we've already returned this element. */
1653 if (RedisModule_DictGet(returned, nv->item, NULL) == NULL) {
1654 /* Mark as returned and add to results. */
1655 RedisModule_DictSet(returned, nv->item, (void*)1);
1656 RedisModule_ReplyWithString(ctx, nv->item);
1657 returned_count++;
1658 }
1659 }
1660 RedisModule_FreeDict(ctx, returned);
1661 } else {
1662 /* For large samples, get a random starting node and walk
1663 * the list.
1664 *
1665 * IMPORTANT: doing so does not really generate random
1666 * elements: it's just a linear scan, but we have no choices.
1667 * If we generate too many random elements, more and more would
1668 * fail the check of being novel (not yet collected in the set
1669 * to return) if the % of elements to emit is too large, we would
1670 * spend too much CPU. */
1671 hnswNode *start_node = hnsw_random_node(vset->hnsw, 0);
1672 hnswNode *current = start_node;
1673
1674 long long returned_count = 0;
1675 while (returned_count < abs_count) {
1676 if (current == NULL) {
1677 /* Restart from head if we hit the end. */
1678 current = vset->hnsw->head;
1679 }
1680 struct vsetNodeVal *nv = current->value;
1681 RedisModule_ReplyWithString(ctx, nv->item);
1682 returned_count++;
1683 current = current->next;
1684 }
1685 }
1686 }
1687 return REDISMODULE_OK;
1688}
1689
1690/* VISMEMBER key element
1691 * Check if an element exists in a vector set.
1692 * Returns 1 if the element exists, 0 if not. */
1693int VISMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1694 RedisModule_AutoMemory(ctx);
1695 if (argc != 3) return RedisModule_WrongArity(ctx);
1696
1697 RedisModuleString *key = argv[1];
1698 RedisModuleString *element = argv[2];
1699
1700 /* Open key. */
1701 RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ);
1702 int type = RedisModule_KeyType(keyptr);
1703
1704 /* Handle non-existing key or wrong type. */
1705 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1706 /* An element of a non existing key does not exist, like
1707 * SISMEMBER & similar. */
1708 return RedisModule_ReplyWithBool(ctx, 0);
1709 }
1710 if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) {
1711 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1712 }
1713
1714 /* Get the object and test membership via the dictionary in constant
1715 * time (assuming a member of average size). */
1716 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr);
1717 hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL);
1718 return RedisModule_ReplyWithBool(ctx, node != NULL);
1719}
1720
1721/* Structure to represent a range boundary. */
1722struct vsetRangeOp {
1723 int incl; /* 1 if inclusive ([), 0 if exclusive ((). */
1724 int min; /* 1 if this is "-" (minimum). */
1725 int max; /* 1 if this is "+" (maximum). */
1726 unsigned char *ele; /* The actual element, NULL if min/max. */
1727 size_t ele_len; /* Length of the element. */
1728};
1729
1730/* Parse a range specification like "[foo" or "(bar" or "-" or "+".
1731 * Returns 1 on success, 0 on error. */
1732int vsetParseRangeOp(RedisModuleString *arg, struct vsetRangeOp *op) {
1733 size_t len;
1734 const char *str = RedisModule_StringPtrLen(arg, &len);
1735
1736 if (len == 0) return 0;
1737
1738 /* Initialize the structure. */
1739 op->incl = 0;
1740 op->min = 0;
1741 op->max = 0;
1742 op->ele = NULL;
1743 op->ele_len = 0;
1744
1745 /* Check for special cases "-" and "+". */
1746 if (len == 1 && str[0] == '-') {
1747 op->min = 1;
1748 return 1;
1749 }
1750 if (len == 1 && str[0] == '+') {
1751 op->max = 1;
1752 return 1;
1753 }
1754
1755 /* Otherwise, must start with ( or [. */
1756 if (str[0] == '[') {
1757 op->incl = 1;
1758 } else if (str[0] == '(') {
1759 op->incl = 0;
1760 } else {
1761 return 0; /* Invalid format. */
1762 }
1763
1764 /* Extract the string part after the bracket. */
1765 if (len > 1) {
1766 op->ele = (unsigned char *)(str + 1);
1767 op->ele_len = len - 1;
1768 } else {
1769 return 0; /* Just a bracket with no string. */
1770 }
1771
1772 return 1;
1773}
1774
1775/* Check if the current element is within the range defined by the end operator.
1776 * Returns 1 if the element is within range, 0 if it has passed the end. */
1777int vsetIsElementInRange(const void *ele, size_t ele_len, struct vsetRangeOp *end_op) {
1778 /* If end is "+", element is always in range. */
1779 if (end_op->max) return 1;
1780
1781 /* Compare current element with end boundary. */
1782 size_t minlen = ele_len < end_op->ele_len ? ele_len : end_op->ele_len;
1783 int cmp = memcmp(ele, end_op->ele, minlen);
1784
1785 if (cmp == 0) {
1786 /* If equal up to minlen, shorter string is smaller. */
1787 if (ele_len < end_op->ele_len) {
1788 cmp = -1;
1789 } else if (ele_len > end_op->ele_len) {
1790 cmp = 1;
1791 }
1792 }
1793
1794 /* Check based on inclusive/exclusive. */
1795 if (end_op->incl) {
1796 return cmp <= 0; /* Inclusive: element <= end. */
1797 } else {
1798 return cmp < 0; /* Exclusive: element < end. */
1799 }
1800}
1801
1802/* VRANGE key start end [count]
1803 * Returns elements in the lexicographical range [start, end]
1804 *
1805 * Elements must be specified in one of the following forms:
1806 *
1807 * [myelement
1808 * (myelement
1809 * +
1810 * -
1811 *
1812 * Elements starting with [ are inclusive, so "myelement" would be
1813 * returned if present in the set. Elements starting with ( are exclusive
1814 * ranges instead. The special - and + elements mean the minimum and maximum
1815 * possible element (inclusive), so "VRANGE key - +" will return everything
1816 * (depending on COUNT of course). The special - element can be used only
1817 * as starting element, the special + element only as ending element. */
1818int VRANGE_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
1819 RedisModule_AutoMemory(ctx);
1820
1821 /* Check arguments. */
1822 if (argc < 4 || argc > 5) return RedisModule_WrongArity(ctx);
1823
1824 /* Parse COUNT if provided. */
1825 long long count = -1; /* Default: return all elements. */
1826 if (argc == 5) {
1827 if (RedisModule_StringToLongLong(argv[4], &count) != REDISMODULE_OK) {
1828 return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT value");
1829 }
1830 }
1831
1832 /* Parse range operators. */
1833 struct vsetRangeOp start_op, end_op;
1834 if (!vsetParseRangeOp(argv[2], &start_op)) {
1835 return RedisModule_ReplyWithError(ctx, "ERR invalid start range format");
1836 }
1837 if (!vsetParseRangeOp(argv[3], &end_op)) {
1838 return RedisModule_ReplyWithError(ctx, "ERR invalid end range format");
1839 }
1840
1841 /* Validate: "-" can only be first arg, "+" can only be second. */
1842 if (start_op.max || end_op.min) {
1843 return RedisModule_ReplyWithError(ctx,
1844 "ERR '-' can only be used as first argument, '+' only as second");
1845 }
1846
1847 /* Open the key. */
1848 RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
1849 int type = RedisModule_KeyType(key);
1850
1851 if (type == REDISMODULE_KEYTYPE_EMPTY) {
1852 return RedisModule_ReplyWithEmptyArray(ctx);
1853 }
1854
1855 if (RedisModule_ModuleTypeGetType(key) != VectorSetType) {
1856 return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
1857 }
1858
1859 struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
1860
1861 /* Start the iterator. */
1862 RedisModuleDictIter *iter;
1863 if (start_op.min) {
1864 /* Start from the beginning. */
1865 iter = RedisModule_DictIteratorStartC(vset->dict, "^", NULL, 0);
1866 } else {
1867 /* Start from the specified element. */
1868 const char *op = start_op.incl ? ">=" : ">";
1869 iter = RedisModule_DictIteratorStartC(vset->dict, op, start_op.ele, start_op.ele_len);
1870 }
1871
1872 /* Collect results. */
1873 RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN);
1874 long long returned = 0;
1875
1876 void *key_data;
1877 size_t key_len;
1878 while ((key_data = RedisModule_DictNextC(iter, &key_len, NULL)) != NULL) {
1879 /* Check if we've collected enough elements. */
1880 if (count >= 0 && returned >= count) break;
1881
1882 /* Check if we've passed the end range. */
1883 if (!vsetIsElementInRange(key_data, key_len, &end_op)) break;
1884
1885 /* Add this element to the result. */
1886 RedisModule_ReplyWithStringBuffer(ctx, key_data, key_len);
1887 returned++;
1888 }
1889
1890 RedisModule_ReplySetArrayLength(ctx, returned);
1891
1892 /* Cleanup. */
1893 RedisModule_DictIteratorStop(iter);
1894
1895 return REDISMODULE_OK;
1896}
1897
1898/* ============================== vset type methods ========================= */
1899
1900#define SAVE_FLAG_HAS_PROJMATRIX (1<<0)
1901#define SAVE_FLAG_HAS_ATTRIBS (1<<1)
1902
1903/* Save object to RDB */
1904void VectorSetRdbSave(RedisModuleIO *rdb, void *value) {
1905 struct vsetObject *vset = value;
1906 RedisModule_SaveUnsigned(rdb, vset->hnsw->vector_dim);
1907 RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count);
1908
1909 uint32_t hnsw_config = (vset->hnsw->quant_type & 0xff) |
1910 ((vset->hnsw->M & 0xffff) << 8);
1911 RedisModule_SaveUnsigned(rdb, hnsw_config);
1912
1913 uint32_t save_flags = 0;
1914 if (vset->proj_matrix) save_flags |= SAVE_FLAG_HAS_PROJMATRIX;
1915 if (vset->numattribs != 0) save_flags |= SAVE_FLAG_HAS_ATTRIBS;
1916 RedisModule_SaveUnsigned(rdb, save_flags);
1917
1918 /* Save projection matrix if present */
1919 if (vset->proj_matrix) {
1920 uint32_t input_dim = vset->proj_input_size;
1921 uint32_t output_dim = vset->hnsw->vector_dim;
1922 RedisModule_SaveUnsigned(rdb, input_dim);
1923 // Output dim is the same as the first value saved
1924 // above, so we don't save it.
1925
1926 // Save projection matrix as binary blob
1927 size_t matrix_size = sizeof(float) * input_dim * output_dim;
1928 RedisModule_SaveStringBuffer(rdb, (const char *)vset->proj_matrix, matrix_size);
1929 }
1930
1931 hnswNode *node = vset->hnsw->head;
1932 while(node) {
1933 struct vsetNodeVal *nv = node->value;
1934 RedisModule_SaveString(rdb, nv->item);
1935 if (vset->numattribs) {
1936 if (nv->attrib)
1937 RedisModule_SaveString(rdb, nv->attrib);
1938 else
1939 RedisModule_SaveStringBuffer(rdb, "", 0);
1940 }
1941 hnswSerNode *sn = hnsw_serialize_node(vset->hnsw,node);
1942 RedisModule_SaveStringBuffer(rdb, (const char *)sn->vector, sn->vector_size);
1943 RedisModule_SaveUnsigned(rdb, sn->params_count);
1944 for (uint32_t j = 0; j < sn->params_count; j++)
1945 RedisModule_SaveUnsigned(rdb, sn->params[j]);
1946 hnsw_free_serialized_node(sn);
1947 node = node->next;
1948 }
1949}
1950
1951/* Load object from RDB. Recover from recoverable errors (read errors)
1952 * by performing cleanup. */
1953void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) {
1954 if (encver != 0) return NULL; // Invalid version
1955
1956 uint32_t dim = RedisModule_LoadUnsigned(rdb);
1957 uint64_t elements = RedisModule_LoadUnsigned(rdb);
1958 uint32_t hnsw_config = RedisModule_LoadUnsigned(rdb);
1959 if (RedisModule_IsIOError(rdb)) return NULL;
1960 uint32_t quant_type = hnsw_config & 0xff;
1961 uint32_t hnsw_m = (hnsw_config >> 8) & 0xffff;
1962
1963 /* Check that the quantization type is correct. Otherwise
1964 * return ASAP signaling the error. */
1965 if (quant_type != HNSW_QUANT_NONE &&
1966 quant_type != HNSW_QUANT_Q8 &&
1967 quant_type != HNSW_QUANT_BIN) return NULL;
1968
1969 if (hnsw_m == 0) hnsw_m = 16; // Default, useful for RDB files predating
1970 // this configuration parameter: it was fixed
1971 // to 16.
1972 struct vsetObject *vset = createVectorSetObject(dim,quant_type,hnsw_m);
1973 RedisModule_Assert(vset != NULL);
1974
1975 /* Load projection matrix if present */
1976 uint32_t save_flags = RedisModule_LoadUnsigned(rdb);
1977 if (RedisModule_IsIOError(rdb)) goto ioerr;
1978 int has_projection = save_flags & SAVE_FLAG_HAS_PROJMATRIX;
1979 int has_attribs = save_flags & SAVE_FLAG_HAS_ATTRIBS;
1980 if (has_projection) {
1981 uint32_t input_dim = RedisModule_LoadUnsigned(rdb);
1982 if (RedisModule_IsIOError(rdb)) goto ioerr;
1983 uint32_t output_dim = dim;
1984 size_t matrix_size = sizeof(float) * input_dim * output_dim;
1985
1986 vset->proj_matrix = RedisModule_Alloc(matrix_size);
1987 vset->proj_input_size = input_dim;
1988
1989 // Load projection matrix as a binary blob
1990 char *matrix_blob = RedisModule_LoadStringBuffer(rdb, NULL);
1991 if (matrix_blob == NULL) goto ioerr;
1992 memcpy(vset->proj_matrix, matrix_blob, matrix_size);
1993 RedisModule_Free(matrix_blob);
1994 }
1995
1996 while(elements--) {
1997 // Load associated string element.
1998 RedisModuleString *ele = RedisModule_LoadString(rdb);
1999 if (RedisModule_IsIOError(rdb)) goto ioerr;
2000 RedisModuleString *attrib = NULL;
2001 if (has_attribs) {
2002 attrib = RedisModule_LoadString(rdb);
2003 if (RedisModule_IsIOError(rdb)) {
2004 RedisModule_FreeString(NULL,ele);
2005 goto ioerr;
2006 }
2007 size_t attrlen;
2008 RedisModule_StringPtrLen(attrib,&attrlen);
2009 if (attrlen == 0) {
2010 RedisModule_FreeString(NULL,attrib);
2011 attrib = NULL;
2012 }
2013 }
2014 size_t vector_len;
2015 void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len);
2016 if (RedisModule_IsIOError(rdb)) {
2017 RedisModule_FreeString(NULL,ele);
2018 if (attrib) RedisModule_FreeString(NULL,attrib);
2019 goto ioerr;
2020 }
2021 uint32_t vector_bytes = hnsw_quants_bytes(vset->hnsw);
2022 if (vector_len != vector_bytes) {
2023 RedisModule_LogIOError(rdb,"warning",
2024 "Mismatching vector dimension");
2025 RedisModule_FreeString(NULL,ele);
2026 if (attrib) RedisModule_FreeString(NULL,attrib);
2027 RedisModule_Free(vector);
2028 goto ioerr;
2029 }
2030
2031 // Load node parameters back.
2032 uint32_t params_count = RedisModule_LoadUnsigned(rdb);
2033 if (RedisModule_IsIOError(rdb)) {
2034 RedisModule_FreeString(NULL,ele);
2035 if (attrib) RedisModule_FreeString(NULL,attrib);
2036 RedisModule_Free(vector);
2037 goto ioerr;
2038 }
2039
2040 uint64_t *params = RedisModule_Alloc(params_count*sizeof(uint64_t));
2041 for (uint32_t j = 0; j < params_count; j++) {
2042 // Ignore loading errors here: handled at the end of the loop.
2043 params[j] = RedisModule_LoadUnsigned(rdb);
2044 }
2045 if (RedisModule_IsIOError(rdb)) {
2046 RedisModule_FreeString(NULL,ele);
2047 if (attrib) RedisModule_FreeString(NULL,attrib);
2048 RedisModule_Free(vector);
2049 RedisModule_Free(params);
2050 goto ioerr;
2051 }
2052
2053 struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
2054 nv->item = ele;
2055 nv->attrib = attrib;
2056 hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, nv);
2057 if (node == NULL) {
2058 RedisModule_LogIOError(rdb,"warning",
2059 "Vector set node index loading error");
2060 vectorSetReleaseNodeValue(nv);
2061 RedisModule_Free(vector);
2062 RedisModule_Free(params);
2063 goto ioerr;
2064 }
2065 if (nv->attrib) vset->numattribs++;
2066 RedisModule_DictSet(vset->dict,ele,node);
2067 RedisModule_Free(vector);
2068 RedisModule_Free(params);
2069 }
2070
2071 uint64_t salt[2];
2072 RedisModule_GetRandomBytes((unsigned char*)salt,sizeof(salt));
2073 if (!hnsw_deserialize_index(vset->hnsw, salt[0], salt[1])) goto ioerr;
2074
2075 return vset;
2076
2077ioerr:
2078 /* We want to recover from I/O errors and free the partially allocated
2079 * data structure to support diskless replication. */
2080 vectorSetReleaseObject(vset);
2081 return NULL;
2082}
2083
2084/* Calculate memory usage */
2085size_t VectorSetMemUsage(const void *value) {
2086 const struct vsetObject *vset = value;
2087 size_t size = sizeof(*vset);
2088
2089 /* Account for HNSW index base structure */
2090 size += sizeof(HNSW);
2091
2092 /* Account for projection matrix if present */
2093 if (vset->proj_matrix) {
2094 /* For the matrix size, we need the input dimension. We can get it
2095 * from the first node if the set is not empty. */
2096 uint32_t input_dim = vset->proj_input_size;
2097 uint32_t output_dim = vset->hnsw->vector_dim;
2098 size += sizeof(float) * input_dim * output_dim;
2099 }
2100
2101 /* Account for each node's memory usage. */
2102 hnswNode *node = vset->hnsw->head;
2103 if (node == NULL) return size;
2104
2105 /* Base node structure. */
2106 size += sizeof(*node) * vset->hnsw->node_count;
2107
2108 /* Vector storage. */
2109 uint64_t vec_storage = hnsw_quants_bytes(vset->hnsw);
2110 size += vec_storage * vset->hnsw->node_count;
2111
2112 /* Layers array. We use 1.33 as average nodes layers count. */
2113 uint64_t layers_storage = sizeof(hnswNodeLayer) * vset->hnsw->node_count;
2114 layers_storage = layers_storage * 4 / 3; // 1.33 times.
2115 size += layers_storage;
2116
2117 /* All the nodes have layer 0 links. */
2118 uint64_t level0_links = node->layers[0].max_links;
2119 uint64_t other_levels_links = level0_links/2;
2120 size += sizeof(hnswNode*) * level0_links * vset->hnsw->node_count;
2121
2122 /* Add the 0.33 remaining part, but upper layers have less links. */
2123 size += (sizeof(hnswNode*) * other_levels_links * vset->hnsw->node_count)/3;
2124
2125 /* Associated string value and attributres.
2126 * Use Redis Module API to get string size, and guess that all the
2127 * elements have similar size as the first few. */
2128 size_t items_scanned = 0, items_size = 0;
2129 size_t attribs_scanned = 0, attribs_size = 0;
2130 int scan_effort = 20;
2131 while(scan_effort > 0 && node) {
2132 struct vsetNodeVal *nv = node->value;
2133 items_size += RedisModule_MallocSizeString(nv->item);
2134 items_scanned++;
2135 if (nv->attrib) {
2136 attribs_size += RedisModule_MallocSizeString(nv->attrib);
2137 attribs_scanned++;
2138 }
2139 scan_effort--;
2140 node = node->next;
2141 }
2142
2143 /* Add the memory usage due to items. */
2144 if (items_scanned)
2145 size += items_size / items_scanned * vset->hnsw->node_count;
2146
2147 /* Add memory usage due to attributres. */
2148 if (attribs_scanned == 0) {
2149 /* We were not lucky enough to find a single attribute in the
2150 * first few items? Let's use a fixed arbitrary value. */
2151 attribs_scanned = 1;
2152 attribs_size = 64;
2153 }
2154 size += attribs_size / attribs_scanned * vset->numattribs;
2155
2156 /* Account for dictionary overhead - this is an approximation. */
2157 size += RedisModule_DictSize(vset->dict) * (sizeof(void*) * 2);
2158
2159 return size;
2160}
2161
2162/* Free the entire data structure */
2163void VectorSetFree(void *value) {
2164 struct vsetObject *vset = value;
2165
2166 vectorSetWaitAllBackgroundClients(vset,1);
2167 vectorSetReleaseObject(value);
2168}
2169
2170/* Add object digest to the digest context */
2171void VectorSetDigest(RedisModuleDigest *md, void *value) {
2172 struct vsetObject *vset = value;
2173
2174 /* Add consistent order-independent hash of all vectors */
2175 hnswNode *node = vset->hnsw->head;
2176
2177 /* Hash the vector dimension and number of nodes. */
2178 RedisModule_DigestAddLongLong(md, vset->hnsw->node_count);
2179 RedisModule_DigestAddLongLong(md, vset->hnsw->vector_dim);
2180 RedisModule_DigestEndSequence(md);
2181
2182 while(node) {
2183 struct vsetNodeVal *nv = node->value;
2184 /* Hash each vector component */
2185 RedisModule_DigestAddStringBuffer(md, node->vector, hnsw_quants_bytes(vset->hnsw));
2186 /* Hash the associated value */
2187 size_t len;
2188 const char *str = RedisModule_StringPtrLen(nv->item, &len);
2189 RedisModule_DigestAddStringBuffer(md, (char*)str, len);
2190 if (nv->attrib) {
2191 str = RedisModule_StringPtrLen(nv->attrib, &len);
2192 RedisModule_DigestAddStringBuffer(md, (char*)str, len);
2193 }
2194 node = node->next;
2195 RedisModule_DigestEndSequence(md);
2196 }
2197}
2198
2199// int VectorSets_InitModuleConfig(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
2200int VectorSets_InitModuleConfig(RedisModuleCtx *ctx) {
2201 if (RegisterModuleConfig(ctx) == REDISMODULE_ERR) {
2202 RedisModule_Log(ctx, "warning", "Error registering module configuration");
2203 return REDISMODULE_ERR;
2204 }
2205 // Load default values
2206 if (RedisModule_LoadDefaultConfigs(ctx) == REDISMODULE_ERR) {
2207 RedisModule_Log(ctx, "warning", "Error loading default module configuration");
2208 return REDISMODULE_ERR;
2209 } else {
2210 RedisModule_Log(ctx, "verbose", "Successfully loaded default module configuration");
2211 }
2212 if (RedisModule_LoadConfigs(ctx) == REDISMODULE_ERR) {
2213 RedisModule_Log(ctx, "warning", "Error loading user module configuration");
2214 return REDISMODULE_ERR;
2215 } else {
2216 RedisModule_Log(ctx, "verbose", "Successfully loaded user module configuration");
2217 }
2218 return REDISMODULE_OK;
2219}
2220
2221/* This function must be present on each Redis module. It is used in order to
2222 * register the commands into the Redis server. */
2223int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
2224 REDISMODULE_NOT_USED(argv);
2225 REDISMODULE_NOT_USED(argc);
2226
2227 if (RedisModule_Init(ctx,"vectorset",1,REDISMODULE_APIVER_1)
2228 == REDISMODULE_ERR) return REDISMODULE_ERR;
2229
2230 if (VectorSets_InitModuleConfig(ctx) == REDISMODULE_ERR) {
2231 return REDISMODULE_ERR;
2232 }
2233
2234 RedisModule_SetModuleOptions(ctx, REDISMODULE_OPTIONS_HANDLE_IO_ERRORS|REDISMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD);
2235
2236 RedisModuleTypeMethods tm = {
2237 .version = REDISMODULE_TYPE_METHOD_VERSION,
2238 .rdb_load = VectorSetRdbLoad,
2239 .rdb_save = VectorSetRdbSave,
2240 .aof_rewrite = NULL,
2241 .mem_usage = VectorSetMemUsage,
2242 .free = VectorSetFree,
2243 .digest = VectorSetDigest
2244 };
2245
2246 VectorSetType = RedisModule_CreateDataType(ctx,"vectorset",0,&tm);
2247 if (VectorSetType == NULL) return REDISMODULE_ERR;
2248
2249 // Register command VADD
2250 if (RedisModule_CreateCommand(ctx,"VADD",
2251 VADD_RedisCommand,"write deny-oom",1,1,1) == REDISMODULE_ERR)
2252 return REDISMODULE_ERR;
2253
2254 RedisModuleCommand *vadd_cmd = RedisModule_GetCommand(ctx, "VADD");
2255 if (vadd_cmd == NULL) return REDISMODULE_ERR;
2256
2257 RedisModuleCommandArg vadd_args[] = {
2258 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2259 { .name = "reduce", .type = REDISMODULE_ARG_TYPE_BLOCK, .token = "REDUCE", .flags = REDISMODULE_CMD_ARG_OPTIONAL,
2260 .subargs = (RedisModuleCommandArg[]) {
2261 { .name = "dim", .type = REDISMODULE_ARG_TYPE_INTEGER },
2262 { .name = NULL }
2263 }
2264 },
2265 { .name = "format", .type = REDISMODULE_ARG_TYPE_ONEOF, .subargs = (RedisModuleCommandArg[]) {
2266 { .name = "fp32", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "FP32" },
2267 { .name = "values", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "VALUES" },
2268 { .name = NULL }
2269 }
2270 },
2271 { .name = "vector", .type = REDISMODULE_ARG_TYPE_STRING },
2272 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2273 { .name = "cas", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "CAS", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2274 { .name = "quant_type", .type = REDISMODULE_ARG_TYPE_ONEOF, .flags = REDISMODULE_CMD_ARG_OPTIONAL, .subargs = (RedisModuleCommandArg[]) {
2275 { .name = "noquant", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "NOQUANT" },
2276 { .name = "bin", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "BIN" },
2277 { .name = "q8", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "Q8" },
2278 { .name = NULL }
2279 }
2280 },
2281 { .name = "build-exploration-factor", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2282 { .name = "attributes", .type = REDISMODULE_ARG_TYPE_STRING, .token = "SETATTR", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2283 { .name = "numlinks", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "M", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2284 { .name = NULL }
2285 };
2286 RedisModuleCommandInfo vadd_info = {
2287 .version = REDISMODULE_COMMAND_INFO_VERSION,
2288 .summary = "Add one or more elements to a vector set, or update its vector if it already exists",
2289 .since = "8.0.0",
2290 .arity = -5,
2291 .args = vadd_args,
2292 };
2293 if (RedisModule_SetCommandInfo(vadd_cmd, &vadd_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2294
2295 // Register command VREM
2296 if (RedisModule_CreateCommand(ctx,"VREM",
2297 VREM_RedisCommand,"write",1,1,1) == REDISMODULE_ERR)
2298 return REDISMODULE_ERR;
2299
2300 RedisModuleCommand *vrem_cmd = RedisModule_GetCommand(ctx, "VREM");
2301 if (vrem_cmd == NULL) return REDISMODULE_ERR;
2302
2303 RedisModuleCommandArg vrem_args[] = {
2304 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2305 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2306 { .name = NULL }
2307 };
2308 RedisModuleCommandInfo vrem_info = {
2309 .version = REDISMODULE_COMMAND_INFO_VERSION,
2310 .summary = "Remove an element from a vector set",
2311 .since = "8.0.0",
2312 .arity = 3,
2313 .args = vrem_args,
2314 };
2315 if (RedisModule_SetCommandInfo(vrem_cmd, &vrem_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2316
2317 // Register command VSIM
2318 if (RedisModule_CreateCommand(ctx,"VSIM",
2319 VSIM_RedisCommand,"readonly",1,1,1) == REDISMODULE_ERR)
2320 return REDISMODULE_ERR;
2321
2322 RedisModuleCommand *vsim_cmd = RedisModule_GetCommand(ctx, "VSIM");
2323 if (vsim_cmd == NULL) return REDISMODULE_ERR;
2324
2325 RedisModuleCommandArg vsim_args[] = {
2326 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2327 { .name = "format", .type = REDISMODULE_ARG_TYPE_ONEOF, .subargs = (RedisModuleCommandArg[]) {
2328 { .name = "ele", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "ELE" },
2329 { .name = "fp32", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "FP32" },
2330 { .name = "values", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "VALUES" },
2331 { .name = NULL }
2332 }
2333 },
2334 { .name = "vector_or_element", .type = REDISMODULE_ARG_TYPE_STRING },
2335 { .name = "withscores", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHSCORES", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2336 { .name = "withattribs", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHATTRIBS", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2337 { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "COUNT", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2338 { .name = "max_distance", .type = REDISMODULE_ARG_TYPE_DOUBLE, .token = "EPSILON", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2339 { .name = "search-exploration-factor", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2340 { .name = "expression", .type = REDISMODULE_ARG_TYPE_STRING, .token = "FILTER", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2341 { .name = "max-filtering-effort", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "FILTER-EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2342 { .name = "truth", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "TRUTH", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2343 { .name = "nothread", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "NOTHREAD", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2344 { .name = NULL }
2345 };
2346 RedisModuleCommandInfo vsim_info = {
2347 .version = REDISMODULE_COMMAND_INFO_VERSION,
2348 .summary = "Return elements by vector similarity",
2349 .since = "8.0.0",
2350 .arity = -4,
2351 .args = vsim_args,
2352 };
2353 if (RedisModule_SetCommandInfo(vsim_cmd, &vsim_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2354
2355 // Register command VDIM
2356 if (RedisModule_CreateCommand(ctx, "VDIM",
2357 VDIM_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2358 return REDISMODULE_ERR;
2359
2360 RedisModuleCommand *vdim_cmd = RedisModule_GetCommand(ctx, "VDIM");
2361 if (vdim_cmd == NULL) return REDISMODULE_ERR;
2362
2363 RedisModuleCommandArg vdim_args[] = {
2364 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2365 { .name = NULL }
2366 };
2367 RedisModuleCommandInfo vdim_info = {
2368 .version = REDISMODULE_COMMAND_INFO_VERSION,
2369 .summary = "Return the dimension of vectors in the vector set",
2370 .since = "8.0.0",
2371 .arity = 2,
2372 .args = vdim_args,
2373 };
2374 if (RedisModule_SetCommandInfo(vdim_cmd, &vdim_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2375
2376 // Register command VCARD
2377 if (RedisModule_CreateCommand(ctx, "VCARD",
2378 VCARD_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2379 return REDISMODULE_ERR;
2380
2381 RedisModuleCommand *vcard_cmd = RedisModule_GetCommand(ctx, "VCARD");
2382 if (vcard_cmd == NULL) return REDISMODULE_ERR;
2383
2384 RedisModuleCommandArg vcard_args[] = {
2385 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2386 { .name = NULL }
2387 };
2388 RedisModuleCommandInfo vcard_info = {
2389 .version = REDISMODULE_COMMAND_INFO_VERSION,
2390 .summary = "Return the number of elements in a vector set",
2391 .since = "8.0.0",
2392 .arity = 2,
2393 .args = vcard_args,
2394 };
2395 if (RedisModule_SetCommandInfo(vcard_cmd, &vcard_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2396
2397 // Register command VEMB
2398 if (RedisModule_CreateCommand(ctx, "VEMB",
2399 VEMB_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2400 return REDISMODULE_ERR;
2401
2402 RedisModuleCommand *vemb_cmd = RedisModule_GetCommand(ctx, "VEMB");
2403 if (vemb_cmd == NULL) return REDISMODULE_ERR;
2404
2405 RedisModuleCommandArg vemb_args[] = {
2406 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2407 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2408 { .name = "raw", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "RAW", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2409 { .name = NULL }
2410 };
2411 RedisModuleCommandInfo vemb_info = {
2412 .version = REDISMODULE_COMMAND_INFO_VERSION,
2413 .summary = "Return the vector associated with an element",
2414 .since = "8.0.0",
2415 .arity = -3,
2416 .args = vemb_args,
2417 };
2418 if (RedisModule_SetCommandInfo(vemb_cmd, &vemb_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2419
2420 // Register command VLINKS
2421 if (RedisModule_CreateCommand(ctx, "VLINKS",
2422 VLINKS_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2423 return REDISMODULE_ERR;
2424
2425 RedisModuleCommand *vlinks_cmd = RedisModule_GetCommand(ctx, "VLINKS");
2426 if (vlinks_cmd == NULL) return REDISMODULE_ERR;
2427
2428 RedisModuleCommandArg vlinks_args[] = {
2429 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2430 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2431 { .name = "withscores", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHSCORES", .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2432 { .name = NULL }
2433 };
2434 RedisModuleCommandInfo vlinks_info = {
2435 .version = REDISMODULE_COMMAND_INFO_VERSION,
2436 .summary = "Return the neighbors of an element at each layer in the HNSW graph",
2437 .since = "8.0.0",
2438 .arity = -3,
2439 .args = vlinks_args,
2440 };
2441 if (RedisModule_SetCommandInfo(vlinks_cmd, &vlinks_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2442
2443 // Register command VINFO
2444 if (RedisModule_CreateCommand(ctx, "VINFO",
2445 VINFO_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2446 return REDISMODULE_ERR;
2447
2448 RedisModuleCommand *vinfo_cmd = RedisModule_GetCommand(ctx, "VINFO");
2449 if (vinfo_cmd == NULL) return REDISMODULE_ERR;
2450
2451 RedisModuleCommandArg vinfo_args[] = {
2452 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2453 { .name = NULL }
2454 };
2455 RedisModuleCommandInfo vinfo_info = {
2456 .version = REDISMODULE_COMMAND_INFO_VERSION,
2457 .summary = "Return information about a vector set",
2458 .since = "8.0.0",
2459 .arity = 2,
2460 .args = vinfo_args,
2461 };
2462 if (RedisModule_SetCommandInfo(vinfo_cmd, &vinfo_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2463
2464 // Register command VSETATTR
2465 if (RedisModule_CreateCommand(ctx, "VSETATTR",
2466 VSETATTR_RedisCommand, "write fast", 1, 1, 1) == REDISMODULE_ERR)
2467 return REDISMODULE_ERR;
2468
2469 RedisModuleCommand *vsetattr_cmd = RedisModule_GetCommand(ctx, "VSETATTR");
2470 if (vsetattr_cmd == NULL) return REDISMODULE_ERR;
2471
2472 RedisModuleCommandArg vsetattr_args[] = {
2473 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2474 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2475 { .name = "json", .type = REDISMODULE_ARG_TYPE_STRING },
2476 { .name = NULL }
2477 };
2478 RedisModuleCommandInfo vsetattr_info = {
2479 .version = REDISMODULE_COMMAND_INFO_VERSION,
2480 .summary = "Associate or remove the JSON attributes of elements",
2481 .since = "8.0.0",
2482 .arity = 4,
2483 .args = vsetattr_args,
2484 };
2485 if (RedisModule_SetCommandInfo(vsetattr_cmd, &vsetattr_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2486
2487 // Register command VGETATTR
2488 if (RedisModule_CreateCommand(ctx, "VGETATTR",
2489 VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
2490 return REDISMODULE_ERR;
2491
2492 RedisModuleCommand *vgetattr_cmd = RedisModule_GetCommand(ctx, "VGETATTR");
2493 if (vgetattr_cmd == NULL) return REDISMODULE_ERR;
2494
2495 RedisModuleCommandArg vgetattr_args[] = {
2496 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2497 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2498 { .name = NULL }
2499 };
2500 RedisModuleCommandInfo vgetattr_info = {
2501 .version = REDISMODULE_COMMAND_INFO_VERSION,
2502 .summary = "Retrieve the JSON attributes of elements",
2503 .since = "8.0.0",
2504 .arity = 3,
2505 .args = vgetattr_args,
2506 };
2507 if (RedisModule_SetCommandInfo(vgetattr_cmd, &vgetattr_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2508
2509 // Register command VRANDMEMBER
2510 if (RedisModule_CreateCommand(ctx, "VRANDMEMBER",
2511 VRANDMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR)
2512 return REDISMODULE_ERR;
2513
2514 RedisModuleCommand *vrandmember_cmd = RedisModule_GetCommand(ctx, "VRANDMEMBER");
2515 if (vrandmember_cmd == NULL) return REDISMODULE_ERR;
2516
2517 RedisModuleCommandArg vrandmember_args[] = {
2518 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2519 { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2520 { .name = NULL }
2521 };
2522 RedisModuleCommandInfo vrandmember_info = {
2523 .version = REDISMODULE_COMMAND_INFO_VERSION,
2524 .summary = "Return one or multiple random members from a vector set",
2525 .since = "8.0.0",
2526 .arity = -2,
2527 .args = vrandmember_args,
2528 };
2529 if (RedisModule_SetCommandInfo(vrandmember_cmd, &vrandmember_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2530
2531 // Register command VISMEMBER
2532 if (RedisModule_CreateCommand(ctx, "VISMEMBER",
2533 VISMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR)
2534 return REDISMODULE_ERR;
2535
2536 RedisModuleCommand *vismember_cmd = RedisModule_GetCommand(ctx, "VISMEMBER");
2537 if (vismember_cmd == NULL) return REDISMODULE_ERR;
2538
2539 RedisModuleCommandArg vismember_args[] = {
2540 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2541 { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING },
2542 { .name = NULL }
2543 };
2544 RedisModuleCommandInfo vismember_info = {
2545 .version = REDISMODULE_COMMAND_INFO_VERSION,
2546 .summary = "Check if an element exists in a vector set",
2547 .since = "8.2.0",
2548 .arity = 3,
2549 .args = vismember_args,
2550 };
2551 if (RedisModule_SetCommandInfo(vismember_cmd, &vismember_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2552
2553 // Register command VRANGE
2554 if (RedisModule_CreateCommand(ctx, "VRANGE",
2555 VRANGE_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR)
2556 return REDISMODULE_ERR;
2557
2558 RedisModuleCommand *vrange_cmd = RedisModule_GetCommand(ctx, "VRANGE");
2559 if (vrange_cmd == NULL) return REDISMODULE_ERR;
2560
2561 RedisModuleCommandArg vrange_args[] = {
2562 { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 },
2563 { .name = "start", .type = REDISMODULE_ARG_TYPE_STRING },
2564 { .name = "end", .type = REDISMODULE_ARG_TYPE_STRING },
2565 { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .flags = REDISMODULE_CMD_ARG_OPTIONAL },
2566 { .name = NULL }
2567 };
2568 RedisModuleCommandInfo vrange_info = {
2569 .version = REDISMODULE_COMMAND_INFO_VERSION,
2570 .summary = "Return vector set elements in a lex range",
2571 .since = "8.4.0",
2572 .arity = -4,
2573 .args = vrange_args,
2574 };
2575 if (RedisModule_SetCommandInfo(vrange_cmd, &vrange_info) == REDISMODULE_ERR) return REDISMODULE_ERR;
2576
2577 // Set the allocator for the HNSW library, so that memory tracking
2578 // is correct in Redis.
2579 hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc,
2580 RedisModule_Realloc);
2581
2582 return REDISMODULE_OK;
2583}
2584
2585int VectorSets_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
2586 return RedisModule_OnLoad(ctx, argv, argc);
2587}
diff --git a/examples/redis-unstable/modules/vector-sets/vset_config.c b/examples/redis-unstable/modules/vector-sets/vset_config.c
new file mode 100644
index 0000000..79dc8a3
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/vset_config.c
@@ -0,0 +1,51 @@
1/* vector set module configuration.
2 *
3 * Copyright (c) 2009-Present, Redis Ltd.
4 * All rights reserved.
5 *
6 * Licensed under your choice of (a) the Redis Source Available License 2.0
7 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
8 * GNU Affero General Public License v3 (AGPLv3).
9*/
10
11#include "vset_config.h"
12
13/* Define __STRING macro for portability (not available in all environments) */
14#ifndef __STRING
15#define __STRING(x) #x
16#endif
17
18#define RM_TRY(expr) \
19 if (expr == REDISMODULE_ERR) { \
20 RedisModule_Log(ctx, "warning", "Could not run " __STRING(expr)); \
21 return REDISMODULE_ERR; \
22 }
23
24VSConfig VSGlobalConfig;
25
26int set_bool_config(const char *name, int val, void *privdata,
27 RedisModuleString **err) {
28 REDISMODULE_NOT_USED(name);
29 REDISMODULE_NOT_USED(err);
30 *(int *)privdata = val;
31 return REDISMODULE_OK;
32}
33
34int get_bool_config(const char *name, void *privdata) {
35 REDISMODULE_NOT_USED(name);
36 return *(int *)privdata;
37}
38
39int RegisterModuleConfig(RedisModuleCtx *ctx) {
40 // Numeric parameters
41 RM_TRY(
42 RedisModule_RegisterBoolConfig(
43 ctx, "vset-force-single-threaded-execution", 0,
44 REDISMODULE_CONFIG_UNPREFIXED,
45 get_bool_config, set_bool_config, NULL,
46 (void *)&(VSGlobalConfig.forceSingleThreadExec)
47 )
48 )
49
50 return REDISMODULE_OK;
51}
diff --git a/examples/redis-unstable/modules/vector-sets/vset_config.h b/examples/redis-unstable/modules/vector-sets/vset_config.h
new file mode 100644
index 0000000..8da94fa
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/vset_config.h
@@ -0,0 +1,24 @@
1/* vector set module configuration.
2 *
3 * Copyright (c) 2009-Present, Redis Ltd.
4 * All rights reserved.
5 *
6 * Licensed under your choice of (a) the Redis Source Available License 2.0
7 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
8 * GNU Affero General Public License v3 (AGPLv3).
9*/
10
11#ifndef VSET_CONFIG_H
12#define VSET_CONFIG_H
13
14#include "../../src/redismodule.h"
15
16typedef struct {
17 int forceSingleThreadExec;
18} VSConfig;
19
20extern VSConfig VSGlobalConfig;
21
22int RegisterModuleConfig(RedisModuleCtx *ctx);
23
24#endif
diff --git a/examples/redis-unstable/modules/vector-sets/w2v.c b/examples/redis-unstable/modules/vector-sets/w2v.c
new file mode 100644
index 0000000..bcf6338
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/w2v.c
@@ -0,0 +1,539 @@
1/*
2 * HNSW (Hierarchical Navigable Small World) Implementation
3 * Based on the paper by Yu. A. Malkov, D. A. Yashunin
4 *
5 * Copyright (c) 2009-Present, Redis Ltd.
6 * All rights reserved.
7 *
8 * Licensed under your choice of (a) the Redis Source Available License 2.0
9 * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
10 * GNU Affero General Public License v3 (AGPLv3).
11 * Originally authored by: Salvatore Sanfilippo
12 */
13
14#define _DEFAULT_SOURCE
15#define _USE_MATH_DEFINES
16#define _POSIX_C_SOURCE 200809L
17
18#include <stdio.h>
19#include <stdlib.h>
20#include <string.h>
21#include <strings.h>
22#include <sys/time.h>
23#include <time.h>
24#include <stdint.h>
25#include <pthread.h>
26#include <stdatomic.h>
27#include <math.h>
28
29#include "hnsw.h"
30
31/* Get current time in milliseconds */
32uint64_t ms_time(void) {
33 struct timeval tv;
34 gettimeofday(&tv, NULL);
35 return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000);
36}
37
38/* Implementation of the recall test with random vectors. */
39void test_recall(HNSW *index, int ef) {
40 const int num_test_vectors = 10000;
41 const int k = 100; // Number of nearest neighbors to find.
42 if (ef < k) ef = k;
43
44 // Add recall distribution counters (2% bins from 0-100%).
45 int recall_bins[50] = {0};
46
47 // Create array to store vectors for mixing.
48 int num_source_vectors = 1000; // Enough, since we mix them.
49 float **source_vectors = malloc(sizeof(float*) * num_source_vectors);
50 if (!source_vectors) {
51 printf("Failed to allocate memory for source vectors\n");
52 return;
53 }
54
55 // Allocate memory for each source vector.
56 for (int i = 0; i < num_source_vectors; i++) {
57 source_vectors[i] = malloc(sizeof(float) * 300);
58 if (!source_vectors[i]) {
59 printf("Failed to allocate memory for source vector %d\n", i);
60 // Clean up already allocated vectors.
61 for (int j = 0; j < i; j++) free(source_vectors[j]);
62 free(source_vectors);
63 return;
64 }
65 }
66
67 /* Populate source vectors from the index, we just scan the
68 * first N items. */
69 int source_count = 0;
70 hnswNode *current = index->head;
71 while (current && source_count < num_source_vectors) {
72 hnsw_get_node_vector(index, current, source_vectors[source_count]);
73 source_count++;
74 current = current->next;
75 }
76
77 if (source_count < num_source_vectors) {
78 printf("Warning: Only found %d nodes for source vectors\n",
79 source_count);
80 num_source_vectors = source_count;
81 }
82
83 // Allocate memory for test vector.
84 float *test_vector = malloc(sizeof(float) * 300);
85 if (!test_vector) {
86 printf("Failed to allocate memory for test vector\n");
87 for (int i = 0; i < num_source_vectors; i++) {
88 free(source_vectors[i]);
89 }
90 free(source_vectors);
91 return;
92 }
93
94 // Allocate memory for results.
95 hnswNode **hnsw_results = malloc(sizeof(hnswNode*) * ef);
96 hnswNode **linear_results = malloc(sizeof(hnswNode*) * ef);
97 float *hnsw_distances = malloc(sizeof(float) * ef);
98 float *linear_distances = malloc(sizeof(float) * ef);
99
100 if (!hnsw_results || !linear_results || !hnsw_distances || !linear_distances) {
101 printf("Failed to allocate memory for results\n");
102 if (hnsw_results) free(hnsw_results);
103 if (linear_results) free(linear_results);
104 if (hnsw_distances) free(hnsw_distances);
105 if (linear_distances) free(linear_distances);
106 for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]);
107 free(source_vectors);
108 free(test_vector);
109 return;
110 }
111
112 // Initialize random seed.
113 srand(time(NULL));
114
115 // Perform recall test.
116 printf("\nPerforming recall test with EF=%d on %d random vectors...\n",
117 ef, num_test_vectors);
118 double total_recall = 0.0;
119
120 for (int t = 0; t < num_test_vectors; t++) {
121 // Create a random vector by mixing 3 existing vectors.
122 float weights[3] = {0.0};
123 int src_indices[3] = {0};
124
125 // Generate random weights.
126 float weight_sum = 0.0;
127 for (int i = 0; i < 3; i++) {
128 weights[i] = (float)rand() / RAND_MAX;
129 weight_sum += weights[i];
130 src_indices[i] = rand() % num_source_vectors;
131 }
132
133 // Normalize weights.
134 for (int i = 0; i < 3; i++) weights[i] /= weight_sum;
135
136 // Mix vectors.
137 memset(test_vector, 0, sizeof(float) * 300);
138 for (int i = 0; i < 3; i++) {
139 for (int j = 0; j < 300; j++) {
140 test_vector[j] +=
141 weights[i] * source_vectors[src_indices[i]][j];
142 }
143 }
144
145 // Perform HNSW search with the specified EF parameter.
146 int slot = hnsw_acquire_read_slot(index);
147 int hnsw_found = hnsw_search(index, test_vector, ef, hnsw_results, hnsw_distances, slot, 0);
148
149 // Perform linear search (ground truth).
150 int linear_found = hnsw_ground_truth_with_filter(index, test_vector, ef, linear_results, linear_distances, slot, 0, NULL, NULL);
151 hnsw_release_read_slot(index, slot);
152
153 // Calculate recall for this query (intersection size / k).
154 if (hnsw_found > k) hnsw_found = k;
155 if (linear_found > k) linear_found = k;
156 int intersection_count = 0;
157 for (int i = 0; i < linear_found; i++) {
158 for (int j = 0; j < hnsw_found; j++) {
159 if (linear_results[i] == hnsw_results[j]) {
160 intersection_count++;
161 break;
162 }
163 }
164 }
165
166 double recall = (double)intersection_count / linear_found;
167 total_recall += recall;
168
169 // Add to distribution bins (2% steps)
170 int bin_index = (int)(recall * 50);
171 if (bin_index >= 50) bin_index = 49; // Handle 100% recall case
172 recall_bins[bin_index]++;
173
174 // Show progress.
175 if ((t+1) % 1000 == 0 || t == num_test_vectors-1) {
176 printf("Processed %d/%d queries, current avg recall: %.2f%%\n",
177 t+1, num_test_vectors, (total_recall / (t+1)) * 100);
178 }
179 }
180
181 // Calculate and print final average recall.
182 double avg_recall = (total_recall / num_test_vectors) * 100;
183 printf("\nRecall Test Results:\n");
184 printf("Average recall@%d (EF=%d): %.2f%%\n", k, ef, avg_recall);
185
186 // Print recall distribution histogram.
187 printf("\nRecall Distribution (2%% bins):\n");
188 printf("================================\n");
189
190 // Find the maximum bin count for scaling.
191 int max_count = 0;
192 for (int i = 0; i < 50; i++) {
193 if (recall_bins[i] > max_count) max_count = recall_bins[i];
194 }
195
196 // Scale factor for histogram (max 50 chars wide)
197 const int max_bars = 50;
198 double scale = (max_count > max_bars) ? (double)max_bars / max_count : 1.0;
199
200 // Print the histogram.
201 for (int i = 0; i < 50; i++) {
202 int bar_len = (int)(recall_bins[i] * scale);
203 printf("%3d%%-%-3d%% | %-6d |", i*2, (i+1)*2, recall_bins[i]);
204 for (int j = 0; j < bar_len; j++) printf("#");
205 printf("\n");
206 }
207
208 // Cleanup.
209 free(hnsw_results);
210 free(linear_results);
211 free(hnsw_distances);
212 free(linear_distances);
213 free(test_vector);
214 for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]);
215 free(source_vectors);
216}
217
218/* Example usage in main() */
219int w2v_single_thread(int m_param, int quantization, uint64_t numele, int massdel, int self_recall, int recall_ef) {
220 /* Create index */
221 HNSW *index = hnsw_new(300, quantization, m_param);
222 float v[300];
223 uint16_t wlen;
224
225 FILE *fp = fopen("word2vec.bin","rb");
226 if (fp == NULL) {
227 perror("word2vec.bin file missing");
228 exit(1);
229 }
230 unsigned char header[8];
231 if (fread(header,8,1,fp) <= 0) { // Skip header
232 perror("Unexpected EOF");
233 exit(1);
234 }
235
236 uint64_t id = 0;
237 uint64_t start_time = ms_time();
238 char *word = NULL;
239 hnswNode *search_node = NULL;
240
241 while(id < numele) {
242 if (fread(&wlen,2,1,fp) == 0) break;
243 word = malloc(wlen+1);
244 if (fread(word,wlen,1,fp) <= 0) {
245 perror("unexpected EOF");
246 exit(1);
247 }
248 word[wlen] = 0;
249 if (fread(v,300*sizeof(float),1,fp) <= 0) {
250 perror("unexpected EOF");
251 exit(1);
252 }
253
254 // Plain API that acquires a write lock for the whole time.
255 hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200);
256
257 if (!strcmp(word,"banana")) search_node = added;
258 if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id);
259 }
260 uint64_t elapsed = ms_time() - start_time;
261 fclose(fp);
262
263 printf("%llu words added (%llu words/sec), last word: %s\n",
264 (unsigned long long)index->node_count,
265 (unsigned long long)id*1000/elapsed, word);
266
267 /* Search query */
268 if (search_node == NULL) search_node = index->head;
269 hnsw_get_node_vector(index,search_node,v);
270 hnswNode *neighbors[10];
271 float distances[10];
272
273 int found, j;
274 start_time = ms_time();
275 for (j = 0; j < 20000; j++)
276 found = hnsw_search(index, v, 10, neighbors, distances, 0, 0);
277 elapsed = ms_time() - start_time;
278 printf("%d searches performed (%llu searches/sec), nodes found: %d\n",
279 j, (unsigned long long)j*1000/elapsed, found);
280
281 if (found > 0) {
282 printf("Found %d neighbors:\n", found);
283 for (int i = 0; i < found; i++) {
284 printf("Node ID: %llu, distance: %f, word: %s\n",
285 (unsigned long long)neighbors[i]->id,
286 distances[i], (char*)neighbors[i]->value);
287 }
288 }
289
290 // Self-recall test (ability to find the node by its own vector).
291 if (self_recall) {
292 hnsw_print_stats(index);
293 hnsw_test_graph_recall(index,200,0);
294 }
295
296 // Recall test with random vectors.
297 if (recall_ef > 0) {
298 test_recall(index, recall_ef);
299 }
300
301 uint64_t connected_nodes;
302 int reciprocal_links;
303 hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
304
305 if (massdel) {
306 int remove_perc = 95;
307 printf("\nRemoving %d%% of nodes...\n", remove_perc);
308 uint64_t initial_nodes = index->node_count;
309
310 hnswNode *current = index->head;
311 while (current && index->node_count > initial_nodes*(100-remove_perc)/100) {
312 hnswNode *next = current->next;
313 hnsw_delete_node(index,current,free);
314 current = next;
315 // In order to don't remove only contiguous nodes, from time
316 // skip a node.
317 if (current && !(random() % remove_perc)) current = current->next;
318 }
319 printf("%llu nodes left\n", (unsigned long long)index->node_count);
320
321 // Test again.
322 hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
323 hnsw_test_graph_recall(index,200,0);
324 }
325
326 hnsw_free(index,free);
327 return 0;
328}
329
330struct threadContext {
331 pthread_mutex_t FileAccessMutex;
332 uint64_t numele;
333 _Atomic uint64_t SearchesDone;
334 _Atomic uint64_t id;
335 FILE *fp;
336 HNSW *index;
337 float *search_vector;
338};
339
340// Note that in practical terms inserting with many concurrent threads
341// may be *slower* and not faster, because there is a lot of
342// contention. So this is more a robustness test than anything else.
343//
344// The optimistic commit API goal is actually to exploit the ability to
345// add faster when there are many concurrent reads.
346void *threaded_insert(void *ctxptr) {
347 struct threadContext *ctx = ctxptr;
348 char *word;
349 float v[300];
350 uint16_t wlen;
351
352 while(1) {
353 pthread_mutex_lock(&ctx->FileAccessMutex);
354 if (fread(&wlen,2,1,ctx->fp) == 0) break;
355 pthread_mutex_unlock(&ctx->FileAccessMutex);
356 word = malloc(wlen+1);
357 if (fread(word,wlen,1,ctx->fp) <= 0) {
358 perror("Unexpected EOF");
359 exit(1);
360 }
361
362 word[wlen] = 0;
363 if (fread(v,300*sizeof(float),1,ctx->fp) <= 0) {
364 perror("Unexpected EOF");
365 exit(1);
366 }
367
368 // Check-and-set API that performs the costly scan for similar
369 // nodes concurrently with other read threads, and finally
370 // applies the check if the graph wasn't modified.
371 InsertContext *ic;
372 uint64_t next_id = ctx->id++;
373 ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, 200);
374 if (hnsw_try_commit_insert(ctx->index, ic, word) == NULL) {
375 // This time try locking since the start.
376 hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200);
377 }
378
379 if (next_id >= ctx->numele) break;
380 if (!((next_id+1) % 10000))
381 printf("%llu added\n", (unsigned long long)next_id+1);
382 }
383 return NULL;
384}
385
386void *threaded_search(void *ctxptr) {
387 struct threadContext *ctx = ctxptr;
388
389 /* Search query */
390 hnswNode *neighbors[10];
391 float distances[10];
392 int found = 0;
393 uint64_t last_id = 0;
394
395 while(ctx->id < 1000000) {
396 int slot = hnsw_acquire_read_slot(ctx->index);
397 found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0);
398 hnsw_release_read_slot(ctx->index,slot);
399 last_id = ++ctx->id;
400 }
401
402 if (found > 0 && last_id == 1000000) {
403 printf("Found %d neighbors:\n", found);
404 for (int i = 0; i < found; i++) {
405 printf("Node ID: %llu, distance: %f, word: %s\n",
406 (unsigned long long)neighbors[i]->id,
407 distances[i], (char*)neighbors[i]->value);
408 }
409 }
410 return NULL;
411}
412
413int w2v_multi_thread(int m_param, int numthreads, int quantization, uint64_t numele) {
414 /* Create index */
415 struct threadContext ctx;
416
417 ctx.index = hnsw_new(300, quantization, m_param);
418
419 ctx.fp = fopen("word2vec.bin","rb");
420 if (ctx.fp == NULL) {
421 perror("word2vec.bin file missing");
422 exit(1);
423 }
424
425 unsigned char header[8];
426 if (fread(header,8,1,ctx.fp) <= 0) { // Skip header
427 perror("Unexpected EOF");
428 exit(1);
429 }
430 pthread_mutex_init(&ctx.FileAccessMutex,NULL);
431
432 uint64_t start_time = ms_time();
433 ctx.id = 0;
434 ctx.numele = numele;
435 pthread_t threads[numthreads];
436 for (int j = 0; j < numthreads; j++)
437 pthread_create(&threads[j], NULL, threaded_insert, &ctx);
438
439 // Wait for all the threads to terminate adding items.
440 for (int j = 0; j < numthreads; j++)
441 pthread_join(threads[j],NULL);
442
443 uint64_t elapsed = ms_time() - start_time;
444 fclose(ctx.fp);
445
446 // Obtain the last word.
447 hnswNode *node = ctx.index->head;
448 char *word = node->value;
449
450 // We will search this last inserted word in the next test.
451 // Let's save its embedding.
452 ctx.search_vector = malloc(sizeof(float)*300);
453 hnsw_get_node_vector(ctx.index,node,ctx.search_vector);
454
455 printf("%llu words added (%llu words/sec), last word: %s\n",
456 (unsigned long long)ctx.index->node_count,
457 (unsigned long long)ctx.id*1000/elapsed, word);
458
459 /* Search query */
460 start_time = ms_time();
461 ctx.id = 0; // We will use this atomic field to stop at N queries done.
462
463 for (int j = 0; j < numthreads; j++)
464 pthread_create(&threads[j], NULL, threaded_search, &ctx);
465
466 // Wait for all the threads to terminate searching.
467 for (int j = 0; j < numthreads; j++)
468 pthread_join(threads[j],NULL);
469
470 elapsed = ms_time() - start_time;
471 printf("%llu searches performed (%llu searches/sec)\n",
472 (unsigned long long)ctx.id,
473 (unsigned long long)ctx.id*1000/elapsed);
474
475 hnsw_print_stats(ctx.index);
476 uint64_t connected_nodes;
477 int reciprocal_links;
478 hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links);
479 printf("%llu connected nodes. Links all reciprocal: %d\n",
480 (unsigned long long)connected_nodes, reciprocal_links);
481 hnsw_free(ctx.index,free);
482 return 0;
483}
484
485int main(int argc, char **argv) {
486 int quantization = HNSW_QUANT_NONE;
487 int numthreads = 0;
488 uint64_t numele = 20000;
489 int m_param = 0; // Default value (0 means use HNSW_DEFAULT_M)
490
491 /* This you can enable in single thread mode for testing: */
492 int massdel = 0; // If true, does the mass deletion test.
493 int self_recall = 0; // If true, does the self-recall test.
494 int recall_ef = 0; // If not 0, does the recall test with this EF value.
495
496 for (int j = 1; j < argc; j++) {
497 int moreargs = argc-j-1;
498
499 if (!strcasecmp(argv[j],"--quant")) {
500 quantization = HNSW_QUANT_Q8;
501 } else if (!strcasecmp(argv[j],"--bin")) {
502 quantization = HNSW_QUANT_BIN;
503 } else if (!strcasecmp(argv[j],"--mass-del")) {
504 massdel = 1;
505 } else if (!strcasecmp(argv[j],"--self-recall")) {
506 self_recall = 1;
507 } else if (moreargs >= 1 && !strcasecmp(argv[j],"--recall")) {
508 recall_ef = atoi(argv[j+1]);
509 j++;
510 } else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) {
511 numthreads = atoi(argv[j+1]);
512 j++;
513 } else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) {
514 numele = strtoll(argv[j+1],NULL,0);
515 j++;
516 if (numele < 1) numele = 1;
517 } else if (moreargs >= 1 && !strcasecmp(argv[j],"--m")) {
518 m_param = atoi(argv[j+1]);
519 j++;
520 } else if (!strcasecmp(argv[j],"--help")) {
521 printf("%s [--quant] [--bin] [--thread <count>] [--numele <count>] [--m <count>] [--mass-del] [--self-recall] [--recall <ef>]\n", argv[0]);
522 exit(0);
523 } else {
524 printf("Unrecognized option or wrong number of arguments: %s\n", argv[j]);
525 exit(1);
526 }
527 }
528
529 if (quantization == HNSW_QUANT_NONE) {
530 printf("You can enable quantization with --quant\n");
531 }
532
533 if (numthreads > 0) {
534 w2v_multi_thread(m_param, numthreads, quantization, numele);
535 } else {
536 printf("Single thread execution. Use --threads 4 for concurrent API\n");
537 w2v_single_thread(m_param, quantization, numele, massdel, self_recall, recall_ef);
538 }
539}