diff options
Diffstat (limited to 'examples/redis-unstable/modules')
56 files changed, 13346 insertions, 0 deletions
diff --git a/examples/redis-unstable/modules/Makefile b/examples/redis-unstable/modules/Makefile new file mode 100644 index 0000000..11fcdc3 --- /dev/null +++ b/examples/redis-unstable/modules/Makefile @@ -0,0 +1,90 @@ + +SUBDIRS = redisjson redistimeseries redisbloom redisearch + +define submake + for dir in $(SUBDIRS); do $(MAKE) -C $$dir $(1); done +endef + +all: prepare_source + $(call submake,$@) + +get_source: + $(call submake,$@) + +prepare_source: get_source handle-werrors setup_environment + +clean: + $(call submake,$@) + +distclean: clean_environment + $(call submake,$@) + +pristine: + $(call submake,$@) + +install: + $(call submake,$@) + +setup_environment: install-rust handle-werrors + +clean_environment: uninstall-rust + +# Keep all of the Rust stuff in one place +install-rust: +ifeq ($(INSTALL_RUST_TOOLCHAIN),yes) + @RUST_VERSION=1.92.0; \ + ARCH="$$(uname -m)"; \ + if ldd --version 2>&1 | grep -q musl; then LIBC_TYPE="musl"; else LIBC_TYPE="gnu"; fi; \ + echo "Detected architecture: $${ARCH} and libc: $${LIBC_TYPE}"; \ + case "$${ARCH}" in \ + 'x86_64') \ + if [ "$${LIBC_TYPE}" = "musl" ]; then \ + RUST_INSTALLER="rust-$${RUST_VERSION}-x86_64-unknown-linux-musl"; \ + RUST_SHA256="e9b977ef480504d3beef0a095b470945af79e8496bcbb0e4a9d4601d6d72dd91"; \ + else \ + RUST_INSTALLER="rust-$${RUST_VERSION}-x86_64-unknown-linux-gnu"; \ + RUST_SHA256="d2ccef59dd9f7439f2c694948069f789a044dc1addcc0803613232af8f88ee0c"; \ + fi ;; \ + 'aarch64') \ + if [ "$${LIBC_TYPE}" = "musl" ]; then \ + RUST_INSTALLER="rust-$${RUST_VERSION}-aarch64-unknown-linux-musl"; \ + RUST_SHA256="0c35fca319d01c3d55345d44dbe66c987858c45e262a77c4db679e9869b0af73"; \ + else \ + RUST_INSTALLER="rust-$${RUST_VERSION}-aarch64-unknown-linux-gnu"; \ + RUST_SHA256="3e383f8b4fca710d0600d0c1de97b78281672be2cda6575ecbe1c183a12e3822"; \ + fi ;; \ + *) echo >&2 "Unsupported architecture: '$${ARCH}'"; exit 1 ;; \ + esac; \ + echo "Downloading and installing Rust standalone installer: $${RUST_INSTALLER}"; \ + wget --quiet -O $${RUST_INSTALLER}.tar.xz https://static.rust-lang.org/dist/$${RUST_INSTALLER}.tar.xz; \ + echo "$${RUST_SHA256} $${RUST_INSTALLER}.tar.xz" | sha256sum -c --quiet || { echo "Rust standalone installer checksum failed!"; exit 1; }; \ + tar -xf $${RUST_INSTALLER}.tar.xz; \ + (cd $${RUST_INSTALLER} && ./install.sh); \ + rm -rf $${RUST_INSTALLER} +endif + +uninstall-rust: +ifeq ($(INSTALL_RUST_TOOLCHAIN),yes) + @if [ -x "/usr/local/lib/rustlib/uninstall.sh" ]; then \ + echo "Uninstalling Rust using uninstall.sh script"; \ + rm -rf ~/.cargo; \ + /usr/local/lib/rustlib/uninstall.sh; \ + else \ + echo "WARNING: Rust toolchain not found or uninstall script is missing."; \ + fi +endif + +handle-werrors: get_source +ifeq ($(DISABLE_WERRORS),yes) + @echo "Disabling -Werror for all modules" + @for dir in $(SUBDIRS); do \ + echo "Processing $$dir"; \ + find $$dir/src -type f \ + \( -name "Makefile" \ + -o -name "*.mk" \ + -o -name "CMakeLists.txt" \) \ + -exec sed -i 's/-Werror//g' {} +; \ + done +endif + +.PHONY: all clean distclean install $(SUBDIRS) setup_environment clean_environment install-rust uninstall-rust handle-werrors diff --git a/examples/redis-unstable/modules/common.mk b/examples/redis-unstable/modules/common.mk new file mode 100644 index 0000000..7e6a975 --- /dev/null +++ b/examples/redis-unstable/modules/common.mk @@ -0,0 +1,51 @@ +PREFIX ?= /usr/local +INSTALL_DIR ?= $(DESTDIR)$(PREFIX)/lib/redis/modules +INSTALL ?= install + +# This logic *partially* follows the current module build system. It is a bit awkward and +# should be changed if/when the modules' build process is refactored. + +ARCH_MAP_x86_64 := x64 +ARCH_MAP_i386 := x86 +ARCH_MAP_i686 := x86 +ARCH_MAP_aarch64 := arm64v8 +ARCH_MAP_arm64 := arm64v8 + +OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') +ARCH := $(ARCH_MAP_$(shell uname -m)) +ifeq ($(ARCH),) + $(error Unrecognized CPU architecture $(shell uname -m)) +endif + +FULL_VARIANT := $(OS)-$(ARCH)-release + +# Common rules for all modules, based on per-module configuration + +all: $(TARGET_MODULE) + +$(TARGET_MODULE): get_source + $(MAKE) -C $(SRC_DIR) + cp ${TARGET_MODULE} ./ + +get_source: $(SRC_DIR)/.prepared + +$(SRC_DIR)/.prepared: + mkdir -p $(SRC_DIR) + git clone --recursive --depth 1 --branch $(MODULE_VERSION) $(MODULE_REPO) $(SRC_DIR) + touch $@ + +clean: + -$(MAKE) -C $(SRC_DIR) clean + -rm -f ./*.so + +distclean: clean + -$(MAKE) -C $(SRC_DIR) distclean + +pristine: + -rm -rf $(SRC_DIR) + +install: $(TARGET_MODULE) + mkdir -p $(INSTALL_DIR) + $(INSTALL) -m 0755 -D $(TARGET_MODULE) $(INSTALL_DIR) + +.PHONY: all clean distclean pristine install diff --git a/examples/redis-unstable/modules/redisbloom/Makefile b/examples/redis-unstable/modules/redisbloom/Makefile new file mode 100644 index 0000000..32bab9a --- /dev/null +++ b/examples/redis-unstable/modules/redisbloom/Makefile @@ -0,0 +1,6 @@ +SRC_DIR = src +MODULE_VERSION = v8.5.90 +MODULE_REPO = https://github.com/redisbloom/redisbloom +TARGET_MODULE = $(SRC_DIR)/bin/$(FULL_VARIANT)/redisbloom.so + +include ../common.mk diff --git a/examples/redis-unstable/modules/redisearch/Makefile b/examples/redis-unstable/modules/redisearch/Makefile new file mode 100644 index 0000000..1672d74 --- /dev/null +++ b/examples/redis-unstable/modules/redisearch/Makefile @@ -0,0 +1,7 @@ +SRC_DIR = src +MODULE_VERSION = v8.5.90 +MODULE_REPO = https://github.com/redisearch/redisearch +TARGET_MODULE = $(SRC_DIR)/bin/$(FULL_VARIANT)/search-community/redisearch.so + +include ../common.mk + diff --git a/examples/redis-unstable/modules/redisjson/Makefile b/examples/redis-unstable/modules/redisjson/Makefile new file mode 100644 index 0000000..a3c49c5 --- /dev/null +++ b/examples/redis-unstable/modules/redisjson/Makefile @@ -0,0 +1,11 @@ +SRC_DIR = src +MODULE_VERSION = v8.5.90 +MODULE_REPO = https://github.com/redisjson/redisjson +TARGET_MODULE = $(SRC_DIR)/bin/$(FULL_VARIANT)/rejson.so + +include ../common.mk + +$(SRC_DIR)/.cargo_fetched: + cd $(SRC_DIR) && cargo fetch + +get_source: $(SRC_DIR)/.cargo_fetched diff --git a/examples/redis-unstable/modules/redistimeseries/Makefile b/examples/redis-unstable/modules/redistimeseries/Makefile new file mode 100644 index 0000000..fac3c3c --- /dev/null +++ b/examples/redis-unstable/modules/redistimeseries/Makefile @@ -0,0 +1,6 @@ +SRC_DIR = src +MODULE_VERSION = v8.5.90 +MODULE_REPO = https://github.com/redistimeseries/redistimeseries +TARGET_MODULE = $(SRC_DIR)/bin/$(FULL_VARIANT)/redistimeseries.so + +include ../common.mk diff --git a/examples/redis-unstable/modules/vector-sets/.gitignore b/examples/redis-unstable/modules/vector-sets/.gitignore new file mode 100644 index 0000000..c72b1b8 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/.gitignore @@ -0,0 +1,11 @@ +__pycache__ +misc +*.so +*.xo +*.o +.DS_Store +w2v +word2vec.bin +TODO +*.txt +*.rdb diff --git a/examples/redis-unstable/modules/vector-sets/Makefile b/examples/redis-unstable/modules/vector-sets/Makefile new file mode 100644 index 0000000..f8c05c9 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/Makefile @@ -0,0 +1,87 @@ +# Compiler settings +CC = cc + +ifdef SANITIZER +ifeq ($(SANITIZER),address) + SAN=-fsanitize=address +else +ifeq ($(SANITIZER),undefined) + SAN=-fsanitize=undefined +else +ifeq ($(SANITIZER),thread) + SAN=-fsanitize=thread +else + $(error "unknown sanitizer=${SANITIZER}") +endif +endif +endif +endif + +CFLAGS = -O2 -Wall -Wextra -g $(SAN) -std=c11 +LDFLAGS = -lm $(SAN) + +# Detect OS +uname_S := $(shell sh -c 'uname -s 2>/dev/null || echo not') +uname_M := $(shell sh -c 'uname -m 2>/dev/null || echo not') + +# Shared library compile flags for linux / osx +ifeq ($(uname_S),Linux) + SHOBJ_CFLAGS ?= -W -Wall -fno-common -g -ggdb -std=c11 -O2 + SHOBJ_LDFLAGS ?= -shared +ifneq (,$(findstring armv,$(uname_M))) + SHOBJ_LDFLAGS += -latomic +endif +ifneq (,$(findstring aarch64,$(uname_M))) + SHOBJ_LDFLAGS += -latomic +endif +else + SHOBJ_CFLAGS ?= -W -Wall -dynamic -fno-common -g -ggdb -std=c11 -O3 + SHOBJ_LDFLAGS ?= -bundle -undefined dynamic_lookup +endif + +# OS X 11.x doesn't have /usr/lib/libSystem.dylib and needs an explicit setting. +ifeq ($(uname_S),Darwin) +ifeq ("$(wildcard /usr/lib/libSystem.dylib)","") +LIBS = -L /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lsystem +endif +endif + +.SUFFIXES: .c .so .xo .o + +all: vset.so + +.c.xo: + $(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@ + +vset.xo: ../../src/redismodule.h expr.c + +vset.so: vset.xo hnsw.xo vset_config.xo + $(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) $(SAN) -lc + +# Example sources / objects +SRCS = hnsw.c w2v.c vset_config.c +OBJS = $(SRCS:.c=.o) + +TARGET = w2v +MODULE = vset.so + +# Default target +all: $(TARGET) $(MODULE) + +# Example linking rule +$(TARGET): $(OBJS) + $(CC) $(OBJS) $(LDFLAGS) -o $(TARGET) + +# Compilation rule for object files +%.o: %.c + $(CC) $(CFLAGS) -c $< -o $@ + +expr-test: expr.c fastjson.c fastjson_test.c + $(CC) $(CFLAGS) expr.c -o expr-test -DTEST_MAIN -lm + +# Clean rule +clean: + rm -f $(TARGET) $(OBJS) *.xo *.so + +# Declare phony targets +.PHONY: all clean diff --git a/examples/redis-unstable/modules/vector-sets/README.md b/examples/redis-unstable/modules/vector-sets/README.md new file mode 100644 index 0000000..be87ff3 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/README.md @@ -0,0 +1,733 @@ +**IMPORTANT:** *Please note that this is a merged module, it's part of the Redis binary now, and you don't need to build it and load it into Redis. Compiling Redis version 8 or greater will result into having the Vector Sets commands available. However, you could compile this module as a shared library in order to load it in older versions of Redis.* + +This module implements Vector Sets for Redis, a new Redis data type similar +to Sorted Sets but having string elements associated to a vector instead of +a score. The fundamental goal of Vector Sets is to make possible adding items, +and later get a subset of the added items that are the most similar to a +specified vector (often a learned embedding), or the most similar to the vector +of an element that is already part of the Vector Set. + +Moreover, Vector sets implement optional filtered search capabilities: it is possible to associate attributes to all or to a subset of elements in the set, and then, using the `FILTER` option of the `VSIM` command, to ask for items similar to a given vector but also passing a filter specified as a simple mathematical expression (Like `".year > 1950"` or similar). This means that **you can have vector similarity and scalar filters at the same time**. + +## Installation + +**WARNING:** If you are running **Redis 8.0 RC1 or greater** you don't need to install anything, just compile Redis, and the Vector Sets commands will be part of the default install. Otherwise to test Vector Sets with older Redis versions follow the following instructions. + +Build with: + + make + +Then load the module with the following command line, or by inserting the needed directives in the `redis.conf` file. + + ./redis-server --loadmodule vset.so + +To run tests, I suggest using this: + + ./redis-server --save "" --enable-debug-command yes + +The execute the tests with: + + ./test.py + +## Reference of available commands + +**VADD: add items into a vector set** + + VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT | Q8 | BIN] + [EF build-exploration-factor] [SETATTR <attributes>] [M <numlinks>] + +Add a new element into the vector set specified by the key. +The vector can be provided as FP32 blob of values, or as floating point +numbers as strings, prefixed by the number of elements (3 in the example): + + VADD mykey VALUES 3 0.1 1.2 0.5 my-element + +Meaning of the options: + +`REDUCE` implements random projection, in order to reduce the +dimensionality of the vector. The projection matrix is saved and reloaded +along with the vector set. **Please note that** the `REDUCE` option must be passed immediately before the vector, like in `REDUCE 50 VALUES ...`. + +`CAS` performs the operation partially using threads, in a +check-and-set style. The neighbor candidates collection, which is slow, is +performed in the background, while the command is executed in the main thread. + +`NOQUANT` forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default. + +`BIN` forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality. + +`Q8` forces the vector to use signed 8 bit quantization. This is the default, and the option only exists in order to make sure to check at insertion time if the vector set is of the same format. + +`EF` plays a role in the effort made to find good candidates when connecting the new node to the existing HNSW graph. The default is 200. Using a larger value, may help to have a better recall. To improve the recall it is also possible to increase `EF` during `VSIM` searches. + +`SETATTR` associates attributes to the newly created entry or update the entry attributes (if it already exists). It is the same as calling the `VSETATTR` attribute separately, so please check the documentation of that command in the filtered search section of this documentation. + +`M` defaults to 16 and is the HNSW famous `M` parameters. It is the maximum number of connections that each node of the graph have with other nodes: more connections mean more memory, but a better ability to explore the graph. Nodes at layer zero (every node exists at least at layer zero) have `M*2` connections, while the other layers only have `M` connections. This means that, for instance, an `M` of 64 will use at least 1024 bytes of memory for each node! That is, `64 links * 2 times * 8 bytes pointers`, and even more, since on average each node has something like 1.33 layers (but the other layers have just `M` connections, instead of `M*2`). If you don't have a recall quality problem, the default is fine, and uses a limited amount of memory. + +**VSIM: return elements by vector similarity** + + VSIM key [ELE|FP32|VALUES] <vector or element> [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON delta] [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD] + +The command returns similar vectors, for simplicity (and verbosity) in the following example, instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements having a vector similar to a given element already in the sorted set: + + > VSIM word_embeddings ELE apple + 1) "apple" + 2) "apples" + 3) "pear" + 4) "fruit" + 5) "berry" + 6) "pears" + 7) "strawberry" + 8) "peach" + 9) "potato" + 10) "grape" + +It is possible to specify a `COUNT` and also to get the similarity score (from 1 to 0, where 1 is identical, 0 is opposite vector) between the query and the returned items. + + > VSIM word_embeddings ELE apple WITHSCORES COUNT 3 + 1) "apple" + 2) "0.9998867657923256" + 3) "apples" + 4) "0.8598527610301971" + 5) "pear" + 6) "0.8226882219314575" + +It is also possible to specify a `EPSILON`, that is a floating point number between 0 and 1 in order to only return elements that have a distance that is no further than the specified one. In vector sets, the returned elements have a similarity score (when compared to the query vector) that is between 1 and 0, where 1 means identical, 0 opposite vectors. If for instance the `EPSILON` option is specified with an argument of 0.2, it means that we will get only elements that have a similarity of 0.8 or better (a distance < 0.2). This is useful when a large `COUNT` is specified, yet we don't want elements that are too far away our query vector. + +The `EF` argument is the exploration factor: the higher it is, the slower the command becomes, but the better the index is explored to find nodes that are near to our query. Sensible values are from 50 to 1000. + +The `TRUTH` option forces the command to perform a linear scan of all the entries inside the set, without using the graph search inside the HNSW, so it returns the best matching elements (the perfect result set) that can be used in order to easily calculate the recall. Of course the linear scan is `O(N)`, so it is much slower than the `log(N)` (considering a small `COUNT`) provided by the HNSW index. + +The `NOTHREAD` option forces the command to execute the search on the data structure in the main thread. Normally `VSIM` spawns a thread instead. This may be useful for benchmarking purposes, or when we work with extremely small vector sets and don't want to pay the cost of spawning a thread. It is possible that in the future this option will be automatically used by Redis when we detect small vector sets. Note that this option blocks the server for all the time needed to complete the command, so it is a source of potential latency issues: if you are in doubt, never use it. + +The `WITHSCORES` option returns, for each returned element, a floating point number representing how near the element is from the query, as a similarity between 0 and 1, where 0 means the vectors are opposite, and 1 means they are pointing exactly in the same direction (maximum similarity). + +The `WITHATTRIBS` option returns, for each element, the JSON attribute associated with the element, or NULL for the elements missing an attribute. + +For `FILTER` and `FILTER-EF` options, please check the filtered search section of this documentation. + +Note that when `WITHSCORES` and `WITHATTRIBS` are provided at the same time, the RESP2 reply guarantees that the returned elements are always in the sequence *ele*,*score*,*attribs*, while RESP3 replies will be in the form *ele > score|attrib* when just one is provided, or *ele -> [score,attrib]* when both are provided, that is, when both options are used and RESP3 is used the score and attribute will be a two-items array associated to the element key. + +**VDIM: return the dimension of the vectors inside the vector set** + + VDIM keyname + +Example: + + > VDIM word_embeddings + (integer) 300 + +Note that in the case of vectors that were populated using the `REDUCE` +option, for random projection, the vector set will report the size of +the projected (reduced) dimension. Yet the user should perform all the +queries using full-size vectors. + +**VCARD: return the number of elements in a vector set** + + VCARD key + +Example: + + > VCARD word_embeddings + (integer) 3000000 + + +**VREM: remove elements from vector set** + + VREM key element + +Example: + + > VADD vset VALUES 3 1 0 1 bar + (integer) 1 + > VREM vset bar + (integer) 1 + > VREM vset bar + (integer) 0 + +VREM does not perform thumstone / logical deletion, but will actually reclaim +the memory from the vector set, so it is save to add and remove elements +in a vector set in the context of long running applications that continuously +update the same index. + +**VEMB: return the approximated vector of an element** + + VEMB key element + +Example: + + > VEMB word_embeddings SQL + 1) "0.18208661675453186" + 2) "0.08535309880971909" + 3) "0.1365649551153183" + 4) "-0.16501599550247192" + 5) "0.14225517213344574" + ... 295 more elements ... + +Because vector sets perform insertion time normalization and optional +quantization, the returned vector could be approximated. `VEMB` will take +care to de-quantized and de-normalize the vector before returning it. + +It is possible to ask VEMB to return raw data, that is, the internal representation used by the vector: fp32, int8, or a bitmap for binary quantization. This behavior is triggered by the `RAW` option of of VEMB: + + VEMB word_embedding apple RAW + +In this case the return value of the command is an array of three or more elements: +1. The name of the quantization used, that is one of: "fp32", "bin", "q8". +2. The a string blob containing the raw data, 4 bytes fp32 floats for fp32, a bitmap for binary quants, or int8 bytes array for q8 quants. +3. A float representing the l2 of the vector before normalization. You need to multiply by this vector if you want to de-normalize the value for any reason. + +For q8 quantization, an additional elements is also returned: the quantization +range, so the integers from -127 to 127 represent (normalized) components +in the range `-range`, `+range`. + +**VISMEMBER: test if a given element already exists** + +This command will return 1 (or true) if the specified element is already in the vector set, otherwise 0 (or false) is returned. + + VISMEMBER key element + +As with other existence check Redis commands, if the key does not exist it is considered as if it was empty, thus the element is reported as non existing. + +**VRANGE: return elements in a lexicographical range + + VRANGE key start end count + +The `VRANGE` command has many different use cases, but its main goal is to +provide a stateless iterator for the elements inside a vector set: that is, +it allows to retrieve all the elements inside a vector set in small amounts +for each call, without an explicit cursor, and with guarantees about what +the user will miss in case the vector set is changing (elements added and/or +removed) during the iteration. + +The command usage is straightforward: + +``` +> VRANGE word_embeddings_int8 [Redis + 10 + 1) "Redis" + 2) "Rediscover" + 3) "Rediscover_Ashland" + 4) "Rediscover_Northern_Ireland" + 5) "Rediscovered" + 6) "Rediscovered_Bookshop" + 7) "Rediscovering" + 8) "Rediscovering_God" + 9) "Rediscovering_Lost" +10) "Rediscovers" +``` + +The above command returns 10 (or less, if less are available in the specified range) elements from "Redis" (inclusive) to the maximum possible element. The comparison is performed byte by byte, as `memcmp()` would do, in this way the elements have a total order. The start and end range can be either a string, prefixed by `[` or `(` (the prefix is mandatory) to tell the command if the range is inclusive or exclusive, or can be the special symbols `-` and `+` that means the maximum and minimum element. + +So for instance if I want to iterate all the elements, ten elements for each call, I'll proceed as such: + +``` +> VRANGE mykey - + 10 + 1) "a" + 2) "a-league" + 3) "a." + 4) "a.d." + 5) "a.k.a." + 6) "a.m." + 7) "a1" + 8) "a2" + 9) "a3" +10) "a7" +``` + +This will give me the first 10 elements. Then I want the next ten elements +starting from the last element in the previous result, but *excluding* it, +so the next range will use the `(` prefix with the last element of the +previous call, that was `"a7"`: + +``` +> VRANGE mykey (a7 + 10 + 1) "a930913" + 2) "aa" + 3) "aaa" + 4) "aaron" + 5) "ab" + 6) "aba" + 7) "abandon" + 8) "abandoned" + 9) "abandoning" +10) "abandonment" +``` + +And so forth. + +The command count is mandatory, however a negative count means to return all the elements in the set. This means that `VRANGE mykey - + -1` will return every element. Of course, iterating like that means that it is possible to block the server for a long time. + +The command time complexity is O(1) to seek to the element (considering the element would be of reasonable size), since we use a Radix Tree in the underlying implementation, plus the time to yield "M" elements. So if M is small, each call is just executed in constant time. However the iteration of a total set (via multiple calls) of N elements is O(N). Basically: this command, with a small count, will never produce latency issues in the Redis server. + +In case the elements are changing continuously as the set is iterated, the guarantees are very simple: each range will produce exactly the elements that were present in the range in the moment the `VRANGE` command was executed. In other words, an iteration performed in this way is *guaranteed* to return all the elements that stayed within the vector set from the start to the end of the iteration. Elements removed or added in the meantime may be returned or not depending on the moment they were added or removed. + +**VLINKS: introspection command that shows neighbors for a node** + + VLINKS key element [WITHSCORES] + +The command reports the neighbors for each level. + +**VINFO: introspection command that shows info about a vector set** + + VINFO key + +Example: + + > VINFO word_embeddings + 1) quant-type + 2) int8 + 3) vector-dim + 4) (integer) 300 + 5) size + 6) (integer) 3000000 + 7) max-level + 8) (integer) 12 + 9) vset-uid + 10) (integer) 1 + 11) hnsw-max-node-uid + 12) (integer) 3000000 + +**VSETATTR: associate or remove the JSON attributes of elements** + + VSETATTR key element "{... json ...}" + +Each element of a vector set can be optionally associated with a JSON string +in order to use the `FILTER` option of `VSIM` to filter elements by scalars +(see the filtered search section for more information). This command can set, +update (if already set) or delete (if you set to an empty string) the +associated JSON attributes of an element. + +The command returns 0 if the element or the key don't exist, without +raising an error, otherwise 1 is returned, and the element attributes +are set or updated. + +**VGETATTR: retrieve the JSON attributes of elements** + + VGETATTR key element + +The command returns the JSON attribute associated with an element, or +null if there is no element associated, or no element at all, or no key. + +**VRANDMEMBER: return random members from a vector set** + + VRANDMEMBER key [count] + +Return one or more random elements from a vector set. + +The semantics of this command are similar to Redis's native SRANDMEMBER command: + +- When called without count, returns a single random element from the set, as a single string (no array reply). +- When called with a positive count, returns up to count distinct random elements (no duplicates). +- When called with a negative count, returns count random elements, potentially with duplicates. +- If the count value is larger than the set size (and positive), only the entire set is returned. + +If the key doesn't exist, returns a Null reply if count is not given, or an empty array if a count is provided. + +Examples: + + > VADD vset VALUES 3 1 0 0 elem1 + (integer) 1 + > VADD vset VALUES 3 0 1 0 elem2 + (integer) 1 + > VADD vset VALUES 3 0 0 1 elem3 + (integer) 1 + + # Return a single random element + > VRANDMEMBER vset + "elem2" + + # Return 2 distinct random elements + > VRANDMEMBER vset 2 + 1) "elem1" + 2) "elem3" + + # Return 3 random elements with possible duplicates + > VRANDMEMBER vset -3 + 1) "elem2" + 2) "elem2" + 3) "elem1" + + # Return more elements than in the set (returns all elements) + > VRANDMEMBER vset 10 + 1) "elem1" + 2) "elem2" + 3) "elem3" + + # When key doesn't exist + > VRANDMEMBER nonexistent + (nil) + > VRANDMEMBER nonexistent 3 + (empty array) + +This command is particularly useful for: + +1. Selecting random samples from a vector set for testing or training. +2. Performance testing by retrieving random elements for subsequent similarity searches. + +When the user asks for unique elements (positev count) the implementation optimizes for two scenarios: +- For small sample sizes (less than 20% of the set size), it uses a dictionary to avoid duplicates, and performs a real random walk inside the graph. +- For large sample sizes (more than 20% of the set size), it starts from a random node and sequentially traverses the internal list, providing faster performances but not really "random" elements. + +The command has `O(N)` worst-case time complexity when requesting many unique elements (it uses linear scanning), or `O(M*log(N))` complexity when the users asks for `M` random elements in a sorted set of `N` elements, with `M` much smaller than `N`. + +# Filtered search + +Each element of the vector set can be associated with a set of attributes specified as a JSON blob: + + > VADD vset VALUES 3 1 1 1 a SETATTR '{"year": 1950}' + (integer) 1 + > VADD vset VALUES 3 -1 -1 -1 b SETATTR '{"year": 1951}' + (integer) 1 + +Specifying an attribute with the `SETATTR` option of `VADD` is exactly equivalent to adding an element and then setting (or updating, if already set) the attributes JSON string. Also the symmetrical `VGETATTR` command returns the attribute associated to a given element. + + > VADD vset VALUES 3 0 1 0 c + (integer) 1 + > VSETATTR vset c '{"year": 1952}' + (integer) 1 + > VGETATTR vset c + "{\"year\": 1952}" + +At this point, I may use the FILTER option of VSIM to only ask for the subset of elements that are verified by my expression: + + > VSIM vset VALUES 3 0 0 0 FILTER '.year > 1950' + 1) "c" + 2) "b" + +The items will be returned again in order of similarity (most similar first), but only the items with the year field matching the expression is returned. + +The expressions are similar to what you would write inside the `if` statement of JavaScript or other familiar programming languages: you can use `and`, `or`, the obvious math operators like `+`, `-`, `/`, `>=`, `<`, ... and so forth (see the expressions section for more info). The selectors of the JSON object attributes start with a dot followed by the name of the key inside the JSON objects. + +Elements with invalid JSON or not having a given specified field **are considered as not matching** the expression, but will not generate any error at runtime. + +## FILTER expressions capabilities + +FILTER expressions allow you to perform complex filtering on vector similarity results using a JavaScript-like syntax. The expression is evaluated against each element's JSON attributes, with only elements that satisfy the expression being included in the results. + +### Expression Syntax + +Expressions support the following operators and capabilities: + +1. **Arithmetic operators**: `+`, `-`, `*`, `/`, `%` (modulo), `**` (exponentiation) +2. **Comparison operators**: `>`, `>=`, `<`, `<=`, `==`, `!=` +3. **Logical operators**: `and`/`&&`, `or`/`||`, `!`/`not` +4. **Containment operator**: `in` +5. **Parentheses** for grouping: `(...)` + +### Selector Notation + +Attributes are accessed using dot notation: + +- `.year` references the "year" attribute +- `.movie.year` would **NOT** reference the "year" field inside a "movie" object, only keys that are at the first level of the JSON object are accessible. + +### JSON and expressions data types + +Expressions can work with: + +- Numbers (dobule precision floats) +- Strings (enclosed in single or double quotes) +- Booleans (no native type: they are represented as 1 for true, 0 for false) +- Arrays (for use with the `in` operator: `value in [1, 2, 3]`) + +JSON attributes are converted in this way: + +- Numbers will be converted to numbers. +- Strings to strings. +- Booleans to 0 or 1 number. +- Arrays to tuples (for "in" operator), but only if composed of just numbers and strings. + +Any other type is ignored, and accessig it will make the expression evaluate to false. + +### The IN operator + +The `IN` operator works in two ways, it can test for membership in an array, like in: + + 5 in [1, 2, 3] + "foo" in [1, "foo", "bar"] + +But can also check for substrings, in case the A and B operators are both strings. + + "foo" in "barfoobar" # Will evaluate to true + "zap" in "foobar" # Will evaluate to false + +### Examples + +``` +# Find items from the 1980s +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.year >= 1980 and .year < 1990' + +# Find action movies with high ratings +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.genre == "action" and .rating > 8.0' + +# Find movies directed by either Spielberg or Nolan +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.director in ["Spielberg", "Nolan"]' + +# Complex condition with numerical operations +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '(.year - 2000) ** 2 < 100 and .rating / 2 > 4' +``` + +### Error Handling + +Elements with any of the following conditions are considered not matching: +- Missing the queried JSON attribute +- Having invalid JSON in their attributes +- Having a JSON value that cannot be converted to the expected type + +This behavior allows you to safely filter on optional attributes without generating errors. + +### FILTER effort + +The `FILTER-EF` option controls the maximum effort spent when filtering vector search results. + +When performing vector similarity search with filtering, Vector Sets perform the standard similarity search as they apply the filter expression to each node. Since many results might be filtered out, Vector Sets may need to examine a lot more candidates than the requested `COUNT` to ensure sufficient matching results are returned. Actually, if the elements matching the filter are very rare or if there are less than elements matching than the specified count, this would trigger a full scan of the HNSW graph. + +For this reason, by default, the maximum effort is limited to a reasonable amount of nodes explored. + +### Modifying the FILTER effort + +1. By default, Vector Sets will explore up to `COUNT * 100` candidates to find matching results. +2. You can control this exploration with the `FILTER-EF` parameter. +3. A higher `FILTER-EF` value increases the chances of finding all relevant matches at the cost of increased processing time. +4. A `FILTER-EF` of zero will explore as many nodes as needed in order to actually return the number of elements specified by `COUNT`. +5. Even when a high `FILTER-EF` value is specified **the implementation will do a lot less work** if the elements passing the filter are very common, because of the early stop conditions of the HNSW implementation (once the specified amount of elements is reached and the quality check of the other candidates trigger an early stop). + +``` +VSIM key [ELE|FP32|VALUES] <vector or element> COUNT 10 FILTER '.year > 2000' FILTER-EF 500 +``` + +In this example, Vector Sets will examine up to 500 potential nodes. Of course if count is reached before exploring 500 nodes, and the quality checks show that it is not possible to make progresses on similarity, the search is ended sooner. + +### Performance Considerations + +- If you have highly selective filters (few items match), use a higher `FILTER-EF`, or just design your application in order to handle a result set that is smaller than the requested count. Note that anyway the additional elements may be too distant than the query vector. +- For less selective filters, the default should be sufficient. +- Very selective filters with low `FILTER-EF` values may return fewer items than requested. +- Extremely high values may impact performance without significantly improving results. + +The optimal `FILTER-EF` value depends on: +1. The selectivity of your filter. +2. The distribution of your data. +3. The required recall quality. + +A good practice is to start with the default and increase if needed when you observe fewer results than expected. + +### Testing a larg-ish data set + +To really see how things work at scale, you can [download](https://antirez.com/word2vec_with_attribs.rdb) the following dataset: + + wget https://antirez.com/word2vec_with_attribs.rdb + +It contains the 3 million words in Word2Vec having as attribute a JSON with just the length of the word. Because of the length distribution of words in large amounts of texts, where longer words become less and less common, this is ideal to check how filtering behaves with a filter verifying as true with less and less elements in a vector set. + +For instance: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 6" + 1) "pastas" + 2) "rotini" + 3) "gnocci" + 4) "panino" + 5) "salads" + 6) "breads" + 7) "salame" + 8) "sauces" + 9) "cheese" + 10) "fritti" + +This will easily retrieve the desired amount of items (`COUNT` is 10 by default) since there are many items of length 6. However: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33" + 1) "skinless_boneless_chicken_breasts" + 2) "boneless_skinless_chicken_breasts" + 3) "Boneless_skinless_chicken_breasts" + +This time even if we asked for 10 items, we only get 3, since the default filter effort will be `10*100 = 1000`. We can tune this giving the effort in an explicit way, with the risk of our query being slower, of course: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33" FILTER-EF 10000 + 1) "skinless_boneless_chicken_breasts" + 2) "boneless_skinless_chicken_breasts" + 3) "Boneless_skinless_chicken_breasts" + 4) "mozzarella_feta_provolone_cheddar" + 5) "Greatfood.com_R_www.greatfood.com" + 6) "Pepperidge_Farm_Goldfish_crackers" + 7) "Prosecuted_Mobsters_Rebuilt_Dying" + 8) "Crispy_Snacker_Sandwiches_Popcorn" + 9) "risultati_delle_partite_disputate" + 10) "Peppermint_Mocha_Twist_Gingersnap" + +This time we get all the ten items, even if the last one will be quite far from our query vector. We encourage to experiment with this test dataset in order to understand better the dynamics of the implementation and the natural tradeoffs of filtered search. + +**Keep in mind** that by default, Redis Vector Sets will try to avoid a likely very useless huge scan of the HNSW graph, and will be more happy to return few or no elements at all, since this is almost always what the user actually wants in the context of retrieving *similar* items to the query. + +# Single Instance Scalability and Latency + +Vector Sets implement a threading model that allows Redis to handle many concurrent requests: by default `VSIM` is always threaded, and `VADD` is not (but can be partially threaded using the `CAS` option). This section explains how the threading and locking mechanisms work, and what to expect in terms of performance. + +## Threading Model + +- The `VSIM` command runs in a separate thread by default, allowing Redis to continue serving other commands. +- A maximum of 32 threads can run concurrently (defined by `HNSW_MAX_THREADS`). +- When this limit is reached, additional `VSIM` requests are queued - Redis remains responsive, no latency event is generated. +- The `VADD` command with the `CAS` option also leverages threading for the computation-heavy candidate search phase, but the insertion itself is performed in the main thread. `VADD` always runs in a sub-millisecond time, so this is not a source of latency, but having too many hundreds of writes per second can be challenging to handle with a single instance. Please, look at the next section about multiple instances scalability. +- Commands run within Lua scripts, MULTI/EXEC blocks, or from replication are executed in the main thread to ensure consistency. + +``` +> VSIM vset VALUES 3 1 1 1 FILTER '.year > 2000' # This runs in a thread. +> VADD vset VALUES 3 1 1 1 element CAS # Candidate search runs in a thread. +``` + +## Locking Mechanism + +Vector Sets use a read/write locking mechanism to coordinate access: + +- Reads (`VSIM`, `VEMB`, etc.) acquire a read lock, allowing multiple concurrent reads. +- Writes (`VADD`, `VREM`, etc.) acquire a write lock, temporarily blocking all reads. +- When a write lock is requested while reads are in progress, the write operation waits for all reads to complete. +- Once a write lock is granted, all reads are blocked until the write completes. +- Each thread has a dedicated slot for tracking visited nodes during graph traversal, avoiding contention. This improves performances but limits the maximum number of concurrent threads, since each node has a memory cost proportional to the number of slots. + +## DEL latency + +Deleting a very large vector set (millions of elements) can cause latency spikes, as deletion rebuilds connections between nodes. This may change in the future. +The deletion latency is most noticeable when using `DEL` on a key containing a large vector set or when the key expires. + +## Performance Characteristics + +- Search operations (`VSIM`) scale almost linearly with the number of CPU cores available, up to the thread limit. You can expect a Vector Set composed of million of items associated with components of dimension 300, with the default int8 quantization, to deliver around 50k VSIM operations per second in a single host. +- Insertion operations (`VADD`) are more computationally expensive than searches, and can't be threaded: expect much lower throughput, in the range of a few thousands inserts per second. +- Binary quantization offers significantly faster search performance at the cost of some recall quality, while int8 quantization, the default, seems to have very small impacts on recall quality, while it significantly improves performances and space efficiency. +- The `EF` parameter has a major impact on both search quality and performance - higher values mean better recall but slower searches. +- Graph traversal time scales logarithmically with the number of elements, making Vector Sets efficient even with millions of vectors + +## Loading / Saving performances + +Vector Sets are able to serialize on disk the graph structure as it is in memory, so loading back the data does not need to rebuild the HNSW graph. This means that Redis can load millions of items per minute. For instance 3 million items with 300 components vectors can be loaded back into memory into around 15 seconds. + +# Scaling vector sets to multiple instances + +The fundamental way vector sets can be scaled to very large data sets +and to many Redis instances is that a given very large set of vectors +can be partitioned into N different Redis keys, that can also live into +different Redis instances. + +For instance, I could add my elements into `key0`, `key1`, `key2`, by hashing +the item in some way, like doing `crc32(item)%3`, effectively splitting +the dataset into three different parts. However once I want all the vectors +of my dataset near to a given query vector, I could simply perform the +`VSIM` command against all the three keys, merging the results by +score (so the commands must be called using the `WITHSCORES` option) on +the client side: once the union of the results are ordered by the +similarity score, the query is equivalent to having a single key `key1+2+3` +containing all the items. + +There are a few interesting facts to note about this pattern: + +1. It is possible to have a logical sorted set that is as big as the sum of all the Redis instances we are using. +2. Deletion operations remain simple, we can hash the key and select the key where our item belongs. +3. However, even if I use 10 different Redis instances, I'm not going to reach 10x the **read** operations per second, compared to using a single server: for each logical query, I need to query all the instances. Yet, smaller graphs are faster to navigate, so there is some win even from the point of view of CPU usage. +4. Insertions, so **write** queries, will be scaled linearly: I can add N items against N instances at the same time, splitting the insertion load evenly. This is very important since vector sets, being based on HNSW data structures, are slower to add items than to query similar items, by a very big factor. +5. While it cannot guarantee always the best results, with proper timeout management this system may be considered *highly available*, since if a subset of N instances are reachable, I'll be still be able to return similar items to my query vector. + +Notably, this pattern can be implemented in a way that avoids paying the sum of the round trip time with all the servers: it is possible to send the queries at the same time to all the instances, so that latency will be equal the slower reply out of of the N servers queries. + +# Optimizing memory usage + +Vector Sets, or better, HNSWs, the underlying data structure used by Vector Sets, combined with the features provided by the Vector Sets themselves (quantization, random projection, filtering, ...) form an implementation that has a non-trivial space of parameters that can be tuned. Despite to the complexity of the implementation and of vector similarity problems, here there is a list of simple ideas that can drive the user to pick the best settings: + +* 8 bit quantization (the default) is almost always a win. It reduces the memory usage of vectors by a factor of 4, yet the performance penalty in terms of recall is minimal. It also reduces insertion and search time by around 2 times or more. +* Binary quantization is much more extreme: it makes vector sets a lot faster, but increases the recall error in a sensible way, for instance from 95% to 80% if all the parameters remain the same. Yet, the speedup is really big, and the memory usage of vectors, compaerd to full precision vectors, 32 times smaller. +* Vectors memory usage are not the only responsible for Vector Set high memory usage per entry: nodes contain, on average `M*2 + M*0.33` pointers, where M is by default 16 (but can be tuned in `VADD`, see the `M` option). Also each node has the string item and the optional JSON attributes: those should be as small as possible in order to avoid contributing more to the memory usage. +* The `M` parameter should be increased to 32 or more only when a near perfect recall is really needed. +* It is possible to gain space (less memory usage) sacrificing time (more CPU time) by using a low `M` (the default of 16, for instance) and a high `EF` (the effort parameter of `VSIM`) in order to scan the graph more deeply. +* When memory usage is seriosu concern, and there is the suspect the vectors we are storing don't contain as much information - at least for our use case - to justify the number of components they feature, random projection (the `REDUCE` option of `VADD`) could be tested to see if dimensionality reduction is possible with acceptable precision loss. + +## Random projection tradeoffs + +Sometimes learned vectors are not as information dense as we could guess, that +is there are components having similar meanings in the space, and components +having values that don't really represent features that matter in our use case. + +At the same time, certain vectors are very big, 1024 components or more. In this cases, it is possible to use the random projection feature of Redis Vector Sets in order to reduce both space (less RAM used) and space (more operstions per second). The feature is accessible via the `REDUCE` option of the `VADD` command. However, keep in mind that you need to test how much reduction impacts the performances of your vectors in term of recall and quality of the results you get back. + +## What is a random projection? + +The concept of Random Projection is relatively simple to grasp. For instance, a projection that turns a 100 components vector into a 10 components vector will perform a different linear transformation between the 100 components and each of the target 10 components. Please note that *each of the target components* will get some random amount of all the 100 original components. It is mathematically proved that this process results in a vector space where elements still have similar distances among them, but still some information will get lost. + +## Examples of projections and loss of precision + +To show you a bit of a extreme case, let's take Word2Vec 3 million items and compress them from 300 to 100, 50 and 25 components vectors. Then, we check the recall compared to the ground truth against each of the vector sets produced in this way (using different `REDUCE` parameters of `VADD`). This is the result, obtained asking for the top 10 elements. + +``` +---------------------------------------------------------------------- +Key Average Recall % Std Dev +---------------------------------------------------------------------- +word_embeddings_int8 95.98 12.14 + ^ This is the same key used for ground truth, but without TRUTH option +word_embeddings_reduced_100 40.20 20.13 +word_embeddings_reduced_50 24.42 16.89 +word_embeddings_reduced_25 14.31 9.99 +``` + +Here the dimensionality reduction we are using is quite extreme: from 300 to 100 means that 66.6% of the original information is lost. The recall drops from 96% to 40%, down to 24% and 14% for even more extreme dimension reduction. + +Reducing the dimension of vectors that are already relatively small, like the above example, of 300 components, will provide only relatively small memory savings, especially because by default Vector Sets use `int8` quantization, that will use only one byte per component: + +``` +> MEMORY USAGE word_embeddings_int8 +(integer) 3107002888 +> MEMORY USAGE word_embeddings_reduced_100 +(integer) 2507122888 +``` + +Of course going, for example, from 2048 component vectors to 1024 would provide a much more sensible memory saving, even with the `int8` quantization used by Vector Sets, assuming the recall loss is acceptable. Other than the memory saving, there is also the reduction in CPU time, translating to more operations per second. + +Another thing to note is that, with certain embedding models, binary quantization (that offers a 8x reduction of memory usage compared to 8 bit quants, and a very big speedup in computation) performs much better than reducing the dimension of vectors of the same amount via random projections: + +``` +word_embeddings_bin 35.48 19.78 +``` + +Here in the same test did above: we have a 35% recall which is not too far than the 40% obtained with a random projection from 300 to 100 components. However, while the first technique reduces the size by 3 times, the size reduced of binary quantization is by 8 times. + +``` +> memory usage word_embeddings_bin +(integer) 2327002888 +``` + +In this specific case the key uses JSON attributes and has a graph connection overhead that is much bigger than the 300 bits each vector takes, but, as already said, for big vectors (1024 components, for instance) or for lower values of `M` (see `VADD`, the `M` parameter connects the level of connectivity, so it changes the amount of pointers used per node) the memory saving is much stronger. + +# Vector Sets troubleshooting and understandability + +## Debugging poor recall or unexpected results + +Vector graphs and similarity queries pose many challenges mainly due to the following three problems: + +1. The error due to the approximated nature of Vector Sets is hard to evaluate. +2. The error added by the quantization is often depends on the exact vector space (the embedding we are using **and** how far apart the elements we represent into such embeddings are). +3. We live in the illusion that learned embeddings capture the best similarity possible among elements, which is obviously not always true, and highly application dependent. + +The only way to debug such problems, is the ability to inspect step by step what is happening inside our application, and the structure of the HNSW graph itself. To do so, we suggest to consider the following tools: + +1. The `TRUTH` option of the `VSIM` command is able to return the ground truth of the most similar elements, without using the HNSW graph, but doing a linear scan. +2. The `VLINKS` command allows to explore the graph to see if the connections among nodes make sense, and to investigate why a given node may be more isolated than expected. Such command can also be used in a different way, when we want very fast "similar items" without paying the HNSW traversal time. It exploits the fact that we have a direct reference from each element in our vector set to each node in our HNSW graph. +3. The `WITHSCORES` option, in the supported commands, return a value that is directly related to the *cosine similarity* between the query and the items vectors, the interval of the similarity is simply rescaled from the -1, 1 original range to 0, 1, otherwise the metric is identical. + +## Clients, latency and bandwidth usage + +During Vector Sets testing, we discovered that often clients introduce considerable latecy and CPU usage (in the client side, not in Redis) for two main reasons: + +1. Often the serialization to `VALUES ... list of floats ...` can be very slow. +2. The vector payload of floats represented as strings is very large, resulting in high bandwidth usage and latency, compared to other Redis commands. + +Switching from `VALUES` to `FP32` as a method for transmitting vectors may easily provide 10-20x speedups. + +# Implementation details + +Vector sets are based on the `hnsw.c` implementation of the HNSW data structure with extensions for speed and functionality. + +The main features are: + +* Proper nodes deletion with relinking. +* 8 bits and binary quantization. +* Threaded queries. +* Filtered search with predicate callback. diff --git a/examples/redis-unstable/modules/vector-sets/commands.json b/examples/redis-unstable/modules/vector-sets/commands.json new file mode 100644 index 0000000..5f42f87 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/commands.json @@ -0,0 +1,446 @@ +{ + "VADD": { + "summary": "Add one or more elements to a vector set, or update its vector if it already exists", + "complexity": "O(log(N)) for each element added, where N is the number of elements in the vector set.", + "group": "vector_set", + "since": "8.0.0", + "arity": -5, + "function": "vaddCommand", + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "token": "REDUCE", + "name": "reduce", + "type": "block", + "optional": true, + "arguments": [ + { + "name": "dim", + "type": "integer" + } + ] + }, + { + "name": "format", + "type": "oneof", + "arguments": [ + { + "name": "fp32", + "type": "pure-token", + "token": "FP32" + }, + { + "name": "values", + "type": "pure-token", + "token": "VALUES" + } + ] + }, + { + "name": "vector", + "type": "string" + }, + { + "name": "element", + "type": "string" + }, + { + "token": "CAS", + "name": "cas", + "type": "pure-token", + "optional": true + }, + { + "name": "quant_type", + "type": "oneof", + "optional": true, + "arguments": [ + { + "name": "noquant", + "type": "pure-token", + "token": "NOQUANT" + }, + { + "name": "bin", + "type": "pure-token", + "token": "BIN" + }, + { + "name": "q8", + "type": "pure-token", + "token": "Q8" + } + ] + }, + { + "token": "EF", + "name": "build-exploration-factor", + "type": "integer", + "optional": true + }, + { + "token": "SETATTR", + "name": "attributes", + "type": "string", + "optional": true + }, + { + "token": "M", + "name": "numlinks", + "type": "integer", + "optional": true + } + ], + "command_flags": [ + "WRITE", + "DENYOOM" + ] + }, + "VREM": { + "summary": "Remove an element from a vector set", + "complexity": "O(log(N)) for each element removed, where N is the number of elements in the vector set.", + "group": "vector_set", + "since": "8.0.0", + "arity": 3, + "function": "vremCommand", + "command_flags": [ + "WRITE" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "element", + "type": "string" + } + ] + }, + "VSIM": { + "summary": "Return elements by vector similarity", + "complexity": "O(log(N)) where N is the number of elements in the vector set.", + "group": "vector_set", + "since": "8.0.0", + "arity": -4, + "function": "vsimCommand", + "command_flags": [ + "READONLY" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "format", + "type": "oneof", + "arguments": [ + { + "name": "ele", + "type": "pure-token", + "token": "ELE" + }, + { + "name": "fp32", + "type": "pure-token", + "token": "FP32" + }, + { + "name": "values", + "type": "pure-token", + "token": "VALUES" + } + ] + }, + { + "name": "vector_or_element", + "type": "string" + }, + { + "token": "WITHSCORES", + "name": "withscores", + "type": "pure-token", + "optional": true + }, + { + "token": "WITHATTRIBS", + "name": "withattribs", + "type": "pure-token", + "optional": true + }, + { + "token": "COUNT", + "name": "count", + "type": "integer", + "optional": true + }, + { + "token": "EPSILON", + "name": "max_distance", + "type": "double", + "optional": true + }, + { + "token": "EF", + "name": "search-exploration-factor", + "type": "integer", + "optional": true + }, + { + "token": "FILTER", + "name": "expression", + "type": "string", + "optional": true + }, + { + "token": "FILTER-EF", + "name": "max-filtering-effort", + "type": "integer", + "optional": true + }, + { + "token": "TRUTH", + "name": "truth", + "type": "pure-token", + "optional": true + }, + { + "token": "NOTHREAD", + "name": "nothread", + "type": "pure-token", + "optional": true + } + ] + }, + "VDIM": { + "summary": "Return the dimension of vectors in the vector set", + "complexity": "O(1)", + "group": "vector_set", + "since": "8.0.0", + "arity": 2, + "function": "vdimCommand", + "command_flags": [ + "READONLY", + "FAST" + ], + "arguments": [ + { + "name": "key", + "type": "key" + } + ] + }, + "VCARD": { + "summary": "Return the number of elements in a vector set", + "complexity": "O(1)", + "group": "vector_set", + "since": "8.0.0", + "arity": 2, + "function": "vcardCommand", + "command_flags": [ + "READONLY", + "FAST" + ], + "arguments": [ + { + "name": "key", + "type": "key" + } + ] + }, + "VEMB": { + "summary": "Return the vector associated with an element", + "complexity": "O(1)", + "group": "vector_set", + "since": "8.0.0", + "arity": -3, + "function": "vembCommand", + "command_flags": [ + "READONLY" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "element", + "type": "string" + }, + { + "token": "RAW", + "name": "raw", + "type": "pure-token", + "optional": true + } + ] + }, + "VLINKS": { + "summary": "Return the neighbors of an element at each layer in the HNSW graph", + "complexity": "O(1)", + "group": "vector_set", + "since": "8.0.0", + "arity": -3, + "function": "vlinksCommand", + "command_flags": [ + "READONLY" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "element", + "type": "string" + }, + { + "token": "WITHSCORES", + "name": "withscores", + "type": "pure-token", + "optional": true + } + ] + }, + "VINFO": { + "summary": "Return information about a vector set", + "complexity": "O(1)", + "group": "vector_set", + "since": "8.0.0", + "arity": 2, + "function": "vinfoCommand", + "command_flags": [ + "READONLY", + "FAST" + ], + "arguments": [ + { + "name": "key", + "type": "key" + } + ] + }, + "VSETATTR": { + "summary": "Associate or remove the JSON attributes of elements", + "complexity": "O(1)", + "group": "vector_set", + "since": "8.0.0", + "arity": 4, + "function": "vsetattrCommand", + "command_flags": [ + "WRITE" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "element", + "type": "string" + }, + { + "name": "json", + "type": "string" + } + ] + }, + "VGETATTR": { + "summary": "Retrieve the JSON attributes of elements", + "complexity": "O(1)", + "group": "vector_set", + "since": "8.0.0", + "arity": 3, + "function": "vgetattrCommand", + "command_flags": [ + "READONLY" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "element", + "type": "string" + } + ] + }, + "VRANDMEMBER": { + "summary": "Return one or multiple random members from a vector set", + "complexity": "O(N) where N is the absolute value of the count argument.", + "group": "vector_set", + "since": "8.0.0", + "arity": -2, + "function": "vrandmemberCommand", + "command_flags": [ + "READONLY" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "count", + "type": "integer", + "optional": true + } + ] + }, + "VISMEMBER": { + "summary": "Check if an element exists in a vector set", + "complexity": "O(1)", + "group": "vector_set", + "since": "8.2.0", + "arity": 3, + "function": "vismemberCommand", + "command_flags": [ + "READONLY" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "element", + "type": "string" + } + ] + }, + "VRANGE": { + "summary": "Return elements in a lexicographical range", + "complexity": "O(log(K)+M) where K is the number of elements in the start prefix, and M is the number of elements returned. In practical terms, the command is just O(M)", + "group": "vector_set", + "since": "8.4.0", + "arity": -4, + "function": "vrangeCommand", + "command_flags": [ + "READONLY" + ], + "arguments": [ + { + "name": "key", + "type": "key" + }, + { + "name": "start", + "type": "string" + }, + { + "name": "end", + "type": "string" + }, + { + "name": "count", + "type": "integer", + "optional": true + } + ] + } +} diff --git a/examples/redis-unstable/modules/vector-sets/examples/cli-tool/.gitignore b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/.gitignore new file mode 100644 index 0000000..5ceb386 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/.gitignore @@ -0,0 +1 @@ +venv diff --git a/examples/redis-unstable/modules/vector-sets/examples/cli-tool/README.md b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/README.md new file mode 100644 index 0000000..ad21744 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/README.md @@ -0,0 +1,44 @@ +This tool is similar to redis-cli (but very basic) but allows +to specify arguments that are expanded as vectors by calling +ollama to get the embedding. + +Whatever is passed as !"foo bar" gets expanded into + VALUES ... embedding ... + +You must have ollama running with the mxbai-emb-large model +already installed for this to work. + +Example: + + redis> KEYS * + 1) food_items + 2) glove_embeddings_bin + 3) many_movies_mxbai-embed-large_BIN + 4) many_movies_mxbai-embed-large_NOQUANT + 5) word_embeddings + 6) word_embeddings_bin + 7) glove_embeddings_fp32 + + redis> VSIM food_items !"drinks with fruit" + 1) (Fruit)Juices,Lemonade,100ml,50 cal,210 kJ + 2) (Fruit)Juices,Limeade,100ml,128 cal,538 kJ + 3) CannedFruit,Canned Fruit Cocktail,100g,81 cal,340 kJ + 4) (Fruit)Juices,Energy-Drink,100ml,87 cal,365 kJ + 5) Fruits,Lime,100g,30 cal,126 kJ + 6) (Fruit)Juices,Coconut Water,100ml,19 cal,80 kJ + 7) Fruits,Lemon,100g,29 cal,122 kJ + 8) (Fruit)Juices,Clamato,100ml,60 cal,252 kJ + 9) Fruits,Fruit salad,100g,50 cal,210 kJ + 10) (Fruit)Juices,Capri-Sun,100ml,41 cal,172 kJ + + redis> vsim food_items !"barilla" + 1) Pasta&Noodles,Spirelli,100g,367 cal,1541 kJ + 2) Pasta&Noodles,Farfalle,100g,358 cal,1504 kJ + 3) Pasta&Noodles,Capellini,100g,353 cal,1483 kJ + 4) Pasta&Noodles,Spaetzle,100g,368 cal,1546 kJ + 5) Pasta&Noodles,Cappelletti,100g,164 cal,689 kJ + 6) Pasta&Noodles,Penne,100g,351 cal,1474 kJ + 7) Pasta&Noodles,Shells,100g,353 cal,1483 kJ + 8) Pasta&Noodles,Linguine,100g,357 cal,1499 kJ + 9) Pasta&Noodles,Rotini,100g,353 cal,1483 kJ + 10) Pasta&Noodles,Rigatoni,100g,353 cal,1483 kJ diff --git a/examples/redis-unstable/modules/vector-sets/examples/cli-tool/cli.py b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/cli.py new file mode 100755 index 0000000..614a021 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/cli-tool/cli.py @@ -0,0 +1,160 @@ +# +# Copyright (c) 2009-Present, Redis Ltd. +# All rights reserved. +# +# Licensed under your choice of (a) the Redis Source Available License 2.0 +# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the +# GNU Affero General Public License v3 (AGPLv3). +# + +#!/usr/bin/env python3 +import argparse +import redis +import requests +import re +import shlex +from prompt_toolkit import PromptSession +from prompt_toolkit.history import InMemoryHistory + +# Default Ollama embeddings URL (can be overridden with --ollama-url) +OLLAMA_URL = "http://localhost:11434/api/embeddings" + +def get_embedding(text): + """Get embedding from local Ollama API""" + url = OLLAMA_URL + payload = { + "model": "mxbai-embed-large", + "prompt": text + } + try: + response = requests.post(url, json=payload) + response.raise_for_status() + return response.json()['embedding'] + except requests.exceptions.RequestException as e: + raise Exception(f"Failed to get embedding: {str(e)}") + +def process_embedding_patterns(text): + """Process !"text" and !!"text" patterns in the command""" + + def replace_with_embedding(match): + text = match.group(1) + embedding = get_embedding(text) + return f"VALUES {len(embedding)} {' '.join(map(str, embedding))}" + + def replace_with_embedding_and_text(match): + text = match.group(1) + embedding = get_embedding(text) + # Return both the embedding values and the original text as next argument + return f'VALUES {len(embedding)} {" ".join(map(str, embedding))} "{text}"' + + # First handle !!"text" pattern (must be done before !"text") + text = re.sub(r'!!"([^"]*)"', replace_with_embedding_and_text, text) + # Then handle !"text" pattern + text = re.sub(r'!"([^"]*)"', replace_with_embedding, text) + return text + +def parse_command(command): + """Parse command respecting quoted strings""" + try: + # Use shlex to properly handle quoted strings + return shlex.split(command) + except ValueError as e: + raise Exception(f"Invalid command syntax: {str(e)}") + +def format_response(response): + """Format the response to match Redis protocol style""" + if response is None: + return "(nil)" + elif isinstance(response, bool): + return "+OK" if response else "(error) Operation failed" + elif isinstance(response, (list, set)): + if not response: + return "(empty list or set)" + return "\n".join(f"{i+1}) {item}" for i, item in enumerate(response)) + elif isinstance(response, int): + return f"(integer) {response}" + else: + return str(response) + +def main(): + global OLLAMA_URL + + parser = argparse.ArgumentParser(prog="cli.py", add_help=False) + parser.add_argument("--ollama-url", dest="ollama_url", + help="Ollama embeddings API URL (default: {OLLAMA_URL})", + default=OLLAMA_URL) + args, _ = parser.parse_known_args() + OLLAMA_URL = args.ollama_url + + # Default connection to localhost:6379 + r = redis.Redis(host='localhost', port=6379, decode_responses=True) + + try: + # Test connection + r.ping() + print("Connected to Redis. Type your commands (CTRL+D to exit):") + print("Special syntax:") + print(" !\"text\" - Replace with embedding") + print(" !!\"text\" - Replace with embedding and append text as value") + print(" \"text\" - Quote strings containing spaces") + except redis.ConnectionError: + print("Error: Could not connect to Redis server") + return + + # Setup prompt session with history + session = PromptSession(history=InMemoryHistory()) + + # Main loop + while True: + try: + # Read input with line editing support + command = session.prompt("redis> ") + + # Skip empty commands + if not command.strip(): + continue + + # Process any embedding patterns before parsing + try: + processed_command = process_embedding_patterns(command) + except Exception as e: + print(f"(error) Embedding processing failed: {str(e)}") + continue + + # Parse the command respecting quoted strings + try: + parts = parse_command(processed_command) + except Exception as e: + print(f"(error) {str(e)}") + continue + + if not parts: + continue + + cmd = parts[0].lower() + args = parts[1:] + + # Execute command + try: + method = getattr(r, cmd, None) + if method is not None: + result = method(*args) + else: + # Use execute_command for unknown commands + result = r.execute_command(cmd, *args) + print(format_response(result)) + except AttributeError: + print(f"(error) Unknown command '{cmd}'") + + except EOFError: + print("\nGoodbye!") + break + except KeyboardInterrupt: + continue # Allow Ctrl+C to clear current line + except redis.RedisError as e: + print(f"(error) {str(e)}") + except Exception as e: + print(f"(error) {str(e)}") + +if __name__ == "__main__": + main() diff --git a/examples/redis-unstable/modules/vector-sets/examples/glove-100/README b/examples/redis-unstable/modules/vector-sets/examples/glove-100/README new file mode 100644 index 0000000..e8bb6dd --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/glove-100/README @@ -0,0 +1,3 @@ +wget http://ann-benchmarks.com/glove-100-angular.hdf5 +python insert.py +python recall.py (use --k <count> optionally, default top-10) diff --git a/examples/redis-unstable/modules/vector-sets/examples/glove-100/insert.py b/examples/redis-unstable/modules/vector-sets/examples/glove-100/insert.py new file mode 100644 index 0000000..595c77f --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/glove-100/insert.py @@ -0,0 +1,56 @@ +# +# Copyright (c) 2009-Present, Redis Ltd. +# All rights reserved. +# +# Licensed under your choice of (a) the Redis Source Available License 2.0 +# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the +# GNU Affero General Public License v3 (AGPLv3). +# + +import h5py +import redis +from tqdm import tqdm + +# Initialize Redis connection +redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8') + +def add_to_redis(index, embedding): + """Add embedding to Redis using VADD command""" + args = ["VADD", "glove_embeddings", "VALUES", "100"] # 100 is vector dimension + args.extend(map(str, embedding)) + args.append(f"{index}") # Using index as identifier since we don't have words + args.append("EF") + args.append("200") + # args.append("NOQUANT") + # args.append("BIN") + redis_client.execute_command(*args) + +def main(): + with h5py.File('glove-100-angular.hdf5', 'r') as f: + # Get the train dataset + train_vectors = f['train'] + total_vectors = train_vectors.shape[0] + + print(f"Starting to process {total_vectors} vectors...") + + # Process in batches to avoid memory issues + batch_size = 1000 + + for i in tqdm(range(0, total_vectors, batch_size)): + batch_end = min(i + batch_size, total_vectors) + batch = train_vectors[i:batch_end] + + for j, vector in enumerate(batch): + try: + current_index = i + j + add_to_redis(current_index, vector) + + except Exception as e: + print(f"Error processing vector {current_index}: {str(e)}") + continue + + if (i + batch_size) % 10000 == 0: + print(f"Processed {i + batch_size} vectors") + +if __name__ == "__main__": + main() diff --git a/examples/redis-unstable/modules/vector-sets/examples/glove-100/recall.py b/examples/redis-unstable/modules/vector-sets/examples/glove-100/recall.py new file mode 100644 index 0000000..3064aed --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/glove-100/recall.py @@ -0,0 +1,87 @@ +# +# Copyright (c) 2009-Present, Redis Ltd. +# All rights reserved. +# +# Licensed under your choice of (a) the Redis Source Available License 2.0 +# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the +# GNU Affero General Public License v3 (AGPLv3). +# + +import h5py +import redis +import numpy as np +from tqdm import tqdm +import argparse + +# Initialize Redis connection +redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8') + +def get_redis_neighbors(query_vector, k): + """Get nearest neighbors using Redis VSIM command""" + args = ["VSIM", "glove_embeddings_bin", "VALUES", "100"] + args.extend(map(str, query_vector)) + args.extend(["COUNT", str(k)]) + args.extend(["EF", 100]) + if False: + print(args) + exit(1) + results = redis_client.execute_command(*args) + return [int(res) for res in results] + +def calculate_recall(ground_truth, predicted, k): + """Calculate recall@k""" + relevant = set(ground_truth[:k]) + retrieved = set(predicted[:k]) + return len(relevant.intersection(retrieved)) / len(relevant) + +def main(): + parser = argparse.ArgumentParser(description='Evaluate Redis VSIM recall') + parser.add_argument('--k', type=int, default=10, help='Number of neighbors to evaluate (default: 10)') + parser.add_argument('--batch', type=int, default=100, help='Progress update frequency (default: 100)') + args = parser.parse_args() + + k = args.k + batch_size = args.batch + + with h5py.File('glove-100-angular.hdf5', 'r') as f: + test_vectors = f['test'][:] + ground_truth_neighbors = f['neighbors'][:] + + num_queries = len(test_vectors) + recalls = [] + + print(f"Evaluating recall@{k} for {num_queries} test queries...") + + for i in tqdm(range(num_queries)): + try: + # Get Redis results + redis_neighbors = get_redis_neighbors(test_vectors[i], k) + + # Get ground truth for this query + true_neighbors = ground_truth_neighbors[i] + + # Calculate recall + recall = calculate_recall(true_neighbors, redis_neighbors, k) + recalls.append(recall) + + if (i + 1) % batch_size == 0: + current_avg_recall = np.mean(recalls) + print(f"Current average recall@{k} after {i+1} queries: {current_avg_recall:.4f}") + + except Exception as e: + print(f"Error processing query {i}: {str(e)}") + continue + + final_recall = np.mean(recalls) + print("\nFinal Results:") + print(f"Average recall@{k}: {final_recall:.4f}") + print(f"Total queries evaluated: {len(recalls)}") + + # Save detailed results + with open(f'recall_evaluation_results_k{k}.txt', 'w') as f: + f.write(f"Average recall@{k}: {final_recall:.4f}\n") + f.write(f"Total queries evaluated: {len(recalls)}\n") + f.write(f"Individual query recalls: {recalls}\n") + +if __name__ == "__main__": + main() diff --git a/examples/redis-unstable/modules/vector-sets/examples/movies/.gitignore b/examples/redis-unstable/modules/vector-sets/examples/movies/.gitignore new file mode 100644 index 0000000..e736c6a --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/movies/.gitignore @@ -0,0 +1,2 @@ +mpst_full_data.csv +partition.json diff --git a/examples/redis-unstable/modules/vector-sets/examples/movies/README b/examples/redis-unstable/modules/vector-sets/examples/movies/README new file mode 100644 index 0000000..3931a6d --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/movies/README @@ -0,0 +1,30 @@ +This example maps long form movies plots to movies titles. +It will create fp32 and binary vectors (the two extremes). + +1. Install ollama, and install the embedding model "mxbai-embed-large" +2. Download mpst_full_data.csv from https://www.kaggle.com/datasets/cryptexcode/mpst-movie-plot-synopses-with-tags +3. python insert.py + +127.0.0.1:6379> VSIM many_movies_mxbai-embed-large_NOQUANT ELE "The Matrix" + 1) "The Matrix" + 2) "The Matrix Reloaded" + 3) "The Matrix Revolutions" + 4) "Commando" + 5) "Avatar" + 6) "Forbidden Planet" + 7) "Terminator Salvation" + 8) "Mandroid" + 9) "The Omega Code" +10) "Coherence" + +127.0.0.1:6379> VSIM many_movies_mxbai-embed-large_BIN ELE "The Matrix" + 1) "The Matrix" + 2) "The Matrix Reloaded" + 3) "The Matrix Revolutions" + 4) "The Omega Code" + 5) "Forbidden Planet" + 6) "Avatar" + 7) "John Carter" + 8) "System Shock 2" + 9) "Coherence" +10) "Tomorrowland" diff --git a/examples/redis-unstable/modules/vector-sets/examples/movies/insert.py b/examples/redis-unstable/modules/vector-sets/examples/movies/insert.py new file mode 100644 index 0000000..dad4da3 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/examples/movies/insert.py @@ -0,0 +1,57 @@ +# +# Copyright (c) 2009-Present, Redis Ltd. +# All rights reserved. +# +# Licensed under your choice of (a) the Redis Source Available License 2.0 +# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the +# GNU Affero General Public License v3 (AGPLv3). +# + +import csv +import requests +import redis + +ModelName="mxbai-embed-large" + +# Initialize Redis connection, setting encoding to utf-8 +redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True, encoding='utf-8') + +def get_embedding(text): + """Get embedding from local API""" + url = "http://localhost:11434/api/embeddings" + payload = { + "model": ModelName, + "prompt": "Represent this movie plot and genre: "+text + } + response = requests.post(url, json=payload) + return response.json()['embedding'] + +def add_to_redis(title, embedding, quant_type): + """Add embedding to Redis using VADD command""" + args = ["VADD", "many_movies_"+ModelName+"_"+quant_type, "VALUES", str(len(embedding))] + args.extend(map(str, embedding)) + args.append(title) + args.append(quant_type) + redis_client.execute_command(*args) + +def main(): + with open('mpst_full_data.csv', 'r', encoding='utf-8') as file: + reader = csv.DictReader(file) + + for movie in reader: + try: + text_to_embed = f"{movie['title']} {movie['plot_synopsis']} {movie['tags']}" + + print(f"Getting embedding for: {movie['title']}") + embedding = get_embedding(text_to_embed) + + add_to_redis(movie['title'], embedding, "BIN") + add_to_redis(movie['title'], embedding, "NOQUANT") + print(f"Successfully processed: {movie['title']}") + + except Exception as e: + print(f"Error processing {movie['title']}: {str(e)}") + continue + +if __name__ == "__main__": + main() diff --git a/examples/redis-unstable/modules/vector-sets/expr.c b/examples/redis-unstable/modules/vector-sets/expr.c new file mode 100644 index 0000000..4f3a1cc --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/expr.c @@ -0,0 +1,959 @@ +/* Filtering of objects based on simple expressions. + * This powers the FILTER option of Vector Sets, but it is otherwise + * general code to be used when we want to tell if a given object (with fields) + * passes or fails a given test for scalars, strings, ... + * + * Copyright (c) 2009-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of (a) the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + * Originally authored by: Salvatore Sanfilippo. + */ + +#ifdef TEST_MAIN +#define RedisModule_Alloc malloc +#define RedisModule_Realloc realloc +#define RedisModule_Free free +#define RedisModule_Strdup strdup +#define RedisModule_Assert assert +#define _DEFAULT_SOURCE +#define _USE_MATH_DEFINES +#include <assert.h> +#include <math.h> +#endif + +#include <stdio.h> +#include <stdlib.h> +#include <ctype.h> +#include <math.h> +#include <string.h> + +#define EXPR_TOKEN_EOF 0 +#define EXPR_TOKEN_NUM 1 +#define EXPR_TOKEN_STR 2 +#define EXPR_TOKEN_TUPLE 3 +#define EXPR_TOKEN_SELECTOR 4 +#define EXPR_TOKEN_OP 5 +#define EXPR_TOKEN_NULL 6 + +#define EXPR_OP_OPAREN 0 /* ( */ +#define EXPR_OP_CPAREN 1 /* ) */ +#define EXPR_OP_NOT 2 /* ! */ +#define EXPR_OP_POW 3 /* ** */ +#define EXPR_OP_MULT 4 /* * */ +#define EXPR_OP_DIV 5 /* / */ +#define EXPR_OP_MOD 6 /* % */ +#define EXPR_OP_SUM 7 /* + */ +#define EXPR_OP_DIFF 8 /* - */ +#define EXPR_OP_GT 9 /* > */ +#define EXPR_OP_GTE 10 /* >= */ +#define EXPR_OP_LT 11 /* < */ +#define EXPR_OP_LTE 12 /* <= */ +#define EXPR_OP_EQ 13 /* == */ +#define EXPR_OP_NEQ 14 /* != */ +#define EXPR_OP_IN 15 /* in */ +#define EXPR_OP_AND 16 /* and */ +#define EXPR_OP_OR 17 /* or */ + +/* This structure represents a token in our expression. It's either + * literals like 4, "foo", or operators like "+", "-", "and", or + * json selectors, that start with a dot: ".age", ".properties.somearray[1]" */ +typedef struct exprtoken { + int refcount; // Reference counting for memory reclaiming. + int token_type; // Token type of the just parsed token. + int offset; // Chars offset in expression. + union { + double num; // Value for EXPR_TOKEN_NUM. + struct { + char *start; // String pointer for EXPR_TOKEN_STR / SELECTOR. + size_t len; // String len for EXPR_TOKEN_STR / SELECTOR. + char *heapstr; // True if we have a private allocation for this + // string. When possible, it just references to the + // string expression we compiled, exprstate->expr. + } str; + int opcode; // Opcode ID for EXPR_TOKEN_OP. + struct { + struct exprtoken **ele; + size_t len; + } tuple; // Tuples are like [1, 2, 3] for "in" operator. + }; +} exprtoken; + +/* Simple stack of expr tokens. This is used both to represent the stack + * of values and the stack of operands during VM execution. */ +typedef struct exprstack { + exprtoken **items; + int numitems; + int allocsize; +} exprstack; + +typedef struct exprstate { + char *expr; /* Expression string to compile. Note that + * expression token strings point directly to this + * string. */ + char *p; // Current position inside 'expr', while parsing. + + // Virtual machine state. + exprstack values_stack; + exprstack ops_stack; // Operator stack used during compilation. + exprstack tokens; // Expression processed into a sequence of tokens. + exprstack program; // Expression compiled into opcodes and values. +} exprstate; + +/* Valid operators. */ +struct { + char *opname; + int oplen; + int opcode; + int precedence; + int arity; +} ExprOptable[] = { + {"(", 1, EXPR_OP_OPAREN, 7, 0}, + {")", 1, EXPR_OP_CPAREN, 7, 0}, + {"!", 1, EXPR_OP_NOT, 6, 1}, + {"not", 3, EXPR_OP_NOT, 6, 1}, + {"**", 2, EXPR_OP_POW, 5, 2}, + {"*", 1, EXPR_OP_MULT, 4, 2}, + {"/", 1, EXPR_OP_DIV, 4, 2}, + {"%", 1, EXPR_OP_MOD, 4, 2}, + {"+", 1, EXPR_OP_SUM, 3, 2}, + {"-", 1, EXPR_OP_DIFF, 3, 2}, + {">", 1, EXPR_OP_GT, 2, 2}, + {">=", 2, EXPR_OP_GTE, 2, 2}, + {"<", 1, EXPR_OP_LT, 2, 2}, + {"<=", 2, EXPR_OP_LTE, 2, 2}, + {"==", 2, EXPR_OP_EQ, 2, 2}, + {"!=", 2, EXPR_OP_NEQ, 2, 2}, + {"in", 2, EXPR_OP_IN, 2, 2}, + {"and", 3, EXPR_OP_AND, 1, 2}, + {"&&", 2, EXPR_OP_AND, 1, 2}, + {"or", 2, EXPR_OP_OR, 0, 2}, + {"||", 2, EXPR_OP_OR, 0, 2}, + {NULL, 0, 0, 0, 0} // Terminator. +}; + +#define EXPR_OP_SPECIALCHARS "+-*%/!()<>=|&" +#define EXPR_SELECTOR_SPECIALCHARS "_-" + +/* ================================ Expr token ============================== */ + +/* Return an heap allocated token of the specified type, setting the + * reference count to 1. */ +exprtoken *exprNewToken(int type) { + exprtoken *t = RedisModule_Alloc(sizeof(exprtoken)); + memset(t,0,sizeof(*t)); + t->token_type = type; + t->refcount = 1; + return t; +} + +/* Generic free token function, can be used to free stack allocated + * objects (in this case the pointer itself will not be freed) or + * heap allocated objects. See the wrappers below. */ +void exprTokenRelease(exprtoken *t) { + if (t == NULL) return; + + RedisModule_Assert(t->refcount > 0); // Catch double free & more. + t->refcount--; + if (t->refcount > 0) return; + + // We reached refcount 0: free the object. + if (t->token_type == EXPR_TOKEN_STR) { + if (t->str.heapstr != NULL) RedisModule_Free(t->str.heapstr); + } else if (t->token_type == EXPR_TOKEN_TUPLE) { + for (size_t j = 0; j < t->tuple.len; j++) + exprTokenRelease(t->tuple.ele[j]); + if (t->tuple.ele) RedisModule_Free(t->tuple.ele); + } + RedisModule_Free(t); +} + +void exprTokenRetain(exprtoken *t) { + t->refcount++; +} + +/* ============================== Stack handling ============================ */ + +#include <stdlib.h> +#include <string.h> + +#define EXPR_STACK_INITIAL_SIZE 16 + +/* Initialize a new expression stack. */ +void exprStackInit(exprstack *stack) { + stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE); + stack->numitems = 0; + stack->allocsize = EXPR_STACK_INITIAL_SIZE; +} + +/* Push a token pointer onto the stack. Does not increment the refcount + * of the token: it is up to the caller doing this. */ +void exprStackPush(exprstack *stack, exprtoken *token) { + /* Check if we need to grow the stack. */ + if (stack->numitems == stack->allocsize) { + size_t newsize = stack->allocsize * 2; + exprtoken **newitems = + RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize); + stack->items = newitems; + stack->allocsize = newsize; + } + stack->items[stack->numitems] = token; + stack->numitems++; +} + +/* Pop a token pointer from the stack. Return NULL if the stack is + * empty. Does NOT recrement the refcount of the token, it's up to the + * caller to do so, as the new owner of the reference. */ +exprtoken *exprStackPop(exprstack *stack) { + if (stack->numitems == 0) return NULL; + stack->numitems--; + return stack->items[stack->numitems]; +} + +/* Just return the last element pushed, without consuming it nor altering + * the reference count. */ +exprtoken *exprStackPeek(exprstack *stack) { + if (stack->numitems == 0) return NULL; + return stack->items[stack->numitems-1]; +} + +/* Free the stack structure state, including the items it contains, that are + * assumed to be heap allocated. The passed pointer itself is not freed. */ +void exprStackFree(exprstack *stack) { + for (int j = 0; j < stack->numitems; j++) + exprTokenRelease(stack->items[j]); + RedisModule_Free(stack->items); +} + +/* Just reset the stack removing all the items, but leaving it in a state + * that makes it still usable for new elements. */ +void exprStackReset(exprstack *stack) { + for (int j = 0; j < stack->numitems; j++) + exprTokenRelease(stack->items[j]); + stack->numitems = 0; +} + +/* =========================== Expression compilation ======================= */ + +void exprConsumeSpaces(exprstate *es) { + while(es->p[0] && isspace(es->p[0])) es->p++; +} + +/* Parse an operator or a literal (just "null" currently). + * When parsing operators, the function will try to match the longest match + * in the operators table. */ +exprtoken *exprParseOperatorOrLiteral(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_OP); + char *start = es->p; + + while(es->p[0] && + (isalpha(es->p[0]) || + strchr(EXPR_OP_SPECIALCHARS,es->p[0]) != NULL)) + { + es->p++; + } + + int matchlen = es->p - start; + int bestlen = 0; + int j; + + // Check if it's a literal. + if (matchlen == 4 && !memcmp("null",start,4)) { + t->token_type = EXPR_TOKEN_NULL; + return t; + } + + // Find the longest matching operator. + for (j = 0; ExprOptable[j].opname != NULL; j++) { + if (ExprOptable[j].oplen > matchlen) continue; + if (memcmp(ExprOptable[j].opname, start, ExprOptable[j].oplen) != 0) + { + continue; + } + if (ExprOptable[j].oplen > bestlen) { + t->opcode = ExprOptable[j].opcode; + bestlen = ExprOptable[j].oplen; + } + } + if (bestlen == 0) { + exprTokenRelease(t); + return NULL; + } else { + es->p = start + bestlen; + } + return t; +} + +// Valid selector charset. +static int is_selector_char(int c) { + return (isalpha(c) || + isdigit(c) || + strchr(EXPR_SELECTOR_SPECIALCHARS,c) != NULL); +} + +/* Parse selectors, they start with a dot and can have alphanumerical + * or few special chars. */ +exprtoken *exprParseSelector(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_SELECTOR); + es->p++; // Skip dot. + char *start = es->p; + + while(es->p[0] && is_selector_char(es->p[0])) es->p++; + int matchlen = es->p - start; + t->str.start = start; + t->str.len = matchlen; + return t; +} + +exprtoken *exprParseNumber(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_NUM); + char num[256]; + int idx = 0; + while(isdigit(es->p[0]) || es->p[0] == '.' || es->p[0] == 'e' || + es->p[0] == 'E' || (idx == 0 && es->p[0] == '-')) + { + if (idx >= (int)sizeof(num)-1) { + exprTokenRelease(t); + return NULL; + } + num[idx++] = es->p[0]; + es->p++; + } + num[idx] = 0; + + char *endptr; + t->num = strtod(num, &endptr); + if (*endptr != '\0') { + exprTokenRelease(t); + return NULL; + } + return t; +} + +exprtoken *exprParseString(exprstate *es) { + char quote = es->p[0]; /* Store the quote type (' or "). */ + es->p++; /* Skip opening quote. */ + + exprtoken *t = exprNewToken(EXPR_TOKEN_STR); + t->str.start = es->p; + + while(es->p[0] != '\0') { + if (es->p[0] == '\\' && es->p[1] != '\0') { + es->p += 2; // Skip escaped char. + continue; + } + if (es->p[0] == quote) { + t->str.len = es->p - t->str.start; + es->p++; // Skip closing quote. + return t; + } + es->p++; + } + /* If we reach here, string was not terminated. */ + exprTokenRelease(t); + return NULL; +} + +/* Parse a tuple of the form [1, "foo", 42]. No nested tuples are + * supported. This type is useful mostly to be used with the "IN" + * operator. */ +exprtoken *exprParseTuple(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_TUPLE); + t->tuple.ele = NULL; + t->tuple.len = 0; + es->p++; /* Skip opening '['. */ + + size_t allocated = 0; + while(1) { + exprConsumeSpaces(es); + + /* Check for empty tuple or end. */ + if (es->p[0] == ']') { + es->p++; + break; + } + + /* Grow tuple array if needed. */ + if (t->tuple.len == allocated) { + size_t newsize = allocated == 0 ? 4 : allocated * 2; + exprtoken **newele = RedisModule_Realloc(t->tuple.ele, + sizeof(exprtoken*) * newsize); + t->tuple.ele = newele; + allocated = newsize; + } + + /* Parse tuple element. */ + exprtoken *ele = NULL; + if (isdigit(es->p[0]) || es->p[0] == '-') { + ele = exprParseNumber(es); + } else if (es->p[0] == '"' || es->p[0] == '\'') { + ele = exprParseString(es); + } else { + exprTokenRelease(t); + return NULL; + } + + /* Error parsing number/string? */ + if (ele == NULL) { + exprTokenRelease(t); + return NULL; + } + + /* Store element if no error was detected. */ + t->tuple.ele[t->tuple.len] = ele; + t->tuple.len++; + + /* Check for next element. */ + exprConsumeSpaces(es); + if (es->p[0] == ']') { + es->p++; + break; + } + if (es->p[0] != ',') { + exprTokenRelease(t); + return NULL; + } + es->p++; /* Skip comma. */ + } + return t; +} + +/* Deallocate the object returned by exprCompile(). */ +void exprFree(exprstate *es) { + if (es == NULL) return; + + /* Free the original expression string. */ + if (es->expr) RedisModule_Free(es->expr); + + /* Free all stacks. */ + exprStackFree(&es->values_stack); + exprStackFree(&es->ops_stack); + exprStackFree(&es->tokens); + exprStackFree(&es->program); + + /* Free the state object itself. */ + RedisModule_Free(es); +} + +/* Split the provided expression into a stack of tokens. Returns + * 0 on success, 1 on error. */ +int exprTokenize(exprstate *es, int *errpos) { + /* Main parsing loop. */ + while(1) { + exprConsumeSpaces(es); + + /* Set a flag to see if we can consider the - part of the + * number, or an operator. */ + int minus_is_number = 0; // By default is an operator. + + exprtoken *last = exprStackPeek(&es->tokens); + if (last == NULL) { + /* If we are at the start of an expression, the minus is + * considered a number. */ + minus_is_number = 1; + } else if (last->token_type == EXPR_TOKEN_OP && + last->opcode != EXPR_OP_CPAREN) + { + /* Also, if the previous token was an operator, the minus + * is considered a number, unless the previous operator is + * a closing parens. In such case it's like (...) -5, or alike + * and we want to emit an operator. */ + minus_is_number = 1; + } + + /* Parse based on the current character. */ + exprtoken *current = NULL; + if (*es->p == '\0') { + current = exprNewToken(EXPR_TOKEN_EOF); + } else if (isdigit(*es->p) || + (minus_is_number && *es->p == '-' && isdigit(es->p[1]))) + { + current = exprParseNumber(es); + } else if (*es->p == '"' || *es->p == '\'') { + current = exprParseString(es); + } else if (*es->p == '.' && is_selector_char(es->p[1])) { + current = exprParseSelector(es); + } else if (*es->p == '[') { + current = exprParseTuple(es); + } else if (isalpha(*es->p) || strchr(EXPR_OP_SPECIALCHARS, *es->p)) { + current = exprParseOperatorOrLiteral(es); + } + + if (current == NULL) { + if (errpos) *errpos = es->p - es->expr; + return 1; // Syntax Error. + } + + /* Push the current token to tokens stack. */ + exprStackPush(&es->tokens, current); + if (current->token_type == EXPR_TOKEN_EOF) break; + } + return 0; +} + +/* Helper function to get operator precedence from the operator table. */ +int exprGetOpPrecedence(int opcode) { + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == opcode) + return ExprOptable[i].precedence; + } + return -1; +} + +/* Helper function to get operator arity from the operator table. */ +int exprGetOpArity(int opcode) { + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == opcode) + return ExprOptable[i].arity; + } + return -1; +} + +/* Process an operator during compilation. Returns 0 on success, 1 on error. + * This function will retain a reference of the operator 'op' in case it + * is pushed on the operators stack. */ +int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *errpos) { + if (op->opcode == EXPR_OP_OPAREN) { + // This is just a marker for us. Do nothing. + exprStackPush(&es->ops_stack, op); + exprTokenRetain(op); + return 0; + } + + if (op->opcode == EXPR_OP_CPAREN) { + /* Process operators until we find the matching opening parenthesis. */ + while (1) { + exprtoken *top_op = exprStackPop(&es->ops_stack); + if (top_op == NULL) { + if (errpos) *errpos = op->offset; + return 1; + } + + if (top_op->opcode == EXPR_OP_OPAREN) { + /* Open parethesis found. Our work finished. */ + exprTokenRelease(top_op); + return 0; + } + + int arity = exprGetOpArity(top_op->opcode); + if (*stack_items < arity) { + exprTokenRelease(top_op); + if (errpos) *errpos = top_op->offset; + return 1; + } + + /* Move the operator on the program stack. */ + exprStackPush(&es->program, top_op); + *stack_items = *stack_items - arity + 1; + } + } + + int curr_prec = exprGetOpPrecedence(op->opcode); + + /* Process operators with higher or equal precedence. */ + while (1) { + exprtoken *top_op = exprStackPeek(&es->ops_stack); + if (top_op == NULL || top_op->opcode == EXPR_OP_OPAREN) break; + + int top_prec = exprGetOpPrecedence(top_op->opcode); + if (top_prec < curr_prec) break; + /* Special case for **: only pop if precedence is strictly higher + * so that the operator is right associative, that is: + * 2 ** 3 ** 2 is evaluated as 2 ** (3 ** 2) == 512 instead + * of (2 ** 3) ** 2 == 64. */ + if (op->opcode == EXPR_OP_POW && top_prec <= curr_prec) break; + + /* Pop and add to program. */ + top_op = exprStackPop(&es->ops_stack); + int arity = exprGetOpArity(top_op->opcode); + if (*stack_items < arity) { + exprTokenRelease(top_op); + if (errpos) *errpos = top_op->offset; + return 1; + } + + /* Move to the program stack. */ + exprStackPush(&es->program, top_op); + *stack_items = *stack_items - arity + 1; + } + + /* Push current operator. */ + exprStackPush(&es->ops_stack, op); + exprTokenRetain(op); + return 0; +} + +/* Compile the expression into a set of push-value and exec-operator + * that exprRun() can execute. The function returns an expstate object + * that can be used for execution of the program. On error, NULL + * is returned, and optionally the position of the error into the + * expression is returned by reference. */ +exprstate *exprCompile(char *expr, int *errpos) { + /* Initialize expression state. */ + exprstate *es = RedisModule_Alloc(sizeof(exprstate)); + es->expr = RedisModule_Strdup(expr); + es->p = es->expr; + + /* Initialize all stacks. */ + exprStackInit(&es->values_stack); + exprStackInit(&es->ops_stack); + exprStackInit(&es->tokens); + exprStackInit(&es->program); + + /* Tokenization. */ + if (exprTokenize(es, errpos)) { + exprFree(es); + return NULL; + } + + /* Compile the expression into a sequence of operations. */ + int stack_items = 0; // Track # of items that would be on the stack + // during execution. This way we can detect arity + // issues at compile time. + + /* Process each token. */ + for (int i = 0; i < es->tokens.numitems; i++) { + exprtoken *token = es->tokens.items[i]; + + if (token->token_type == EXPR_TOKEN_EOF) break; + + /* Handle values (numbers, strings, selectors, null). */ + if (token->token_type == EXPR_TOKEN_NUM || + token->token_type == EXPR_TOKEN_STR || + token->token_type == EXPR_TOKEN_TUPLE || + token->token_type == EXPR_TOKEN_SELECTOR || + token->token_type == EXPR_TOKEN_NULL) + { + exprStackPush(&es->program, token); + exprTokenRetain(token); + stack_items++; + continue; + } + + /* Handle operators. */ + if (token->token_type == EXPR_TOKEN_OP) { + if (exprProcessOperator(es, token, &stack_items, errpos)) { + exprFree(es); + return NULL; + } + continue; + } + } + + /* Process remaining operators on the stack. */ + while (es->ops_stack.numitems > 0) { + exprtoken *op = exprStackPop(&es->ops_stack); + if (op->opcode == EXPR_OP_OPAREN) { + if (errpos) *errpos = op->offset; + exprTokenRelease(op); + exprFree(es); + return NULL; + } + + int arity = exprGetOpArity(op->opcode); + if (stack_items < arity) { + if (errpos) *errpos = op->offset; + exprTokenRelease(op); + exprFree(es); + return NULL; + } + + exprStackPush(&es->program, op); + stack_items = stack_items - arity + 1; + } + + /* Verify that exactly one value would remain on the stack after + * execution. We could also check that such value is a number, but this + * would make the code more complex without much gains. */ + if (stack_items != 1) { + if (errpos) { + /* Point to the last token's offset for error reporting. */ + exprtoken *last = es->tokens.items[es->tokens.numitems - 1]; + *errpos = last->offset; + } + exprFree(es); + return NULL; + } + return es; +} + +/* ============================ Expression execution ======================== */ + +/* Convert a token to its numeric value. For strings we attempt to parse them + * as numbers, returning 0 if conversion fails. */ +double exprTokenToNum(exprtoken *t) { + char buf[256]; + if (t->token_type == EXPR_TOKEN_NUM) { + return t->num; + } else if (t->token_type == EXPR_TOKEN_STR && t->str.len < sizeof(buf)) { + memcpy(buf, t->str.start, t->str.len); + buf[t->str.len] = '\0'; + char *endptr; + double val = strtod(buf, &endptr); + return *endptr == '\0' ? val : 0; + } else { + return 0; + } +} + +/* Convert object to true/false (0 or 1) */ +double exprTokenToBool(exprtoken *t) { + if (t->token_type == EXPR_TOKEN_NUM) { + return t->num != 0; + } else if (t->token_type == EXPR_TOKEN_STR && t->str.len == 0) { + return 0; // Empty string are false, like in Javascript. + } else if (t->token_type == EXPR_TOKEN_NULL) { + return 0; // Null is surely more false than true... + } else { + return 1; // Every non numerical type is true. + } +} + +/* Compare two tokens. Returns true if they are equal. */ +int exprTokensEqual(exprtoken *a, exprtoken *b) { + // If both are strings, do string comparison. + if (a->token_type == EXPR_TOKEN_STR && b->token_type == EXPR_TOKEN_STR) { + return a->str.len == b->str.len && + memcmp(a->str.start, b->str.start, a->str.len) == 0; + } + + // If both are numbers, do numeric comparison. + if (a->token_type == EXPR_TOKEN_NUM && b->token_type == EXPR_TOKEN_NUM) { + return a->num == b->num; + } + + /* If one of the two is null, the expression is true only if + * both are null. */ + if (a->token_type == EXPR_TOKEN_NULL || b->token_type == EXPR_TOKEN_NULL) { + return a->token_type == b->token_type; + } + + // Mixed types - convert to numbers and compare. + return exprTokenToNum(a) == exprTokenToNum(b); +} + +/* Return true if the string a is a substring of b. */ +int exprTokensStringIn(exprtoken *a, exprtoken *b) { + RedisModule_Assert(a->token_type == EXPR_TOKEN_STR && + b->token_type == EXPR_TOKEN_STR); + if (a->str.len > b->str.len) return 0; // A is bigger, can't be a substring. + for (size_t i = 0; i <= b->str.len - a->str.len; i++) { + if (memcmp(b->str.start+i,a->str.start,a->str.len) == 0) return 1; + } + return 0; +} + +#include "fastjson.c" // JSON parser implementation used by exprRun(). + +/* Execute the compiled expression program. Returns 1 if the final stack value + * evaluates to true, 0 otherwise. Also returns 0 if any selector callback + * fails. */ +int exprRun(exprstate *es, char *json, size_t json_len) { + exprStackReset(&es->values_stack); + + // Execute each instruction in the program. + for (int i = 0; i < es->program.numitems; i++) { + exprtoken *t = es->program.items[i]; + + // Handle selectors by calling the callback. + if (t->token_type == EXPR_TOKEN_SELECTOR) { + exprtoken *obj = NULL; + if (t->str.len > 0) + obj = jsonExtractField(json,json_len,t->str.start,t->str.len); + + // Selector not found or JSON object not convertible to + // expression tokens. Evaluate the expression to false. + if (obj == NULL) return 0; + exprStackPush(&es->values_stack, obj); + continue; + } + + // Push non-operator values directly onto the stack. + if (t->token_type != EXPR_TOKEN_OP) { + exprStackPush(&es->values_stack, t); + exprTokenRetain(t); + continue; + } + + // Handle operators. + exprtoken *result = exprNewToken(EXPR_TOKEN_NUM); + + // Pop operands - we know we have enough from compile-time checks. + exprtoken *b = exprStackPop(&es->values_stack); + exprtoken *a = NULL; + if (exprGetOpArity(t->opcode) == 2) { + a = exprStackPop(&es->values_stack); + } + + switch(t->opcode) { + case EXPR_OP_NOT: + result->num = exprTokenToBool(b) == 0 ? 1 : 0; + break; + case EXPR_OP_POW: { + double base = exprTokenToNum(a); + double exp = exprTokenToNum(b); + result->num = pow(base, exp); + break; + } + case EXPR_OP_MULT: + result->num = exprTokenToNum(a) * exprTokenToNum(b); + break; + case EXPR_OP_DIV: + result->num = exprTokenToNum(a) / exprTokenToNum(b); + break; + case EXPR_OP_MOD: { + double va = exprTokenToNum(a); + double vb = exprTokenToNum(b); + result->num = fmod(va, vb); + break; + } + case EXPR_OP_SUM: + result->num = exprTokenToNum(a) + exprTokenToNum(b); + break; + case EXPR_OP_DIFF: + result->num = exprTokenToNum(a) - exprTokenToNum(b); + break; + case EXPR_OP_GT: + result->num = exprTokenToNum(a) > exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_GTE: + result->num = exprTokenToNum(a) >= exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_LT: + result->num = exprTokenToNum(a) < exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_LTE: + result->num = exprTokenToNum(a) <= exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_EQ: + result->num = exprTokensEqual(a, b) ? 1 : 0; + break; + case EXPR_OP_NEQ: + result->num = !exprTokensEqual(a, b) ? 1 : 0; + break; + case EXPR_OP_IN: { + /* For 'in' operator, b must be a tuple, and we check for + * membership. Otherwise both a and b must be strings, and + * in this case we check if a is a substring of b. */ + result->num = 0; // Default to false. + if (b->token_type == EXPR_TOKEN_TUPLE) { + for (size_t j = 0; j < b->tuple.len; j++) { + if (exprTokensEqual(a, b->tuple.ele[j])) { + result->num = 1; // Found a match. + break; + } + } + } else if (a->token_type == EXPR_TOKEN_STR && + b->token_type == EXPR_TOKEN_STR) + { + result->num = exprTokensStringIn(a,b); + } + break; + } + case EXPR_OP_AND: + result->num = + exprTokenToBool(a) != 0 && exprTokenToBool(b) != 0 ? 1 : 0; + break; + case EXPR_OP_OR: + result->num = + exprTokenToBool(a) != 0 || exprTokenToBool(b) != 0 ? 1 : 0; + break; + default: + // Do nothing: we don't want runtime errors. + break; + } + + // Free operands and push result. + if (a) exprTokenRelease(a); + exprTokenRelease(b); + exprStackPush(&es->values_stack, result); + } + + // Get final result from stack. + exprtoken *final = exprStackPop(&es->values_stack); + if (final == NULL) return 0; + + // Convert result to boolean. + int retval = exprTokenToBool(final); + exprTokenRelease(final); + return retval; +} + +/* ============================ Simple test main ============================ */ + +#ifdef TEST_MAIN +#include "fastjson_test.c" + +void exprPrintToken(exprtoken *t) { + switch(t->token_type) { + case EXPR_TOKEN_EOF: + printf("EOF"); + break; + case EXPR_TOKEN_NUM: + printf("NUM:%g", t->num); + break; + case EXPR_TOKEN_STR: + printf("STR:\"%.*s\"", (int)t->str.len, t->str.start); + break; + case EXPR_TOKEN_SELECTOR: + printf("SEL:%.*s", (int)t->str.len, t->str.start); + break; + case EXPR_TOKEN_OP: + printf("OP:"); + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == t->opcode) { + printf("%s", ExprOptable[i].opname); + break; + } + } + break; + default: + printf("UNKNOWN"); + break; + } +} + +void exprPrintStack(exprstack *stack, const char *name) { + printf("%s (%d items):", name, stack->numitems); + for (int j = 0; j < stack->numitems; j++) { + printf(" "); + exprPrintToken(stack->items[j]); + } + printf("\n"); +} + +int main(int argc, char **argv) { + /* Check for JSON parser test mode. */ + if (argc >= 2 && strcmp(argv[1], "--test-json-parser") == 0) { + run_fastjson_test(); + return 0; + } + + char *testexpr = "(5+2)*3 and .year > 1980 and 'foo' == 'foo'"; + char *testjson = "{\"year\": 1984, \"name\": \"The Matrix\"}"; + if (argc >= 2) testexpr = argv[1]; + if (argc >= 3) testjson = argv[2]; + + printf("Compiling expression: %s\n", testexpr); + + int errpos = 0; + exprstate *es = exprCompile(testexpr,&errpos); + if (es == NULL) { + printf("Compilation failed near \"...%s\"\n", testexpr+errpos); + return 1; + } + + exprPrintStack(&es->tokens, "Tokens"); + exprPrintStack(&es->program, "Program"); + printf("Running against object: %s\n", testjson); + int result = exprRun(es,testjson,strlen(testjson)); + printf("Result1: %s\n", result ? "True" : "False"); + result = exprRun(es,testjson,strlen(testjson)); + printf("Result2: %s\n", result ? "True" : "False"); + + exprFree(es); + return 0; +} +#endif diff --git a/examples/redis-unstable/modules/vector-sets/fastjson.c b/examples/redis-unstable/modules/vector-sets/fastjson.c new file mode 100644 index 0000000..78926e2 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/fastjson.c @@ -0,0 +1,441 @@ +/* Ultra‑lightweight top‑level JSON field extractor. + * Return the element directly as an expr.c token. + * This code is directly included inside expr.c. + * + * Copyright (c) 2025-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of the Redis Source Available License 2.0 + * (RSALv2) or the Server Side Public License v1 (SSPLv1). + * + * Originally authored by: Salvatore Sanfilippo. + * + * ------------------------------------------------------------------ + * + * DESIGN GOALS: + * + * 1. Zero heap allocations while seeking the requested key. + * 2. A single parse (and therefore a single allocation, if needed) + * when the key finally matches. + * 3. Same subset‑of‑JSON coverage needed by expr.c: + * - Strings (escapes: \" \\ \n \r \t). + * - Numbers (double). + * - Booleans. + * - Null. + * - Flat arrays of the above primitives. + * + * Any other value (nested object, unicode escape, etc.) returns NULL. + * Should be very easy to extend it in case in the future we want + * more for the FILTER option of VSIM. + * 4. No global state, so this file can be #included directly in expr.c. + * + * The only API expr.c uses directly is: + * + * exprtoken *jsonExtractField(const char *json, size_t json_len, + * const char *field, size_t field_len); + * ------------------------------------------------------------------ */ + +#include <ctype.h> +#include <string.h> + +// Forward declarations. +static int jsonSkipValue(const char **p, const char *end); +static exprtoken *jsonParseValueToken(const char **p, const char *end); + +/* Similar to ctype.h isdigit() but covers the whole JSON number charset, + * including exp form. */ +static int jsonIsNumberChar(int c) { + return isdigit(c) || c=='-' || c=='+' || c=='.' || c=='e' || c=='E'; +} + +/* ========================== Fast skipping of JSON ========================= + * The helpers here are designed to skip values without performing any + * allocation. This way, for the use case of this JSON parser, we are able + * to easily (and with good speed) skip fields and values we are not + * interested in. Then, later in the code, when we find the field we want + * to obtain, we finally call the functions that turn a given JSON value + * associated to a field into our of our expressions token. + * ========================================================================== */ + +/* Advance *p consuming all the spaces. */ +static inline void jsonSkipWhiteSpaces(const char **p, const char *end) { + while (*p < end && isspace((unsigned char)**p)) (*p)++; +} + +/* Advance *p past a JSON string. Returns 1 on success, 0 on error. */ +static int jsonSkipString(const char **p, const char *end) { + if (*p >= end || **p != '"') return 0; + (*p)++; /* Skip opening quote. */ + while (*p < end) { + if (**p == '\\') { + (*p) += 2; + continue; + } + if (**p == '"') { + (*p)++; /* Skip closing quote. */ + return 1; + } + (*p)++; + } + return 0; /* unterminated */ +} + +/* Skip an array or object generically using depth counter. + * Opener and closer tells the function how the aggregated + * data type starts/stops, basically [] or {}. */ +static int jsonSkipBracketed(const char **p, const char *end, + char opener, char closer) { + int depth = 1; + (*p)++; /* Skip opener. */ + + /* Loop until we reach the end of the input or find the matching + * closer (depth becomes 0). */ + while (*p < end && depth > 0) { + char c = **p; + + if (c == '"') { + // Found a string, delegate skipping to jsonSkipString(). + if (!jsonSkipString(p, end)) { + return 0; // String skipping failed (e.g., unterminated) + } + /* jsonSkipString() advances *p past the closing quote. + * Continue the loop to process the character *after* the string. */ + continue; + } + + /* If it's not a string, check if it affects the depth for the + * specific brackets we are currently tracking. */ + if (c == opener) { + depth++; + } else if (c == closer) { + depth--; + } + + /* Always advance the pointer for any non-string character. + * This handles commas, colons, whitespace, numbers, literals, + * and even nested brackets of a *different* type than the + * one we are currently skipping (e.g. skipping a { inside []). */ + (*p)++; + } + + /* Return 1 (true) if we successfully found the matching closer, + * otherwise there is a parse error and we return 0. */ + return depth == 0; +} + +/* Skip a single JSON literal (true, null, ...) starting at *p. + * Returns 1 on success, 0 on failure. */ +static int jsonSkipLiteral(const char **p, const char *end, const char *lit) { + size_t l = strlen(lit); + if (*p + l > end) return 0; + if (strncmp(*p, lit, l) == 0) { *p += l; return 1; } + return 0; +} + +/* Skip number, don't check that number format is correct, just consume + * number-alike characters. + * + * Note: More robust number skipping might check validity, + * but for skipping, just consuming plausible characters is enough. */ +static int jsonSkipNumber(const char **p, const char *end) { + const char *num_start = *p; + while (*p < end && jsonIsNumberChar(**p)) (*p)++; + return *p > num_start; // Any progress made? Otherwise no number found. +} + +/* Skip any JSON value. 1 = success, 0 = error. */ +static int jsonSkipValue(const char **p, const char *end) { + jsonSkipWhiteSpaces(p, end); + if (*p >= end) return 0; + switch (**p) { + case '"': return jsonSkipString(p, end); + case '{': return jsonSkipBracketed(p, end, '{', '}'); + case '[': return jsonSkipBracketed(p, end, '[', ']'); + case 't': return jsonSkipLiteral(p, end, "true"); + case 'f': return jsonSkipLiteral(p, end, "false"); + case 'n': return jsonSkipLiteral(p, end, "null"); + default: return jsonSkipNumber(p, end); + } +} + +/* =========================== JSON to exprtoken ============================ + * The functions below convert a given json value to the equivalent + * expression token structure. + * ========================================================================== */ + +static exprtoken *jsonParseStringToken(const char **p, const char *end) { + if (*p >= end || **p != '"') return NULL; + const char *start = ++(*p); + int esc = 0; size_t len = 0; int has_esc = 0; + const char *q = *p; + while (q < end) { + if (esc) { esc = 0; q++; len++; has_esc = 1; continue; } + if (*q == '\\') { esc = 1; q++; continue; } + if (*q == '"') break; + q++; len++; + } + if (q >= end || *q != '"') return NULL; // Unterminated string + exprtoken *t = exprNewToken(EXPR_TOKEN_STR); + + if (!has_esc) { + // No escapes, we can point directly into the original JSON string. + t->str.start = (char*)start; t->str.len = len; t->str.heapstr = NULL; + } else { + // Escapes present, need to allocate and copy/process escapes. + char *dst = RedisModule_Alloc(len + 1); + + t->str.start = t->str.heapstr = dst; t->str.len = len; + const char *r = start; esc = 0; + while (r < q) { + if (esc) { + switch (*r) { + // Supported escapes from Goal 3. + case 'n': *dst='\n'; break; + case 'r': *dst='\r'; break; + case 't': *dst='\t'; break; + case '\\': *dst='\\'; break; + case '"': *dst='\"'; break; + // Escapes (like \uXXXX, \b, \f) are not supported for now, + // we just copy them verbatim. + default: *dst=*r; break; + } + dst++; esc = 0; r++; continue; + } + if (*r == '\\') { esc = 1; r++; continue; } + *dst++ = *r++; + } + *dst = '\0'; // Null-terminate the allocated string. + } + *p = q + 1; // Advance the main pointer past the closing quote. + return t; +} + +static exprtoken *jsonParseNumberToken(const char **p, const char *end) { + // Use a buffer to extract the number literal for parsing with strtod(). + char buf[256]; int idx = 0; + const char *start = *p; // For strtod partial failures check. + + // Copy potential number characters to buffer. + while (*p < end && idx < (int)sizeof(buf)-1 && jsonIsNumberChar(**p)) { + buf[idx++] = **p; + (*p)++; + } + buf[idx]='\0'; // Null-terminate buffer. + + if (idx==0) return NULL; // No number characters found. + + char *ep; // End pointer for strtod validation. + double v = strtod(buf, &ep); + + /* Check if strtod() consumed the entire buffer content. + * If not, the number format was invalid. */ + if (*ep!='\0') { + // strtod() failed; rewind p to the start and return NULL + *p = start; + return NULL; + } + + // If strtod() succeeded, create and return the token.. + exprtoken *t = exprNewToken(EXPR_TOKEN_NUM); + t->num = v; + return t; +} + +static exprtoken *jsonParseLiteralToken(const char **p, const char *end, const char *lit, int type, double num) { + size_t l = strlen(lit); + + // Ensure we don't read past 'end'. + if ((*p + l) > end) return NULL; + + if (strncmp(*p, lit, l) != 0) return NULL; // Literal doesn't match. + + // Check that the character *after* the literal is a valid JSON delimiter + // (whitespace, comma, closing bracket/brace, or end of input) + // This prevents matching "trueblabla" as "true". + if ((*p + l) < end) { + char next_char = *(*p + l); + if (!isspace((unsigned char)next_char) && next_char!=',' && + next_char!=']' && next_char!='}') { + return NULL; // Invalid character following literal. + } + } + + // Literal matched and is correctly terminated. + *p += l; + exprtoken *t = exprNewToken(type); + t->num = num; + return t; +} + +static exprtoken *jsonParseArrayToken(const char **p, const char *end) { + if (*p >= end || **p != '[') return NULL; + (*p)++; // Skip '['. + jsonSkipWhiteSpaces(p,end); + + exprtoken *t = exprNewToken(EXPR_TOKEN_TUPLE); + t->tuple.len = 0; t->tuple.ele = NULL; size_t alloc = 0; + + // Handle empty array []. + if (*p < end && **p == ']') { + (*p)++; // Skip ']'. + return t; + } + + // Parse array elements. + while (1) { + exprtoken *ele = jsonParseValueToken(p,end); + if (!ele) { + exprTokenRelease(t); // Clean up partially built array token. + return NULL; + } + + // Grow allocated space for elements if needed. + if (t->tuple.len == alloc) { + size_t newsize = alloc ? alloc * 2 : 4; + // Check for potential overflow if newsize becomes huge. + if (newsize < alloc) { + exprTokenRelease(ele); + exprTokenRelease(t); + return NULL; + } + exprtoken **newele = RedisModule_Realloc(t->tuple.ele, + sizeof(exprtoken*)*newsize); + t->tuple.ele = newele; + alloc = newsize; + } + t->tuple.ele[t->tuple.len++] = ele; // Add element. + + jsonSkipWhiteSpaces(p,end); + if (*p>=end) { + // Unterminated array. Note that this check is crucial because + // previous value parsed may seek 'p' to 'end'. + exprTokenRelease(t); + return NULL; + } + + // Check for comma (more elements) or closing bracket. + if (**p == ',') { + (*p)++; // Skip ',' + jsonSkipWhiteSpaces(p,end); // Skip whitespace before next element + continue; // Parse next element + } else if (**p == ']') { + (*p)++; // Skip ']' + return t; // End of array + } else { + // Unexpected character (not ',' or ']') + exprTokenRelease(t); + return NULL; + } + } +} + +/* Turn a JSON value into an expr token. */ +static exprtoken *jsonParseValueToken(const char **p, const char *end) { + jsonSkipWhiteSpaces(p,end); + if (*p >= end) return NULL; + + switch (**p) { + case '"': return jsonParseStringToken(p,end); + case '[': return jsonParseArrayToken(p,end); + case '{': return NULL; // No nested elements support for now. + case 't': return jsonParseLiteralToken(p,end,"true",EXPR_TOKEN_NUM,1); + case 'f': return jsonParseLiteralToken(p,end,"false",EXPR_TOKEN_NUM,0); + case 'n': return jsonParseLiteralToken(p,end,"null",EXPR_TOKEN_NULL,0); + default: + // Check if it starts like a number. + if (isdigit((unsigned char)**p) || **p=='-' || **p=='+') { + return jsonParseNumberToken(p,end); + } + // Anything else is an unsupported type or malformed JSON. + return NULL; + } +} + +/* ============================== Fast key seeking ========================== */ + +/* Finds the start of the value for a given field key within a JSON object. + * Returns pointer to the first char of the value, or NULL if not found/error. + * This function does not perform any allocation and is optimized to seek + * the specified *toplevel* filed as fast as possible. */ +static const char *jsonSeekField(const char *json, const char *end, + const char *field, size_t flen) { + const char *p = json; + jsonSkipWhiteSpaces(&p,end); + if (p >= end || *p != '{') return NULL; // Must start with '{'. + p++; // skip '{'. + + while (1) { + jsonSkipWhiteSpaces(&p,end); + if (p >= end) return NULL; // Reached end within object. + + if (*p == '}') return NULL; // End of object, field not found. + + // Expecting a key (string). + if (*p != '"') return NULL; // Key must be a string. + + // --- Key Matching using jsonSkipString --- + const char *key_start = p + 1; // Start of key content. + const char *key_end_p = p; // Will later contain the end. + + // Use jsonSkipString() to find the end. + if (!jsonSkipString(&key_end_p, end)) { + // Unterminated / invalid key string. + return NULL; + } + + // Calculate the length of the key's content. + size_t klen = (key_end_p - 1) - key_start; + + /* Perform the comparison using the raw key content. + * WARNING: This uses memcmp(), so we don't handle escaped chars + * within the key matching against unescaped chars in 'field'. */ + int match = klen == flen && !memcmp(key_start, field, flen); + + // Update the main pointer 'p' to be after the key string. + p = key_end_p; + + // Now we expect to find a ":" followed by a value. + jsonSkipWhiteSpaces(&p,end); + if (p>=end || *p!=':') return NULL; // Expect ':' after key + p++; // Skip ':'. + + // Seek value. + jsonSkipWhiteSpaces(&p,end); + if (p>=end) return NULL; // Expect value after ':' + + if (match) { + // Found the matching key, p now points to the start of the value. + return p; + } else { + // Key didn't match, skip the corresponding value. + if (!jsonSkipValue(&p,end)) return NULL; // Syntax error. + } + + + // Look for comma or a closing brace. + jsonSkipWhiteSpaces(&p,end); + if (p>=end) return NULL; // Reached end after value. + + if (*p == ',') { + p++; // Skip comma, continue loop to find next key. + continue; + } else if (*p == '}') { + return NULL; // Reached end of object, field not found. + } + return NULL; // Malformed JSON (unexpected char after value). + } +} + +/* This is the only real API that this file conceptually exports (it is + * inlined, actually). */ +exprtoken *jsonExtractField(const char *json, size_t json_len, + const char *field, size_t field_len) +{ + const char *end = json + json_len; + const char *valptr = jsonSeekField(json,end,field,field_len); + if (!valptr) return NULL; + + /* Key found, valptr points to the start of the value. + * Convert it into an expression token object. */ + return jsonParseValueToken(&valptr,end); +} diff --git a/examples/redis-unstable/modules/vector-sets/fastjson_test.c b/examples/redis-unstable/modules/vector-sets/fastjson_test.c new file mode 100644 index 0000000..1ea76a9 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/fastjson_test.c @@ -0,0 +1,406 @@ +/* fastjson_test.c - Stress test for fastjson.c + * + * This performs boundary and corruption tests to ensure + * the JSON parser handles edge cases without accessing + * memory outside the bounds of the input. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <unistd.h> +#include <signal.h> +#include <time.h> +#include <sys/mman.h> +#include <sys/types.h> +#include <fcntl.h> +#include <errno.h> +#include <setjmp.h> + +/* Page size constant - typically 4096 or 16k bytes (Apple Silicon). + * We use 16k so that it will work on both, but not with Linux huge pages. */ +#define PAGE_SIZE 4096*4 +#define MAX_JSON_SIZE (PAGE_SIZE - 128) /* Keep some margin */ +#define MAX_FIELD_SIZE 64 +#define NUM_TEST_ITERATIONS 100000 +#define NUM_CORRUPTION_TESTS 10000 +#define NUM_BOUNDARY_TESTS 10000 + +/* Test state tracking */ +static char *safe_page = NULL; /* Start of readable/writable page */ +static char *unsafe_page = NULL; /* Start of inaccessible guard page */ +static int boundary_violation = 0; /* Flag for boundary violations */ +static jmp_buf jmpbuf; /* For signal handling */ +static int tests_passed = 0; +static int tests_failed = 0; +static int corruptions_passed = 0; +static int boundary_tests_passed = 0; + +/* Test metadata for tracking */ +typedef struct { + char *json; + size_t json_len; + char field[MAX_FIELD_SIZE]; + size_t field_len; + int expected_result; +} test_case_t; + +/* Forward declarations for test JSON generation */ +char *generate_random_json(size_t *len, char *field, size_t *field_len, int *has_field); +void corrupt_json(char *json, size_t len); +void setup_test_memory(void); +void cleanup_test_memory(void); +void run_normal_tests(void); +void run_corruption_tests(void); +void run_boundary_tests(void); +void print_test_summary(void); + +/* Signal handler for segmentation violations */ +static void sigsegv_handler(int sig) { + boundary_violation = 1; + printf("Boundary violation detected! Caught signal %d\n", sig); + longjmp(jmpbuf, 1); +} + +/* Wrapper for jsonExtractField to check for boundary violations */ +exprtoken *safe_extract_field(const char *json, size_t json_len, + const char *field, size_t field_len) { + boundary_violation = 0; + + if (setjmp(jmpbuf) == 0) { + return jsonExtractField(json, json_len, field, field_len); + } else { + return NULL; /* Return NULL if boundary violation occurred */ + } +} + +/* Setup two adjacent memory pages - one readable/writable, one inaccessible */ +void setup_test_memory(void) { + /* Request a page of memory, with specific alignment. We rely on the + * fact that hopefully the page after that will cause a segfault if + * accessed. */ + void *region = mmap(NULL, PAGE_SIZE, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, 0); + + if (region == MAP_FAILED) { + perror("mmap failed"); + exit(EXIT_FAILURE); + } + + safe_page = (char*)region; + unsafe_page = safe_page + PAGE_SIZE; + // Uncomment to make sure it crashes :D + // printf("%d\n", unsafe_page[5]); + + /* Set up signal handlers for memory access violations */ + struct sigaction sa; + sa.sa_handler = sigsegv_handler; + sigemptyset(&sa.sa_mask); + sa.sa_flags = 0; + + sigaction(SIGSEGV, &sa, NULL); + sigaction(SIGBUS, &sa, NULL); +} + +void cleanup_test_memory(void) { + if (safe_page != NULL) { + munmap(safe_page, PAGE_SIZE); + safe_page = NULL; + unsafe_page = NULL; + } +} + +/* Generate random strings with proper escaping for JSON */ +void generate_random_string(char *buffer, size_t max_len) { + static const char charset[] = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + size_t len = 1 + rand() % (max_len - 2); /* Ensure at least 1 char */ + + for (size_t i = 0; i < len; i++) { + buffer[i] = charset[rand() % (sizeof(charset) - 1)]; + } + buffer[len] = '\0'; +} + +/* Generate random numbers as strings */ +void generate_random_number(char *buffer, size_t max_len) { + double num = (double)rand() / RAND_MAX * 1000.0; + + /* Occasionally make it negative or add decimal places */ + if (rand() % 5 == 0) num = -num; + if (rand() % 3 != 0) num += (double)(rand() % 100) / 100.0; + + snprintf(buffer, max_len, "%.6g", num); +} + +/* Generate a random field name */ +void generate_random_field(char *field, size_t *field_len) { + generate_random_string(field, MAX_FIELD_SIZE / 2); + *field_len = strlen(field); +} + +/* Generate a random JSON object with fields */ +char *generate_random_json(size_t *len, char *field, size_t *field_len, int *has_field) { + char *json = malloc(MAX_JSON_SIZE); + if (json == NULL) { + perror("malloc"); + exit(EXIT_FAILURE); + } + + char buffer[MAX_JSON_SIZE / 4]; /* Buffer for generating values */ + int pos = 0; + int num_fields = 1 + rand() % 10; /* Random number of fields */ + int target_field_index = rand() % num_fields; /* Which field to return */ + + /* Start the JSON object */ + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "{"); + + /* Generate random field/value pairs */ + for (int i = 0; i < num_fields; i++) { + /* Add a comma if not the first field */ + if (i > 0) { + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, ", "); + } + + /* Generate a field name */ + if (i == target_field_index) { + /* This is our target field - save it for the caller */ + generate_random_field(field, field_len); + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "\"%s\": ", field); + *has_field = 1; + /* Sometimes change the last char so that it will not match. */ + if (rand() % 2) { + *has_field = 0; + field[*field_len-1] = '!'; + } + } else { + generate_random_string(buffer, MAX_FIELD_SIZE / 4); + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "\"%s\": ", buffer); + } + + /* Generate a random value type */ + int value_type = rand() % 5; + switch (value_type) { + case 0: /* String */ + generate_random_string(buffer, MAX_JSON_SIZE / 8); + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "\"%s\"", buffer); + break; + + case 1: /* Number */ + generate_random_number(buffer, MAX_JSON_SIZE / 8); + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "%s", buffer); + break; + + case 2: /* Boolean: true */ + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "true"); + break; + + case 3: /* Boolean: false */ + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "false"); + break; + + case 4: /* Null */ + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "null"); + break; + + case 5: /* Array (simple) */ + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "["); + int array_items = 1 + rand() % 5; + for (int j = 0; j < array_items; j++) { + if (j > 0) pos += snprintf(json + pos, MAX_JSON_SIZE - pos, ", "); + + /* Array items - either number or string */ + if (rand() % 2) { + generate_random_number(buffer, MAX_JSON_SIZE / 16); + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "%s", buffer); + } else { + generate_random_string(buffer, MAX_JSON_SIZE / 16); + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "\"%s\"", buffer); + } + } + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "]"); + break; + } + } + + /* Close the JSON object */ + pos += snprintf(json + pos, MAX_JSON_SIZE - pos, "}"); + *len = pos; + + return json; +} + +/* Corrupt JSON by replacing random characters */ +void corrupt_json(char *json, size_t len) { + if (len < 2) return; /* Too short to corrupt safely */ + + /* Corrupt 1-3 characters */ + int num_corruptions = 1 + rand() % 3; + for (int i = 0; i < num_corruptions; i++) { + size_t pos = rand() % len; + char corruption = " \t\n{}[]\":,0123456789abcdefXYZ"[rand() % 30]; + json[pos] = corruption; + } +} + +/* Run standard parser tests with generated valid JSON */ +void run_normal_tests(void) { + printf("Running normal JSON extraction tests...\n"); + + for (int i = 0; i < NUM_TEST_ITERATIONS; i++) { + char field[MAX_FIELD_SIZE] = {0}; + size_t field_len = 0; + size_t json_len = 0; + int has_field = 0; + + /* Generate random JSON */ + char *json = generate_random_json(&json_len, field, &field_len, &has_field); + + /* Use valid field to test parser */ + exprtoken *token = safe_extract_field(json, json_len, field, field_len); + + /* Check if we got a token as expected */ + if (has_field && token != NULL) { + exprTokenRelease(token); + tests_passed++; + } else if (!has_field && token == NULL) { + tests_passed++; + } else { + tests_failed++; + } + + /* Test with a non-existent field */ + char nonexistent_field[MAX_FIELD_SIZE] = "nonexistent_field"; + token = safe_extract_field(json, json_len, nonexistent_field, strlen(nonexistent_field)); + + if (token == NULL) { + tests_passed++; + } else { + exprTokenRelease(token); + tests_failed++; + } + + free(json); + } +} + +/* Run tests with corrupted JSON */ +void run_corruption_tests(void) { + printf("Running JSON corruption tests...\n"); + + for (int i = 0; i < NUM_CORRUPTION_TESTS; i++) { + char field[MAX_FIELD_SIZE] = {0}; + size_t field_len = 0; + size_t json_len = 0; + int has_field = 0; + + /* Generate random JSON */ + char *json = generate_random_json(&json_len, field, &field_len, &has_field); + + /* Make a copy and corrupt it */ + char *corrupted = malloc(json_len + 1); + if (!corrupted) { + perror("malloc"); + free(json); + exit(EXIT_FAILURE); + } + + memcpy(corrupted, json, json_len + 1); + corrupt_json(corrupted, json_len); + + /* Test with corrupted JSON */ + exprtoken *token = safe_extract_field(corrupted, json_len, field, field_len); + + /* We're just testing that it doesn't crash or access invalid memory */ + if (boundary_violation) { + printf("Boundary violation with corrupted JSON!\n"); + tests_failed++; + } else { + if (token != NULL) { + exprTokenRelease(token); + } + corruptions_passed++; + } + + free(corrupted); + free(json); + } +} + +/* Run tests at memory boundaries */ +void run_boundary_tests(void) { + printf("Running memory boundary tests...\n"); + + for (int i = 0; i < NUM_BOUNDARY_TESTS; i++) { + char field[MAX_FIELD_SIZE] = {0}; + size_t field_len = 0; + size_t json_len = 0; + int has_field = 0; + + /* Generate random JSON */ + char *temp_json = generate_random_json(&json_len, field, &field_len, &has_field); + + /* Truncate the JSON to a random length */ + size_t truncated_len = 1 + rand() % json_len; + + /* Place at the edge of the safe page */ + size_t offset = PAGE_SIZE - truncated_len; + memcpy(safe_page + offset, temp_json, truncated_len); + + /* Test parsing with non-existent field (forcing it to scan to end) */ + char nonexistent_field[MAX_FIELD_SIZE] = "nonexistent_field"; + exprtoken *token = safe_extract_field(safe_page + offset, truncated_len, + nonexistent_field, strlen(nonexistent_field)); + + /* We're just testing that it doesn't access memory beyond the boundary */ + if (boundary_violation) { + printf("Boundary violation at edge of memory page!\n"); + tests_failed++; + } else { + if (token != NULL) { + exprTokenRelease(token); + } + boundary_tests_passed++; + } + + free(temp_json); + } +} + +/* Print summary of test results */ +void print_test_summary(void) { + printf("\n===== FASTJSON PARSER TEST SUMMARY =====\n"); + printf("Normal tests passed: %d/%d\n", tests_passed, NUM_TEST_ITERATIONS * 2); + printf("Corruption tests passed: %d/%d\n", corruptions_passed, NUM_CORRUPTION_TESTS); + printf("Boundary tests passed: %d/%d\n", boundary_tests_passed, NUM_BOUNDARY_TESTS); + printf("Failed tests: %d\n", tests_failed); + + if (tests_failed == 0) { + printf("\nALL TESTS PASSED! The JSON parser appears to be robust.\n"); + } else { + printf("\nSome tests FAILED. The JSON parser may be vulnerable.\n"); + } +} + +/* Entry point for fastjson parser test */ +void run_fastjson_test(void) { + printf("Starting fastjson parser stress test...\n"); + + /* Seed the random number generator */ + srand(time(NULL)); + + /* Setup test memory environment */ + setup_test_memory(); + + /* Run the various test phases */ + run_normal_tests(); + run_corruption_tests(); + run_boundary_tests(); + + /* Print summary */ + print_test_summary(); + + /* Cleanup */ + cleanup_test_memory(); +} diff --git a/examples/redis-unstable/modules/vector-sets/hnsw.c b/examples/redis-unstable/modules/vector-sets/hnsw.c new file mode 100644 index 0000000..2b4ebc0 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/hnsw.c @@ -0,0 +1,2999 @@ +/* HNSW (Hierarchical Navigable Small World) Implementation. + * + * Based on the paper by Yu. A. Malkov, D. A. Yashunin. + * + * Many details of this implementation, not covered in the paper, were + * obtained simulating different workloads and checking the connection + * quality of the graph. + * + * Notably, this implementation: + * + * 1. Only uses bi-directional links, implementing strategies in order to + * link new nodes even when candidates are full, and our new node would + * be not close enough to replace old links in candidate. + * + * 2. We normalize on-insert, making cosine similarity and dot product the + * same. This means we can't use euclidean distance or alike here. + * Together with quantization, this provides an important speedup that + * makes HNSW more practical. + * + * 3. The quantization used is int8. And it is performed per-vector, so the + * "range" (max abs value) is also stored alongside with the quantized data. + * + * 4. This library implements true elements deletion, not just marking the + * element as deleted, but removing it (we can do it since our links are + * bidirectional), and reliking the nodes orphaned of one link among + * them. + * + * Copyright (c) 2009-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of (a) the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + * Originally authored by: Salvatore Sanfilippo. + */ + +#define _DEFAULT_SOURCE +#define _POSIX_C_SOURCE 200809L + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <math.h> +#include <stdint.h> +#include <float.h> /* for INFINITY if not in math.h */ +#include <assert.h> +#include "hnsw.h" +#include "mixer.h" + +/* Check if we can compile SIMD code with function attributes */ +#if defined (__x86_64__) && ((defined(__GNUC__) && __GNUC__ >= 5) || (defined(__clang__) && __clang_major__ >= 4)) +#if defined(__has_attribute) && __has_attribute(target) +#define HAVE_AVX2 +#define HAVE_AVX512 +#endif +#endif + +#if defined (HAVE_AVX2) +#define ATTRIBUTE_TARGET_AVX2 __attribute__((target("avx2,fma"))) +#define VSET_USE_AVX2 (__builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma")) +#else +#define ATTRIBUTE_TARGET_AVX2 +#define VSET_USE_AVX2 0 +#endif + +#if defined (HAVE_AVX512) +#define ATTRIBUTE_TARGET_AVX512 __attribute__((target("avx512f,fma"))) +#define VSET_USE_AVX512 (__builtin_cpu_supports("avx512f")) +#else +#define ATTRIBUTE_TARGET_AVX512 +#define VSET_USE_AVX512 0 +#endif + +/* Include SIMD headers when supported */ +#if defined(HAVE_AVX2) || defined(HAVE_AVX512) +#include <immintrin.h> +#endif + +#if 0 +#define debugmsg printf +#else +#define debugmsg if(0) printf +#endif + +#ifndef INFINITY +#define INFINITY (1.0/0.0) +#endif + +#define MIN(a,b) ((a) < (b) ? (a) : (b)) + +/* Algorithm parameters. */ + +#define HNSW_P 0.25 /* Probability of level increase. */ +#define HNSW_MAX_LEVEL 16 /* Max level nodes can reach. */ +#define HNSW_EF_C 200 /* Default size of dynamic candidate list while + * inserting a new node, in case 0 is passed to + * the 'ef' argument while inserting. This is also + * used when deleting nodes for the search step + * needed sometimes to reconnect nodes that remain + * orphaned of one link. */ + +static void (*hfree)(void *p) = free; +static void *(*hmalloc)(size_t s) = malloc; +static void *(*hrealloc)(void *old, size_t s) = realloc; + +void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t), + void *(*realloc_ptr)(void*, size_t)) +{ + hfree = free_ptr; + hmalloc = malloc_ptr; + hrealloc = realloc_ptr; +} + +// Get a warning if you use the libc allocator functions for mistake. +#define malloc use_hmalloc_instead +#define realloc use_hrealloc_instead +#define free use_hfree_instead + +/* ============================== Prototypes ================================ */ +void hnsw_cursor_element_deleted(HNSW *index, hnswNode *deleted); + +/* ============================ Priority queue ================================ + * We need a priority queue to take an ordered list of candidates. Right now + * it is implemented as a linear array, since it is relatively small. + * + * You may find it to be odd that we take the best element (smaller distance) + * at the end of the array, but this way popping from the pqueue is O(1), as + * we need to just decrement the count, and this is a very used operation + * in a critical code path. This makes the priority queue implementation a + * bit more complex in the insertion, but for good reasons. */ + +/* Maximum number of candidates we'll ever need (cit. Bill Gates). */ +#define HNSW_MAX_CANDIDATES 256 + +typedef struct { + hnswNode *node; + float distance; +} pqitem; + +typedef struct { + pqitem *items; /* Array of items. */ + uint32_t count; /* Current number of items. */ + uint32_t cap; /* Maximum capacity. */ +} pqueue; + +/* The HNSW algorithms access the pqueue conceptually from nearest (index 0) + * to farthest (larger indexes) node, so the following macros are used to + * access the pqueue in this fashion, even if the internal order is + * actually reversed. */ +#define pq_get_node(q,i) ((q)->items[(q)->count-(i+1)].node) +#define pq_get_distance(q,i) ((q)->items[(q)->count-(i+1)].distance) + +/* Create a new priority queue with given capacity. Adding to the + * pqueue only retains 'capacity' elements with the shortest distance. */ +pqueue *pq_new(uint32_t capacity) { + pqueue *pq = hmalloc(sizeof(*pq)); + if (!pq) return NULL; + + pq->items = hmalloc(sizeof(pqitem) * capacity); + if (!pq->items) { + hfree(pq); + return NULL; + } + + pq->count = 0; + pq->cap = capacity; + return pq; +} + +/* Free a priority queue. */ +void pq_free(pqueue *pq) { + if (!pq) return; + hfree(pq->items); + hfree(pq); +} + +/* Insert maintaining distance order (higher distances first). */ +void pq_push(pqueue *pq, hnswNode *node, float distance) { + if (pq->count < pq->cap) { + /* Queue not full: shift right from high distances to make room. */ + uint32_t i = pq->count; + while (i > 0 && pq->items[i-1].distance < distance) { + pq->items[i] = pq->items[i-1]; + i--; + } + pq->items[i].node = node; + pq->items[i].distance = distance; + pq->count++; + } else { + /* Queue full: if new item is worse than worst, ignore it. */ + if (distance >= pq->items[0].distance) return; + + /* Otherwise shift left from low distances to drop worst. */ + uint32_t i = 0; + while (i < pq->cap-1 && pq->items[i+1].distance > distance) { + pq->items[i] = pq->items[i+1]; + i++; + } + pq->items[i].node = node; + pq->items[i].distance = distance; + } +} + +/* Remove and return the top (closest) element, which is at count-1 + * since we store elements with higher distances first. + * Runs in constant time. */ +hnswNode *pq_pop(pqueue *pq, float *distance) { + if (pq->count == 0) return NULL; + pq->count--; + *distance = pq->items[pq->count].distance; + return pq->items[pq->count].node; +} + +/* Get distance of the furthest element. + * An empty priority queue has infinite distance as its furthest element, + * note that this behavior is needed by the algorithms below. */ +float pq_max_distance(pqueue *pq) { + if (pq->count == 0) return INFINITY; + return pq->items[0].distance; +} + +/* ============================ HNSW algorithm ============================== */ + +#if defined(HAVE_AVX512) +/* AVX512 optimized dot product for float vectors */ +ATTRIBUTE_TARGET_AVX512 +float vectors_distance_float_avx512(const float *x, const float *y, uint32_t dim) { + __m512 sum = _mm512_setzero_ps(); + uint32_t i; + + /* Process 16 floats at a time with AVX512 */ + for (i = 0; i + 15 < dim; i += 16) { + __m512 vx = _mm512_loadu_ps(&x[i]); + __m512 vy = _mm512_loadu_ps(&y[i]); + sum = _mm512_fmadd_ps(vx, vy, sum); + } + + /* Horizontal sum of the 16 elements in sum */ + float dot = _mm512_reduce_add_ps(sum); + + /* Handle remaining elements */ + for (; i < dim; i++) { + dot += x[i] * y[i]; + } + + return 1.0f - dot; +} +#endif /* HAVE_AVX512 */ + +#if defined(HAVE_AVX2) +/* AVX2 optimized dot product for float vectors */ +ATTRIBUTE_TARGET_AVX2 +float vectors_distance_float_avx2(const float *x, const float *y, uint32_t dim) { + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + uint32_t i; + + /* Process 16 floats at a time with two AVX2 registers */ + for (i = 0; i + 15 < dim; i += 16) { + __m256 vx1 = _mm256_loadu_ps(&x[i]); + __m256 vy1 = _mm256_loadu_ps(&y[i]); + __m256 vx2 = _mm256_loadu_ps(&x[i + 8]); + __m256 vy2 = _mm256_loadu_ps(&y[i + 8]); + + sum1 = _mm256_fmadd_ps(vx1, vy1, sum1); + sum2 = _mm256_fmadd_ps(vx2, vy2, sum2); + } + + /* Combine the two sums */ + __m256 combined = _mm256_add_ps(sum1, sum2); + + /* Horizontal sum of the 8 elements */ + __m128 sum_high = _mm256_extractf128_ps(combined, 1); + __m128 sum_low = _mm256_castps256_ps128(combined); + __m128 sum_128 = _mm_add_ps(sum_high, sum_low); + + sum_128 = _mm_hadd_ps(sum_128, sum_128); + sum_128 = _mm_hadd_ps(sum_128, sum_128); + + float dot = _mm_cvtss_f32(sum_128); + + /* Handle remaining elements */ + for (; i < dim; i++) { + dot += x[i] * y[i]; + } + + return 1.0f - dot; +} +#endif /* HAVE_AVX2 */ + +/* Optimized dot product: automatically selects best available implementation + * Dot product: our vectors are already normalized. + * Version for not quantized vectors of floats. */ +float vectors_distance_float(const float *x, const float *y, uint32_t dim) { +#if defined(HAVE_AVX512) + if (dim >= 16 && VSET_USE_AVX512) { + return vectors_distance_float_avx512(x, y, dim); + } +#endif + +#if defined(HAVE_AVX2) + if (VSET_USE_AVX2 && dim >= 16) { + return vectors_distance_float_avx2(x, y, dim); + } +#endif + + /* Fallback to original scalar implementation */ + float dot0 = 0.0f, dot1 = 0.0f; + uint32_t i; + + /* Use two accumulators to reduce dependencies among multiplications. + * This provides a clear speed boost in Apple silicon, but should be + * help in general. */ + for (i = 0; i + 7 < dim; i += 8) { + dot0 += x[i] * y[i] + + x[i+1] * y[i+1] + + x[i+2] * y[i+2] + + x[i+3] * y[i+3]; + + dot1 += x[i+4] * y[i+4] + + x[i+5] * y[i+5] + + x[i+6] * y[i+6] + + x[i+7] * y[i+7]; + } + + /* Handle the remaining elements. These are a minority in the case + * of a small vector, don't optimize this part. */ + for (; i < dim; i++) dot0 += x[i] * y[i]; + + /* The following line may be counter intuitive. The dot product of + * normalized vectors is equivalent to their cosine similarity. The + * cosine will be from -1 (vectors facing opposite directions in the + * N-dim space) to 1 (vectors are facing in the same direction). + * + * We kinda want a "score" of distance from 0 to 2 (this is a distance + * function and we want minimize the distance for K-NN searches), so we + * can't just add 1: that would return a number in the 0-2 range, with + * 0 meaning opposite vectors and 2 identical vectors, so this is + * similarity, not distance. + * + * Returning instead (1 - dotprod) inverts the meaning: 0 is identical + * and 2 is opposite, hence it is their distance. + * + * Why don't normalize the similarity right now, and return from 0 to + * 1? Because division is costly. */ + return 1.0f - (dot0 + dot1); +} + +/* Q8 quants dotproduct. We do integer math and later fix it by range. */ +float vectors_distance_q8(const int8_t *x, const int8_t *y, uint32_t dim, + float range_a, float range_b) { + // Handle zero vectors special case. + if (range_a == 0 || range_b == 0) { + /* Zero vector distance from anything is 1.0 + * (since 1.0 - dot_product where dot_product = 0). */ + return 1.0f; + } + + /* Each vector is quantized from [-max_abs, +max_abs] to [-127, 127] + * where range = 2*max_abs. */ + const float scale_product = (range_a/127) * (range_b/127); + + int32_t dot0 = 0, dot1 = 0; + uint32_t i; + + // Process 8 elements at a time for better pipeline utilization. + for (i = 0; i + 7 < dim; i += 8) { + dot0 += ((int32_t)x[i]) * ((int32_t)y[i]) + + ((int32_t)x[i+1]) * ((int32_t)y[i+1]) + + ((int32_t)x[i+2]) * ((int32_t)y[i+2]) + + ((int32_t)x[i+3]) * ((int32_t)y[i+3]); + + dot1 += ((int32_t)x[i+4]) * ((int32_t)y[i+4]) + + ((int32_t)x[i+5]) * ((int32_t)y[i+5]) + + ((int32_t)x[i+6]) * ((int32_t)y[i+6]) + + ((int32_t)x[i+7]) * ((int32_t)y[i+7]); + } + + // Handle remaining elements. + for (; i < dim; i++) dot0 += ((int32_t)x[i]) * ((int32_t)y[i]); + + // Convert to original range. + float dotf = (dot0 + dot1) * scale_product; + float distance = 1.0f - dotf; + + // Clamp distance to [0, 2]. + if (distance < 0) distance = 0; + else if (distance > 2) distance = 2; + return distance; +} + +static inline int popcount64(uint64_t x) { + x = (x & 0x5555555555555555) + ((x >> 1) & 0x5555555555555555); + x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333); + x = (x & 0x0F0F0F0F0F0F0F0F) + ((x >> 4) & 0x0F0F0F0F0F0F0F0F); + x = (x & 0x00FF00FF00FF00FF) + ((x >> 8) & 0x00FF00FF00FF00FF); + x = (x & 0x0000FFFF0000FFFF) + ((x >> 16) & 0x0000FFFF0000FFFF); + x = (x & 0x00000000FFFFFFFF) + ((x >> 32) & 0x00000000FFFFFFFF); + return x; +} + +/* Binary vectors distance. */ +float vectors_distance_bin(const uint64_t *x, const uint64_t *y, uint32_t dim) { + uint32_t len = (dim+63)/64; + uint32_t opposite = 0; + for (uint32_t j = 0; j < len; j++) { + uint64_t xor = x[j]^y[j]; + opposite += popcount64(xor); + } + return (float)opposite*2/dim; +} + +/* Dot product between nodes. Will call the right version depending on the + * quantization used. */ +float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b) { + switch(index->quant_type) { + case HNSW_QUANT_NONE: + return vectors_distance_float(a->vector,b->vector,index->vector_dim); + case HNSW_QUANT_Q8: + return vectors_distance_q8(a->vector,b->vector,index->vector_dim,a->quants_range,b->quants_range); + case HNSW_QUANT_BIN: + return vectors_distance_bin(a->vector,b->vector,index->vector_dim); + default: + assert(1 != 1); + return 0; + } +} + +/* This do Q8 'range' quantization. + * For people looking at this code thinking: Oh, I could use min/max + * quants instead! Well: I tried with min/max normalization but the dot + * product needs to accumulate the sum for later correction, and it's slower. */ +void quantize_to_q8(float *src, int8_t *dst, uint32_t dim, float *rangeptr) { + float max_abs = 0; + for (uint32_t j = 0; j < dim; j++) { + if (src[j] > max_abs) max_abs = src[j]; + if (-src[j] > max_abs) max_abs = -src[j]; + } + + if (max_abs == 0) { + if (rangeptr) *rangeptr = 0; + memset(dst, 0, dim); + return; + } + + const float scale = 127.0f / max_abs; // Scale to map to [-127, 127]. + + for (uint32_t j = 0; j < dim; j++) { + dst[j] = (int8_t)roundf(src[j] * scale); + } + if (rangeptr) *rangeptr = max_abs; // Return max_abs instead of 2*max_abs. +} + +/* Binary quantization of vector 'src' to 'dst'. We use full words of + * 64 bit as smallest unit, we will just set all the unused bits to 0 + * so that they'll be the same in all the vectors, and when xor+popcount + * is used to compute the distance, such bits are not considered. This + * allows to go faster. */ +void quantize_to_bin(float *src, uint64_t *dst, uint32_t dim) { + memset(dst,0,(dim+63)/64*sizeof(uint64_t)); + for (uint32_t j = 0; j < dim; j++) { + uint32_t word = j/64; + uint32_t bit = j&63; + /* Since cosine similarity checks the vector direction and + * not magnitudo, we do likewise in the binary quantization and + * just remember if the component is positive or negative. */ + if (src[j] > 0) dst[word] |= 1ULL<<bit; + } +} + +/* L2 normalization of the float vector. + * + * Store the L2 value on 'l2ptr' if not NULL. This way the process + * can be reversed even if some precision will be lost. */ +void hnsw_normalize_vector(float *x, float *l2ptr, uint32_t dim) { + float l2 = 0; + uint32_t i; + for (i = 0; i + 3 < dim; i += 4) { + l2 += x[i]*x[i] + + x[i+1]*x[i+1] + + x[i+2]*x[i+2] + + x[i+3]*x[i+3]; + } + for (; i < dim; i++) l2 += x[i]*x[i]; + if (l2 == 0) return; // All zero vector, can't normalize. + + l2 = sqrtf(l2); + if (l2ptr) *l2ptr = l2; + for (i = 0; i < dim; i++) x[i] /= l2; +} + +/* Helper function to generate random level. */ +uint32_t random_level(void) { + static const int threshold = HNSW_P * RAND_MAX; + uint32_t level = 0; + + while (rand() < threshold && level < HNSW_MAX_LEVEL) + level += 1; + return level; +} + +/* Create new HNSW index, quantized or not. */ +HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type, uint32_t m) { + HNSW *index = hmalloc(sizeof(HNSW)); + if (!index) return NULL; + + /* M parameter sanity check. */ + if (m == 0) m = HNSW_DEFAULT_M; + else if (m > HNSW_MAX_M) m = HNSW_MAX_M; + + index->M = m; + index->quant_type = quant_type; + index->enter_point = NULL; + index->max_level = 0; + index->vector_dim = vector_dim; + index->node_count = 0; + index->last_id = 0; + index->head = NULL; + index->cursors = NULL; + + /* Initialize epochs array. */ + for (int i = 0; i < HNSW_MAX_THREADS; i++) + index->current_epoch[i] = 0; + + /* Initialize locks. */ + if (pthread_rwlock_init(&index->global_lock, NULL) != 0) { + hfree(index); + return NULL; + } + + for (int i = 0; i < HNSW_MAX_THREADS; i++) { + if (pthread_mutex_init(&index->slot_locks[i], NULL) != 0) { + /* Clean up previously initialized mutexes. */ + for (int j = 0; j < i; j++) + pthread_mutex_destroy(&index->slot_locks[j]); + pthread_rwlock_destroy(&index->global_lock); + hfree(index); + return NULL; + } + } + + /* Initialize atomic variables. */ + index->next_slot = 0; + index->version = 0; + return index; +} + +/* Fill 'vec' with the node vector, de-normalizing and de-quantizing it + * as needed. Note that this function will return an approximated version + * of the original vector. */ +void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec) { + if (index->quant_type == HNSW_QUANT_NONE) { + memcpy(vec,node->vector,index->vector_dim*sizeof(float)); + } else if (index->quant_type == HNSW_QUANT_Q8) { + int8_t *quants = node->vector; + for (uint32_t j = 0; j < index->vector_dim; j++) + vec[j] = (quants[j]*node->quants_range)/127; + } else if (index->quant_type == HNSW_QUANT_BIN) { + uint64_t *bits = node->vector; + for (uint32_t j = 0; j < index->vector_dim; j++) { + uint32_t word = j/64; + uint32_t bit = j&63; + vec[j] = (bits[word] & (1ULL<<bit)) ? 1.0f : -1.0f; + } + } + + // De-normalize. + if (index->quant_type != HNSW_QUANT_BIN) { + for (uint32_t j = 0; j < index->vector_dim; j++) + vec[j] *= node->l2; + } +} + +/* Return the number of bytes needed to represent a vector in the index, + * that is function of the dimension of the vectors and the quantization + * type used. */ +uint32_t hnsw_quants_bytes(HNSW *index) { + switch(index->quant_type) { + case HNSW_QUANT_NONE: return index->vector_dim * sizeof(float); + case HNSW_QUANT_Q8: return index->vector_dim; + case HNSW_QUANT_BIN: return (index->vector_dim+63)/64*8; + default: assert(0 && "Quantization type not supported."); + } +} + +/* Create new node. Returns NULL on out of memory. + * It is possible to pass the vector as floats or, in case this index + * was already stored on disk and is being loaded, or serialized and + * transmitted in any form, the already quantized version in + * 'qvector'. + * + * Only vector or qvector should be non-NULL. The reason why passing + * a quantized vector is useful, is that because re-normalizing and + * re-quantizing several times the same vector may accumulate rounding + * errors. So if you work with quantized indexes, you should save + * the quantized indexes. + * + * Note that, together with qvector, the quantization range is needed, + * since this library uses per-vector quantization. In case of quantized + * vectors the l2 is considered to be '1', so if you want to restore + * the right l2 (to use the API that returns an approximation of the + * original vector) make sure to save the l2 on disk and set it back + * after the node creation (see later for the serialization API that + * handles this and more). */ +hnswNode *hnsw_node_new(HNSW *index, uint64_t id, const float *vector, const int8_t *qvector, float qrange, uint32_t level, int normalize) { + hnswNode *node = hmalloc(sizeof(hnswNode)+(sizeof(hnswNodeLayer)*(level+1))); + if (!node) return NULL; + + if (id == 0) id = ++index->last_id; + node->level = level; + node->id = id; + node->next = NULL; + node->vector = NULL; + node->l2 = 1; // Default in case of already quantized vectors. It is + // up to the caller to fill this later, if needed. + + /* Initialize visited epoch array. */ + for (int i = 0; i < HNSW_MAX_THREADS; i++) + node->visited_epoch[i] = 0; + + if (qvector == NULL) { + /* Copy input vector. */ + node->vector = hmalloc(sizeof(float) * index->vector_dim); + if (!node->vector) { + hfree(node); + return NULL; + } + memcpy(node->vector, vector, sizeof(float) * index->vector_dim); + if (normalize) + hnsw_normalize_vector(node->vector,&node->l2,index->vector_dim); + + /* Handle quantization. */ + if (index->quant_type != HNSW_QUANT_NONE) { + void *quants = hmalloc(hnsw_quants_bytes(index)); + if (quants == NULL) { + hfree(node->vector); + hfree(node); + return NULL; + } + + // Quantize. + switch(index->quant_type) { + case HNSW_QUANT_Q8: + quantize_to_q8(node->vector,quants,index->vector_dim,&node->quants_range); + break; + case HNSW_QUANT_BIN: + quantize_to_bin(node->vector,quants,index->vector_dim); + break; + default: + assert(0 && "Quantization type not handled."); + break; + } + + // Discard the full precision vector. + hfree(node->vector); + node->vector = quants; + } + } else { + // We got the already quantized vector. Just copy it. + assert(index->quant_type != HNSW_QUANT_NONE); + uint32_t vector_bytes = hnsw_quants_bytes(index); + node->vector = hmalloc(vector_bytes); + node->quants_range = qrange; + if (node->vector == NULL) { + hfree(node); + return NULL; + } + memcpy(node->vector,qvector,vector_bytes); + } + + /* Initialize each layer. */ + for (uint32_t i = 0; i <= level; i++) { + uint32_t max_links = (i == 0) ? index->M*2 : index->M; + node->layers[i].max_links = max_links; + node->layers[i].num_links = 0; + node->layers[i].worst_distance = 0; + node->layers[i].worst_idx = 0; + node->layers[i].links = hmalloc(sizeof(hnswNode*) * max_links); + if (!node->layers[i].links) { + for (uint32_t j = 0; j < i; j++) hfree(node->layers[j].links); + hfree(node->vector); + hfree(node); + return NULL; + } + } + + return node; +} + +/* Free a node. */ +void hnsw_node_free(hnswNode *node) { + if (!node) return; + + for (uint32_t i = 0; i <= node->level; i++) + hfree(node->layers[i].links); + + hfree(node->vector); + hfree(node); +} + +/* Free the entire index. */ +void hnsw_free(HNSW *index,void(*free_value)(void*value)) { + if (!index) return; + + hnswNode *current = index->head; + while (current) { + hnswNode *next = current->next; + if (free_value) free_value(current->value); + hnsw_node_free(current); + current = next; + } + + /* Destroy locks */ + pthread_rwlock_destroy(&index->global_lock); + for (int i = 0; i < HNSW_MAX_THREADS; i++) { + pthread_mutex_destroy(&index->slot_locks[i]); + } + + hfree(index); +} + +/* Add node to linked list of nodes. We may need to scan the whole + * HNSW graph for several reasons. The list is doubly linked since we + * also need the ability to remove a node without scanning the whole thing. */ +void hnsw_add_node(HNSW *index, hnswNode *node) { + node->next = index->head; + node->prev = NULL; + if (index->head) + index->head->prev = node; + index->head = node; + index->node_count++; +} + +/* Search the specified layer starting from the specified entry point + * to collect 'ef' nodes that are near to 'query'. + * + * This function implements optional hybrid search, so that each node + * can be accepted or not based on its associated value. In this case + * a callback 'filter_callback' should be passed, together with a maximum + * effort for the search (number of candidates to evaluate), since even + * with a a low "EF" value we risk that there are too few nodes that satisfy + * the provided filter, and we could trigger a full scan. */ +pqueue *search_layer_with_filter( + HNSW *index, hnswNode *query, hnswNode *entry_point, + uint32_t ef, uint32_t layer, uint32_t slot, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates) +{ + // Mark visited nodes with a never seen epoch. + index->current_epoch[slot]++; + + pqueue *candidates = pq_new(HNSW_MAX_CANDIDATES); + pqueue *results = pq_new(ef); + if (!candidates || !results) { + if (candidates) pq_free(candidates); + if (results) pq_free(results); + return NULL; + } + + // Take track of the total effort: only used when filtering via + // a callback to have a bound effort. + uint32_t evaluated_candidates = 1; + + // Add entry point. + float dist = hnsw_distance(index, query, entry_point); + pq_push(candidates, entry_point, dist); + if (filter_callback == NULL || + filter_callback(entry_point->value, filter_privdata)) + { + pq_push(results, entry_point, dist); + } + entry_point->visited_epoch[slot] = index->current_epoch[slot]; + + // Process candidates. + while (candidates->count > 0) { + // Max effort. If zero, we keep scanning. + if (filter_callback && + max_candidates && + evaluated_candidates >= max_candidates) break; + + float cur_dist; + hnswNode *current = pq_pop(candidates, &cur_dist); + evaluated_candidates++; + + float furthest = pq_max_distance(results); + if (results->count >= ef && cur_dist > furthest) break; + + /* Check neighbors. */ + for (uint32_t i = 0; i < current->layers[layer].num_links; i++) { + hnswNode *neighbor = current->layers[layer].links[i]; + + if (neighbor->visited_epoch[slot] == index->current_epoch[slot]) + continue; // Already visited during this scan. + + neighbor->visited_epoch[slot] = index->current_epoch[slot]; + float neighbor_dist = hnsw_distance(index, query, neighbor); + + furthest = pq_max_distance(results); + if (filter_callback == NULL) { + /* Original HNSW logic when no filtering: + * Add to results if better than current max or + * results not full. */ + if (neighbor_dist < furthest || results->count < ef) { + pq_push(candidates, neighbor, neighbor_dist); + pq_push(results, neighbor, neighbor_dist); + } + } else { + /* With filtering: we add candidates even if doesn't match + * the filter, in order to continue to explore the graph. */ + if (neighbor_dist < furthest || candidates->count < ef) { + pq_push(candidates, neighbor, neighbor_dist); + } + + /* Add results only if passes filter. */ + if (filter_callback(neighbor->value, filter_privdata)) { + if (neighbor_dist < furthest || results->count < ef) { + pq_push(results, neighbor, neighbor_dist); + } + } + } + } + } + + pq_free(candidates); + return results; +} + +/* Just a wrapper without hybrid search callback. */ +pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point, + uint32_t ef, uint32_t layer, uint32_t slot) +{ + return search_layer_with_filter(index, query, entry_point, ef, layer, slot, + NULL, NULL, 0); +} + +/* This function is used in order to initialize a node allocated in the + * function stack with the specified vector. The idea is that we can + * easily use hnsw_distance() from a vector and the HNSW nodes this way: + * + * hnswNode myQuery; + * hnsw_init_tmp_node(myIndex,&myQuery,0,some_vector); + * hnsw_distance(&myQuery, some_hnsw_node); + * + * Make sure to later free the node with: + * + * hnsw_free_tmp_node(&myQuery,some_vector); + * You have to pass the vector to the free function, because sometimes + * hnsw_init_tmp_node() may just avoid allocating a vector at all, + * just reusing 'some_vector' pointer. + * + * Return 0 on out of memory, 1 on success. + */ +int hnsw_init_tmp_node(HNSW *index, hnswNode *node, int is_normalized, const float *vector) { + node->vector = NULL; + + /* Work on a normalized query vector if the input vector is + * not normalized. */ + if (!is_normalized) { + node->vector = hmalloc(sizeof(float)*index->vector_dim); + if (node->vector == NULL) return 0; + memcpy(node->vector,vector,sizeof(float)*index->vector_dim); + hnsw_normalize_vector(node->vector,NULL,index->vector_dim); + } else { + node->vector = (float*)vector; + } + + /* If quantization is enabled, our query fake node should be + * quantized as well. */ + if (index->quant_type != HNSW_QUANT_NONE) { + void *quants = hmalloc(hnsw_quants_bytes(index)); + if (quants == NULL) { + if (node->vector != vector) hfree(node->vector); + return 0; + } + switch(index->quant_type) { + case HNSW_QUANT_Q8: + quantize_to_q8(node->vector, quants, index->vector_dim, &node->quants_range); + break; + case HNSW_QUANT_BIN: + quantize_to_bin(node->vector, quants, index->vector_dim); + } + if (node->vector != vector) hfree(node->vector); + node->vector = quants; + } + return 1; +} + +/* Free the stack allocated node initialized by hnsw_init_tmp_node(). */ +void hnsw_free_tmp_node(hnswNode *node, const float *vector) { + if (node->vector != vector) hfree(node->vector); +} + +/* Return approximated K-NN items. Note that neighbors and distances + * arrays must have space for at least 'k' items. + * norm_query should be set to 1 if the query vector is already + * normalized, otherwise, if 0, the function will copy the vector, + * L2-normalize the copy and search using the normalized version. + * + * If the filter_privdata callback is passed, only elements passing the + * specified filter (invoked with privdata and the value associated + * to the node as arguments) are returned. In such case, if max_candidates + * is not NULL, it represents the maximum number of nodes to explore, since + * the search may be otherwise unbound if few or no elements pass the + * filter. */ +int hnsw_search_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates) + +{ + if (!index || !query_vector || !neighbors || k == 0) return -1; + if (!index->enter_point) return 0; // Empty index. + + /* Use a fake node that holds the query vector, this way we can + * use our normal node to node distance functions when checking + * the distance between query and graph nodes. */ + hnswNode query; + if (hnsw_init_tmp_node(index,&query,query_vector_is_normalized,query_vector) == 0) return -1; + + // Start searching from the entry point. + hnswNode *curr_ep = index->enter_point; + + /* Start from higher layer to layer 1 (layer 0 is handled later) + * in the next section. Descend to the most similar node found + * so far. */ + for (int lc = index->max_level; lc > 0; lc--) { + pqueue *results = search_layer(index, &query, curr_ep, 1, lc, slot); + if (!results) continue; + + if (results->count > 0) { + curr_ep = pq_get_node(results,0); + } + pq_free(results); + } + + /* Search bottom layer (the most densely populated) with ef = k */ + pqueue *results = search_layer_with_filter( + index, &query, curr_ep, k, 0, slot, filter_callback, + filter_privdata, max_candidates); + if (!results) { + hnsw_free_tmp_node(&query, query_vector); + return -1; + } + + /* Copy results. */ + uint32_t found = MIN(k, results->count); + for (uint32_t i = 0; i < found; i++) { + neighbors[i] = pq_get_node(results,i); + if (distances) { + distances[i] = pq_get_distance(results,i); + } + } + + pq_free(results); + hnsw_free_tmp_node(&query, query_vector); + return found; +} + +/* Wrapper to hnsw_search_with_filter() when no filter is needed. */ +int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized) +{ + return hnsw_search_with_filter(index,query_vector,k,neighbors, + distances,slot,query_vector_is_normalized, + NULL,NULL,0); +} + +/* Rescan a node and update the wortst neighbor index. + * The followinng two functions are variants of this function to be used + * when links are added or removed: they may do less work than a full scan. */ +void hnsw_update_worst_neighbor(HNSW *index, hnswNode *node, uint32_t layer) { + float worst_dist = 0; + uint32_t worst_idx = 0; + for (uint32_t i = 0; i < node->layers[layer].num_links; i++) { + float dist = hnsw_distance(index, node, node->layers[layer].links[i]); + if (dist > worst_dist) { + worst_dist = dist; + worst_idx = i; + } + } + node->layers[layer].worst_distance = worst_dist; + node->layers[layer].worst_idx = worst_idx; +} + +/* Update node worst neighbor distance information when a new neighbor + * is added. */ +void hnsw_update_worst_neighbor_on_add(HNSW *index, hnswNode *node, uint32_t layer, uint32_t added_index, float distance) { + (void) index; // Unused but here for API symmetry. + if (node->layers[layer].num_links == 1 || // First neighbor? + distance > node->layers[layer].worst_distance) // New worst? + { + node->layers[layer].worst_distance = distance; + node->layers[layer].worst_idx = added_index; + } +} + +/* Update node worst neighbor distance information when a linked neighbor + * is removed. */ +void hnsw_update_worst_neighbor_on_remove(HNSW *index, hnswNode *node, uint32_t layer, uint32_t removed_idx) +{ + if (node->layers[layer].num_links == 0) { + node->layers[layer].worst_distance = 0; + node->layers[layer].worst_idx = 0; + } else if (removed_idx == node->layers[layer].worst_idx) { + hnsw_update_worst_neighbor(index,node,layer); + } else if (removed_idx < node->layers[layer].worst_idx) { + // Just update index if we removed element before worst. + node->layers[layer].worst_idx--; + } +} + +/* We have a list of candidate nodes to link to the new node, when inserting + * one. This function selects which nodes to link and performs the linking. + * + * Parameters: + * + * - 'candidates' is the priority queue of potential good nodes to link to the + * new node 'new_node'. + * - 'required_links' is as many links we would like our new_node to get + * at the specified layer. + * - 'aggressive' changes the strategy used to find good neighbors as follows: + * + * This function is called with aggressive=0 for all the layers, including + * layer 0. When called like that, it will use the diversity of links and + * quality of links checks before linking our new node with some candidate. + * + * However if the insert function finds that at layer 0, with aggressive=0, + * few connections were made, it calls this function again with aggressiveness + * levels greater up to 2. + * + * At aggressive=1, the diversity checks are disabled, and the candidate + * node for linking is accepted even if it is nearest to an already accepted + * neighbor than it is to the new node. + * + * When we link our new node by replacing the link of a candidate neighbor + * that already has the max number of links, inevitably some other node loses + * a connection (to make space for our new node link). In this case: + * + * 1. If such "dropped" node would remain with too little links, we try with + * some different neighbor instead, however as the 'aggressive' parameter + * has incremental values (0, 1, 2) we are more and more willing to leave + * the dropped node with fever connections. + * 2. If aggressive=2, we will scan the candidate neighbor node links to + * find a different linked-node to replace, one better connected even if + * its distance is not the worse. + * + * Note: this function is also called during deletion of nodes in order to + * provide certain nodes with additional links. + */ +void select_neighbors(HNSW *index, pqueue *candidates, hnswNode *new_node, + uint32_t layer, uint32_t required_links, int aggressive) +{ + for (uint32_t i = 0; i < candidates->count; i++) { + hnswNode *neighbor = pq_get_node(candidates,i); + if (neighbor == new_node) continue; // Don't link node with itself. + + /* Use our cached distance among the new node and the candidate. */ + float dist = pq_get_distance(candidates,i); + + /* First of all, since our links are all bidirectional, if the + * new node for any reason has no longer room, or if it accumulated + * the required number of links, return ASAP. */ + if (new_node->layers[layer].num_links >= new_node->layers[layer].max_links || + new_node->layers[layer].num_links >= required_links) return; + + /* If aggressive is true, it is possible that the new node + * already got some link among the candidates (see the top comment, + * this function gets re-called in case of too few links). + * So we need to check if this candidate is already linked to + * the new node. */ + if (aggressive) { + int duplicated = 0; + for (uint32_t j = 0; j < new_node->layers[layer].num_links; j++) { + if (new_node->layers[layer].links[j] == neighbor) { + duplicated = 1; + break; + } + } + if (duplicated) continue; + } + + /* Diversity check. We accept new candidates + * only if there is no element already accepted that is nearest + * to the candidate than the new element itself. + * However this check is disabled if we have pressure to find + * new links (aggressive != 0) */ + if (!aggressive) { + int diversity_failed = 0; + for (uint32_t j = 0; j < new_node->layers[layer].num_links; j++) { + float link_dist = hnsw_distance(index, neighbor, + new_node->layers[layer].links[j]); + if (link_dist < dist) { + diversity_failed = 1; + break; + } + } + if (diversity_failed) continue; + } + + /* If potential neighbor node has space, simply add the new link. + * We will have space as well. */ + uint32_t n = neighbor->layers[layer].num_links; + if (n < neighbor->layers[layer].max_links) { + /* Link candidate to new node. */ + neighbor->layers[layer].links[n] = new_node; + neighbor->layers[layer].num_links++; + + /* Update candidate worst link info. */ + hnsw_update_worst_neighbor_on_add(index,neighbor,layer,n,dist); + + /* Link new node to candidate. */ + uint32_t new_links = new_node->layers[layer].num_links; + new_node->layers[layer].links[new_links] = neighbor; + new_node->layers[layer].num_links++; + + /* Update new node worst link info. */ + hnsw_update_worst_neighbor_on_add(index,new_node,layer,new_links,dist); + continue; + } + + /* ==================================================================== + * Replacing existing candidate neighbor link step. + * ================================================================== */ + + /* If we are here, our accepted candidate for linking is full. + * + * If new node is more distant to candidate than its current worst link + * then we skip it: we would not be able to establish a bidirectional + * connection without compromising link quality of candidate. + * + * At aggressiveness > 0 we don't care about this check. */ + if (!aggressive && dist >= neighbor->layers[layer].worst_distance) + continue; + + /* We can add it: we are ready to replace the candidate neighbor worst + * link with the new node, assuming certain conditions are met. */ + hnswNode *worst_node = neighbor->layers[layer].links[neighbor->layers[layer].worst_idx]; + + /* The worst node linked to our candidate may remain too disconnected + * if we remove the candidate node as its link. Let's check if + * this is the case: */ + if (aggressive == 0 && + worst_node->layers[layer].num_links <= index->M/2) + continue; + + /* Aggressive level = 1. It's ok if the node remains with just + * HNSW_M/4 links. */ + else if (aggressive == 1 && + worst_node->layers[layer].num_links <= index->M/4) + continue; + + /* If aggressive is set to 2, then the new node we are adding failed + * to find enough neighbors. We can't insert an almost orphaned new + * node, so let's see if the target node has some other link + * that is well connected in the graph: we could drop it instead + * of the worst link. */ + if (aggressive == 2 && worst_node->layers[layer].num_links <= + index->M/4) + { + /* Let's see if we can find at least a candidate link that + * would remain with a few connections. Track the one + * that is the farthest away (worst distance) from our candidate + * neighbor (in order to remove the less interesting link). */ + worst_node = NULL; + uint32_t worst_idx = 0; + float max_dist = 0; + for (uint32_t j = 0; j < neighbor->layers[layer].num_links; j++) { + hnswNode *to_drop = neighbor->layers[layer].links[j]; + + /* Skip this if it would remain too disconnected as well. + * + * NOTE about index->M/4 min connections requirement: + * + * It is not too strict, since leaving a node with just a + * single link does not just leave it too weakly connected, but + * also sometimes creates cycles with few disconnected + * nodes linked among them. */ + if (to_drop->layers[layer].num_links <= index->M/4) continue; + + float link_dist = hnsw_distance(index, neighbor, to_drop); + if (worst_node == NULL || link_dist > max_dist) { + worst_node = to_drop; + max_dist = link_dist; + worst_idx = j; + } + } + + if (worst_node != NULL) { + /* We found a node that we can drop. Let's pretend this is + * the worst node of the candidate to unify the following + * code path. Later we will fix the worst node info anyway. */ + neighbor->layers[layer].worst_distance = max_dist; + neighbor->layers[layer].worst_idx = worst_idx; + } else { + /* Otherwise we have no other option than reallocating + * the max number of links for this target node, and + * ensure at least a few connections for our new node. */ + uint32_t reallocation_limit = layer == 0 ? + index->M * 3 : index->M *2; + if (neighbor->layers[layer].max_links >= reallocation_limit) + continue; + + uint32_t new_max_links = neighbor->layers[layer].max_links+1; + hnswNode **new_links = hrealloc(neighbor->layers[layer].links, + sizeof(hnswNode*) * new_max_links); + if (new_links == NULL) continue; // Non critical. + + /* Update neighbor's link capacity. */ + neighbor->layers[layer].links = new_links; + neighbor->layers[layer].max_links = new_max_links; + + /* Establish bidirectional link. */ + uint32_t n = neighbor->layers[layer].num_links; + neighbor->layers[layer].links[n] = new_node; + neighbor->layers[layer].num_links++; + hnsw_update_worst_neighbor_on_add(index, neighbor, layer, + n, dist); + + n = new_node->layers[layer].num_links; + new_node->layers[layer].links[n] = neighbor; + new_node->layers[layer].num_links++; + hnsw_update_worst_neighbor_on_add(index, new_node, layer, + n, dist); + continue; + } + } + + // Remove backlink from the worst node of our candidate. + for (uint64_t j = 0; j < worst_node->layers[layer].num_links; j++) { + if (worst_node->layers[layer].links[j] == neighbor) { + memmove(&worst_node->layers[layer].links[j], + &worst_node->layers[layer].links[j+1], + (worst_node->layers[layer].num_links - j - 1) * sizeof(hnswNode*)); + worst_node->layers[layer].num_links--; + hnsw_update_worst_neighbor_on_remove(index,worst_node,layer,j); + break; + } + } + + /* Replace worst link with the new node. */ + neighbor->layers[layer].links[neighbor->layers[layer].worst_idx] = new_node; + + /* Update the worst link in the target node, at this point + * the link that we replaced may no longer be the worst. */ + hnsw_update_worst_neighbor(index,neighbor,layer); + + // Add new node -> candidate link. + uint32_t new_links = new_node->layers[layer].num_links; + new_node->layers[layer].links[new_links] = neighbor; + new_node->layers[layer].num_links++; + + // Update new node worst link. + hnsw_update_worst_neighbor_on_add(index,new_node,layer,new_links,dist); + } +} + +/* This function implements node reconnection after a node deletion in HNSW. + * When a node is deleted, other nodes at the specified layer lose one + * connection (all the neighbors of the deleted node). This function attempts + * to pair such nodes together in a way that maximizes connection quality + * among the M nodes that were former neighbors of our deleted node. + * + * The algorithm works by first building a distance matrix among the nodes: + * + * N0 N1 N2 N3 + * N0 0 1.2 0.4 0.9 + * N1 1.2 0 0.8 0.5 + * N2 0.4 0.8 0 1.1 + * N3 0.9 0.5 1.1 0 + * + * For each potential pairing (i,j) we compute a score that combines: + * 1. The direct cosine distance between the two nodes + * 2. The average distance to other nodes that would no longer be + * available for pairing if we select this pair + * + * We want to balance local node-to-node requirements and global requirements. + * For instance sometimes connecting A with B, while optimal, would leave + * C and D to be connected without other choices, and this could be a very + * bad connection. Maybe instead A and C and B and D are both relatively high + * quality connections. + * + * The formula used to calculate the score of each connection is: + * + * score[i,j] = W1*(2-distance[i,j]) + W2*((new_avg_i + new_avg_j)/2) + * where new_avg_x is the average of distances in row x excluding distance[i,j] + * + * So the score is directly proportional to the SIMILARITY of the two nodes + * and also directly proportional to the DISTANCE of the potential other + * connections that we lost by pairign i,j. So we have a cost for missed + * opportunities, or better, in this case, a reward if the missing + * opportunities are not so good (big average distance). + * + * W1 and W2 are weights (defaults: 0.7 and 0.3) that determine the relative + * importance of immediate connection quality vs future pairing potential. + * + * After the initial pairing phase, any nodes that couldn't be paired + * (due to odd count or existing connections) are handled by searching + * the broader graph using the standard HNSW neighbor selection logic. + */ +void hnsw_reconnect_nodes(HNSW *index, hnswNode **nodes, int count, uint32_t layer) { + if (count <= 0) return; + debugmsg("Reconnecting %d nodes\n", count); + + /* Step 1: Build the distance matrix between all nodes. + * Since distance(i,j) = distance(j,i), we only compute the upper triangle + * and mirror it to the lower triangle. */ + float *distances = hmalloc((unsigned long) count * count * sizeof(float)); + if (!distances) return; + + for (int i = 0; i < count; i++) { + distances[i*count + i] = 0; // Distance to self is 0 + for (int j = i+1; j < count; j++) { + float dist = hnsw_distance(index, nodes[i], nodes[j]); + distances[i*count + j] = dist; // Upper triangle. + distances[j*count + i] = dist; // Lower triangle. + } + } + + /* Step 2: Calculate row averages (will be used in scoring): + * please note that we just calculate row averages and not + * columns averages since the matrix is symmetrical, so those + * are the same: check the image in the top comment if you have any + * doubt about this. */ + float *row_avgs = hmalloc(count * sizeof(float)); + if (!row_avgs) { + hfree(distances); + return; + } + + for (int i = 0; i < count; i++) { + float sum = 0; + int valid_count = 0; + for (int j = 0; j < count; j++) { + if (i != j) { + sum += distances[i*count + j]; + valid_count++; + } + } + row_avgs[i] = valid_count ? sum / valid_count : 0; + } + + /* Step 3: Build scoring matrix. What we do here is to combine how + * good is a given i,j nodes connection, with how badly connecting + * i,j will affect the remaining quality of connections left to + * pair the other nodes. */ + float *scores = hmalloc((unsigned long) count * count * sizeof(float)); + if (!scores) { + hfree(distances); + hfree(row_avgs); + return; + } + + /* Those weights were obtained manually... No guarantee that they + * are optimal. However with these values the algorithm is certain + * better than its greedy version that just attempts to pick the + * best pair each time (verified experimentally). */ + const float W1 = 0.7; // Weight for immediate distance. + const float W2 = 0.3; // Weight for future potential. + + for (int i = 0; i < count; i++) { + for (int j = 0; j < count; j++) { + if (i == j) { + scores[i*count + j] = -1; // Invalid pairing. + continue; + } + + // Check for existing connection between i and j. + int already_linked = 0; + for (uint32_t k = 0; k < nodes[i]->layers[layer].num_links; k++) + { + if (nodes[i]->layers[layer].links[k] == nodes[j]) { + scores[i*count + j] = -1; // Already linked. + already_linked = 1; + break; + } + } + if (already_linked) continue; + + float dist = distances[i*count + j]; + + /* Calculate new averages excluding this pair. + * Handle edge case where we might have too few elements. + * Note that it would be not very smart to recompute the average + * each time scanning the row, we can remove the element + * and adjust the average without it. */ + float new_avg_i = 0, new_avg_j = 0; + if (count > 2) { + new_avg_i = (row_avgs[i] * (count-1) - dist) / (count-2); + new_avg_j = (row_avgs[j] * (count-1) - dist) / (count-2); + } + + /* Final weighted score: the more similar i,j, the better + * the score. The more distant are the pairs we lose by + * connecting i,j, the better the score. */ + scores[i*count + j] = W1*(2-dist) + W2*((new_avg_i + new_avg_j)/2); + } + } + + // Step 5: Pair nodes greedily based on scores. + int *used = hmalloc(count*sizeof(int)); + memset(used,0,count*sizeof(int)); + if (!used) { + hfree(distances); + hfree(row_avgs); + hfree(scores); + return; + } + + /* Scan the matrix looking each time for the potential + * link with the best score. */ + while(1) { + float max_score = -1; + int best_j = -1, best_i = -1; + + // Seek best score i,j values. + for (int i = 0; i < count; i++) { + if (used[i]) continue; // Already connected. + + /* No space left? Not possible after a node deletion but makes + * this function more future-proof. */ + if (nodes[i]->layers[layer].num_links >= + nodes[i]->layers[layer].max_links) continue; + + for (int j = 0; j < count; j++) { + if (i == j) continue; // Same node, skip. + if (used[j]) continue; // Already connected. + float score = scores[i*count + j]; + if (score < 0) continue; // Invalid link. + + /* If the target node has space, and its score is better + * than any other seen so far... remember it is the best. */ + if (score > max_score && + nodes[j]->layers[layer].num_links < + nodes[j]->layers[layer].max_links) + { + // Track the best connection found so far. + max_score = score; + best_j = j; + best_i = i; + } + } + } + + // Possible link found? Connect i and j. + if (best_j != -1) { + debugmsg("[%d] linking %d with %d: %f\n", layer, (int)best_i, (int)best_j, max_score); + // Link i -> j. + int link_idx = nodes[best_i]->layers[layer].num_links; + nodes[best_i]->layers[layer].links[link_idx] = nodes[best_j]; + nodes[best_i]->layers[layer].num_links++; + + // Update worst distance if needed. + float dist = distances[best_i*count + best_j]; + hnsw_update_worst_neighbor_on_add(index,nodes[best_i],layer,link_idx,dist); + + // Link j -> i. + link_idx = nodes[best_j]->layers[layer].num_links; + nodes[best_j]->layers[layer].links[link_idx] = nodes[best_i]; + nodes[best_j]->layers[layer].num_links++; + + // Update worst distance if needed. + hnsw_update_worst_neighbor_on_add(index,nodes[best_j],layer,link_idx,dist); + + // Mark connection as used. + used[best_i] = used[best_j] = 1; + } else { + break; // No more valid connections available. + } + } + + /* Step 6: Handle remaining unpaired nodes using the standard HNSW + * neighbor selection. */ + for (int i = 0; i < count; i++) { + if (used[i]) continue; + + // Skip if node is already at max connections. + if (nodes[i]->layers[layer].num_links >= + nodes[i]->layers[layer].max_links) + continue; + + debugmsg("[%d] Force linking %d\n", layer, i); + + /* First, try with local nodes as candidates. + * Some candidate may have space. */ + pqueue *candidates = pq_new(count); + if (!candidates) continue; + + /* Add all the local nodes having some space as candidates + * to be linked with this node. */ + for (int j = 0; j < count; j++) { + if (i != j && // Must not be itself. + nodes[j]->layers[layer].num_links < // Must not be full. + nodes[j]->layers[layer].max_links) + { + float dist = distances[i*count + j]; + pq_push(candidates, nodes[j], dist); + } + } + + /* Try local candidates first with aggressive = 1. + * So we will link only if there is space. + * We want one link more than the links we already have. */ + uint32_t wanted_links = nodes[i]->layers[layer].num_links+1; + if (candidates->count > 0) { + select_neighbors(index, candidates, nodes[i], layer, + wanted_links, 1); + debugmsg("Final links after attempt with local nodes: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links); + } + + // If still no connection, search the broader graph. + if (nodes[i]->layers[layer].num_links != wanted_links) { + debugmsg("No force linking possible with local candidates\n"); + pq_free(candidates); + + // Find entry point for target layer by descending through levels. + hnswNode *curr_ep = index->enter_point; + for (uint32_t lc = index->max_level; lc > layer; lc--) { + pqueue *results = search_layer(index, nodes[i], curr_ep, 1, lc, 0); + if (results) { + if (results->count > 0) { + curr_ep = pq_get_node(results,0); + } + pq_free(results); + } + } + + if (curr_ep) { + /* Search this layer for candidates. + * Use the default EF_C in this case, since it's not an + * "insert" operation, and we don't know the user + * specified "EF". */ + candidates = search_layer(index, nodes[i], curr_ep, HNSW_EF_C, layer, 0); + if (candidates) { + /* Try to connect with aggressiveness proportional to the + * node linking condition. */ + int aggressiveness = + (nodes[i]->layers[layer].num_links > index->M / 2) + ? 1 : 2; + select_neighbors(index, candidates, nodes[i], layer, + wanted_links, aggressiveness); + debugmsg("Final links with broader search: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links); + pq_free(candidates); + } + } + } else { + pq_free(candidates); + } + } + + // Cleanup. + hfree(distances); + hfree(row_avgs); + hfree(scores); + hfree(used); +} + +/* This is an helper function in order to support node deletion. + * It's goal is just to: + * + * 1. Remove the node from the bidirectional links of neighbors in the graph. + * 2. Remove the node from the linked list of nodes. + * 3. Fix the entry point in the graph. We just select one of the neighbors + * of the deleted node at a lower level. If none is found, we do + * a full scan. + * 4. The node itself amd its aux value field are NOT freed. It's up to the + * caller to do it, by using hnsw_node_free(). + * 5. The node associated value (node->value) is NOT freed. + * + * Why this function will not free the node? Because in node updates it + * could be a good idea to reuse the node allocation for different reasons + * (currently not implemented). + * In general it is more future-proof to be able to reuse the node if + * needed. Right now this library reuses the node only when links are + * not touched (see hnsw_update() for more information). */ +void hnsw_unlink_node(HNSW *index, hnswNode *node) { + if (!index || !node) return; + + index->version++; // This node may be missing in an already compiled list + // of neighbors. Make optimistic concurrent inserts fail. + + /* Remove all bidirectional links at each level. + * Note that in this implementation all the + * links are guaranteed to be bedirectional. */ + + /* For each level of the deleted node... */ + for (uint32_t level = 0; level <= node->level; level++) { + /* For each linked node of the deleted node... */ + for (uint32_t i = 0; i < node->layers[level].num_links; i++) { + hnswNode *linked = node->layers[level].links[i]; + /* Find and remove the backlink in the linked node */ + for (uint32_t j = 0; j < linked->layers[level].num_links; j++) { + if (linked->layers[level].links[j] == node) { + /* Remove by shifting remaining links left */ + memmove(&linked->layers[level].links[j], + &linked->layers[level].links[j + 1], + (linked->layers[level].num_links - j - 1) * sizeof(hnswNode*)); + linked->layers[level].num_links--; + hnsw_update_worst_neighbor_on_remove(index,linked,level,j); + break; + } + } + } + } + + /* Update cursors pointing at this element. */ + if (index->cursors) hnsw_cursor_element_deleted(index,node); + + /* Update the previous node's next pointer. */ + if (node->prev) { + node->prev->next = node->next; + } else { + /* If there's no previous node, this is the head. */ + index->head = node->next; + } + + /* Update the next node's prev pointer. */ + if (node->next) node->next->prev = node->prev; + + /* Update node count. */ + index->node_count--; + + /* If this node was the enter_point, we need to update it. */ + if (node == index->enter_point) { + /* Reset entry point - we'll find a new one (unless the HNSW is + * now empty) */ + index->enter_point = NULL; + index->max_level = 0; + + /* Step 1: Try to find a replacement by scanning levels + * from top to bottom. Under normal conditions, if there is + * any other node at the same level, we have a link. Anyway + * we descend levels to find any neighbor at the higher level + * possible. */ + for (int level = node->level; level >= 0; level--) { + if (node->layers[level].num_links > 0) { + index->enter_point = node->layers[level].links[0]; + break; + } + } + + /* Step 2: If no links were found at any level, do a full scan. + * This should never happen in practice if the HNSW is not + * empty. */ + if (!index->enter_point) { + uint32_t new_max_level = 0; + hnswNode *current = index->head; + + while (current) { + if (current != node && current->level >= new_max_level) { + new_max_level = current->level; + index->enter_point = current; + } + current = current->next; + } + } + + /* Update max_level. */ + if (index->enter_point) + index->max_level = index->enter_point->level; + } + + /* Clear the node's links but don't free the node itself */ + node->prev = node->next = NULL; +} + +/* Higher level API for hnsw_unlink_node() + hnsw_reconnect_nodes() actual work. + * This will get the write lock, will delete the node, free it, + * reconnect the node neighbors among themselves, and unlock again. + * If free_value function pointer is not NULL, then the function provided is + * used to free node->value. + * + * The function returns 0 on error (inability to acquire the lock), otherwise + * 1 is returned. */ +int hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)) { + if (pthread_rwlock_wrlock(&index->global_lock) != 0) return 0; + hnsw_unlink_node(index,node); + if (free_value && node->value) free_value(node->value); + + /* Relink all the nodes orphaned of this node link. + * Do it for all the levels. */ + for (unsigned int j = 0; j <= node->level; j++) { + hnsw_reconnect_nodes(index, node->layers[j].links, + node->layers[j].num_links, j); + } + hnsw_node_free(node); + pthread_rwlock_unlock(&index->global_lock); + return 1; +} + +/* ============================ Threaded API ================================ + * Concurrent readers should use the following API to get a slot assigned + * (and a lock, too), do their read-only call, and unlock the slot. + * + * There is a reason why read operations don't implement opaque transparent + * locking directly on behalf of the user: when we return a result set + * with hnsw_search(), we report a set of nodes. The caller will do something + * with the nodes and the associated values, so the unlocking of the + * slot should happen AFTER the result was already used, otherwise we may + * have changes to the HNSW nodes as the result is being accessed. */ + +/* Try to acquire a read slot. Returns the slot number (0 to HNSW_MAX_THREADS-1) + * on success, -1 on error (pthread mutex errors). */ +int hnsw_acquire_read_slot(HNSW *index) { + /* First try a non-blocking approach on all slots. */ + for (uint32_t i = 0; i < HNSW_MAX_THREADS; i++) { + if (pthread_mutex_trylock(&index->slot_locks[i]) == 0) { + if (pthread_rwlock_rdlock(&index->global_lock) != 0) { + pthread_mutex_unlock(&index->slot_locks[i]); + return -1; + } + return i; + } + } + + /* All trylock attempts failed, use atomic increment to select slot. */ + uint32_t slot = index->next_slot++ % HNSW_MAX_THREADS; + + /* Try to lock the selected slot. */ + if (pthread_mutex_lock(&index->slot_locks[slot]) != 0) return -1; + + /* Get read lock. */ + if (pthread_rwlock_rdlock(&index->global_lock) != 0) { + pthread_mutex_unlock(&index->slot_locks[slot]); + return -1; + } + + return slot; +} + +/* Release a previously acquired read slot: note that it is important that + * nodes returned by hnsw_search() are accessed while the read lock is + * still active, to be sure that nodes are not freed. */ +void hnsw_release_read_slot(HNSW *index, int slot) { + if (slot < 0 || slot >= HNSW_MAX_THREADS) return; + pthread_rwlock_unlock(&index->global_lock); + pthread_mutex_unlock(&index->slot_locks[slot]); +} + +/* ============================ Nodes insertion ============================= + * We have an optimistic API separating the read-only candidates search + * and the write side (actual node insertion). We internally also use + * this API to provide the plain hnsw_insert() function for code unification. */ + +struct InsertContext { + pqueue *level_queues[HNSW_MAX_LEVEL]; /* Candidates for each level. */ + hnswNode *node; /* Pre-allocated node ready for insertion */ + uint64_t version; /* Index version at preparation time. This is used + * for CAS-like locking during change commit. */ +}; + +/* Optimistic insertion API. + * + * WARNING: Note that this is an internal function: users should call + * hnsw_prepare_insert() instead. + * + * This is how it works: you use hnsw_prepare_insert() and it will return + * a context where good candidate neighbors are already pre-selected. + * This step only uses read locks. + * + * Then finally you try to actually commit the new node with + * hnsw_try_commit_insert(): this time we will require a write lock, but + * for less time than it would be otherwise needed if using directly + * hnsw_insert(). When you try to commit the write, if no node was deleted in + * the meantime, your operation will succeed, otherwise it will fail, and + * you should try to just use the hnsw_insert() API, since there is + * contention. + * + * See hnsw_node_new() for information about 'vector' and 'qvector' + * arguments, and which one to pass. */ +InsertContext *hnsw_prepare_insert_nolock(HNSW *index, const float *vector, + const int8_t *qvector, float qrange, uint64_t id, + int slot, int ef) +{ + InsertContext *ctx = hmalloc(sizeof(*ctx)); + if (!ctx) return NULL; + + memset(ctx, 0, sizeof(*ctx)); + ctx->version = index->version; + + /* Crete a new node that we may be able to insert into the + * graph later, when calling the commit function. */ + uint32_t level = random_level(); + ctx->node = hnsw_node_new(index, id, vector, qvector, qrange, level, 1); + if (!ctx->node) { + hfree(ctx); + return NULL; + } + + hnswNode *curr_ep = index->enter_point; + + /* Empty graph, no need to collect candidates. */ + if (curr_ep == NULL) return ctx; + + /* Phase 1: Find good entry point on the highest level of the new + * node we are going to insert. */ + for (unsigned int lc = index->max_level; lc > level; lc--) { + pqueue *results = search_layer(index, ctx->node, curr_ep, 1, lc, slot); + + if (results) { + if (results->count > 0) curr_ep = pq_get_node(results,0); + pq_free(results); + } + } + + /* Phase 2: Collect a set of potential connections for each layer of + * the new node. */ + for (int lc = MIN(level, index->max_level); lc >= 0; lc--) { + pqueue *candidates = + search_layer(index, ctx->node, curr_ep, ef, lc, slot); + + if (!candidates) continue; + curr_ep = (candidates->count > 0) ? pq_get_node(candidates,0) : curr_ep; + ctx->level_queues[lc] = candidates; + } + + return ctx; +} + +/* External API for hnsw_prepare_insert_nolock(), handling locking. */ +InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector, + const int8_t *qvector, float qrange, uint64_t id, + int ef) +{ + InsertContext *ctx; + int slot = hnsw_acquire_read_slot(index); + ctx = hnsw_prepare_insert_nolock(index,vector,qvector,qrange,id,slot,ef); + hnsw_release_read_slot(index,slot); + return ctx; +} + +/* Free an insert context and all its resources. */ +void hnsw_free_insert_context(InsertContext *ctx) { + if (!ctx) return; + for (uint32_t i = 0; i < HNSW_MAX_LEVEL; i++) { + if (ctx->level_queues[i]) pq_free(ctx->level_queues[i]); + } + if (ctx->node) hnsw_node_free(ctx->node); + hfree(ctx); +} + +/* Commit a prepared insert operation. This function is a low level API that + * should not be called by the user. See instead hnsw_try_commit_insert(), that + * will perform the CAS check and acquire the write lock. + * + * See the top comment in hnsw_prepare_insert() for more information + * on the optimistic insertion API. + * + * This function can't fail and always returns the pointer to the + * just inserted node. Out of memory is not possible since no critical + * allocation is never performed in this code path: we populate links + * on already allocated nodes. */ +hnswNode *hnsw_commit_insert_nolock(HNSW *index, InsertContext *ctx, void *value) { + hnswNode *node = ctx->node; + node->value = value; + + /* Handle first node case. */ + if (index->enter_point == NULL) { + index->version++; // First node, make concurrent inserts fail. + index->enter_point = node; + index->max_level = node->level; + hnsw_add_node(index, node); + ctx->node = NULL; // So hnsw_free_insert_context() will not free it. + hnsw_free_insert_context(ctx); + return node; + } + + /* Connect the node with near neighbors at each level. */ + for (int lc = MIN(node->level,index->max_level); lc >= 0; lc--) { + if (ctx->level_queues[lc] == NULL) continue; + + /* Try to provide index->M connections to our node. The call + * is not guaranteed to be able to provide all the links we would + * like to have for the new node: they must be bi-directional, obey + * certain quality checks, and so forth, so later there are further + * calls to force the hand a bit if needed. + * + * Let's start with aggressiveness = 0. */ + select_neighbors(index, ctx->level_queues[lc], node, lc, index->M, 0); + + /* Layer 0 and too few connections? Let's be more aggressive. */ + if (lc == 0 && node->layers[0].num_links < index->M/2) { + select_neighbors(index, ctx->level_queues[lc], node, lc, + index->M, 1); + + /* Still too few connections? Let's go to + * aggressiveness level '2' in linking strategy. */ + if (node->layers[0].num_links < index->M/4) { + select_neighbors(index, ctx->level_queues[lc], node, lc, + index->M/4, 2); + } + } + } + + /* If new node level is higher than current max, update entry point. */ + if (node->level > index->max_level) { + index->version++; // Entry point changed, make concurrent inserts fail. + index->enter_point = node; + index->max_level = node->level; + } + + /* Add node to the linked list. */ + hnsw_add_node(index, node); + ctx->node = NULL; // So hnsw_free_insert_context() will not free the node. + hnsw_free_insert_context(ctx); + return node; +} + +/* If the context obtained with hnsw_prepare_insert() is still valid + * (nodes not deleted in the meantime) then add the new node to the HNSW + * index and return its pointer. Otherwise NULL is returned and the operation + * should be either performed with the blocking API hnsw_insert() or attempted + * again. */ +hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx, void *value) { + /* Check if the version changed since preparation. Note that we + * should access index->version under the write lock in order to + * be sure we can safely commit the write: this is just a fast-path + * in order to return ASAP without acquiring the write lock in case + * the version changed. */ + if (ctx->version != index->version) { + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Try to acquire write lock. */ + if (pthread_rwlock_wrlock(&index->global_lock) != 0) { + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Check version again under write lock. */ + if (ctx->version != index->version) { + pthread_rwlock_unlock(&index->global_lock); + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Commit the change: note that it's up to hnsw_commit_insert_nolock() + * to free the insertion context. */ + hnswNode *node = hnsw_commit_insert_nolock(index, ctx, value); + + /* Release the write lock. */ + pthread_rwlock_unlock(&index->global_lock); + return node; +} + +/* Insert a new element into the graph. + * See hnsw_node_new() for information about 'vector' and 'qvector' + * arguments, and which one to pass. + * + * Return NULL on out of memory during insert. Otherwise the newly + * inserted node pointer is returned. */ +hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, void *value, int ef) { + /* Write lock. We acquire the write lock even for the prepare() + * operation (that is a read-only operation) since we want this function + * to don't fail in the check-and-set stage of commit(). + * + * Basically here we are using the optimistic API in a non-optimistinc + * way in order to have a single insertion code in the implementation. */ + if (pthread_rwlock_wrlock(&index->global_lock) != 0) return NULL; + + // Prepare the insertion - note we pass slot 0 since we're single threaded. + InsertContext *ctx = hnsw_prepare_insert_nolock(index, vector, qvector, + qrange, id, 0, ef); + if (!ctx) { + pthread_rwlock_unlock(&index->global_lock); + return NULL; + } + + // Commit the prepared insertion without version checking. + hnswNode *node = hnsw_commit_insert_nolock(index, ctx, value); + + // Release write lock and return our node pointer. + pthread_rwlock_unlock(&index->global_lock); + return node; +} + +/* Helper function for qsort call in hnsw_should_reuse_node(). */ +static int compare_floats(const float *a, const float *b) { + if (*a < *b) return 1; + if (*a > *b) return -1; + return 0; +} + +/* This function determines if a node can be reused with a new vector by: + * + * 1. Computing average of worst 25% of current distances. + * 2. Checking if at least 50% of new distances stay below this threshold. + * 3. Requiring a minimum number of links for the check to be meaningful. + * + * This check is useful when we want to just update a node that already + * exists in the graph. Often the new vector is a learned embedding generated + * by some model, and the embedding represents some document that perhaps + * changed just slightly compared to the past, so the new embedding will + * be very nearby. We need to find a way do determine if the current node + * neighbors (practically speaking its location in the grapb) are good + * enough even with the new vector. + * + * XXX: this function needs improvements: successive updates to the same + * node with more and more distant vectors will make the node drift away + * from its neighbors. One of the additional metrics used could be + * neighbor-to-neighbor distance, that represents a more absolute check + * of fit for the new vector. */ +int hnsw_should_reuse_node(HNSW *index, hnswNode *node, int is_normalized, const float *new_vector) { + /* Step 1: Not enough links? Advice to avoid reuse. */ + const uint32_t min_links_for_reuse = 4; + uint32_t layer0_connections = node->layers[0].num_links; + if (layer0_connections < min_links_for_reuse) return 0; + + /* Step2: get all current distances and run our heuristic. */ + float *old_distances = hmalloc(sizeof(float) * layer0_connections); + if (!old_distances) return 0; + + // Temporary node with the new vector, to simplify the next logic. + hnswNode tmp_node; + if (hnsw_init_tmp_node(index,&tmp_node,is_normalized,new_vector) == 0) { + hfree(old_distances); + return 0; + } + + /* Get old dinstances and sort them to access the 25% worst + * (bigger) ones. */ + for (uint32_t i = 0; i < layer0_connections; i++) { + old_distances[i] = hnsw_distance(index, node, node->layers[0].links[i]); + } + qsort(old_distances, layer0_connections, sizeof(float), + (int (*)(const void*, const void*))(&compare_floats)); + + uint32_t count = (layer0_connections+3)/4; // 25% approx to larger int. + if (count > layer0_connections) count = layer0_connections; // Futureproof. + float worst_avg = 0; + + // Compute average of 25% worst dinstances. + for (uint32_t i = 0; i < count; i++) worst_avg += old_distances[i]; + worst_avg /= count; + hfree(old_distances); + + // Count how many new distances stay below the threshold. + uint32_t good_distances = 0; + for (uint32_t i = 0; i < layer0_connections; i++) { + float new_dist = hnsw_distance(index, &tmp_node, node->layers[0].links[i]); + if (new_dist <= worst_avg) good_distances++; + } + hnsw_free_tmp_node(&tmp_node,new_vector); + + /* At least 50% of the nodes should pass our quality test, for the + * node to be reused. */ + return good_distances >= layer0_connections/2; +} + +/** + * Return a random node from the HNSW graph. + * + * This function performs a random walk starting from the entry point, + * using only level 0 connections for navigation. It uses log^2(N) steps + * to ensure proper mixing time. + */ + +hnswNode *hnsw_random_node(HNSW *index, int slot) { + if (index->node_count == 0 || index->enter_point == NULL) + return NULL; + + (void)slot; // Unused, but we need the caller to acquire the lock. + + /* First phase: descend from max level to level 0 taking random paths. + * Note that we don't need a more conservative log^2(N) steps for + * proper mixing, since we already descend to a random cluster here. */ + hnswNode *current = index->enter_point; + for (uint32_t level = index->max_level; level > 0; level--) { + /* If current node doesn't have this level or no links, continue + * to lower level. */ + if (current->level < level || current->layers[level].num_links == 0) + continue; + + /* Choose random neighbor at this level. */ + uint32_t rand_neighbor = rand() % current->layers[level].num_links; + current = current->layers[level].links[rand_neighbor]; + } + + /* Second phase: at level 0, take log(N) * c random steps. */ + const int c = 3; // Multiplier for more thorough exploration. + double logN = log2(index->node_count + 1); + uint32_t num_walks = (uint32_t)(logN * c); + + /* Avoid the ping-pong effect: imagine there are just two nodes and + * the number of walks selected is even. We will select always the + * first element of the graph; conversely, if it is odd, we will always + * select the other element. One way to add more selection randomness is + * to randomly add '1' or '0' to the number of walks to perform. */ + num_walks += rand() & 1; + + // Perform random walk at level 0. + for (uint32_t i = 0; i < num_walks; i++) { + if (current->layers[0].num_links == 0) return current; + + // Choose random neighbor. + uint32_t rand_neighbor = rand() % current->layers[0].num_links; + current = current->layers[0].links[rand_neighbor]; + } + return current; +} + +/* ============================= Serialization ============================== + * + * TO SERIALIZE + * ============ + * + * To serialize on disk, you need to persist the vector dimension, number + * of elements, and the quantization type index->quant_type. These are + * global values for the whole index. + * + * Then, to serialize each node: + * + * call hnsw_serialize_node() with each node you find in the linked list + * of nodes, starting at index->head (each node has a next pointer). + * The function will return an hnswSerNode structure, you will need + * to store the following on disk (for each node): + * + * - The sernode->vector data, that is sernode->vector_size bytes. + * - The sernode->params array, that points to an array of uint64_t + * integers. There are sernode->params_count total items. These + * parameters contain everything there is to need about your node: how + * many levels it has, its ID, the list of neighbors for each level (as node + * IDs), and so forth. + * + * You need to to save your own node->value in some way as well, but it already + * belongs to the user of the API, since, for this library, it's just a pointer, + * so the user should know how to serialized its private data. + * + * RELOADING FROM DISK / NET + * ========================= + * + * When reloading nodes, you first load the index vector dimension and + * quantization type, and create the index with: + * + * HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type); + * + * Then you load back, for each node (you stored how many nodes you had) + * the vector and the params array / count. + * You also load the value associated with your node. + * + * At this point you add back the loaded elements into the index with: + * + * hnsw_insert_serialized(HNSW *index, void *vector, uint64_t params, + * uint32_t params_len, void *value); + * + * Once you added all the nodes back, you need to resolve the pointers + * (since so far they are added just with the node IDs as reference), so + * you call: + * + * hnsw_deserialize_index(index); + * + * The index is now ready to be used like if it has been always in memory. + * + * DESIGN NOTES + * ============ + * + * Why this API does not just give you a binary blob to save? Because in + * many systems (and in Redis itself) to save integers / floats can have + * more interesting encodings that just storing a 64 bit value. Many vector + * indexes will be small, and their IDs will be small numbers, so the storage + * system can exploit that and use less disk space, less network bandwidth + * and so forth. + * + * How is the data stored in these arrays of numbers? Oh well, we have + * things that are obviously numbers like node ID, number of levels for the + * node and so forth. Also each of our nodes have an unique incremental ID, + * so we can store a node set of links in terms of linked node IDs. This + * data is put directly in the loaded node pointer space! We just cast the + * integer to the pointer (so THIS IS NOT SAFE for 32 bit systems). Then + * we want to translate such IDs into pointers. To do that, we build an + * hash table, then scan all the nodes again and fix all the links converting + * the ID to the pointer. */ + +/* History of serialization versions: + * version 0: the first implementation, lacking worst node id/info. + * version 1: includes worst link id/info. */ +#define HNSW_SERIALIZATION_VERSION 1 + +/* This is a special worst link index that is set when loading a serialized + * node with version 0 (this version of the serialization lacked explicit + * information about the worst link index/distance). This way, later, the + * function that fixes a deserialized index will know to compute the worst + * index info at runtime. */ +#define HNSW_SER_WORSTLINK_MISSING UINT32_MAX + +/* Return the serialized node information as specified in the top comment + * above. Note that the returned information is true as long as the node + * provided is not deleted or modified, so this function should be called + * when there are no concurrent writes. + * + * The function hnsw_serialize_node() should be called in order to + * free the result of this function. */ +hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node) { + /* The first step is calculating the number of uint64_t parameters + * that we need in order to serialize the node. */ + uint32_t num_params = 0; + num_params += 2; // node ID, number of layers. + for (uint32_t i = 0; i <= node->level; i++) { + num_params += 2; // max_links and num_links info for this layer. + num_params += node->layers[i].num_links; // The IDs of linked nodes. + num_params += 1; // worst link id/distance parameter. + } + + /* We use another 64bit value to store two floats that are about + * the vector: l2 and quantization range (that is only used if the + * vector is quantized). */ + num_params++; + + /* Allocate the return object and the parameters array. */ + hnswSerNode *sn = hmalloc(sizeof(hnswSerNode)); + if (sn == NULL) return NULL; + sn->params = hmalloc(sizeof(uint64_t)*num_params); + if (sn->params == NULL) { + hfree(sn); + return NULL; + } + + /* Fill data. */ + sn->params_count = num_params; + sn->vector = node->vector; + sn->vector_size = hnsw_quants_bytes(index); + + uint32_t param_idx = 0; + sn->params[param_idx++] = node->id; + /* The second parameter contains information about the serialization + * version of this node, the node level and some unused field: + * + * +--------+--------+--------+--------+ + * |VVVVVVVV|........|........|LLLLLLLL| + * +--------+--------+--------+--------+ + * + * V is the version, 8 bits. + * L is the node level, 8 bits (but actually 16 is the max so far). + * The middle two bytes are reserved for future uses. */ + sn->params[param_idx] = node->level & 0xff; + sn->params[param_idx] |= HNSW_SERIALIZATION_VERSION << 24; + param_idx++; + for (uint32_t i = 0; i <= node->level; i++) { + sn->params[param_idx++] = node->layers[i].num_links; + sn->params[param_idx++] = node->layers[i].max_links; + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + sn->params[param_idx++] = node->layers[i].links[j]->id; + } + /* Since version 1: pack and store worst_idx and worst_distance. */ + uint32_t worst_distance_bits; + memcpy(&worst_distance_bits, &node->layers[i].worst_distance, + sizeof(float)); + uint64_t wi = + (((uint64_t)worst_distance_bits) << 32) | node->layers[i].worst_idx; + sn->params[param_idx++] = wi; + } + + /* Store l2 and range as uint32_t, in a way that is endian-safe. + * Note that in big endian archs both are reversed: integers and + * also the bytes of floats, so they will match. */ + uint64_t l2_and_range; + uint32_t l2_bits, range_bits; + memcpy(&l2_bits,&node->l2,sizeof(float)); + memcpy(&range_bits,&node->quants_range,sizeof(float)); + l2_and_range = ((uint64_t)range_bits<<32) | l2_bits; + + sn->params[param_idx++] = l2_and_range; + + /* Better safe than sorry: */ + assert(param_idx == num_params); + return sn; +} + +/* This is needed in order to free hnsw_serialize_node() returned + * structure. */ +void hnsw_free_serialized_node(hnswSerNode *sn) { + hfree(sn->params); + hfree(sn); +} + +/* Load a serialized node. See the top comment in this section of code + * for the documentation about how to use this. + * + * The function returns NULL both on out of memory and if the remaining + * parameters length does not match the number of links or other items + * to load. */ +hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value) +{ + if (params_len < 2) return NULL; + + uint64_t id = params[0]; + /* Check the node serialization function for the specific layout + * of param[1] fields. */ + uint32_t level = params[1] & 0xff; // Node level. + uint32_t version = (params[1] & 0xff000000) >> 24; // Format version. + + if (version > HNSW_SERIALIZATION_VERSION) return NULL; + int has_worst_link_info = version > 0; + + /* Keep track of maximum ID seen while loading. */ + if (id >= index->last_id) index->last_id = id; + + /* Create node, passing vector data directly based on quantization type. */ + hnswNode *node; + if (index->quant_type != HNSW_QUANT_NONE) { + node = hnsw_node_new(index, id, NULL, vector, 0, level, 0); + } else { + node = hnsw_node_new(index, id, vector, NULL, 0, level, 0); + } + if (!node) return NULL; + + /* Load params array into the node. */ + uint32_t param_idx = 2; + for (uint32_t i = 0; i <= level; i++) { + /* Sanity check. */ + if (param_idx + 2 + has_worst_link_info > params_len) { + hnsw_node_free(node); + return NULL; + } + + uint32_t num_links = params[param_idx++]; + uint32_t max_links = params[param_idx++]; + + /* Sanity check: links should be less than max links and + * in general a reasonable amount. */ + if (num_links > max_links || max_links > HNSW_MAX_M*4) { + hnsw_node_free(node); + return NULL; + } + + /* If max_links is larger than current allocation, reallocate. + * It could happen in select_neighbors() that we over-allocate the + * node under very unlikely to happen conditions. */ + if (max_links > node->layers[i].max_links) { + hnswNode **new_links = hrealloc(node->layers[i].links, + sizeof(hnswNode*) * max_links); + if (!new_links) { + hnsw_node_free(node); + return NULL; + } + node->layers[i].links = new_links; + node->layers[i].max_links = max_links; + } + node->layers[i].num_links = num_links; + + /* Sanity check. */ + if (param_idx + num_links + has_worst_link_info > params_len) { + hnsw_node_free(node); + return NULL; + } + + /* Fill links for this layer with the IDs. Note that this + * is going to not work in 32 bit systems. Deleting / adding-back + * nodes can produce IDs larger than 2^32-1 even if we can't never + * fit more than 2^32 nodes in a 32 bit system. */ + for (uint32_t j = 0; j < num_links; j++) + node->layers[i].links[j] = (hnswNode*)params[param_idx++]; + + if (has_worst_link_info) { + uint64_t wi = params[param_idx++]; + uint32_t worst_idx = wi & 0xffffffff; + uint32_t worst_distance_bits = wi >> 32; + float worst_distance; + memcpy(&worst_distance,&worst_distance_bits,sizeof(float)); + node->layers[i].worst_idx = worst_idx; + node->layers[i].worst_distance = worst_distance; + + // Sanity check the worst ID range. + if (node->layers[i].num_links > 0 && + node->layers[i].worst_idx >= node->layers[i].num_links) + { + hnsw_node_free(node); + return NULL; + } + } else { + node->layers[i].worst_idx = HNSW_SER_WORSTLINK_MISSING; + node->layers[i].worst_distance = 0; + } + } + + /* Get l2 and quantization range. */ + if (param_idx >= params_len) { + hnsw_node_free(node); + return NULL; + } + + /* Load l2 and range packed into an uint64_t in an endian safe way. */ + uint64_t l2_and_range = params[param_idx]; + uint32_t l2_bits, range_bits; + l2_bits = l2_and_range & 0xffffffff; + range_bits = l2_and_range >> 32; + memcpy(&node->l2, &l2_bits, sizeof(float)); + memcpy(&node->quants_range, &range_bits, sizeof(float)); + + node->value = value; + hnsw_add_node(index, node); + + /* Keep track of higher node level and set the entry point to the + * greatest level node seen so far: thanks to this check we don't + * need to remember what our entry point was during serialization. */ + if (index->enter_point == NULL || level > index->max_level) { + index->max_level = level; + index->enter_point = node; + } + return node; +} + +/* Integer hashing, used by hnsw_deserialize_index(). + * MurmurHash3's 64-bit finalizer function. */ +uint64_t hnsw_hash_node_id(uint64_t id) { + id ^= id >> 33; + id *= 0xff51afd7ed558ccd; + id ^= id >> 33; + id *= 0xc4ceb9fe1a85ec53; + id ^= id >> 33; + return id; +} + +/* Helper for duplicated link detection in hnsw_deserialize_index(). */ +static int qsort_compare_pointers(const void *aptr, const void *bptr) { + uintptr_t a = *((uintptr_t*)aptr); + uintptr_t b = *((uintptr_t*)bptr); + if (a > b) return 1; + if (a < b) return -1; + return 0; +} + +/* Fix pointers of neighbors nodes: after loading the serialized nodes, the + * neighbors links are just IDs (casted to pointers), instead of the actual + * pointers. We need to resolve IDs into pointers. + * + * The two integers salt0 and salt1 are used to make the internal state + * of the function unguessable to an external attacker, in order to protect + * from corruptions. Show be two random numbers from /dev/urandom if possible + * otherwise can be just 0,0 if the application is not security critical and + * never processes untrusted inputs. + * + * Return 0 on error (out of memory or some ID that can't be resolved), 1 on + * success. */ +int hnsw_deserialize_index(HNSW *index, uint64_t salt0, uint64_t salt1) { + /* We will use simple linear probing, so over-allocating is a good + * idea: anyway this flat array of pointers will consume a fraction + * of the memory of the loaded index. */ + uint64_t min_size = index->node_count*2; + uint64_t table_size = 1; + while(table_size < min_size) table_size <<= 1; + + hnswNode **table = hmalloc(sizeof(hnswNode*) * table_size); + if (table == NULL) return 0; + memset(table,0,sizeof(hnswNode*) * table_size); + + /* First pass: populate the ID -> pointer hash table. */ + hnswNode *node = index->head; + while(node) { + uint64_t bucket = hnsw_hash_node_id(node->id) & (table_size-1); + for (uint64_t j = 0; j < table_size; j++) { + if (table[bucket] == NULL) { + table[bucket] = node; + break; + } + bucket = (bucket+1) & (table_size-1); + } + node = node->next; + } + + /* Second pass: fix pointers of all the neighbors links. + * As we scan and fix the links, we also compute the accumulator + * register "reciprocal", that is used in order to guarantee that all + * the links are reciprocal. + * + * This is how it works, we hash (using a strong hash function) the + * following key for each link that we see from A to B (or vice versa): + * + * hash(salt || A || B || link-level) + * + * We always sort A and B, so the same link from A to B and from B to A + * will hash the same. The we xor the result into the 128 bit accumulator. + * If each link has its own backlink, the accumulator is guaranteed to + * be zero at the end. + * + * Collisions are extremely unlikely to happen, and an external attacker + * can't easily control the hash function output, since the salt is + * unknown, and also there would be to control the pointers. + * + * This algorithm is O(1) for each node so it is basically free for + * us, as we scan the list of nodes, and runs on constant and very + * small memory. */ + uint64_t accumulator[2] = {0,0}; + + node = index->head; // Rewind. + while(node) { + uint64_t this_node_id = node->id; + for (uint32_t i = 0; i <= node->level; i++) { + // Check if there are duplicated links: those are + // also corruptions of the on-disk serialization format. + if (node->layers[i].num_links > 0) { + qsort(node->layers[i].links, node->layers[i].num_links, + sizeof(void*), qsort_compare_pointers); + for (uint32_t j = 0; j < node->layers[i].num_links-1; j++) { + if (node->layers[i].links[j] == node->layers[i].links[j+1]) + goto corrupted; + } + } + + // Resolve pointers. + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + uint64_t linked_id = (uint64_t) node->layers[i].links[j]; + + // We can't link to our own node. + if (linked_id == this_node_id) goto corrupted; + + // Compute accumulator for reciprocal links check. + uint64_t mixed_h1, mixed_h2; + secure_pair_mixer_128(salt0, salt1, this_node_id, linked_id, (uint64_t)i, &mixed_h1, &mixed_h2); + + accumulator[0] ^= mixed_h1; + accumulator[1] ^= mixed_h2; + + // Fix links. + uint64_t bucket = hnsw_hash_node_id(linked_id) & (table_size-1); + hnswNode *neighbor = NULL; + for (uint64_t k = 0; k < table_size; k++) { + if (table[bucket] && table[bucket]->id == linked_id) { + neighbor = table[bucket]; + break; + } + bucket = (bucket+1) & (table_size-1); + } + + /* The neighbor must exist and also exist at the right + * level. */ + if (neighbor == NULL || neighbor->level < i) { + /* Unresolved link! Either a bug in this code + * or broken serialization data. */ + goto corrupted; + } + node->layers[i].links[j] = neighbor; + } + + /* The worst link information was missing from older + * serialization formats. Compute it on the fly if needed. */ + if (node->layers[i].worst_idx == HNSW_SER_WORSTLINK_MISSING) { + hnsw_update_worst_neighbor(index,node,i); + } + } + node = node->next; + } + + /* Check that links are reciprocal, otherwise fail. */ + if (accumulator[0] || accumulator[1]) goto corrupted; + + /* Everything fine. Return success. */ + hfree(table); + return 1; + +corrupted: + /* Some corruption error detected. */ + hfree(table); + return 0; +} + +/* ================================ Iterator ================================ */ + +/* Get a cursor that can be used as argument of hnsw_cursor_next() to iterate + * all the elements that remain there from the start to the end of the + * iteration, excluding newly added elements. + * + * The function returns NULL on out of memory. */ +hnswCursor *hnsw_cursor_init(HNSW *index) { + if (pthread_rwlock_wrlock(&index->global_lock) != 0) return NULL; + hnswCursor *cursor = hmalloc(sizeof(*cursor)); + if (cursor == NULL) { + pthread_rwlock_unlock(&index->global_lock); + return NULL; + } + cursor->index = index; + cursor->next = index->cursors; + cursor->current = index->head; + index->cursors = cursor; + pthread_rwlock_unlock(&index->global_lock); + return cursor; +} + +/* Free the cursor. Can be called both at the end of the iteration, when + * hnsw_cursor_next() returned NULL, or before. */ +void hnsw_cursor_free(hnswCursor *cursor) { + HNSW *index = cursor->index; + if (pthread_rwlock_wrlock(&index->global_lock) != 0) { + // No easy way to recover from that. We will leak memory. + return; + } + + hnswCursor *x = index->cursors; + hnswCursor *prev = NULL; + while(x) { + if (x == cursor) { + if (prev) + prev->next = cursor->next; + else + index->cursors = cursor->next; + hfree(cursor); + break; + } + prev = x; + x = x->next; + } + pthread_rwlock_unlock(&index->global_lock); +} + +/* Acquire a lock to use the cursor. Returns 1 if the lock was acquired + * with success, otherwise zero is returned. The returned element is + * protected after calling hnsw_cursor_next() for all the time required to + * access it, then hnsw_cursor_release_lock() should be called in order + * to unlock the HNSW index. */ +int hnsw_cursor_acquire_lock(hnswCursor *cursor) { + return pthread_rwlock_rdlock(&cursor->index->global_lock) == 0; +} + +/* Release the cursor lock, see hnsw_cursor_acquire_lock() top comment + * for more information. */ +void hnsw_cursor_release_lock(hnswCursor *cursor) { + pthread_rwlock_unlock(&cursor->index->global_lock); +} + +/* Return the next element of the HNSW. See hnsw_cursor_init() for + * the guarantees of the function. */ +hnswNode *hnsw_cursor_next(hnswCursor *cursor) { + hnswNode *ret = cursor->current; + if (ret) cursor->current = ret->next; + return ret; +} + +/* Called by hnsw_unlink_node() if there is at least an active cursor. + * Will scan the cursors to see if any cursor is going to yield this + * one, and in this case, updates the current element to the next. */ +void hnsw_cursor_element_deleted(HNSW *index, hnswNode *deleted) { + hnswCursor *x = index->cursors; + while(x) { + if (x->current == deleted) x->current = deleted->next; + x = x->next; + } +} + +/* ============================ Debugging stuff ============================= */ + +/* Show stats about nodes connections. */ +void hnsw_print_stats(HNSW *index) { + if (!index || !index->head) { + printf("Empty index or NULL pointer passed\n"); + return; + } + + long long total_links = 0; + int min_links = -1; // We'll set this to first node's count. + int isolated_nodes = 0; + uint32_t node_count = 0; + + // Iterate through all nodes using the linked list. + hnswNode *current = index->head; + while (current) { + // Count total links for this node across all layers. + int node_total_links = 0; + for (uint32_t layer = 0; layer <= current->level; layer++) + node_total_links += current->layers[layer].num_links; + + // Update statistics. + total_links += node_total_links; + + // Initialize or update minimum links. + if (min_links == -1 || node_total_links < min_links) { + min_links = node_total_links; + } + + // Check if node is isolated (no links at all). + if (node_total_links == 0) isolated_nodes++; + + node_count++; + current = current->next; + } + + // Print statistics + printf("HNSW Graph Statistics:\n"); + printf("----------------------\n"); + printf("Total nodes: %u\n", node_count); + if (node_count > 0) { + printf("Average links per node: %.2f\n", + (float)total_links / node_count); + printf("Minimum links in a single node: %d\n", min_links); + printf("Number of isolated nodes: %d (%.1f%%)\n", + isolated_nodes, + (float)isolated_nodes * 100 / node_count); + } +} + +/* Validate graph connectivity and link reciprocity. Takes pointers to store results: + * - connected_nodes: will contain number of reachable nodes from entry point. + * - reciprocal_links: will contain 1 if all links are reciprocal, 0 otherwise. + * Returns 0 on success, -1 on error (NULL parameters and such). + */ +int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links) { + if (!index || !connected_nodes || !reciprocal_links) return -1; + if (!index->enter_point) { + *connected_nodes = 0; + *reciprocal_links = 1; // Empty graph is valid. + return 0; + } + + // Initialize connectivity check. + index->current_epoch[0]++; + *connected_nodes = 0; + *reciprocal_links = 1; + + // Initialize node stack. + uint64_t stack_size = index->node_count; + hnswNode **stack = hmalloc(sizeof(hnswNode*) * stack_size); + if (!stack) return -1; + uint64_t stack_top = 0; + + // Start from entry point. + index->enter_point->visited_epoch[0] = index->current_epoch[0]; + (*connected_nodes)++; + stack[stack_top++] = index->enter_point; + + // Process all reachable nodes. + while (stack_top > 0) { + hnswNode *current = stack[--stack_top]; + + // Explore all neighbors at each level. + for (uint32_t level = 0; level <= current->level; level++) { + for (uint64_t i = 0; i < current->layers[level].num_links; i++) { + hnswNode *neighbor = current->layers[level].links[i]; + + // Check reciprocity. + int found_backlink = 0; + for (uint64_t j = 0; j < neighbor->layers[level].num_links; j++) { + if (neighbor->layers[level].links[j] == current) { + found_backlink = 1; + break; + } + } + if (!found_backlink) { + *reciprocal_links = 0; + } + + // If we haven't visited this neighbor yet. + if (neighbor->visited_epoch[0] != index->current_epoch[0]) { + neighbor->visited_epoch[0] = index->current_epoch[0]; + (*connected_nodes)++; + if (stack_top < stack_size) { + stack[stack_top++] = neighbor; + } else { + // This should never happen in a valid graph. + hfree(stack); + return -1; + } + } + } + } + } + + hfree(stack); + + // Now scan for unreachable nodes and print debug info. + printf("\nUnreachable nodes debug information:\n"); + printf("=====================================\n"); + + hnswNode *current = index->head; + while (current) { + if (current->visited_epoch[0] != index->current_epoch[0]) { + printf("\nUnreachable node found:\n"); + printf("- Node pointer: %p\n", (void*)current); + printf("- Node ID: %llu\n", (unsigned long long)current->id); + printf("- Node level: %u\n", current->level); + + // Print info about all its links at each level. + for (uint32_t level = 0; level <= current->level; level++) { + printf(" Level %u links (%u):\n", level, + current->layers[level].num_links); + for (uint64_t i = 0; i < current->layers[level].num_links; i++) { + hnswNode *neighbor = current->layers[level].links[i]; + // Check reciprocity for this specific link + int found_backlink = 0; + for (uint64_t j = 0; j < neighbor->layers[level].num_links; j++) { + if (neighbor->layers[level].links[j] == current) { + found_backlink = 1; + break; + } + } + printf(" - Link %llu: pointer=%p, id=%llu, visited=%s,recpr=%s\n", + (unsigned long long)i, (void*)neighbor, + (unsigned long long)neighbor->id, + neighbor->visited_epoch[0] == index->current_epoch[0] ? + "yes" : "no", + found_backlink ? "yes" : "no"); + } + } + } + current = current->next; + } + + printf("Total connected nodes: %llu\n", (unsigned long long)*connected_nodes); + printf("All links are bi-directiona? %s\n", (*reciprocal_links)?"yes":"no"); + return 0; +} + +/* Test graph recall ability by verifying each node can be found searching + * for its own vector. This helps validate that the majority of nodes are + * properly connected and easily reachable in the graph structure. Every + * unreachable node is reported. + * + * Normally only a small percentage of nodes will be not reachable when + * visited. This is expected and part of the statistical properties + * of HNSW. This happens especially with entries that have an ambiguous + * meaning in the represented space, and are across two or multiple clusters + * of items. + * + * The function works by: + * 1. Iterating through all nodes in the linked list + * 2. Using each node's vector to perform a search with specified EF + * 3. Verifying the node can find itself as nearest neighbor + * 4. Collecting and reporting statistics about reachability + * + * This is just a debugging function that reports stuff in the standard + * output, part of the implementation because this kind of functions + * provide some visibility on what happens inside the HNSW. + */ +void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose) { + // Stats + uint32_t total_nodes = 0; + uint32_t unreachable_nodes = 0; + uint32_t perfectly_reachable = 0; // Node finds itself as first result + + // For storing search results + hnswNode **neighbors = hmalloc(sizeof(hnswNode*) * test_ef); + float *distances = hmalloc(sizeof(float) * test_ef); + float *test_vector = hmalloc(sizeof(float) * index->vector_dim); + if (!neighbors || !distances || !test_vector) { + hfree(neighbors); + hfree(distances); + hfree(test_vector); + return; + } + + // Get a read slot for searching (even if it's highly unlikely that + // this test will be run threaded...). + int slot = hnsw_acquire_read_slot(index); + if (slot < 0) { + hfree(neighbors); + hfree(distances); + return; + } + + printf("\nTesting graph recall\n"); + printf("====================\n"); + + // Process one node at a time using the linked list + hnswNode *current = index->head; + while (current) { + total_nodes++; + + // If using quantization, we need to reconstruct the normalized vector + if (index->quant_type == HNSW_QUANT_Q8) { + int8_t *quants = current->vector; + // Reconstruct normalized vector from quantized data + for (uint32_t j = 0; j < index->vector_dim; j++) { + test_vector[j] = (quants[j] * current->quants_range) / 127; + } + } else if (index->quant_type == HNSW_QUANT_NONE) { + memcpy(test_vector,current->vector,sizeof(float)*index->vector_dim); + } else { + assert(0 && "Quantization type not supported."); + } + + // Search using the node's own vector with high ef + int found = hnsw_search(index, test_vector, test_ef, neighbors, + distances, slot, 1); + + if (found == 0) continue; // Empty HNSW? + + // Look for the node itself in the results + int found_self = 0; + int self_position = -1; + for (int i = 0; i < found; i++) { + if (neighbors[i] == current) { + found_self = 1; + self_position = i; + break; + } + } + + if (!found_self || self_position != 0) { + unreachable_nodes++; + if (verbose) { + if (!found_self) + printf("\nNode %s cannot find itself:\n", (char*)current->value); + else + printf("\nNode %s is not top result:\n", (char*)current->value); + printf("- Node ID: %llu\n", (unsigned long long)current->id); + printf("- Node level: %u\n", current->level); + printf("- Found %d neighbors but self not among them\n", found); + printf("- Closest neighbor distance: %f\n", distances[0]); + printf("- Neighbors: "); + for (uint32_t i = 0; i < current->layers[0].num_links; i++) { + printf("%s ", (char*)current->layers[0].links[i]->value); + } + printf("\n"); + printf("\nFound instead: "); + for (int j = 0; j < found && j < 10; j++) { + printf("%s ", (char*)neighbors[j]->value); + } + printf("\n"); + } + } else { + perfectly_reachable++; + } + current = current->next; + } + + // Release read slot + hnsw_release_read_slot(index, slot); + + // Free resources + hfree(neighbors); + hfree(distances); + hfree(test_vector); + + // Print final statistics + printf("Total nodes tested: %u\n", total_nodes); + printf("Perfectly reachable nodes: %u (%.1f%%)\n", + perfectly_reachable, + total_nodes ? (float)perfectly_reachable * 100 / total_nodes : 0); + printf("Unreachable/suboptimal nodes: %u (%.1f%%)\n", + unreachable_nodes, + total_nodes ? (float)unreachable_nodes * 100 / total_nodes : 0); +} + +/* Return exact K-NN items by performing a linear scan of all nodes. + * This function has the same signature as hnsw_search_with_filter() but + * instead of using the graph structure, it scans all nodes to find the + * true nearest neighbors. + * + * Note that neighbors and distances arrays must have space for at least 'k' items. + * norm_query should be set to 1 if the query vector is already normalized. + * + * If the filter_callback is passed, only elements passing the specified filter + * are returned. The slot parameter is ignored but kept for API consistency. */ +int hnsw_ground_truth_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata) +{ + /* Note that we don't really use the slot here: it's a linear scan. + * Yet we want the user to acquire the slot as this will hold the + * global lock in read only mode. */ + (void) slot; + + /* Take our query vector into a temporary node. */ + hnswNode query; + if (hnsw_init_tmp_node(index, &query, query_vector_is_normalized, query_vector) == 0) return -1; + + /* Accumulate best results into a priority queue. */ + pqueue *results = pq_new(k); + if (!results) { + hnsw_free_tmp_node(&query, query_vector); + return -1; + } + + /* Scan all nodes linearly. */ + hnswNode *current = index->head; + while (current) { + /* Apply filter if needed. */ + if (filter_callback && + !filter_callback(current->value, filter_privdata)) + { + current = current->next; + continue; + } + + /* Calculate distance to query. */ + float dist = hnsw_distance(index, &query, current); + + /* Add to results to pqueue. Will be accepted only if better than + * the current worse or pqueue not full. */ + pq_push(results, current, dist); + current = current->next; + } + + /* Copy results to output arrays. */ + uint32_t found = MIN(k, results->count); + for (uint32_t i = 0; i < found; i++) { + neighbors[i] = pq_get_node(results, i); + if (distances) distances[i] = pq_get_distance(results, i); + } + + /* Clean up. */ + pq_free(results); + hnsw_free_tmp_node(&query, query_vector); + return found; +} diff --git a/examples/redis-unstable/modules/vector-sets/hnsw.h b/examples/redis-unstable/modules/vector-sets/hnsw.h new file mode 100644 index 0000000..8935521 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/hnsw.h @@ -0,0 +1,189 @@ +/* + * HNSW (Hierarchical Navigable Small World) Implementation + * Based on the paper by Yu. A. Malkov, D. A. Yashunin + * + * Copyright (c) 2009-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of (a) the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + * Originally authored by: Salvatore Sanfilippo. + */ + +#ifndef HNSW_H +#define HNSW_H + +#include <pthread.h> +#include <stdatomic.h> + +#define HNSW_DEFAULT_M 16 /* Used when 0 is given at creation time. */ +#define HNSW_MIN_M 4 /* Probably even too low already. */ +#define HNSW_MAX_M 4096 /* Safeguard sanity limit. */ +#define HNSW_MAX_THREADS 32 /* Maximum number of concurrent threads */ + +/* Quantization types you can enable at creation time in hnsw_new() */ +#define HNSW_QUANT_NONE 0 // No quantization. +#define HNSW_QUANT_Q8 1 // Q8 quantization. +#define HNSW_QUANT_BIN 2 // Binary quantization. + +/* Layer structure for HNSW nodes. Each node will have from one to a few + * of this depending on its level. */ +typedef struct { + struct hnswNode **links; /* Array of neighbors for this layer */ + uint32_t num_links; /* Number of used links */ + uint32_t max_links; /* Maximum links for this layer. We may + * reallocate the node in very particular + * conditions in order to allow linking of + * new inserted nodes, so this may change + * dynamically and be > M*2 for a small set of + * nodes. */ + float worst_distance; /* Distance to the worst neighbor */ + uint32_t worst_idx; /* Index of the worst neighbor */ +} hnswNodeLayer; + +/* Node structure for HNSW graph */ +typedef struct hnswNode { + uint32_t level; /* Node's maximum level */ + uint64_t id; /* Unique identifier, may be useful in order to + * have a bitmap of visited notes to use as + * alternative to epoch / visited_epoch. + * Also used in serialization in order to retain + * links specifying IDs. */ + void *vector; /* The vector, quantized or not. */ + float quants_range; /* Quantization range for this vector: + * min/max values will be in the range + * -quants_range, +quants_range */ + float l2; /* L2 before normalization. */ + + /* Last time (epoch) this node was visited. We need one per thread. + * This avoids having a different data structure where we track + * visited nodes, but costs memory per node. */ + uint64_t visited_epoch[HNSW_MAX_THREADS]; + + void *value; /* Associated value */ + struct hnswNode *prev, *next; /* Prev/Next node in the list starting at + * HNSW->head. */ + + /* Links (and links info) per each layer. Note that this is part + * of the node allocation to be more cache friendly: reliable 3% speedup + * on Apple silicon, and does not make anything more complex. */ + hnswNodeLayer layers[]; +} hnswNode; + +struct HNSW; + +/* It is possible to navigate an HNSW with a cursor that guarantees + * visiting all the elements that remain in the HNSW from the start to the + * end of the process (but not the new ones, so that the process will + * eventually finish). Check hnsw_cursor_init(), hnsw_cursor_next() and + * hnsw_cursor_free(). */ +typedef struct hnswCursor { + struct HNSW *index; // Reference to the index of this cursor. + hnswNode *current; // Element to report when hnsw_cursor_next() is called. + struct hnswCursor *next; // Next cursor active. +} hnswCursor; + +/* Main HNSW index structure */ +typedef struct HNSW { + hnswNode *enter_point; /* Entry point for the graph */ + uint32_t M; /* M as in the paper: layer 0 has M*2 max + neighbors (M populated at insertion time) + while all the other layers have M neighbors. */ + uint32_t max_level; /* Current maximum level in the graph */ + uint32_t vector_dim; /* Dimensionality of stored vectors */ + uint64_t node_count; /* Total number of nodes */ + _Atomic uint64_t last_id; /* Last node ID used */ + uint64_t current_epoch[HNSW_MAX_THREADS]; /* Current epoch for visit tracking */ + hnswNode *head; /* Linked list of nodes. Last first */ + + /* We have two locks here: + * 1. A global_lock that is used to perform write operations blocking all + * the readers. + * 2. One mutex per epoch slot, in order for read operations to acquire + * a lock on a specific slot to use epochs tracking of visited nodes. */ + pthread_rwlock_t global_lock; /* Global read-write lock */ + pthread_mutex_t slot_locks[HNSW_MAX_THREADS]; /* Per-slot locks */ + + _Atomic uint32_t next_slot; /* Next thread slot to try */ + _Atomic uint64_t version; /* Version for optimistic concurrency, this is + * incremented on deletions and entry point + * updates. */ + uint32_t quant_type; /* Quantization used. HNSW_QUANT_... */ + hnswCursor *cursors; +} HNSW; + +/* Serialized node. This structure is used as return value of + * hnsw_serialize_node(). */ +typedef struct hnswSerNode { + void *vector; + uint32_t vector_size; + uint64_t *params; + uint32_t params_count; +} hnswSerNode; + +/* Insert preparation context */ +typedef struct InsertContext InsertContext; + +/* Core HNSW functions */ +HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type, uint32_t m); +void hnsw_free(HNSW *index,void(*free_value)(void*value)); +void hnsw_node_free(hnswNode *node); +void hnsw_print_stats(HNSW *index); +hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, + float qrange, uint64_t id, void *value, int ef); +int hnsw_search(HNSW *index, const float *query, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized); +int hnsw_search_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates); +void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec); +int hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)); +hnswNode *hnsw_random_node(HNSW *index, int slot); + +/* Thread safety functions. */ +int hnsw_acquire_read_slot(HNSW *index); +void hnsw_release_read_slot(HNSW *index, int slot); + +/* Optimistic insertion API. */ +InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, int ef); +hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx, void *value); +void hnsw_free_insert_context(InsertContext *ctx); + +/* Serialization. */ +hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node); +void hnsw_free_serialized_node(hnswSerNode *sn); +hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value); +int hnsw_deserialize_index(HNSW *index, uint64_t salt0, uint64_t salt1); + +// Helper function in case the user wants to directly copy +// the vector bytes. +uint32_t hnsw_quants_bytes(HNSW *index); + +/* Cursors. */ +hnswCursor *hnsw_cursor_init(HNSW *index); +void hnsw_cursor_free(hnswCursor *cursor); +hnswNode *hnsw_cursor_next(hnswCursor *cursor); +int hnsw_cursor_acquire_lock(hnswCursor *cursor); +void hnsw_cursor_release_lock(hnswCursor *cursor); + +/* Allocator selection. */ +void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t), + void *(*realloc_ptr)(void*, size_t)); + +/* Testing. */ +int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links); +void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose); +float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b); +int hnsw_ground_truth_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata); + +#endif /* HNSW_H */ diff --git a/examples/redis-unstable/modules/vector-sets/mixer.h b/examples/redis-unstable/modules/vector-sets/mixer.h new file mode 100644 index 0000000..d75e193 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/mixer.h @@ -0,0 +1,106 @@ +/* Redis implementation for vector sets. The data structure itself + * is implemented in hnsw.c. + * + * Copyright (c) 2009-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of (a) the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + * Originally authored by: Salvatore Sanfilippo. + * + * ============================================================================= + * + * Mixing function for HNSW link integrity verification + * Designed to resist collision attacks when salts are unknown. + */ + +#include <stdint.h> +#include <string.h> + +static inline uint64_t ROTL64(uint64_t x, int r) { + return (x << r) | (x >> (64 - r)); +} + +// Use more rounds and stronger constants +#define MIX_PRIME_1 0xFF51AFD7ED558CCDULL +#define MIX_PRIME_2 0xC4CEB9FE1A85EC53ULL +#define MIX_PRIME_3 0x9E3779B97F4A7C15ULL +#define MIX_PRIME_4 0xBF58476D1CE4E5B9ULL +#define MIX_PRIME_5 0x94D049BB133111EBULL +#define MIX_PRIME_6 0x2B7E151628AED2A7ULL + +/* Mixer design goals: + * 1. Thorough mixing of the level parameter. + * 2. Enough rounds of mixing. + * 3. Cross-influence between h1 and h2. + * 4. Domain separation to prevent related-key attacks. + */ +void secure_pair_mixer_128(uint64_t salt0, uint64_t salt1, + uint64_t id1_in, uint64_t id2_in, uint64_t level, + uint64_t* out_h1, uint64_t* out_h2) { + // Order independence (A -> B links should hash as B -> A links). + uint64_t id_a = (id1_in < id2_in) ? id1_in : id2_in; + uint64_t id_b = (id1_in < id2_in) ? id2_in : id1_in; + + // Domain separation: mix salts with a constant to prevent + // related-key attacks. + uint64_t h1 = salt0 ^ 0xDEADBEEFDEADBEEFULL; + uint64_t h2 = salt1 ^ 0xCAFEBABECAFEBABEULL; + + // First, thoroughly mix the level into both accumulators + // This prevents predictable level values from being a weakness + uint64_t level_mix = level; + level_mix *= MIX_PRIME_5; + level_mix ^= level_mix >> 32; + level_mix *= MIX_PRIME_6; + + h1 ^= level_mix; + h2 ^= ROTL64(level_mix, 31); + + // Mix in id_a with strong diffusion. + h1 ^= id_a; + h1 *= MIX_PRIME_1; + h1 = ROTL64(h1, 23); + h1 *= MIX_PRIME_2; + + // Mix in id_b. + h2 ^= id_b; + h2 *= MIX_PRIME_3; + h2 = ROTL64(h2, 29); + h2 *= MIX_PRIME_4; + + // Three rounds of cross-mixing for better security. + for (int i = 0; i < 3; i++) { + // Cross-influence. + uint64_t tmp = h1; + h1 += h2; + h2 += tmp; + + // Mix h1. + h1 ^= ROTL64(h1, 31); + h1 *= MIX_PRIME_1; + h1 ^= salt0; + + // Mix h2. + h2 ^= ROTL64(h2, 37); + h2 *= MIX_PRIME_2; + h2 ^= salt1; + } + + // Finalization with avalanche rounds. + h1 ^= h1 >> 33; + h1 *= MIX_PRIME_3; + h1 ^= h1 >> 29; + h1 *= MIX_PRIME_4; + h1 ^= h1 >> 32; + + h2 ^= h2 >> 33; + h2 *= MIX_PRIME_5; + h2 ^= h2 >> 29; + h2 *= MIX_PRIME_6; + h2 ^= h2 >> 32; + + *out_h1 = h1; + *out_h2 = h2; +} diff --git a/examples/redis-unstable/modules/vector-sets/test.py b/examples/redis-unstable/modules/vector-sets/test.py new file mode 100755 index 0000000..8d56d58 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/test.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# +# Vector set tests. +# A Redis instance should be running in the default port. +# +# Copyright (c) 2009-Present, Redis Ltd. +# All rights reserved. +# +# Licensed under your choice of (a) the Redis Source Available License 2.0 +# (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the +# GNU Affero General Public License v3 (AGPLv3). +# + +import redis +import random +import struct +import math +import time +import sys +import os +import importlib +import inspect +import argparse +from typing import List, Tuple, Optional +from dataclasses import dataclass + +def colored(text: str, color: str) -> str: + colors = { + 'red': '\033[91m', + 'green': '\033[92m', + 'yellow': '\033[93m', + 'blue': '\033[94m', + 'magenta': '\033[95m', + 'cyan': '\033[96m', + } + reset = '\033[0m' + return f"{colors.get(color, '')}{text}{reset}" + +@dataclass +class VectorData: + vectors: List[List[float]] + names: List[str] + + def find_k_nearest(self, query_vector: List[float], k: int) -> List[Tuple[str, float]]: + """Find k-nearest neighbors using the same scoring as Redis VSIM WITHSCORES.""" + similarities = [] + query_norm = math.sqrt(sum(x*x for x in query_vector)) + if query_norm == 0: + return [] + + for i, vec in enumerate(self.vectors): + vec_norm = math.sqrt(sum(x*x for x in vec)) + if vec_norm == 0: + continue + + dot_product = sum(a*b for a,b in zip(query_vector, vec)) + cosine_sim = dot_product / (query_norm * vec_norm) + distance = 1.0 - cosine_sim + redis_similarity = 1.0 - (distance/2.0) + similarities.append((self.names[i], redis_similarity)) + + similarities.sort(key=lambda x: x[1], reverse=True) + return similarities[:k] + +def generate_random_vector(dim: int) -> List[float]: + """Generate a random normalized vector.""" + vec = [random.gauss(0, 1) for _ in range(dim)] + norm = math.sqrt(sum(x*x for x in vec)) + return [x/norm for x in vec] + +def fill_redis_with_vectors(r: redis.Redis, key: str, count: int, dim: int, + with_reduce: Optional[int] = None) -> VectorData: + """Fill Redis with random vectors and return a VectorData object for verification.""" + vectors = [] + names = [] + + r.delete(key) + for i in range(count): + vec = generate_random_vector(dim) + name = f"{key}:item:{i}" + vectors.append(vec) + names.append(name) + + vec_bytes = struct.pack(f'{dim}f', *vec) + args = [key] + if with_reduce: + args.extend(['REDUCE', with_reduce]) + args.extend(['FP32', vec_bytes, name]) + r.execute_command('VADD', *args) + + return VectorData(vectors=vectors, names=names) + +class TestCase: + def __init__(self, primary_port=6379, replica_port=6380): + self.error_msg = None + self.error_details = None + self.test_key = f"test:{self.__class__.__name__.lower()}" + # Primary Redis instance + self.redis = redis.Redis(port=primary_port,db=9) + self.redis3 = redis.Redis(port=primary_port,protocol=3,db=9) + # Replica Redis instance + self.replica = redis.Redis(port=replica_port,db=9) + # Replication status + self.replication_setup = False + # Ports + self.primary_port = primary_port + self.replica_port = replica_port + + def setup(self): + self.redis.delete(self.test_key) + + def teardown(self): + self.redis.delete(self.test_key) + + def setup_replication(self) -> bool: + """ + Setup replication between primary and replica Redis instances. + Returns True if replication is successfully established, False otherwise. + """ + # Configure replica to replicate from primary + self.replica.execute_command('REPLICAOF', '127.0.0.1', self.primary_port) + + # Wait for replication to be established + max_attempts = 50 + for attempt in range(max_attempts): + # Check replication info + repl_info = self.replica.info('replication') + + # Check if replication is established + if (repl_info.get('role') == 'slave' and + repl_info.get('master_host') == '127.0.0.1' and + repl_info.get('master_port') == self.primary_port and + repl_info.get('master_link_status') == 'up'): + + self.replication_setup = True + return True + + # Wait before next attempt + print(colored(".",'cyan'),end="",flush=True) + time.sleep(0.5) + + # If we get here, replication wasn't established + self.error_msg = "Failed to establish replication between primary and replica" + return False + + def test(self): + raise NotImplementedError("Subclasses must implement test method") + + def run(self): + try: + self.setup() + self.test() + return True + except AssertionError as e: + self.error_msg = str(e) + import traceback + self.error_details = traceback.format_exc() + return False + except Exception as e: + self.error_msg = f"Unexpected error: {str(e)}" + import traceback + self.error_details = traceback.format_exc() + return False + finally: + self.teardown() + + def getname(self): + """Each test class should override this to provide its name""" + return self.__class__.__name__ + + def estimated_runtime(self): + """"Each test class should override this if it takes a significant amount of time to run. Default is 100ms""" + return 0.1 + +def find_test_classes(primary_port, replica_port): + test_classes = [] + script_dir = os.path.dirname(os.path.abspath(__file__)) + tests_dir = os.path.join(script_dir, 'tests') + + if not os.path.exists(tests_dir): + return [] + + for file in os.listdir(tests_dir): + if file.endswith('.py'): + module_name = f"tests.{file[:-3]}" + try: + module = importlib.import_module(module_name) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__name__ != 'TestCase' and hasattr(obj, 'test'): + # Create test instance with specified ports + test_instance = obj(primary_port,replica_port) + test_classes.append(test_instance) + except Exception as e: + print(f"Error loading {file}: {e}") + + return test_classes + +def check_redis_empty(r, instance_name): + """Check if Redis instance is empty""" + try: + dbsize = r.dbsize() + if dbsize > 0: + print(colored(f"ERROR: {instance_name} Redis instance DB 9 is not empty (dbsize: {dbsize}).", "red")) + print(colored("Make sure you're not using a production instance and that all data is safe to delete.", "red")) + sys.exit(1) + except redis.exceptions.ConnectionError: + print(colored(f"ERROR: Cannot connect to {instance_name} Redis instance.", "red")) + sys.exit(1) + +def check_replica_running(replica_port): + """Check if replica Redis instance is running""" + r = redis.Redis(port=replica_port) + try: + r.ping() + return True + except redis.exceptions.ConnectionError: + print(colored(f"WARNING: Replica Redis instance (port {replica_port}) is not running.", "yellow")) + print(colored("Replication tests will be skipped. Make sure to start the replica instance.", "yellow")) + return False + +def run_tests(): + # Parse command line arguments + parser = argparse.ArgumentParser(description='Run Redis vector tests.') + parser.add_argument('--primary-port', type=int, default=6379, help='Primary Redis instance port (default: 6379)') + parser.add_argument('--replica-port', type=int, default=6380, help='Replica Redis instance port (default: 6380)') + args = parser.parse_args() + + print("================================================") + print(f"Make sure to have Redis running on localhost") + print(f"Primary port: {args.primary_port}") + print(f"Replica port: {args.replica_port}") + print("with --enable-debug-command yes") + print("================================================\n") + + # Check if Redis instances are empty + primary = redis.Redis(port=args.primary_port,db=9) + replica = redis.Redis(port=args.replica_port,db=9) + + check_redis_empty(primary, "Primary") + + # Check if replica is running + replica_running = check_replica_running(args.replica_port) + if replica_running: + check_redis_empty(replica, "Replica") + + tests = find_test_classes(args.primary_port, args.replica_port) + if not tests: + print("No tests found!") + return + + # Sort tests by estimated runtime + tests.sort(key=lambda t: t.estimated_runtime()) + + passed = 0 + skipped = 0 + total = len(tests) + + for test in tests: + print(f"{test.getname()}: ", end="") + sys.stdout.flush() + + if not replica_running and test.getname().lower().find("replication") != -1: + print(colored("SKIPPING","yellow")) + skipped += 1 + continue + + start_time = time.time() + success = test.run() + duration = time.time() - start_time + + if success: + print(colored("OK", "green"), f"({duration:.2f}s)") + passed += 1 + else: + print(colored("ERR", "red"), f"({duration:.2f}s)") + print(f"Error: {test.error_msg}") + if test.error_details: + print("\nTraceback:") + print(test.error_details) + + print("\n" + "="*50) + print(f"\nTest Summary: {passed}/{total} tests passed") + + if passed == total: + print(colored("ALL TESTS PASSED!", "green")) + else: + if total-skipped-passed > 0: + print(colored(f"{total-skipped-passed} TESTS FAILED!", "red")) + sys.exit(1) + if skipped > 0: + print(colored(f"{skipped} TESTS SKIPPED!", "yellow")) + +if __name__ == "__main__": + run_tests() diff --git a/examples/redis-unstable/modules/vector-sets/tests/basic_commands.py b/examples/redis-unstable/modules/vector-sets/tests/basic_commands.py new file mode 100644 index 0000000..8481a36 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/basic_commands.py @@ -0,0 +1,21 @@ +from test import TestCase, generate_random_vector +import struct + +class BasicCommands(TestCase): + def getname(self): + return "VADD, VDIM, VCARD basic usage" + + def test(self): + # Test VADD + vec = generate_random_vector(4) + vec_bytes = struct.pack('4f', *vec) + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + assert result == 1, "VADD should return 1 for first item" + + # Test VDIM + dim = self.redis.execute_command('VDIM', self.test_key) + assert dim == 4, f"VDIM should return 4, got {dim}" + + # Test VCARD + card = self.redis.execute_command('VCARD', self.test_key) + assert card == 1, f"VCARD should return 1, got {card}" diff --git a/examples/redis-unstable/modules/vector-sets/tests/basic_similarity.py b/examples/redis-unstable/modules/vector-sets/tests/basic_similarity.py new file mode 100644 index 0000000..11c3c9b --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/basic_similarity.py @@ -0,0 +1,35 @@ +from test import TestCase + +class BasicSimilarity(TestCase): + def getname(self): + return "VSIM reported distance makes sense with 4D vectors" + + def test(self): + # Add two very similar vectors, one different + vec1 = [1, 0, 0, 0] + vec2 = [0.99, 0.01, 0, 0] + vec3 = [0.1, 1, -1, 0.5] + + # Add vectors using VALUES format + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], f'{self.test_key}:item:1') + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec2], f'{self.test_key}:item:2') + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec3], f'{self.test_key}:item:3') + + # Query similarity with vec1 + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], 'WITHSCORES') + + # Convert results to dictionary + results_dict = {} + for i in range(0, len(result), 2): + key = result[i].decode() + score = float(result[i+1]) + results_dict[key] = score + + # Verify results + assert results_dict[f'{self.test_key}:item:1'] > 0.99, "Self-similarity should be very high" + assert results_dict[f'{self.test_key}:item:2'] > 0.99, "Similar vector should have high similarity" + assert results_dict[f'{self.test_key}:item:3'] < 0.8, "Not very similar vector should have low similarity" diff --git a/examples/redis-unstable/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py b/examples/redis-unstable/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py new file mode 100644 index 0000000..f4b3a12 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/concurrent_vadd_cas_del_vsim.py @@ -0,0 +1,156 @@ +from test import TestCase, generate_random_vector +import threading +import time +import struct + +class ThreadingStressTest(TestCase): + def getname(self): + return "Concurrent VADD/DEL/VSIM operations stress test" + + def estimated_runtime(self): + return 10 # Test runs for 10 seconds + + def test(self): + # Constants - easy to modify if needed + NUM_VADD_THREADS = 10 + NUM_VSIM_THREADS = 1 + NUM_DEL_THREADS = 1 + TEST_DURATION = 10 # seconds + VECTOR_DIM = 100 + DEL_INTERVAL = 1 # seconds + + # Shared flags and state + stop_event = threading.Event() + error_list = [] + error_lock = threading.Lock() + + def log_error(thread_name, error): + with error_lock: + error_list.append(f"{thread_name}: {error}") + + def vadd_worker(thread_id): + """Thread function to perform VADD operations""" + thread_name = f"VADD-{thread_id}" + try: + vector_count = 0 + while not stop_event.is_set(): + try: + # Generate random vector + vec = generate_random_vector(VECTOR_DIM) + vec_bytes = struct.pack(f'{VECTOR_DIM}f', *vec) + + # Add vector with CAS option + self.redis.execute_command( + 'VADD', + self.test_key, + 'FP32', + vec_bytes, + f'{self.test_key}:item:{thread_id}:{vector_count}', + 'CAS' + ) + + vector_count += 1 + + # Small sleep to reduce CPU pressure + if vector_count % 10 == 0: + time.sleep(0.001) + except Exception as e: + log_error(thread_name, f"Error: {str(e)}") + time.sleep(0.1) # Slight backoff on error + except Exception as e: + log_error(thread_name, f"Thread error: {str(e)}") + + def del_worker(): + """Thread function that deletes the key periodically""" + thread_name = "DEL" + try: + del_count = 0 + while not stop_event.is_set(): + try: + # Sleep first, then delete + time.sleep(DEL_INTERVAL) + if stop_event.is_set(): + break + + self.redis.delete(self.test_key) + del_count += 1 + except Exception as e: + log_error(thread_name, f"Error: {str(e)}") + except Exception as e: + log_error(thread_name, f"Thread error: {str(e)}") + + def vsim_worker(thread_id): + """Thread function to perform VSIM operations""" + thread_name = f"VSIM-{thread_id}" + try: + search_count = 0 + while not stop_event.is_set(): + try: + # Generate query vector + query_vec = generate_random_vector(VECTOR_DIM) + query_str = [str(x) for x in query_vec] + + # Perform similarity search + args = ['VSIM', self.test_key, 'VALUES', VECTOR_DIM] + args.extend(query_str) + args.extend(['COUNT', 10]) + self.redis.execute_command(*args) + + search_count += 1 + + # Small sleep to reduce CPU pressure + if search_count % 10 == 0: + time.sleep(0.005) + except Exception as e: + # Don't log empty array errors, as they're expected when key doesn't exist + if "empty array" not in str(e).lower(): + log_error(thread_name, f"Error: {str(e)}") + time.sleep(0.1) # Slight backoff on error + except Exception as e: + log_error(thread_name, f"Thread error: {str(e)}") + + # Start all threads + threads = [] + + # VADD threads + for i in range(NUM_VADD_THREADS): + thread = threading.Thread(target=vadd_worker, args=(i,)) + thread.start() + threads.append(thread) + + # DEL threads + for _ in range(NUM_DEL_THREADS): + thread = threading.Thread(target=del_worker) + thread.start() + threads.append(thread) + + # VSIM threads + for i in range(NUM_VSIM_THREADS): + thread = threading.Thread(target=vsim_worker, args=(i,)) + thread.start() + threads.append(thread) + + # Let the test run for the specified duration + time.sleep(TEST_DURATION) + + # Signal all threads to stop + stop_event.set() + + # Wait for threads to finish + for thread in threads: + thread.join(timeout=2.0) + + # Check if Redis is still responsive + try: + ping_result = self.redis.ping() + assert ping_result, "Redis did not respond to PING after stress test" + except Exception as e: + assert False, f"Redis connection failed after stress test: {str(e)}" + + # Report any errors for diagnosis, but don't fail the test unless PING fails + if error_list: + error_count = len(error_list) + print(f"\nEncountered {error_count} errors during stress test.") + print("First 5 errors:") + for error in error_list[:5]: + print(f"- {error}") diff --git a/examples/redis-unstable/modules/vector-sets/tests/concurrent_vsim_and_del.py b/examples/redis-unstable/modules/vector-sets/tests/concurrent_vsim_and_del.py new file mode 100644 index 0000000..9bbf011 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/concurrent_vsim_and_del.py @@ -0,0 +1,48 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import threading, time + +class ConcurrentVSIMAndDEL(TestCase): + def getname(self): + return "Concurrent VSIM and DEL operations" + + def estimated_runtime(self): + return 2 + + def test(self): + # Fill the key with 5000 random vectors + dim = 128 + count = 5000 + fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # List to store results from threads + thread_results = [] + + def vsim_thread(): + """Thread function to perform VSIM operations until the key is deleted""" + while True: + query_vec = generate_random_vector(dim) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], 'COUNT', 10) + if not result: + # Empty array detected, key is deleted + thread_results.append(True) + break + + # Start multiple threads to perform VSIM operations + threads = [] + for _ in range(4): # Start 4 threads + t = threading.Thread(target=vsim_thread) + t.start() + threads.append(t) + + # Delete the key while threads are still running + time.sleep(1) + self.redis.delete(self.test_key) + + # Wait for all threads to finish (they will exit once they detect the key is deleted) + for t in threads: + t.join() + + # Verify that all threads detected an empty array or error + assert len(thread_results) == len(threads), "Not all threads detected the key deletion" + assert all(thread_results), "Some threads did not detect an empty array or error after DEL" diff --git a/examples/redis-unstable/modules/vector-sets/tests/debug_digest.py b/examples/redis-unstable/modules/vector-sets/tests/debug_digest.py new file mode 100644 index 0000000..78f06d8 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/debug_digest.py @@ -0,0 +1,39 @@ +from test import TestCase, generate_random_vector +import struct + +class DebugDigestTest(TestCase): + def getname(self): + return "[regression] DEBUG DIGEST-VALUE with attributes" + + def test(self): + # Generate random vectors + vec1 = generate_random_vector(4) + vec2 = generate_random_vector(4) + vec_bytes1 = struct.pack('4f', *vec1) + vec_bytes2 = struct.pack('4f', *vec2) + + # Add vectors to the key, one with attribute, one without + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes1, f'{self.test_key}:item:1') + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes2, f'{self.test_key}:item:2', 'SETATTR', '{"color":"red"}') + + # Call DEBUG DIGEST-VALUE on the key + try: + digest1 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + assert digest1 is not None, "DEBUG DIGEST-VALUE should return a value" + + # Change attribute and verify digest changes + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', '{"color":"blue"}') + + digest2 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + assert digest2 is not None, "DEBUG DIGEST-VALUE should return a value after attribute change" + assert digest1 != digest2, "Digest should change when an attribute is modified" + + # Remove attribute and verify digest changes again + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', '') + + digest3 = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + assert digest3 is not None, "DEBUG DIGEST-VALUE should return a value after attribute removal" + assert digest2 != digest3, "Digest should change when an attribute is removed" + + except Exception as e: + raise AssertionError(f"DEBUG DIGEST-VALUE command failed: {str(e)}") diff --git a/examples/redis-unstable/modules/vector-sets/tests/deletion.py b/examples/redis-unstable/modules/vector-sets/tests/deletion.py new file mode 100644 index 0000000..cb91959 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/deletion.py @@ -0,0 +1,173 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +""" +A note about this test: +It was experimentally tried to modify hnsw.c in order to +avoid calling hnsw_reconnect_nodes(). In this case, the test +fails very often with EF set to 250, while it hardly +fails at all with the same parameters if hnsw_reconnect_nodes() +is called. + +Note that for the nature of the test (it is very strict) it can +still fail from time to time, without this signaling any +actual bug. +""" + +class VREM(TestCase): + def getname(self): + return "Deletion and graph state after deletion" + + def estimated_runtime(self): + return 2.0 + + def format_neighbors_with_scores(self, links_result, old_links=None, items_to_remove=None): + """Format neighbors with their similarity scores and status indicators""" + if not links_result: + return "No neighbors" + + output = [] + for level, neighbors in enumerate(links_result): + level_num = len(links_result) - level - 1 + output.append(f"Level {level_num}:") + + # Get neighbors and scores + neighbors_with_scores = [] + for i in range(0, len(neighbors), 2): + neighbor = neighbors[i].decode() if isinstance(neighbors[i], bytes) else neighbors[i] + score = float(neighbors[i+1]) if i+1 < len(neighbors) else None + status = "" + + # For old links, mark deleted ones + if items_to_remove and neighbor in items_to_remove: + status = " [lost]" + # For new links, mark newly added ones + elif old_links is not None: + # Check if this neighbor was in the old links at this level + was_present = False + if old_links and level < len(old_links): + old_neighbors = [n.decode() if isinstance(n, bytes) else n + for n in old_links[level]] + was_present = neighbor in old_neighbors + if not was_present: + status = " [gained]" + + if score is not None: + neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor} ({score:.6f}){status}") + else: + neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor}{status}") + + output.extend([" " + n for n in neighbors_with_scores]) + return "\n".join(output) + + def test(self): + # 1. Fill server with random elements + dim = 128 + count = 5000 + data = fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # 2. Do VSIM to get 200 items + query_vec = generate_random_vector(dim) + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', 200, 'WITHSCORES') + + # Convert results to list of (item, score) pairs, sorted by score + items = [] + for i in range(0, len(results), 2): + item = results[i].decode() + score = float(results[i+1]) + items.append((item, score)) + items.sort(key=lambda x: x[1], reverse=True) # Sort by similarity + + # Store the graph structure for all items before deletion + neighbors_before = {} + for item, _ in items: + links = self.redis.execute_command('VLINKS', self.test_key, item, 'WITHSCORES') + if links: # Some items might not have links + neighbors_before[item] = links + + # 3. Remove 100 random items + items_to_remove = set(item for item, _ in random.sample(items, 100)) + # Keep track of top 10 non-removed items + top_remaining = [] + for item, score in items: + if item not in items_to_remove: + top_remaining.append((item, score)) + if len(top_remaining) == 10: + break + + # Remove the items + for item in items_to_remove: + result = self.redis.execute_command('VREM', self.test_key, item) + assert result == 1, f"VREM failed to remove {item}" + + # 4. Do VSIM again with same vector + new_results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', 200, 'WITHSCORES', + 'EF', 500) + + # Convert new results to dict of item -> score + new_scores = {} + for i in range(0, len(new_results), 2): + item = new_results[i].decode() + score = float(new_results[i+1]) + new_scores[item] = score + + failure = False + failed_item = None + failed_reason = None + # 5. Verify all top 10 non-removed items are still found with similar scores + for item, old_score in top_remaining: + if item not in new_scores: + failure = True + failed_item = item + failed_reason = "missing" + break + new_score = new_scores[item] + if abs(new_score - old_score) >= 0.01: + failure = True + failed_item = item + failed_reason = f"score changed: {old_score:.6f} -> {new_score:.6f}" + break + + if failure: + print("\nTest failed!") + print(f"Problem with item: {failed_item} ({failed_reason})") + + print("\nOriginal neighbors (with similarity scores):") + if failed_item in neighbors_before: + print(self.format_neighbors_with_scores( + neighbors_before[failed_item], + items_to_remove=items_to_remove)) + else: + print("No neighbors found in original graph") + + print("\nCurrent neighbors (with similarity scores):") + current_links = self.redis.execute_command('VLINKS', self.test_key, + failed_item, 'WITHSCORES') + if current_links: + print(self.format_neighbors_with_scores( + current_links, + old_links=neighbors_before.get(failed_item))) + else: + print("No neighbors in current graph") + + print("\nOriginal results (top 20):") + for item, score in items[:20]: + deleted = "[deleted]" if item in items_to_remove else "" + print(f"{item}: {score:.6f} {deleted}") + + print("\nNew results after removal (top 20):") + new_items = [] + for i in range(0, len(new_results), 2): + item = new_results[i].decode() + score = float(new_results[i+1]) + new_items.append((item, score)) + new_items.sort(key=lambda x: x[1], reverse=True) + for item, score in new_items[:20]: + print(f"{item}: {score:.6f}") + + raise AssertionError(f"Test failed: Problem with item {failed_item} ({failed_reason}). *** IMPORTANT *** This test may fail from time to time without indicating that there is a bug. However normally it should pass. The fact is that it's a quite extreme test where we destroy 50% of nodes of top results and still expect perfect recall, with vectors that are very hostile because of the distribution used.") + diff --git a/examples/redis-unstable/modules/vector-sets/tests/dimension_validation.py b/examples/redis-unstable/modules/vector-sets/tests/dimension_validation.py new file mode 100644 index 0000000..f081152 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/dimension_validation.py @@ -0,0 +1,67 @@ +from test import TestCase, generate_random_vector +import struct +import redis.exceptions + +class DimensionValidation(TestCase): + def getname(self): + return "[regression] Dimension Validation with Projection" + + def estimated_runtime(self): + return 0.5 + + def test(self): + # Test scenario 1: Create a set with projection + original_dim = 100 + reduced_dim = 50 + + # Create the initial vector and set with projection + vec1 = generate_random_vector(original_dim) + vec1_bytes = struct.pack(f'{original_dim}f', *vec1) + + # Add first vector with projection + result = self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', vec1_bytes, f'{self.test_key}:item:1') + assert result == 1, "First VADD with REDUCE should return 1" + + # Check VINFO returns the correct projection information + info = self.redis.execute_command('VINFO', self.test_key) + info_map = {k.decode('utf-8'): v for k, v in zip(info[::2], info[1::2])} + assert 'vector-dim' in info_map, "VINFO should contain vector-dim" + assert info_map['vector-dim'] == reduced_dim, f"Expected reduced dimension {reduced_dim}, got {info['vector-dim']}" + assert 'projection-input-dim' in info_map, "VINFO should contain projection-input-dim" + assert info_map['projection-input-dim'] == original_dim, f"Expected original dimension {original_dim}, got {info['projection-input-dim']}" + + # Test scenario 2: Try adding a mismatched vector - should fail + wrong_dim = 80 + wrong_vec = generate_random_vector(wrong_dim) + wrong_vec_bytes = struct.pack(f'{wrong_dim}f', *wrong_vec) + + # This should fail with dimension mismatch error + try: + self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', wrong_vec_bytes, f'{self.test_key}:item:2') + assert False, "VADD with wrong dimension should fail" + except redis.exceptions.ResponseError as e: + assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error, got: {e}" + + # Test scenario 3: Add a correctly-sized vector + vec2 = generate_random_vector(original_dim) + vec2_bytes = struct.pack(f'{original_dim}f', *vec2) + + # This should succeed + result = self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', vec2_bytes, f'{self.test_key}:item:3') + assert result == 1, "VADD with correct dimensions should succeed" + + # Check VSIM also validates input dimensions + wrong_query = generate_random_vector(wrong_dim) + try: + self.redis.execute_command('VSIM', self.test_key, + 'VALUES', wrong_dim, *[str(x) for x in wrong_query], + 'COUNT', 10) + assert False, "VSIM with wrong dimension should fail" + except redis.exceptions.ResponseError as e: + assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error in VSIM, got: {e}" diff --git a/examples/redis-unstable/modules/vector-sets/tests/epsilon.py b/examples/redis-unstable/modules/vector-sets/tests/epsilon.py new file mode 100644 index 0000000..97e11c0 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/epsilon.py @@ -0,0 +1,77 @@ +from test import TestCase + +class EpsilonOption(TestCase): + def getname(self): + return "VSIM EPSILON option filtering" + + def estimated_runtime(self): + return 0.1 + + def test(self): + # Add vectors as shown in the example + # Vector 'a' at (1, 1) - normalized to (0.707, 0.707) + result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '1', '1', 'a') + assert result == 1, "VADD should return 1 for item 'a'" + + # Vector 'b' at (0, 1) - normalized to (0, 1) + result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '0', '1', 'b') + assert result == 1, "VADD should return 1 for item 'b'" + + # Vector 'c' at (0, 0) - this will be a zero vector, might be handled specially + result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '0', '0', 'c') + assert result == 1, "VADD should return 1 for item 'c'" + + # Vector 'd' at (0, -1) - normalized to (0, -1) + result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '0', '-1', 'd') + assert result == 1, "VADD should return 1 for item 'd'" + + # Vector 'e' at (-1, -1) - normalized to (-0.707, -0.707) + result = self.redis.execute_command('VADD', self.test_key, 'VALUES', '2', '-1', '-1', 'e') + assert result == 1, "VADD should return 1 for item 'e'" + + # Test without EPSILON - should return all items + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES') + # Result is a flat list: [elem1, score1, elem2, score2, ...] + elements_all = [result[i].decode() for i in range(0, len(result), 2)] + scores_all = [float(result[i]) for i in range(1, len(result), 2)] + + assert len(elements_all) == 5, f"Should return 5 elements without EPSILON, got {len(elements_all)}" + assert elements_all[0] == 'a', "First element should be 'a' (most similar)" + assert scores_all[0] == 1.0, "Score for 'a' should be 1.0 (identical)" + + # Test with EPSILON 0.5 - should return only elements with similarity >= 0.5 (distance < 0.5) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES', 'EPSILON', '0.5') + elements_epsilon_0_5 = [result[i].decode() for i in range(0, len(result), 2)] + scores_epsilon_0_5 = [float(result[i]) for i in range(1, len(result), 2)] + + assert len(elements_epsilon_0_5) == 3, f"With EPSILON 0.5, should return 3 elements, got {len(elements_epsilon_0_5)}" + assert set(elements_epsilon_0_5) == {'a', 'b', 'c'}, f"With EPSILON 0.5, should get a, b, c, got {elements_epsilon_0_5}" + + # Verify all returned scores are >= 0.5 + for i, score in enumerate(scores_epsilon_0_5): + assert score >= 0.5, f"Element {elements_epsilon_0_5[i]} has score {score} which is < 0.5" + + # Test with EPSILON 0.2 - should return only elements with similarity >= 0.8 (distance < 0.2) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES', 'EPSILON', '0.2') + elements_epsilon_0_2 = [result[i].decode() for i in range(0, len(result), 2)] + scores_epsilon_0_2 = [float(result[i]) for i in range(1, len(result), 2)] + + assert len(elements_epsilon_0_2) == 2, f"With EPSILON 0.2, should return 2 elements, got {len(elements_epsilon_0_2)}" + assert set(elements_epsilon_0_2) == {'a', 'b'}, f"With EPSILON 0.2, should get a, b, got {elements_epsilon_0_2}" + + # Verify all returned scores are >= 0.8 (since distance < 0.2 means similarity > 0.8) + for i, score in enumerate(scores_epsilon_0_2): + assert score >= 0.8, f"Element {elements_epsilon_0_2[i]} has score {score} which is < 0.8" + + # Test with very small EPSILON - should return only the exact match + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES', 'EPSILON', '0.001') + elements_epsilon_small = [result[i].decode() for i in range(0, len(result), 2)] + + assert len(elements_epsilon_small) == 1, f"With EPSILON 0.001, should return only 1 element, got {len(elements_epsilon_small)}" + assert elements_epsilon_small[0] == 'a', "With very small EPSILON, should only get 'a'" + + # Test with EPSILON 1.0 - should return all elements (since all similarities are between 0 and 1) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', '2', '1', '1', 'WITHSCORES', 'EPSILON', '1.0') + elements_epsilon_1 = [result[i].decode() for i in range(0, len(result), 2)] + + assert len(elements_epsilon_1) == 5, f"With EPSILON 1.0, should return all 5 elements, got {len(elements_epsilon_1)}" diff --git a/examples/redis-unstable/modules/vector-sets/tests/evict_empty.py b/examples/redis-unstable/modules/vector-sets/tests/evict_empty.py new file mode 100644 index 0000000..6c78c82 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/evict_empty.py @@ -0,0 +1,27 @@ +from test import TestCase, generate_random_vector +import struct + +class VREM_LastItemDeletesKey(TestCase): + def getname(self): + return "VREM last item deletes key" + + def test(self): + # Generate a random vector + vec = generate_random_vector(4) + vec_bytes = struct.pack('4f', *vec) + + # Add the vector to the key + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + assert result == 1, "VADD should return 1 for first item" + + # Verify the key exists + exists = self.redis.exists(self.test_key) + assert exists == 1, "Key should exist after VADD" + + # Remove the item + result = self.redis.execute_command('VREM', self.test_key, f'{self.test_key}:item:1') + assert result == 1, "VREM should return 1 for successful removal" + + # Verify the key no longer exists + exists = self.redis.exists(self.test_key) + assert exists == 0, "Key should no longer exist after VREM of last item" diff --git a/examples/redis-unstable/modules/vector-sets/tests/filter_expr.py b/examples/redis-unstable/modules/vector-sets/tests/filter_expr.py new file mode 100644 index 0000000..364915d --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/filter_expr.py @@ -0,0 +1,242 @@ +from test import TestCase + +class VSIMFilterExpressions(TestCase): + def getname(self): + return "VSIM FILTER expressions basic functionality" + + def test(self): + # Create a small set of vectors with different attributes + + # Basic vectors for testing - all orthogonal for clear results + vec1 = [1, 0, 0, 0] + vec2 = [0, 1, 0, 0] + vec3 = [0, 0, 1, 0] + vec4 = [0, 0, 0, 1] + vec5 = [0.5, 0.5, 0, 0] + + # Add vectors with various attributes + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], f'{self.test_key}:item:1') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:1', + '{"age": 25, "name": "Alice", "active": true, "scores": [85, 90, 95], "city": "New York"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec2], f'{self.test_key}:item:2') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', + '{"age": 30, "name": "Bob", "active": false, "scores": [70, 75, 80], "city": "Boston"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec3], f'{self.test_key}:item:3') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:3', + '{"age": 35, "name": "Charlie", "scores": [60, 65, 70], "city": "Seattle"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec4], f'{self.test_key}:item:4') + # Item 4 has no attribute at all + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec5], f'{self.test_key}:item:5') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:5', + 'invalid json') # Intentionally malformed JSON + + # Basic equality with numbers + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age == 25') + assert len(result) == 1, "Expected 1 result for age == 25" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for age == 25" + + # Greater than + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > 25') + assert len(result) == 2, "Expected 2 results for age > 25" + + # Less than or equal + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age <= 30') + assert len(result) == 2, "Expected 2 results for age <= 30" + + # String equality + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name == "Alice"') + assert len(result) == 1, "Expected 1 result for name == Alice" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for name == Alice" + + # String inequality + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name != "Alice"') + assert len(result) == 2, "Expected 2 results for name != Alice" + + # Boolean value + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.active') + assert len(result) == 1, "Expected 1 result for .active being true" + + # Logical AND + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > 20 and .age < 30') + assert len(result) == 1, "Expected 1 result for 20 < age < 30" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for 20 < age < 30" + + # Logical OR + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age < 30 or .age > 35') + assert len(result) == 1, "Expected 1 result for age < 30 or age > 35" + + # Logical NOT + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '!(.age == 25)') + assert len(result) == 2, "Expected 2 results for NOT(age == 25)" + + # The "in" operator with array + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age in [25, 35]') + assert len(result) == 2, "Expected 2 results for age in [25, 35]" + + # The "in" operator with strings in array + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name in ["Alice", "David"]') + assert len(result) == 1, "Expected 1 result for name in [Alice, David]" + + # The "in" operator for substring matching + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"lic" in .name') + assert len(result) == 1, "Expected 1 result for 'lic' in name" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 (Alice)" + + # The "in" operator with city substring + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"ork" in .city') + assert len(result) == 1, "Expected 1 result for 'ork' in city" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 (New York)" + + # The "in" operator with no matches + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"xyz" in .name') + assert len(result) == 0, "Expected 0 results for 'xyz' in name" + + # Off-by-one tests - substring at the beginning + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"Ali" in .name') + assert len(result) == 1, "Expected 1 result for 'Ali' at beginning of 'Alice'" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1" + + # Off-by-one tests - substring at the end + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"ice" in .name') + assert len(result) == 1, "Expected 1 result for 'ice' at end of 'Alice'" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1" + + # Off-by-one tests - exact match (entire string) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"Alice" in .name') + assert len(result) == 1, "Expected 1 result for exact match 'Alice' in 'Alice'" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1" + + # Off-by-one tests - single character + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"A" in .name') + assert len(result) == 1, "Expected 1 result for single char 'A' in 'Alice'" + + # Off-by-one tests - empty string (should match all strings) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"" in .name') + assert len(result) == 3, "Expected 3 results for empty string (matches all strings)" + + # Off-by-one tests - non-empty strings are never substrings of "" + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name in ""') + assert len(result) == 0, "Expected 0 results for empty string on the right of IN operator" + + # Off-by-one tests - empty string match empty string. + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '"" in .name && "" in ""') + assert len(result) == 3, "Expected empty string matching empty string" + + # Arithmetic operations - addition + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age + 10 > 40') + assert len(result) == 1, "Expected 1 result for age + 10 > 40" + + # Arithmetic operations - multiplication + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age * 2 > 60') + assert len(result) == 1, "Expected 1 result for age * 2 > 60" + + # Arithmetic operations - division + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age / 5 == 5') + assert len(result) == 1, "Expected 1 result for age / 5 == 5" + + # Arithmetic operations - modulo + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age % 2 == 0') + assert len(result) == 1, "Expected 1 result for age % 2 == 0" + + # Power operator + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age ** 2 > 900') + assert len(result) == 1, "Expected 1 result for age^2 > 900" + + # Missing attribute (should exclude items missing that attribute) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.missing_field == "value"') + assert len(result) == 0, "Expected 0 results for missing_field == value" + + # No attribute set at all + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.any_field') + assert f'{self.test_key}:item:4' not in [item.decode() for item in result], "Item with no attribute should be excluded" + + # Malformed JSON + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.any_field') + assert f'{self.test_key}:item:5' not in [item.decode() for item in result], "Item with malformed JSON should be excluded" + + # Complex expression combining multiple operators + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '(.age > 20 and .age < 40) and (.city == "Boston" or .city == "New York")') + assert len(result) == 2, "Expected 2 results for the complex expression" + expected_items = [f'{self.test_key}:item:1', f'{self.test_key}:item:2'] + assert set([item.decode() for item in result]) == set(expected_items), "Expected item:1 and item:2 for the complex expression" + + # Parentheses to control operator precedence + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > (20 + 10)') + assert len(result) == 1, "Expected 1 result for age > (20 + 10)" + + # Array access (arrays evaluate to true) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.scores') + assert len(result) == 3, "Expected 3 results for .scores (arrays evaluate to true)" diff --git a/examples/redis-unstable/modules/vector-sets/tests/filter_int.py b/examples/redis-unstable/modules/vector-sets/tests/filter_int.py new file mode 100644 index 0000000..0fd1dc1 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/filter_int.py @@ -0,0 +1,668 @@ +from test import TestCase, generate_random_vector +import struct +import random +import math +import json +import time + +class VSIMFilterAdvanced(TestCase): + def getname(self): + return "VSIM FILTER comprehensive functionality testing" + + def estimated_runtime(self): + return 15 # This test might take up to 15 seconds for the large dataset + + def setup(self): + super().setup() + self.dim = 32 # Vector dimension + self.count = 5000 # Number of vectors for large tests + self.small_count = 50 # Number of vectors for small/quick tests + + # Categories for attributes + self.categories = ["electronics", "furniture", "clothing", "books", "food"] + self.cities = ["New York", "London", "Tokyo", "Paris", "Berlin", "Sydney", "Toronto", "Singapore"] + self.price_ranges = [(10, 50), (50, 200), (200, 1000), (1000, 5000)] + self.years = list(range(2000, 2025)) + + def create_attributes(self, index): + """Create realistic attributes for a vector""" + category = random.choice(self.categories) + city = random.choice(self.cities) + min_price, max_price = random.choice(self.price_ranges) + price = round(random.uniform(min_price, max_price), 2) + year = random.choice(self.years) + in_stock = random.random() > 0.3 # 70% chance of being in stock + rating = round(random.uniform(1, 5), 1) + views = int(random.expovariate(1/1000)) # Exponential distribution for page views + tags = random.sample(["popular", "sale", "new", "limited", "exclusive", "clearance"], + k=random.randint(0, 3)) + + # Add some specific patterns for testing + # Every 10th item has a specific property combination for testing + is_premium = (index % 10 == 0) + + # Create attributes dictionary + attrs = { + "id": index, + "category": category, + "location": city, + "price": price, + "year": year, + "in_stock": in_stock, + "rating": rating, + "views": views, + "tags": tags + } + + if is_premium: + attrs["is_premium"] = True + attrs["special_features"] = ["premium", "warranty", "support"] + + # Add sub-categories for more complex filters + if category == "electronics": + attrs["subcategory"] = random.choice(["phones", "computers", "cameras", "audio"]) + elif category == "furniture": + attrs["subcategory"] = random.choice(["chairs", "tables", "sofas", "beds"]) + elif category == "clothing": + attrs["subcategory"] = random.choice(["shirts", "pants", "dresses", "shoes"]) + + # Add some intentionally missing fields for testing + if random.random() > 0.9: # 10% chance of missing price + del attrs["price"] + + # Some items have promotion field + if random.random() > 0.7: # 30% chance of having a promotion + attrs["promotion"] = random.choice(["discount", "bundle", "gift"]) + + # Create invalid JSON for a small percentage of vectors + if random.random() > 0.98: # 2% chance of having invalid JSON + return "{{invalid json}}" + + return json.dumps(attrs) + + def create_vectors_with_attributes(self, key, count): + """Create vectors and add attributes to them""" + vectors = [] + names = [] + attribute_map = {} # To store attributes for verification + + # Create vectors + for i in range(count): + vec = generate_random_vector(self.dim) + vectors.append(vec) + name = f"{key}:item:{i}" + names.append(name) + + # Add to Redis + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', key, 'FP32', vec_bytes, name) + + # Create and add attributes + attrs = self.create_attributes(i) + self.redis.execute_command('VSETATTR', key, name, attrs) + + # Store attributes for later verification + try: + attribute_map[name] = json.loads(attrs) if '{' in attrs else None + except json.JSONDecodeError: + attribute_map[name] = None + + return vectors, names, attribute_map + + def filter_linear_search(self, vectors, names, query_vector, filter_expr, attribute_map, k=10): + """Perform a linear search with filtering for verification""" + similarities = [] + query_norm = math.sqrt(sum(x*x for x in query_vector)) + + if query_norm == 0: + return [] + + for i, vec in enumerate(vectors): + name = names[i] + attributes = attribute_map.get(name) + + # Skip if doesn't match filter + if not self.matches_filter(attributes, filter_expr): + continue + + vec_norm = math.sqrt(sum(x*x for x in vec)) + if vec_norm == 0: + continue + + dot_product = sum(a*b for a,b in zip(query_vector, vec)) + cosine_sim = dot_product / (query_norm * vec_norm) + distance = 1.0 - cosine_sim + redis_similarity = 1.0 - (distance/2.0) + similarities.append((name, redis_similarity)) + + similarities.sort(key=lambda x: x[1], reverse=True) + return similarities[:k] + + def matches_filter(self, attributes, filter_expr): + """Filter matching for verification - uses Python eval to handle complex expressions""" + if attributes is None: + return False # No attributes or invalid JSON + + # Replace JSON path selectors with Python dictionary access + py_expr = filter_expr + + # Handle `.field` notation (replace with attributes['field']) + i = 0 + while i < len(py_expr): + if py_expr[i] == '.' and (i == 0 or not py_expr[i-1].isalnum()): + # Find the end of the selector (stops at operators or whitespace) + j = i + 1 + while j < len(py_expr) and (py_expr[j].isalnum() or py_expr[j] == '_'): + j += 1 + + if j > i + 1: # Found a valid selector + field = py_expr[i+1:j] + # Use a safe access pattern that returns a default value based on context + py_expr = py_expr[:i] + f"attributes.get('{field}')" + py_expr[j:] + i = i + len(f"attributes.get('{field}')") + else: + i += 1 + else: + i += 1 + + # Convert not operator if needed + py_expr = py_expr.replace('!', ' not ') + + try: + # Custom evaluation that handles exceptions for missing fields + # by returning False for the entire expression + + # Split the expression on logical operators + parts = [] + for op in [' and ', ' or ']: + if op in py_expr: + parts = py_expr.split(op) + break + + if not parts: # No logical operators found + parts = [py_expr] + + # Try to evaluate each part - if any part fails, + # the whole expression should fail + try: + result = eval(py_expr, {"attributes": attributes}) + return bool(result) + except (TypeError, AttributeError): + # This typically happens when trying to compare None with + # numbers or other types, or when an attribute doesn't exist + return False + except Exception as e: + print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}") + return False + + except Exception as e: + print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}") + return False + + def safe_decode(self,item): + return item.decode() if isinstance(item, bytes) else item + + def calculate_recall(self, redis_results, linear_results, k=10): + """Calculate recall (percentage of correct results retrieved)""" + redis_set = set(self.safe_decode(item) for item in redis_results) + linear_set = set(item[0] for item in linear_results[:k]) + + if not linear_set: + return 1.0 # If no linear results, consider it perfect recall + + intersection = redis_set.intersection(linear_set) + return len(intersection) / len(linear_set) + + def test_recall_with_filter(self, filter_expr, ef=500, filter_ef=None): + """Test recall for a given filter expression""" + # Create query vector + query_vec = generate_random_vector(self.dim) + + # First, get ground truth using linear scan + linear_results = self.filter_linear_search( + self.vectors, self.names, query_vec, filter_expr, self.attribute_map, k=50) + + # Calculate true selectivity from ground truth + true_selectivity = len(linear_results) / len(self.names) if self.names else 0 + + # Perform Redis search with filter + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 50, 'WITHSCORES', 'EF', ef, 'FILTER', filter_expr]) + if filter_ef: + cmd_args.extend(['FILTER-EF', filter_ef]) + + start_time = time.time() + redis_results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + # Convert Redis results to dict + redis_items = {} + for i in range(0, len(redis_results), 2): + key = redis_results[i].decode() if isinstance(redis_results[i], bytes) else redis_results[i] + score = float(redis_results[i+1]) + redis_items[key] = score + + # Calculate metrics + recall = self.calculate_recall(redis_items.keys(), linear_results) + selectivity = len(redis_items) / len(self.names) if redis_items else 0 + + # Compare against the true selectivity from linear scan + assert abs(selectivity - true_selectivity) < 0.1, \ + f"Redis selectivity {selectivity:.3f} differs significantly from ground truth {true_selectivity:.3f}" + + # We expect high recall for standard parameters + if ef >= 500 and (filter_ef is None or filter_ef >= 1000): + try: + assert recall >= 0.7, \ + f"Low recall {recall:.2f} for filter '{filter_expr}'" + except AssertionError as e: + # Get items found in each set + redis_items_set = set(redis_items.keys()) + linear_items_set = set(item[0] for item in linear_results) + + # Find items in each set + only_in_redis = redis_items_set - linear_items_set + only_in_linear = linear_items_set - redis_items_set + in_both = redis_items_set & linear_items_set + + # Build comprehensive debug message + debug = f"\nGround Truth: {len(linear_results)} matching items (total vectors: {len(self.vectors)})" + debug += f"\nRedis Found: {len(redis_items)} items with FILTER-EF: {filter_ef or 'default'}" + debug += f"\nItems in both sets: {len(in_both)} (recall: {recall:.4f})" + debug += f"\nItems only in Redis: {len(only_in_redis)}" + debug += f"\nItems only in Ground Truth: {len(only_in_linear)}" + + # Show some example items from each set with their scores + if only_in_redis: + debug += "\n\nTOP 5 ITEMS ONLY IN REDIS:" + sorted_redis = sorted([(k, v) for k, v in redis_items.items()], key=lambda x: x[1], reverse=True) + for i, (item, score) in enumerate(sorted_redis[:5]): + if item in only_in_redis: + debug += f"\n {i+1}. {item} (Score: {score:.4f})" + + # Show attribute that should match filter + attr = self.attribute_map.get(item) + if attr: + debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}" + + if only_in_linear: + debug += "\n\nTOP 5 ITEMS ONLY IN GROUND TRUTH:" + for i, (item, score) in enumerate(linear_results[:5]): + if item in only_in_linear: + debug += f"\n {i+1}. {item} (Score: {score:.4f})" + + # Show attribute that should match filter + attr = self.attribute_map.get(item) + if attr: + debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}" + + # Help identify parsing issues + debug += "\n\nPARSING CHECK:" + debug += f"\nRedis command: VSIM {self.test_key} VALUES {self.dim} [...] FILTER '{filter_expr}'" + + # Check for WITHSCORES handling issues + if len(redis_results) > 0 and len(redis_results) % 2 == 0: + debug += f"\nRedis returned {len(redis_results)} items (looks like item,score pairs)" + debug += f"\nFirst few results: {redis_results[:4]}" + + # Check the filter implementation + debug += "\n\nFILTER IMPLEMENTATION CHECK:" + debug += f"\nFilter expression: '{filter_expr}'" + debug += "\nSample attribute matches from attribute_map:" + count_matching = 0 + for i, (name, attrs) in enumerate(self.attribute_map.items()): + if attrs and self.matches_filter(attrs, filter_expr): + count_matching += 1 + if i < 3: # Show first 3 matches + debug += f"\n - {name}: {attrs}" + debug += f"\nTotal items matching filter in attribute_map: {count_matching}" + + # Check if results array handling could be wrong + debug += "\n\nRESULT ARRAYS CHECK:" + if len(linear_results) >= 1: + debug += f"\nlinear_results[0]: {linear_results[0]}" + if isinstance(linear_results[0], tuple) and len(linear_results[0]) == 2: + debug += " (correct tuple format: (name, score))" + else: + debug += " (UNEXPECTED FORMAT!)" + + # Debug sort order + debug += "\n\nSORTING CHECK:" + if len(linear_results) >= 2: + debug += f"\nGround truth first item score: {linear_results[0][1]}" + debug += f"\nGround truth second item score: {linear_results[1][1]}" + debug += f"\nCorrectly sorted by similarity? {linear_results[0][1] >= linear_results[1][1]}" + + # Re-raise with detailed information + raise AssertionError(str(e) + debug) + + return recall, selectivity, query_time, len(redis_items) + + def test(self): + print(f"\nRunning comprehensive VSIM FILTER tests...") + + # Create a larger dataset for testing + print(f"Creating dataset with {self.count} vectors and attributes...") + self.vectors, self.names, self.attribute_map = self.create_vectors_with_attributes( + self.test_key, self.count) + + # ==== 1. Recall and Precision Testing ==== + print("Testing recall for various filters...") + + # Test basic filters with different selectivity + results = {} + results["category"] = self.test_recall_with_filter('.category == "electronics"') + results["price_high"] = self.test_recall_with_filter('.price > 1000') + results["in_stock"] = self.test_recall_with_filter('.in_stock') + results["rating"] = self.test_recall_with_filter('.rating >= 4') + results["complex1"] = self.test_recall_with_filter('.category == "electronics" and .price < 500') + + print("Filter | Recall | Selectivity | Time (ms) | Results") + print("----------------------------------------------------") + for name, (recall, selectivity, time_ms, count) in results.items(): + print(f"{name:7} | {recall:.3f} | {selectivity:.3f} | {time_ms*1000:.1f} | {count}") + + # ==== 2. Filter Selectivity Performance ==== + print("\nTesting filter selectivity performance...") + + # High selectivity (very few matches) + high_sel_recall, _, high_sel_time, _ = self.test_recall_with_filter('.is_premium') + + # Medium selectivity + med_sel_recall, _, med_sel_time, _ = self.test_recall_with_filter('.price > 100 and .price < 1000') + + # Low selectivity (many matches) + low_sel_recall, _, low_sel_time, _ = self.test_recall_with_filter('.year > 2000') + + print(f"High selectivity recall: {high_sel_recall:.3f}, time: {high_sel_time*1000:.1f}ms") + print(f"Med selectivity recall: {med_sel_recall:.3f}, time: {med_sel_time*1000:.1f}ms") + print(f"Low selectivity recall: {low_sel_recall:.3f}, time: {low_sel_time*1000:.1f}ms") + + # ==== 3. FILTER-EF Parameter Testing ==== + print("\nTesting FILTER-EF parameter...") + + # Test with different FILTER-EF values + filter_expr = '.category == "electronics" and .price > 200' + ef_values = [100, 500, 2000, 5000] + + print("FILTER-EF | Recall | Time (ms)") + print("-----------------------------") + for filter_ef in ef_values: + recall, _, query_time, _ = self.test_recall_with_filter( + filter_expr, ef=500, filter_ef=filter_ef) + print(f"{filter_ef:9} | {recall:.3f} | {query_time*1000:.1f}") + + # Assert that higher FILTER-EF generally gives better recall + low_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=100) + high_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=5000) + + # This might not always be true due to randomness, but generally holds + # We use a softer assertion to avoid flaky tests + assert high_ef_recall >= low_ef_recall * 0.8, \ + f"Higher FILTER-EF should generally give better recall: {high_ef_recall:.3f} vs {low_ef_recall:.3f}" + + # ==== 4. Complex Filter Expressions ==== + print("\nTesting complex filter expressions...") + + # Test a variety of complex expressions + complex_filters = [ + '.price > 100 and (.category == "electronics" or .category == "furniture")', + '(.rating > 4 and .in_stock) or (.price < 50 and .views > 1000)', + '.category in ["electronics", "clothing"] and .price > 200 and .rating >= 3', + '(.category == "electronics" and .subcategory == "phones") or (.category == "furniture" and .price > 1000)', + '.year > 2010 and !(.price < 100) and .in_stock' + ] + + print("Expression | Results | Time (ms)") + print("-----------------------------") + for i, expr in enumerate(complex_filters): + try: + _, _, query_time, result_count = self.test_recall_with_filter(expr) + print(f"Complex {i+1} | {result_count:7} | {query_time*1000:.1f}") + except Exception as e: + print(f"Complex {i+1} | Error: {str(e)}") + + # ==== 5. Attribute Type Testing ==== + print("\nTesting different attribute types...") + + type_filters = [ + ('.price > 500', "Numeric"), + ('.category == "books"', "String equality"), + ('.in_stock', "Boolean"), + ('.tags in ["sale", "new"]', "Array membership"), + ('.rating * 2 > 8', "Arithmetic") + ] + + for expr, type_name in type_filters: + try: + _, _, query_time, result_count = self.test_recall_with_filter(expr) + print(f"{type_name:16} | {expr:30} | {result_count:5} results | {query_time*1000:.1f}ms") + except Exception as e: + print(f"{type_name:16} | {expr:30} | Error: {str(e)}") + + # ==== 6. Filter + Count Interaction ==== + print("\nTesting COUNT parameter with filters...") + + filter_expr = '.category == "electronics"' + counts = [5, 20, 100] + + for count in counts: + query_vec = generate_random_vector(self.dim) + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', count, 'WITHSCORES', 'FILTER', filter_expr]) + + results = self.redis.execute_command(*cmd_args) + result_count = len(results) // 2 # Divide by 2 because WITHSCORES returns pairs + + # We expect result count to be at most the requested count + assert result_count <= count, f"Got {result_count} results with COUNT {count}" + print(f"COUNT {count:3} | Got {result_count:3} results") + + # ==== 7. Edge Cases ==== + print("\nTesting edge cases...") + + # Test with no matching items + no_match_expr = '.category == "nonexistent_category"' + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', no_match_expr) + assert len(results) == 0, f"Expected 0 results for non-matching filter, got {len(results)}" + print(f"No matching items: {len(results)} results (expected 0)") + + # Test with invalid filter syntax + try: + self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', '.category === "books"') # Triple equals is invalid + assert False, "Expected error for invalid filter syntax" + except: + print("Invalid filter syntax correctly raised an error") + + # Test with extremely long complex expression + long_expr = ' and '.join([f'.rating > {i/10}' for i in range(10)]) + try: + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', long_expr) + print(f"Long expression: {len(results)} results") + except Exception as e: + print(f"Long expression error: {str(e)}") + + print("\nComprehensive VSIM FILTER tests completed successfully") + + +class VSIMFilterSelectivityTest(TestCase): + def getname(self): + return "VSIM FILTER selectivity performance benchmark" + + def estimated_runtime(self): + return 8 # This test might take up to 8 seconds + + def setup(self): + super().setup() + self.dim = 32 + self.count = 10000 + self.test_key = f"{self.test_key}:selectivity" # Use a different key + + def create_vector_with_age_attribute(self, name, age): + """Create a vector with a specific age attribute""" + vec = generate_random_vector(self.dim) + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name) + self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps({"age": age})) + + def test(self): + print("\nRunning VSIM FILTER selectivity benchmark...") + + # Create a dataset where we control the exact selectivity + print(f"Creating controlled dataset with {self.count} vectors...") + + # Create vectors with age attributes from 1 to 100 + for i in range(self.count): + age = (i % 100) + 1 # Ages from 1 to 100 + name = f"{self.test_key}:item:{i}" + self.create_vector_with_age_attribute(name, age) + + # Create a query vector + query_vec = generate_random_vector(self.dim) + + # Test filters with different selectivities + selectivities = [0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.99] + results = [] + + print("\nSelectivity | Filter | Results | Time (ms)") + print("--------------------------------------------------") + + for target_selectivity in selectivities: + # Calculate age threshold for desired selectivity + # For example, age <= 10 gives 10% selectivity + age_threshold = int(target_selectivity * 100) + filter_expr = f'.age <= {age_threshold}' + + # Run query and measure time + start_time = time.time() + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr]) + + results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + actual_selectivity = len(results) / min(100, int(target_selectivity * self.count)) + print(f"{target_selectivity:.2f} | {filter_expr:15} | {len(results):7} | {query_time*1000:.1f}") + + # Add assertion to ensure reasonable performance for different selectivities + # For very selective queries (1%), we might need more exploration + if target_selectivity <= 0.05: + # For very selective queries, ensure we can find some results + assert len(results) > 0, f"No results found for {filter_expr}" + else: + # For less selective queries, performance should be reasonable + assert query_time < 1.0, f"Query too slow: {query_time:.3f}s for {filter_expr}" + + print("\nSelectivity benchmark completed successfully") + + +class VSIMFilterComparisonTest(TestCase): + def getname(self): + return "VSIM FILTER EF parameter comparison" + + def estimated_runtime(self): + return 8 # This test might take up to 8 seconds + + def setup(self): + super().setup() + self.dim = 32 + self.count = 5000 + self.test_key = f"{self.test_key}:efparams" # Use a different key + + def create_dataset(self): + """Create a dataset with specific attribute patterns for testing FILTER-EF""" + vectors = [] + names = [] + + # Create vectors with category and quality score attributes + for i in range(self.count): + vec = generate_random_vector(self.dim) + name = f"{self.test_key}:item:{i}" + + # Add vector to Redis + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name) + + # Create attributes - we want a very selective filter + # Only 2% of items have category=premium AND quality>90 + category = "premium" if random.random() < 0.1 else random.choice(["standard", "economy", "basic"]) + quality = random.randint(1, 100) + + attrs = { + "id": i, + "category": category, + "quality": quality + } + + self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps(attrs)) + vectors.append(vec) + names.append(name) + + return vectors, names + + def test(self): + print("\nRunning VSIM FILTER-EF parameter comparison...") + + # Create dataset + vectors, names = self.create_dataset() + + # Create a selective filter that matches ~2% of items + filter_expr = '.category == "premium" and .quality > 90' + + # Create query vector + query_vec = generate_random_vector(self.dim) + + # Test different FILTER-EF values + ef_values = [50, 100, 500, 1000, 5000] + results = [] + + print("\nFILTER-EF | Results | Time (ms) | Notes") + print("---------------------------------------") + + baseline_count = None + + for ef in ef_values: + # Run query and measure time + start_time = time.time() + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr, 'FILTER-EF', ef]) + + query_results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + # Set baseline for comparison + if baseline_count is None: + baseline_count = len(query_results) + + recall_rate = len(query_results) / max(1, baseline_count) if baseline_count > 0 else 1.0 + + notes = "" + if ef == 5000: + notes = "Baseline" + elif recall_rate < 0.5: + notes = "Low recall!" + + print(f"{ef:9} | {len(query_results):7} | {query_time*1000:.1f} | {notes}") + results.append((ef, len(query_results), query_time)) + + # If we have enough results at highest EF, check that recall improves with higher EF + if results[-1][1] >= 5: # At least 5 results for highest EF + # Extract result counts + result_counts = [r[1] for r in results] + + # The last result (highest EF) should typically find more results than the first (lowest EF) + # but we use a soft assertion to avoid flaky tests + assert result_counts[-1] >= result_counts[0], \ + f"Higher FILTER-EF should find at least as many results: {result_counts[-1]} vs {result_counts[0]}" + + print("\nFILTER-EF parameter comparison completed successfully") diff --git a/examples/redis-unstable/modules/vector-sets/tests/large_scale.py b/examples/redis-unstable/modules/vector-sets/tests/large_scale.py new file mode 100644 index 0000000..eac5dca --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/large_scale.py @@ -0,0 +1,56 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +class LargeScale(TestCase): + def getname(self): + return "Large Scale Comparison" + + def estimated_runtime(self): + return 10 + + def test(self): + dim = 300 + count = 20000 + k = 50 + + # Fill Redis and get reference data for comparison + random.seed(42) # Make test deterministic + data = fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # Generate query vector + query_vec = generate_random_vector(dim) + + # Get results from Redis with good exploration factor + redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', k, 'WITHSCORES', 'EF', 500) + + # Convert Redis results to dict + redis_results = {} + for i in range(0, len(redis_raw), 2): + key = redis_raw[i].decode() + score = float(redis_raw[i+1]) + redis_results[key] = score + + # Get results from linear scan + linear_results = data.find_k_nearest(query_vec, k) + linear_items = {name: score for name, score in linear_results} + + # Compare overlap + redis_set = set(redis_results.keys()) + linear_set = set(linear_items.keys()) + overlap = len(redis_set & linear_set) + + # If test fails, print comparison for debugging + if overlap < k * 0.7: + data.print_comparison({'items': redis_results, 'query_vector': query_vec}, k) + + assert overlap >= k * 0.7, \ + f"Expected at least 70% overlap in top {k} results, got {overlap/k*100:.1f}%" + + # Verify scores for common items + for item in redis_set & linear_set: + redis_score = redis_results[item] + linear_score = linear_items[item] + assert abs(redis_score - linear_score) < 0.01, \ + f"Score mismatch for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}" diff --git a/examples/redis-unstable/modules/vector-sets/tests/memory_usage.py b/examples/redis-unstable/modules/vector-sets/tests/memory_usage.py new file mode 100644 index 0000000..d0f3f09 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/memory_usage.py @@ -0,0 +1,36 @@ +from test import TestCase, generate_random_vector +import struct + +class MemoryUsageTest(TestCase): + def getname(self): + return "[regression] MEMORY USAGE with attributes" + + def test(self): + # Generate random vectors + vec1 = generate_random_vector(4) + vec2 = generate_random_vector(4) + vec_bytes1 = struct.pack('4f', *vec1) + vec_bytes2 = struct.pack('4f', *vec2) + + # Add vectors to the key, one with attribute, one without + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes1, f'{self.test_key}:item:1') + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes2, f'{self.test_key}:item:2', 'SETATTR', '{"color":"red"}') + + # Get memory usage for the key + try: + memory_usage = self.redis.execute_command('MEMORY', 'USAGE', self.test_key) + # If we got here without exception, the command worked + assert memory_usage > 0, "MEMORY USAGE should return a positive value" + + # Add more attributes to increase complexity + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:1', '{"color":"blue","size":10}') + + # Check memory usage again + new_memory_usage = self.redis.execute_command('MEMORY', 'USAGE', self.test_key) + assert new_memory_usage > 0, "MEMORY USAGE should still return a positive value after setting attributes" + + # Memory usage should be higher after adding attributes + assert new_memory_usage > memory_usage, "Memory usage increase after adding attributes" + + except Exception as e: + raise AssertionError(f"MEMORY USAGE command failed: {str(e)}") diff --git a/examples/redis-unstable/modules/vector-sets/tests/node_update.py b/examples/redis-unstable/modules/vector-sets/tests/node_update.py new file mode 100644 index 0000000..53aa2dd --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/node_update.py @@ -0,0 +1,85 @@ +from test import TestCase, generate_random_vector +import struct +import math +import random + +class VectorUpdateAndClusters(TestCase): + def getname(self): + return "VADD vector update with cluster relocation" + + def estimated_runtime(self): + return 2.0 # Should take around 2 seconds + + def generate_cluster_vector(self, base_vec, noise=0.1): + """Generate a vector that's similar to base_vec with some noise.""" + vec = [x + random.gauss(0, noise) for x in base_vec] + # Normalize + norm = math.sqrt(sum(x*x for x in vec)) + return [x/norm for x in vec] + + def test(self): + dim = 128 + vectors_per_cluster = 5000 + + # Create two very different base vectors for our clusters + cluster1_base = generate_random_vector(dim) + cluster2_base = [-x for x in cluster1_base] # Opposite direction + + # Add vectors from first cluster + for i in range(vectors_per_cluster): + vec = self.generate_cluster_vector(cluster1_base) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, + f'{self.test_key}:cluster1:{i}') + + # Add vectors from second cluster + for i in range(vectors_per_cluster): + vec = self.generate_cluster_vector(cluster2_base) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, + f'{self.test_key}:cluster2:{i}') + + # Pick a test vector from cluster1 + test_key = f'{self.test_key}:cluster1:0' + + # Verify it's in cluster1 using VSIM + initial_vec = self.generate_cluster_vector(cluster1_base) + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in initial_vec], + 'COUNT', 100, 'WITHSCORES') + + # Count how many cluster1 items are in top results + cluster1_count = sum(1 for i in range(0, len(results), 2) + if b'cluster1' in results[i]) + assert cluster1_count > 80, "Initial clustering check failed" + + # Now update the test vector to be in cluster2 + new_vec = self.generate_cluster_vector(cluster2_base, noise=0.05) + vec_bytes = struct.pack(f'{dim}f', *new_vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, test_key) + + # Verify the embedding was actually updated using VEMB + emb_result = self.redis.execute_command('VEMB', self.test_key, test_key) + updated_vec = [float(x) for x in emb_result] + + # Verify updated vector matches what we inserted + dot_product = sum(a*b for a,b in zip(updated_vec, new_vec)) + similarity = dot_product / (math.sqrt(sum(x*x for x in updated_vec)) * + math.sqrt(sum(x*x for x in new_vec))) + assert similarity > 0.9, "Vector was not properly updated" + + # Verify it's now in cluster2 using VSIM + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in cluster2_base], + 'COUNT', 100, 'WITHSCORES') + + # Verify our updated vector is among top results + found = False + for i in range(0, len(results), 2): + if results[i].decode() == test_key: + found = True + similarity = float(results[i+1]) + assert similarity > 0.80, f"Updated vector has low similarity: {similarity}" + break + + assert found, "Updated vector not found in cluster2 proximity" diff --git a/examples/redis-unstable/modules/vector-sets/tests/persistence.py b/examples/redis-unstable/modules/vector-sets/tests/persistence.py new file mode 100644 index 0000000..79730f4 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/persistence.py @@ -0,0 +1,86 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +class HNSWPersistence(TestCase): + def getname(self): + return "HNSW Persistence" + + def estimated_runtime(self): + return 30 + + def _verify_results(self, key, dim, query_vec, reduced_dim=None): + """Run a query and return results dict""" + k = 10 + args = ['VSIM', key] + + if reduced_dim: + args.extend(['VALUES', dim]) + args.extend([str(x) for x in query_vec]) + else: + args.extend(['VALUES', dim]) + args.extend([str(x) for x in query_vec]) + + args.extend(['COUNT', k, 'WITHSCORES']) + results = self.redis.execute_command(*args) + + results_dict = {} + for i in range(0, len(results), 2): + key = results[i].decode() + score = float(results[i+1]) + results_dict[key] = score + return results_dict + + def test(self): + # Setup dimensions + dim = 128 + reduced_dim = 32 + count = 5000 + random.seed(42) + + # Create two datasets - one normal and one with dimension reduction + normal_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:normal", count, dim) + projected_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:projected", + count, dim, reduced_dim) + + # Generate query vectors we'll use before and after reload + query_vec_normal = generate_random_vector(dim) + query_vec_projected = generate_random_vector(dim) + + # Get initial results for both sets + initial_normal = self._verify_results(f"{self.test_key}:normal", + dim, query_vec_normal) + initial_projected = self._verify_results(f"{self.test_key}:projected", + dim, query_vec_projected, reduced_dim) + + # Force Redis to save and reload the dataset + self.redis.execute_command('DEBUG', 'RELOAD') + + # Verify results after reload + reloaded_normal = self._verify_results(f"{self.test_key}:normal", + dim, query_vec_normal) + reloaded_projected = self._verify_results(f"{self.test_key}:projected", + dim, query_vec_projected, reduced_dim) + + # Verify normal vectors results + assert len(initial_normal) == len(reloaded_normal), \ + "Normal vectors: Result count mismatch before/after reload" + + for key in initial_normal: + assert key in reloaded_normal, f"Normal vectors: Missing item after reload: {key}" + assert abs(initial_normal[key] - reloaded_normal[key]) < 0.0001, \ + f"Normal vectors: Score mismatch for {key}: " + \ + f"before={initial_normal[key]:.6f}, after={reloaded_normal[key]:.6f}" + + # Verify projected vectors results + assert len(initial_projected) == len(reloaded_projected), \ + "Projected vectors: Result count mismatch before/after reload" + + for key in initial_projected: + assert key in reloaded_projected, \ + f"Projected vectors: Missing item after reload: {key}" + assert abs(initial_projected[key] - reloaded_projected[key]) < 0.0001, \ + f"Projected vectors: Score mismatch for {key}: " + \ + f"before={initial_projected[key]:.6f}, after={reloaded_projected[key]:.6f}" + + self.redis.delete(f"{self.test_key}:normal") + self.redis.delete(f"{self.test_key}:projected") diff --git a/examples/redis-unstable/modules/vector-sets/tests/reduce.py b/examples/redis-unstable/modules/vector-sets/tests/reduce.py new file mode 100644 index 0000000..e39164f --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/reduce.py @@ -0,0 +1,71 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector + +class Reduce(TestCase): + def getname(self): + return "Dimension Reduction" + + def estimated_runtime(self): + return 0.2 + + def test(self): + original_dim = 100 + reduced_dim = 80 + count = 1000 + k = 50 # Number of nearest neighbors to check + + # Fill Redis with vectors using REDUCE and get reference data + data = fill_redis_with_vectors(self.redis, self.test_key, count, original_dim, reduced_dim) + + # Verify dimension is reduced + dim = self.redis.execute_command('VDIM', self.test_key) + assert dim == reduced_dim, f"Expected dimension {reduced_dim}, got {dim}" + + # Generate query vector and get nearest neighbors using Redis + query_vec = generate_random_vector(original_dim) + redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES', + original_dim, *[str(x) for x in query_vec], + 'COUNT', k, 'WITHSCORES') + + # Convert Redis results to dict + redis_results = {} + for i in range(0, len(redis_raw), 2): + key = redis_raw[i].decode() + score = float(redis_raw[i+1]) + redis_results[key] = score + + # Get results from linear scan with original vectors + linear_results = data.find_k_nearest(query_vec, k) + linear_items = {name: score for name, score in linear_results} + + # Compare overlap between reduced and non-reduced results + redis_set = set(redis_results.keys()) + linear_set = set(linear_items.keys()) + overlap = len(redis_set & linear_set) + overlap_ratio = overlap / k + + # With random projection, we expect some loss of accuracy but should + # maintain at least some similarity structure. + # Note that gaussian distribution is the worse with this test, so + # in real world practice, things will be better. + min_expected_overlap = 0.1 # At least 10% overlap in top-k + assert overlap_ratio >= min_expected_overlap, \ + f"Dimension reduction lost too much structure. Only {overlap_ratio*100:.1f}% overlap in top {k}" + + # For items that appear in both results, scores should be reasonably correlated + common_items = redis_set & linear_set + for item in common_items: + redis_score = redis_results[item] + linear_score = linear_items[item] + # Allow for some deviation due to dimensionality reduction + assert abs(redis_score - linear_score) < 0.2, \ + f"Score mismatch too high for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}" + + # If test fails, print comparison for debugging + if overlap_ratio < min_expected_overlap: + print("\nLow overlap in results. Details:") + print("\nTop results from linear scan (original vectors):") + for name, score in linear_results: + print(f"{name}: {score:.3f}") + print("\nTop results from Redis (reduced vectors):") + for item, score in sorted(redis_results.items(), key=lambda x: x[1], reverse=True): + print(f"{item}: {score:.3f}") diff --git a/examples/redis-unstable/modules/vector-sets/tests/replication.py b/examples/redis-unstable/modules/vector-sets/tests/replication.py new file mode 100644 index 0000000..91dfdf7 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/replication.py @@ -0,0 +1,92 @@ +from test import TestCase, generate_random_vector +import struct +import random +import time + +class ComprehensiveReplicationTest(TestCase): + def getname(self): + return "Comprehensive Replication Test with mixed operations" + + def estimated_runtime(self): + # This test will take longer than the default 100ms + return 20.0 # 20 seconds estimate + + def test(self): + # Setup replication between primary and replica + assert self.setup_replication(), "Failed to setup replication" + + # Test parameters + num_vectors = 5000 + vector_dim = 8 + delete_probability = 0.1 + cas_probability = 0.3 + + # Keep track of added items for potential deletion + added_items = [] + + # Add vectors and occasionally delete + for i in range(num_vectors): + # Generate a random vector + vec = generate_random_vector(vector_dim) + vec_bytes = struct.pack(f'{vector_dim}f', *vec) + item_name = f"{self.test_key}:item:{i}" + + # Decide whether to use CAS or not + use_cas = random.random() < cas_probability + + if use_cas and added_items: + # Get an existing item for CAS reference (if available) + cas_item = random.choice(added_items) + try: + # Add with CAS + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, + item_name, 'CAS') + # Only add to our list if actually added (CAS might fail) + if result == 1: + added_items.append(item_name) + except Exception as e: + print(f" CAS VADD failed: {e}") + else: + try: + # Add without CAS + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, item_name) + # Only add to our list if actually added + if result == 1: + added_items.append(item_name) + except Exception as e: + print(f" VADD failed: {e}") + + # Randomly delete items (with 10% probability) + if random.random() < delete_probability and added_items: + try: + # Select a random item to delete + item_to_delete = random.choice(added_items) + # Delete the item using VREM (not VDEL) + self.redis.execute_command('VREM', self.test_key, item_to_delete) + # Remove from our list + added_items.remove(item_to_delete) + except Exception as e: + print(f" VREM failed: {e}") + + # Allow time for replication to complete + time.sleep(2.0) + + # Verify final VCARD matches + primary_card = self.redis.execute_command('VCARD', self.test_key) + replica_card = self.replica.execute_command('VCARD', self.test_key) + assert primary_card == replica_card, f"Final VCARD mismatch: primary={primary_card}, replica={replica_card}" + + # Verify VDIM matches + primary_dim = self.redis.execute_command('VDIM', self.test_key) + replica_dim = self.replica.execute_command('VDIM', self.test_key) + assert primary_dim == replica_dim, f"VDIM mismatch: primary={primary_dim}, replica={replica_dim}" + + # Verify digests match using DEBUG DIGEST + primary_digest = self.redis.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + replica_digest = self.replica.execute_command('DEBUG', 'DIGEST-VALUE', self.test_key) + assert primary_digest == replica_digest, f"Digest mismatch: primary={primary_digest}, replica={replica_digest}" + + # Print summary + print(f"\n Added and maintained {len(added_items)} vectors with dimension {vector_dim}") + print(f" Final vector count: {primary_card}") + print(f" Final digest: {primary_digest[0].decode()}") diff --git a/examples/redis-unstable/modules/vector-sets/tests/threading_config.py b/examples/redis-unstable/modules/vector-sets/tests/threading_config.py new file mode 100644 index 0000000..dfc931a --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/threading_config.py @@ -0,0 +1,249 @@ +from test import TestCase, generate_random_vector +import struct + + +class ThreadingConfigTest(TestCase): + """ + Test suite for vset-force-single-threaded-execution configuration. + + This test validates the behavior of VADD and VSIM commands under different + threading configurations. The new configuration is MUTABLE and BINARY: + - false (0): Multi-threaded execution enabled (default) + - true (1): Force single-threaded execution + + Key behaviors tested: + - VADD with and without CAS option under both threading modes + - VSIM with and without NOTHREAD option under both threading modes + - Configuration reading, validation, and runtime modification + - Thread behavior switching (multi-threaded vs forced single-threaded) + """ + + def getname(self): + return "vset-force-single-threaded-execution configuration testing" + + def estimated_runtime(self): + return 0.5 # Updated for mutable config testing with mode switching + + def get_config_value(self): + """Get current vset-force-single-threaded-execution config value""" + try: + result = self.redis.execute_command('CONFIG', 'GET', 'vset-force-single-threaded-execution') + if len(result) >= 2: + # Redis returns 'yes'/'no' for boolean configs + return result[1].decode() if isinstance(result[1], bytes) else result[1] + return None + except Exception: + return None + + def set_config_value(self, value): + """Set vset-force-single-threaded-execution config value""" + try: + # Convert boolean to yes/no string + str_value = 'yes' if value else 'no' + result = self.redis.execute_command('CONFIG', 'SET', 'vset-force-single-threaded-execution', str_value) + return result == b'OK' or result == 'OK' + except Exception as e: + print(f"Failed to set config: {e}") + return False + + def test_config_access_and_mutability(self): + """Test 1: Configuration access and mutability""" + # Get initial value + initial_value = self.get_config_value() + assert initial_value is not None, "Should be able to read vset-force-single-threaded-execution config" + assert initial_value in ['yes', 'no'], f"Config value should be yes/no, got {initial_value}" + + # Test mutability by toggling the value + new_value = 'no' if initial_value == 'yes' else 'yes' + assert self.set_config_value(new_value == 'yes'), "Should be able to change config value" + + # Verify the change + current_value = self.get_config_value() + assert current_value == new_value, f"Config should be {new_value}, got {current_value}" + + # Restore original value + assert self.set_config_value(initial_value == 'yes'), "Should be able to restore original value" + + return initial_value == 'yes' + + def test_vadd_without_cas(self, force_single_threaded=False): + """Test 2: VADD command without CAS option""" + # Set threading mode + self.set_config_value(force_single_threaded) + + # Clear test data to avoid dimension conflicts + self.redis.delete(self.test_key) + + dim = 64 + vec = generate_random_vector(dim) + vec_bytes = struct.pack(f'{dim}f', *vec) + + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + assert result == 1, f"VADD should return 1 for new item, got {result}" + + # Verify the vector was added + card = self.redis.execute_command('VCARD', self.test_key) + assert card == 1, f"VCARD should return 1, got {card}" + + def test_vadd_with_cas(self, force_single_threaded=False): + """Test 3: VADD command with CAS option""" + # Set threading mode + self.set_config_value(force_single_threaded) + + # Clear test data to avoid dimension conflicts + self.redis.delete(self.test_key) + + dim = 64 + vec = generate_random_vector(dim) + vec_bytes = struct.pack(f'{dim}f', *vec) + + # First insertion with CAS should succeed + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:cas', 'CAS') + assert result == 1, f"First VADD with CAS should return 1, got {result}" + + # Second insertion of same item with CAS should return 0 + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:cas', 'CAS') + assert result == 0, f"Duplicate VADD with CAS should return 0, got {result}" + + def test_vsim_without_nothread(self, force_single_threaded=False): + """Test 4: VSIM command without NOTHREAD""" + # Set threading mode + self.set_config_value(force_single_threaded) + + # Clear test data to avoid dimension conflicts + self.redis.delete(self.test_key) + + dim = 64 + + # Add test vectors + for i in range(5): + vec = generate_random_vector(dim) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:{i}') + + # Test VSIM without NOTHREAD + query_vec = generate_random_vector(dim) + args = ['VSIM', self.test_key, 'VALUES', dim] + [str(x) for x in query_vec] + ['COUNT', 3] + result = self.redis.execute_command(*args) + + assert isinstance(result, list), f"VSIM should return a list, got {type(result)}" + assert len(result) <= 3, f"VSIM should return at most 3 results, got {len(result)}" + + def test_vsim_with_nothread(self, force_single_threaded=False): + """Test 5: VSIM command with NOTHREAD""" + # Set threading mode + self.set_config_value(force_single_threaded) + + dim = 64 + + # Ensure we have vectors to search (use existing vectors from previous test) + card = self.redis.execute_command('VCARD', self.test_key) + if card == 0: + # Add test vectors if none exist + for i in range(5): + vec = generate_random_vector(dim) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:{i}') + + # Test VSIM with NOTHREAD + query_vec = generate_random_vector(dim) + args = ['VSIM', self.test_key, 'VALUES', dim] + [str(x) for x in query_vec] + ['COUNT', 3, 'NOTHREAD'] + result = self.redis.execute_command(*args) + + assert isinstance(result, list), f"VSIM with NOTHREAD should return a list, got {type(result)}" + assert len(result) <= 3, f"VSIM with NOTHREAD should return at most 3 results, got {len(result)}" + + def test_threading_mode_comparison(self): + """Test 6: Compare behavior between threading modes""" + dim = 64 + + # Clear test data + self.redis.delete(self.test_key) + + # Test multi-threaded mode (default) + self.set_config_value(False) # Multi-threaded + self.test_vadd_without_cas(False) + self.test_vadd_with_cas(False) + multi_threaded_card = self.redis.execute_command('VCARD', self.test_key) + + # Clear and test single-threaded mode + self.redis.delete(self.test_key) + self.set_config_value(True) # Single-threaded + self.test_vadd_without_cas(True) + self.test_vadd_with_cas(True) + single_threaded_card = self.redis.execute_command('VCARD', self.test_key) + + # Both modes should produce same results + assert multi_threaded_card == single_threaded_card, \ + f"Both modes should produce same results: multi={multi_threaded_card}, single={single_threaded_card}" + + def test_nothread_override_behavior(self): + """Test 7: NOTHREAD option should work regardless of config""" + dim = 64 + + # Test with both config modes + for force_single in [False, True]: + self.set_config_value(force_single) + self.redis.delete(self.test_key) + + # Add test vectors + for i in range(3): + vec = generate_random_vector(dim) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:{i}') + + # NOTHREAD should work regardless of config + query_vec = generate_random_vector(dim) + args = ['VSIM', self.test_key, 'VALUES', dim] + [str(x) for x in query_vec] + ['COUNT', 2, 'NOTHREAD'] + result = self.redis.execute_command(*args) + + assert isinstance(result, list), f"NOTHREAD should work with force_single={force_single}" + assert len(result) <= 2, f"NOTHREAD should return ≤2 results with force_single={force_single}" + + def test(self): + """Main test method - runs all threading configuration tests""" + # Get initial configuration + initial_force_single = self.test_config_access_and_mutability() + print(f"Initial vset-force-single-threaded-execution: {'yes' if initial_force_single else 'no'}") + + # Clear test data + self.redis.delete(self.test_key) + + # Test both threading modes + print("Testing multi-threaded mode...") + self.set_config_value(False) + self.test_vadd_without_cas(False) + self.test_vadd_with_cas(False) + self.test_vsim_without_nothread(False) + self.test_vsim_with_nothread(False) + + print("Testing single-threaded mode...") + self.set_config_value(True) + self.test_vadd_without_cas(True) + self.test_vadd_with_cas(True) + self.test_vsim_without_nothread(True) + self.test_vsim_with_nothread(True) + + # Test mode comparison and NOTHREAD override + self.test_threading_mode_comparison() + self.test_nothread_override_behavior() + + # Restore initial configuration + self.set_config_value(initial_force_single) + + # Print summary + self._print_test_summary(initial_force_single) + + def _print_test_summary(self, initial_force_single): + """Print a summary of what was tested""" + print(f"\nThreading Configuration Test Summary:") + print(f" Configuration: vset-force-single-threaded-execution") + print(f" Type: Boolean, Mutable") + print(f" Initial value: {'yes' if initial_force_single else 'no'}") + print(f" Tested modes: Both multi-threaded (no) and single-threaded (yes)") + print(f" VADD: Works correctly in both modes") + print(f" VADD with CAS: Works correctly in both modes") + print(f" VSIM: Works correctly in both modes") + print(f" NOTHREAD option: Overrides config in both modes") + print(f" Configuration mutability: ✅ Successfully changed at runtime") + print(f" All tests passed successfully!") diff --git a/examples/redis-unstable/modules/vector-sets/tests/vadd_cas.py b/examples/redis-unstable/modules/vector-sets/tests/vadd_cas.py new file mode 100644 index 0000000..3cb3508 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/vadd_cas.py @@ -0,0 +1,98 @@ +from test import TestCase, generate_random_vector +import threading +import struct +import math +import time +import random +from typing import List, Dict + +class ConcurrentCASTest(TestCase): + def getname(self): + return "Concurrent VADD with CAS" + + def estimated_runtime(self): + return 1.5 + + def worker(self, vectors: List[List[float]], start_idx: int, end_idx: int, + dim: int, results: Dict[str, bool]): + """Worker thread that adds a subset of vectors using VADD CAS""" + for i in range(start_idx, end_idx): + vec = vectors[i] + name = f"{self.test_key}:item:{i}" + vec_bytes = struct.pack(f'{dim}f', *vec) + + # Try to add the vector with CAS + try: + result = self.redis.execute_command('VADD', self.test_key, 'FP32', + vec_bytes, name, 'CAS') + results[name] = (result == 1) # Store if it was actually added + except Exception as e: + results[name] = False + print(f"Error adding {name}: {e}") + + def verify_vector_similarity(self, vec1: List[float], vec2: List[float]) -> float: + """Calculate cosine similarity between two vectors""" + dot_product = sum(a*b for a,b in zip(vec1, vec2)) + norm1 = math.sqrt(sum(x*x for x in vec1)) + norm2 = math.sqrt(sum(x*x for x in vec2)) + return dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0 + + def test(self): + # Test parameters + dim = 128 + total_vectors = 5000 + num_threads = 8 + vectors_per_thread = total_vectors // num_threads + + # Generate all vectors upfront + random.seed(42) # For reproducibility + vectors = [generate_random_vector(dim) for _ in range(total_vectors)] + + # Prepare threads and results dictionary + threads = [] + results = {} # Will store success/failure for each vector + + # Launch threads + for i in range(num_threads): + start_idx = i * vectors_per_thread + end_idx = start_idx + vectors_per_thread if i < num_threads-1 else total_vectors + thread = threading.Thread(target=self.worker, + args=(vectors, start_idx, end_idx, dim, results)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify cardinality + card = self.redis.execute_command('VCARD', self.test_key) + assert card == total_vectors, \ + f"Expected {total_vectors} elements, but found {card}" + + # Verify each vector + num_verified = 0 + for i in range(total_vectors): + name = f"{self.test_key}:item:{i}" + + # Verify the item was successfully added + assert results[name], f"Vector {name} was not successfully added" + + # Get the stored vector + stored_vec_raw = self.redis.execute_command('VEMB', self.test_key, name) + stored_vec = [float(x) for x in stored_vec_raw] + + # Verify vector dimensions + assert len(stored_vec) == dim, \ + f"Stored vector dimension mismatch for {name}: {len(stored_vec)} != {dim}" + + # Calculate similarity with original vector + similarity = self.verify_vector_similarity(vectors[i], stored_vec) + assert similarity > 0.99, \ + f"Low similarity ({similarity}) for {name}" + + num_verified += 1 + + # Final verification + assert num_verified == total_vectors, \ + f"Only verified {num_verified} out of {total_vectors} vectors" diff --git a/examples/redis-unstable/modules/vector-sets/tests/vemb.py b/examples/redis-unstable/modules/vector-sets/tests/vemb.py new file mode 100644 index 0000000..0f4cf77 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/vemb.py @@ -0,0 +1,41 @@ +from test import TestCase +import struct +import math + +class VEMB(TestCase): + def getname(self): + return "VEMB Command" + + def test(self): + dim = 4 + + # Add same vector in both formats + vec = [1, 0, 0, 0] + norm = math.sqrt(sum(x*x for x in vec)) + vec = [x/norm for x in vec] # Normalize the vector + + # Add using FP32 + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + + # Add using VALUES + self.redis.execute_command('VADD', self.test_key, 'VALUES', dim, + *[str(x) for x in vec], f'{self.test_key}:item:2') + + # Get both back with VEMB + result1 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:1') + result2 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:2') + + retrieved_vec1 = [float(x) for x in result1] + retrieved_vec2 = [float(x) for x in result2] + + # Compare both vectors with original (allow for small quantization errors) + for i in range(dim): + assert abs(vec[i] - retrieved_vec1[i]) < 0.01, \ + f"FP32 vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec1[i]}" + assert abs(vec[i] - retrieved_vec2[i]) < 0.01, \ + f"VALUES vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec2[i]}" + + # Test non-existent item + result = self.redis.execute_command('VEMB', self.test_key, 'nonexistent') + assert result is None, "Non-existent item should return nil" diff --git a/examples/redis-unstable/modules/vector-sets/tests/vismember.py b/examples/redis-unstable/modules/vector-sets/tests/vismember.py new file mode 100644 index 0000000..eabebca --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/vismember.py @@ -0,0 +1,47 @@ +from test import TestCase, generate_random_vector +import struct + +class BasicVISMEMBER(TestCase): + def getname(self): + return "VISMEMBER basic functionality" + + def test(self): + # Add multiple vectors to the vector set + vec1 = generate_random_vector(4) + vec2 = generate_random_vector(4) + vec_bytes1 = struct.pack('4f', *vec1) + vec_bytes2 = struct.pack('4f', *vec2) + + # Create item keys + item1 = f'{self.test_key}:item:1' + item2 = f'{self.test_key}:item:2' + nonexistent_item = f'{self.test_key}:item:nonexistent' + + # Add the vectors + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes1, item1) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes2, item2) + + # Test VISMEMBER with existing elements + result1 = self.redis.execute_command('VISMEMBER', self.test_key, item1) + assert result1 == 1, f"VISMEMBER should return 1 for existing item, got {result1}" + + result2 = self.redis.execute_command('VISMEMBER', self.test_key, item2) + assert result2 == 1, f"VISMEMBER should return 1 for existing item, got {result2}" + + # Test VISMEMBER with non-existent element + result3 = self.redis.execute_command('VISMEMBER', self.test_key, nonexistent_item) + assert result3 == 0, f"VISMEMBER should return 0 for non-existent item, got {result3}" + + # Test VISMEMBER with non-existent key + nonexistent_key = f'{self.test_key}_nonexistent' + result4 = self.redis.execute_command('VISMEMBER', nonexistent_key, item1) + assert result4 == 0, f"VISMEMBER should return 0 for non-existent key, got {result4}" + + # Test VISMEMBER after removing an element + self.redis.execute_command('VREM', self.test_key, item1) + result5 = self.redis.execute_command('VISMEMBER', self.test_key, item1) + assert result5 == 0, f"VISMEMBER should return 0 after element removal, got {result5}" + + # Verify item2 still exists + result6 = self.redis.execute_command('VISMEMBER', self.test_key, item2) + assert result6 == 1, f"VISMEMBER should still return 1 for remaining item, got {result6}" diff --git a/examples/redis-unstable/modules/vector-sets/tests/vrand-ping-pong.py b/examples/redis-unstable/modules/vector-sets/tests/vrand-ping-pong.py new file mode 100644 index 0000000..99d2e9a --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/vrand-ping-pong.py @@ -0,0 +1,35 @@ +from test import TestCase, generate_random_vector +import struct + +class VRANDMEMBERPingPongRegressionTest(TestCase): + def getname(self): + return "[regression] VRANDMEMBER ping-pong" + + def test(self): + """ + This test ensures that when only two vectors exist, VRANDMEMBER + does not get stuck returning only one of them due to the "ping-pong" issue. + """ + self.redis.delete(self.test_key) # Clean up before test + dim = 4 + + # Add exactly two vectors + vec1_name = "vec1" + vec1_data = generate_random_vector(dim) + self.redis.execute_command('VADD', self.test_key, 'VALUES', dim, *vec1_data, vec1_name) + + vec2_name = "vec2" + vec2_data = generate_random_vector(dim) + self.redis.execute_command('VADD', self.test_key, 'VALUES', dim, *vec2_data, vec2_name) + + # Call VRANDMEMBER many times and check for distribution + iterations = 100 + results = [] + for _ in range(iterations): + member = self.redis.execute_command('VRANDMEMBER', self.test_key) + results.append(member.decode()) + + # Verify that both members were returned, proving it's not stuck + unique_results = set(results) + + assert len(unique_results) == 2, f"Ping-pong test failed: should have returned 2 unique members, but got {len(unique_results)}." diff --git a/examples/redis-unstable/modules/vector-sets/tests/vrandmember.py b/examples/redis-unstable/modules/vector-sets/tests/vrandmember.py new file mode 100644 index 0000000..ca9e006 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/vrandmember.py @@ -0,0 +1,55 @@ +from test import TestCase, generate_random_vector, fill_redis_with_vectors +import struct + +class VRANDMEMBERTest(TestCase): + def getname(self): + return "VRANDMEMBER basic functionality" + + def test(self): + # Test with empty key + result = self.redis.execute_command('VRANDMEMBER', self.test_key) + assert result is None, "VRANDMEMBER on non-existent key should return NULL" + + result = self.redis.execute_command('VRANDMEMBER', self.test_key, 5) + assert isinstance(result, list) and len(result) == 0, "VRANDMEMBER with count on non-existent key should return empty array" + + # Fill with vectors + dim = 4 + count = 100 + data = fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # Test single random member + result = self.redis.execute_command('VRANDMEMBER', self.test_key) + assert result is not None, "VRANDMEMBER should return a random member" + assert result.decode() in data.names, "Random member should be in the set" + + # Test multiple unique members (positive count) + positive_count = 10 + result = self.redis.execute_command('VRANDMEMBER', self.test_key, positive_count) + assert isinstance(result, list), "VRANDMEMBER with positive count should return an array" + assert len(result) == positive_count, f"Should return {positive_count} members" + + # Check for uniqueness + decoded_results = [r.decode() for r in result] + assert len(decoded_results) == len(set(decoded_results)), "Results should be unique with positive count" + for item in decoded_results: + assert item in data.names, "All returned items should be in the set" + + # Test more members than in the set + result = self.redis.execute_command('VRANDMEMBER', self.test_key, count + 10) + assert len(result) == count, "Should return only the available members when asking for more than exist" + + # Test with duplicates (negative count) + negative_count = -20 + result = self.redis.execute_command('VRANDMEMBER', self.test_key, negative_count) + assert isinstance(result, list), "VRANDMEMBER with negative count should return an array" + assert len(result) == abs(negative_count), f"Should return {abs(negative_count)} members" + + # Check that all returned elements are valid + decoded_results = [r.decode() for r in result] + for item in decoded_results: + assert item in data.names, "All returned items should be in the set" + + # Test with count = 0 (edge case) + result = self.redis.execute_command('VRANDMEMBER', self.test_key, 0) + assert isinstance(result, list) and len(result) == 0, "VRANDMEMBER with count=0 should return empty array" diff --git a/examples/redis-unstable/modules/vector-sets/tests/vrange.py b/examples/redis-unstable/modules/vector-sets/tests/vrange.py new file mode 100644 index 0000000..7e57588 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/vrange.py @@ -0,0 +1,113 @@ +from test import TestCase, generate_random_vector +import struct + +class BasicVRANGE(TestCase): + def getname(self): + return "VRANGE basic functionality and iteration" + + def test(self): + # Add multiple elements with different names for lexicographical ordering + elements = [ + "apple", "apricot", "banana", "cherry", "date", + "elderberry", "fig", "grape", "honeydew", "kiwi", + "lemon", "mango", "nectarine", "orange", "papaya", + "quince", "raspberry", "strawberry", "tangerine", "watermelon" + ] + + # Add all elements to the vector set + for elem in elements: + vec = generate_random_vector(4) + vec_bytes = struct.pack('4f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, elem) + + # Test 1: Basic range with inclusive boundaries + result = self.redis.execute_command('VRANGE', self.test_key, '[apple', '[grape', '5') + result = [r.decode() for r in result] + assert result == ['apple', 'apricot', 'banana', 'cherry', 'date'], f"Expected first 5 elements from apple, got {result}" + + # Test 2: Exclusive start boundary + result = self.redis.execute_command('VRANGE', self.test_key, '(apple', '[cherry', '10') + result = [r.decode() for r in result] + assert result == ['apricot', 'banana', 'cherry'], f"Expected elements after apple up to cherry inclusive, got {result}" + + # Test 3: Exclusive end boundary + result = self.redis.execute_command('VRANGE', self.test_key, '[banana', '(cherry', '10') + result = [r.decode() for r in result] + assert result == ['banana'], f"Expected only banana (cherry excluded), got {result}" + + # Test 4: Using '-' for minimum element + result = self.redis.execute_command('VRANGE', self.test_key, '-', '[banana', '10') + result = [r.decode() for r in result] + assert result[0] == 'apple', "Should start from the first element" + assert result[-1] == 'banana', "Should end at banana" + + # Test 5: Using '+' for maximum element + result = self.redis.execute_command('VRANGE', self.test_key, '[raspberry', '+', '10') + result = [r.decode() for r in result] + assert 'raspberry' in result and 'strawberry' in result and 'tangerine' in result and 'watermelon' in result, "Should include all elements from raspberry onwards" + + # Test 6: Full range with '-' and '+' + result = self.redis.execute_command('VRANGE', self.test_key, '-', '+', '100') + result = [r.decode() for r in result] + assert len(result) == len(elements), f"Should return all {len(elements)} elements" + assert result == sorted(elements), "Elements should be in lexicographical order" + + # Test 7: Iterator pattern - verify each element appears exactly once + seen = set() + batch_size = 3 + current = '-' + + while True: + if current == '-': + # First iteration + result = self.redis.execute_command('VRANGE', self.test_key, '-', '+', str(batch_size)) + else: + # Subsequent iterations - exclusive start from last element + result = self.redis.execute_command('VRANGE', self.test_key, f'({current}', '+', str(batch_size)) + + result = [r.decode() for r in result] + + if not result: + break + + # Check no duplicates in this batch + for elem in result: + assert elem not in seen, f"Element {elem} appeared more than once" + seen.add(elem) + + # Update current to last element + current = result[-1] + + # Break if we got less than requested (end of set) + if len(result) < batch_size: + break + + # Verify we saw all elements exactly once + assert seen == set(elements), f"Iterator should visit all elements exactly once. Missing: {set(elements) - seen}, Extra: {seen - set(elements)}" + + # Test 8: Count of 0 returns empty array + result = self.redis.execute_command('VRANGE', self.test_key, '-', '+', '0') + assert result == [], f"Count of 0 should return empty array, got {result}" + + # Test 9: Range with no matching elements + result = self.redis.execute_command('VRANGE', self.test_key, '[zebra', '+', '10') + assert result == [], f"Range beyond all elements should return empty array, got {result}" + + # Test 10: Non-existent key + result = self.redis.execute_command('VRANGE', 'nonexistent_key', '-', '+', '10') + assert result == [], f"Non-existent key should return empty array, got {result}" + + # Test 11: Partial word boundaries + result = self.redis.execute_command('VRANGE', self.test_key, '[app', '[apr', '10') + result = [r.decode() for r in result] + assert 'apple' in result, "Should include 'apple' which starts with 'app'" + assert 'apricot' not in result, "Should not include 'apricot' as it's >= 'apr'" + + # Test 12: Single element range + result = self.redis.execute_command('VRANGE', self.test_key, '[cherry', '[cherry', '10') + result = [r.decode() for r in result] + assert result == ['cherry'], f"Inclusive single element range should return that element, got {result}" + + # Test 13: Empty range (start > end) + result = self.redis.execute_command('VRANGE', self.test_key, '[grape', '[apple', '10') + assert result == [], f"Range where start > end should return empty array, got {result}" diff --git a/examples/redis-unstable/modules/vector-sets/tests/vsim_limit_efsearch.py b/examples/redis-unstable/modules/vector-sets/tests/vsim_limit_efsearch.py new file mode 100644 index 0000000..25b9689 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/vsim_limit_efsearch.py @@ -0,0 +1,32 @@ +from test import TestCase, generate_random_vector +import struct + +class VSIMLimitEFSearch(TestCase): + def getname(self): + return "VSIM Limit EF Search" + + def estimated_runtime(self): + return 0.2 + + def test(self): + dim = 32 + vec = generate_random_vector(dim) + vec_bytes = struct.pack(f'{dim}f', *vec) + + # Add test vector + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + + query_vec = generate_random_vector(dim) + + # Test EF upper bound (should accept 1000000) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], 'EF', 1000000) + assert isinstance(result, list), "EF=1000000 should be accepted" + + # Test EF over limit (should reject > 1000000) + try: + self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], 'EF', 1000001) + assert False, "EF=1000001 should be rejected" + except Exception as e: + assert "invalid EF" in str(e), f"Expected EF validation error, got: {e}" diff --git a/examples/redis-unstable/modules/vector-sets/tests/with.py b/examples/redis-unstable/modules/vector-sets/tests/with.py new file mode 100644 index 0000000..d14a23f --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/with.py @@ -0,0 +1,214 @@ +from test import TestCase, generate_random_vector +import struct +import json +import random + +class VSIMWithAttribs(TestCase): + def getname(self): + return "VSIM WITHATTRIBS/WITHSCORES functionality testing" + + def setup(self): + super().setup() + self.dim = 8 + self.count = 20 + + # Create vectors with attributes + for i in range(self.count): + vec = generate_random_vector(self.dim) + vec_bytes = struct.pack(f'{self.dim}f', *vec) + + # Item name + name = f"{self.test_key}:item:{i}" + + # Add to Redis + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name) + + # Create and add attribute + if i % 5 == 0: + # Every 5th item has no attribute (for testing NULL responses) + continue + + category = random.choice(["electronics", "furniture", "clothing"]) + price = random.randint(50, 1000) + attrs = {"category": category, "price": price, "id": i} + + self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps(attrs)) + + def is_numeric(self, value): + """Check if a value can be converted to float""" + try: + if isinstance(value, (int, float)): + return True + if isinstance(value, bytes): + float(value.decode('utf-8')) + return True + if isinstance(value, str): + float(value) + return True + return False + except (ValueError, TypeError): + return False + + def test(self): + # Create query vector + query_vec = generate_random_vector(self.dim) + + # Test 1: VSIM with no additional options (should be same for RESP2 and RESP3) + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 5]) + + results_resp2 = self.redis.execute_command(*cmd_args) + results_resp3 = self.redis3.execute_command(*cmd_args) + + # Both should return simple arrays of item names + assert len(results_resp2) == 5, f"RESP2: Expected 5 results, got {len(results_resp2)}" + assert len(results_resp3) == 5, f"RESP3: Expected 5 results, got {len(results_resp3)}" + assert all(isinstance(item, bytes) for item in results_resp2), "RESP2: Results should be byte strings" + assert all(isinstance(item, bytes) for item in results_resp3), "RESP3: Results should be byte strings" + + # Test 2: VSIM with WITHSCORES only + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 5, 'WITHSCORES']) + + results_resp2 = self.redis.execute_command(*cmd_args) + results_resp3 = self.redis3.execute_command(*cmd_args) + + # RESP2: Should be a flat array alternating item, score + assert len(results_resp2) == 10, f"RESP2: Expected 10 elements (5 items × 2), got {len(results_resp2)}" + for i in range(0, len(results_resp2), 2): + assert isinstance(results_resp2[i], bytes), f"RESP2: Item at {i} should be bytes" + assert self.is_numeric(results_resp2[i+1]), f"RESP2: Score at {i+1} should be numeric" + score = float(results_resp2[i+1]) if isinstance(results_resp2[i+1], bytes) else results_resp2[i+1] + assert 0 <= score <= 1, f"RESP2: Score {score} should be between 0 and 1" + + # RESP3: Should be a dict/map with items as keys and scores as DIRECT values (not arrays) + assert isinstance(results_resp3, dict), f"RESP3: Expected dict, got {type(results_resp3)}" + assert len(results_resp3) == 5, f"RESP3: Expected 5 entries, got {len(results_resp3)}" + for item, score in results_resp3.items(): + assert isinstance(item, bytes), f"RESP3: Key should be bytes" + # Score should be a direct value, NOT an array + assert not isinstance(score, list), f"RESP3: With single WITH option, value should not be array" + assert self.is_numeric(score), f"RESP3: Score should be numeric, got {type(score)}" + score_val = float(score) if isinstance(score, bytes) else score + assert 0 <= score_val <= 1, f"RESP3: Score {score_val} should be between 0 and 1" + + # Test 3: VSIM with WITHATTRIBS only + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 5, 'WITHATTRIBS']) + + results_resp2 = self.redis.execute_command(*cmd_args) + results_resp3 = self.redis3.execute_command(*cmd_args) + + # RESP2: Should be a flat array alternating item, attribute + assert len(results_resp2) == 10, f"RESP2: Expected 10 elements (5 items × 2), got {len(results_resp2)}" + for i in range(0, len(results_resp2), 2): + assert isinstance(results_resp2[i], bytes), f"RESP2: Item at {i} should be bytes" + attr = results_resp2[i+1] + assert attr is None or isinstance(attr, bytes), f"RESP2: Attribute at {i+1} should be None or bytes" + if attr is not None: + # Verify it's valid JSON + json.loads(attr) + + # RESP3: Should be a dict/map with items as keys and attributes as DIRECT values (not arrays) + assert isinstance(results_resp3, dict), f"RESP3: Expected dict, got {type(results_resp3)}" + assert len(results_resp3) == 5, f"RESP3: Expected 5 entries, got {len(results_resp3)}" + for item, attr in results_resp3.items(): + assert isinstance(item, bytes), f"RESP3: Key should be bytes" + # Attribute should be a direct value, NOT an array + assert not isinstance(attr, list), f"RESP3: With single WITH option, value should not be array" + assert attr is None or isinstance(attr, bytes), f"RESP3: Attribute should be None or bytes" + if attr is not None: + # Verify it's valid JSON + json.loads(attr) + + # Test 4: VSIM with both WITHSCORES and WITHATTRIBS + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 5, 'WITHSCORES', 'WITHATTRIBS']) + + results_resp2 = self.redis.execute_command(*cmd_args) + results_resp3 = self.redis3.execute_command(*cmd_args) + + # RESP2: Should be a flat array with pattern: item, score, attribute + assert len(results_resp2) == 15, f"RESP2: Expected 15 elements (5 items × 3), got {len(results_resp2)}" + for i in range(0, len(results_resp2), 3): + assert isinstance(results_resp2[i], bytes), f"RESP2: Item at {i} should be bytes" + assert self.is_numeric(results_resp2[i+1]), f"RESP2: Score at {i+1} should be numeric" + score = float(results_resp2[i+1]) if isinstance(results_resp2[i+1], bytes) else results_resp2[i+1] + assert 0 <= score <= 1, f"RESP2: Score {score} should be between 0 and 1" + attr = results_resp2[i+2] + assert attr is None or isinstance(attr, bytes), f"RESP2: Attribute at {i+2} should be None or bytes" + + # RESP3: Should be a dict where each value is a 2-element array [score, attribute] + assert isinstance(results_resp3, dict), f"RESP3: Expected dict, got {type(results_resp3)}" + assert len(results_resp3) == 5, f"RESP3: Expected 5 entries, got {len(results_resp3)}" + for item, value in results_resp3.items(): + assert isinstance(item, bytes), f"RESP3: Key should be bytes" + # With BOTH options, value MUST be an array + assert isinstance(value, list), f"RESP3: With both WITH options, value should be a list, got {type(value)}" + assert len(value) == 2, f"RESP3: Value should have 2 elements [score, attr], got {len(value)}" + + score, attr = value + assert self.is_numeric(score), f"RESP3: Score should be numeric" + score_val = float(score) if isinstance(score, bytes) else score + assert 0 <= score_val <= 1, f"RESP3: Score {score_val} should be between 0 and 1" + assert attr is None or isinstance(attr, bytes), f"RESP3: Attribute should be None or bytes" + + # Test 5: Verify consistency - same items returned in same order + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 5, 'WITHSCORES', 'WITHATTRIBS']) + + results_resp2 = self.redis.execute_command(*cmd_args) + results_resp3 = self.redis3.execute_command(*cmd_args) + + # Extract items from RESP2 (every 3rd element starting from 0) + items_resp2 = [results_resp2[i] for i in range(0, len(results_resp2), 3)] + + # Extract items from RESP3 (keys of the dict) + items_resp3 = list(results_resp3.keys()) + + # Verify same items returned + assert set(items_resp2) == set(items_resp3), "RESP2 and RESP3 should return the same items" + + # Build a mapping from items to scores and attributes for comparison + data_resp2 = {} + for i in range(0, len(results_resp2), 3): + item = results_resp2[i] + score = float(results_resp2[i+1]) if isinstance(results_resp2[i+1], bytes) else results_resp2[i+1] + attr = results_resp2[i+2] + data_resp2[item] = (score, attr) + + data_resp3 = {} + for item, value in results_resp3.items(): + score = float(value[0]) if isinstance(value[0], bytes) else value[0] + attr = value[1] + data_resp3[item] = (score, attr) + + # Verify scores and attributes match for each item + for item in data_resp2: + score_resp2, attr_resp2 = data_resp2[item] + score_resp3, attr_resp3 = data_resp3[item] + + assert abs(score_resp2 - score_resp3) < 0.0001, \ + f"Scores for {item} don't match: RESP2={score_resp2}, RESP3={score_resp3}" + assert attr_resp2 == attr_resp3, \ + f"Attributes for {item} don't match: RESP2={attr_resp2}, RESP3={attr_resp3}" + + # Test 6: Test ordering of WITHSCORES and WITHATTRIBS doesn't matter + cmd_args1 = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args1.extend([str(x) for x in query_vec]) + cmd_args1.extend(['COUNT', 3, 'WITHSCORES', 'WITHATTRIBS']) + + cmd_args2 = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args2.extend([str(x) for x in query_vec]) + cmd_args2.extend(['COUNT', 3, 'WITHATTRIBS', 'WITHSCORES']) # Reversed order + + results1_resp3 = self.redis3.execute_command(*cmd_args1) + results2_resp3 = self.redis3.execute_command(*cmd_args2) + + # Both should return the same structure + assert results1_resp3 == results2_resp3, "Order of WITH options shouldn't matter" diff --git a/examples/redis-unstable/modules/vector-sets/vset.c b/examples/redis-unstable/modules/vector-sets/vset.c new file mode 100644 index 0000000..500f8e9 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/vset.c @@ -0,0 +1,2587 @@ +/* Redis implementation for vector sets. The data structure itself + * is implemented in hnsw.c. + * + * Copyright (c) 2009-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of (a) the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + * Originally authored by: Salvatore Sanfilippo. + * + * ======================== Understand threading model ========================= + * This code implements threaded operarations for two of the commands: + * + * 1. VSIM, by default. + * 2. VADD, if the CAS option is specified. + * + * Note that even if the second operation, VADD, is a write operation, only + * the neighbors collection for the new node is performed in a thread: then, + * the actual insert is performed in the reply callback VADD_CASReply(), + * which is executed in the main thread. + * + * Threaded operations need us to protect various operations with mutexes, + * even if a certain degree of protection is already provided by the HNSW + * library. Here are a few very important things about this implementation + * and the way locking is performed. + * + * 1. All the write operations are performed in the main Redis thread: + * this also include VADD_CASReply() callback, that is called by Redis + * internals only in the context of the main thread. However the HNSW + * library allows background threads in hnsw_search() (VSIM) to modify + * nodes metadata to speedup search (to understand if a node was already + * visited), but this only happens after acquiring a specific lock + * for a given "read slot". + * + * 2. We use a global lock for each Vector Set object, called "in_use". This + * lock is a read-write lock, and is acquired in read mode by all the + * threads that perform reads in the background. It is only acquired in + * write mode by vectorSetWaitAllBackgroundClients(): the function acquires + * the lock and immediately releases it, with the effect of waiting all the + * background threads still running from ending their execution. + * + * Note that no thread can be spawned, since we only call + * vectorSetWaitAllBackgroundClients() from the main Redis thread, that + * is also the only thread spawning other threads. + * + * vectorSetWaitAllBackgroundClients() is used in two ways: + * A) When we need to delete a vector set because of (DEL) or other + * operations destroying the object, we need to wait that all the + * background threads working with this object finished their work. + * B) When we modify the HNSW nodes bypassing the normal locking + * provided by the HNSW library. This only happens when we update + * an existing node attribute so far, in VSETATTR and when we call + * VADD to update a node with the SETATTR option. + * + * 3. Often during read operations performed by Redis commands in the + * main thread (VCARD, VEMB, VRANDMEMBER, ...) we don't acquire any + * lock at all. The commands run in the main Redis thread, we can only + * have, at the same time, background reads against the same data + * structure. Note that VSIM_thread() and VADD_thread() still modify the + * read slot metadata, that is node->visited_epoch[slot], but as long as + * our read commands running in the main thread don't need to use + * hnsw_search() or other HNSW functions using the visited epochs slots + * we are safe. + * + * 4. There is a race from the moment we create a thread, passing the + * vector set object, to the moment the thread can actually lock the + * result win the in_use_lock mutex: as the thread starts, in the meanwhile + * a DEL/expire could trigger and remove the object. For this reason + * we use an atomic counter that protects our object for this small + * time in vectorSetWaitAllBackgroundClients(). This prevents removal + * of objects that are about to be taken by threads. + * + * Note that other competing solutions could be used to fix the problem + * but have their set of issues, however they are worth documenting here + * and evaluating in the future: + * + * A. Using a conditional variable we could "wait" for the thread to + * acquire the lock. However this means waiting before returning + * to the event loop, and would make the command execution slower. + * B. We could use again an atomic variable, like we did, but this time + * as a refcount for the object, with a vsetAcquire() vsetRelease(). + * In this case, the command could retain the object in the main thread + * before starting the thread, and the thread, after the work is done, + * could release it. This way sometimes the object would be freed by + * the thread, and it's while now can be safe to do the kind of resource + * deallocation that vectorSetReleaseObject() does, given that the + * Redis Modules API is not always thread safe this solution may not + * be future-proof. However there is to evaluate it better in the + * future. + * C. We could use the "B" solution but instead of freeing the object + * in the thread, in this specific case we could just put it into a + * list and defer it for later freeing (for instance in the reply + * callback), so that the object is always freed in the main thread. + * This would require a list of objects to free. + * + * However the current solution only disadvantage is the potential busy + * loop, but this busy loop in practical terms will almost never do + * much: to trigger it, a number of circumnstances must happen: deleting + * Vector Set keys while using them, hitting the small window needed to + * start the thread and read-lock the mutex. + */ + +#define _DEFAULT_SOURCE +#define _USE_MATH_DEFINES +#define _POSIX_C_SOURCE 200809L + +#include "../../src/redismodule.h" +#include <stdio.h> +#include <stdlib.h> +#include <ctype.h> +#include <string.h> +#include <strings.h> +#include <stdint.h> +#include <math.h> +#include <pthread.h> +#include <stdatomic.h> +#include "hnsw.h" +#include "vset_config.h" + +// We inline directly the expression implementation here so that building +// the module is trivial. +#include "expr.c" + +static RedisModuleType *VectorSetType; +static uint64_t VectorSetTypeNextId = 0; + +// Default EF value if not specified during creation. +#define VSET_DEFAULT_C_EF 200 + +// Default EF value if not specified during search. +#define VSET_DEFAULT_SEARCH_EF 100 + +// Default num elements returned by VSIM. +#define VSET_DEFAULT_COUNT 10 + +/* ========================== Internal data structure ====================== */ + +/* Our abstract data type needs a dual representation similar to Redis + * sorted set: the proximity graph, and also a element -> graph-node map + * that will allow us to perform deletions and other operations that have + * as input the element itself. */ +struct vsetObject { + HNSW *hnsw; // Proximity graph. + RedisModuleDict *dict; // Element -> node mapping. + float *proj_matrix; // Random projection matrix, NULL if no projection + uint32_t proj_input_size; // Input dimension after projection. + // Output dimension is implicit in + // hnsw->vector_dim. + pthread_rwlock_t in_use_lock; // Lock needed to destroy the object safely. + uint64_t id; // Unique ID used by threaded VADD to know the + // object is still the same. + uint64_t numattribs; // Number of nodes associated with an attribute. + atomic_int thread_creation_pending; // Number of threads that are currently + // pending to lock the object. +}; + +/* Each node has two associated values: the associated string (the item + * in the set) and potentially a JSON string, that is, the attributes, used + * for hybrid search with the VSIM FILTER option. */ +struct vsetNodeVal { + RedisModuleString *item; + RedisModuleString *attrib; +}; + +/* Count the number of set bits in an integer (population count/Hamming weight). + * This is a portable implementation that doesn't rely on compiler + * extensions. */ +static inline uint32_t bit_count(uint32_t n) { + uint32_t count = 0; + while (n) { + count += n & 1; + n >>= 1; + } + return count; +} + +/* Create a Hadamard-based projection matrix for dimensionality reduction. + * Uses {-1, +1} entries with a pattern based on bit operations. + * The pattern is matrix[i][j] = (i & j) % 2 == 0 ? 1 : -1 + * Matrix is scaled by 1/sqrt(input_dim) for normalization. + * Returns NULL on allocation failure. + * + * Note that compared to other approaches (random gaussian weights), what + * we have here is deterministic, it means that our replicas will have + * the same set of weights. Also this approach seems to work much better + * in practice, and the distances between elements are better guaranteed. + * + * Note that we still save the projection matrix in the RDB file, because + * in the future we may change the weights generation, and we want everything + * to be backward compatible. */ +float *createProjectionMatrix(uint32_t input_dim, uint32_t output_dim) { + float *matrix = RedisModule_Alloc(sizeof(float) * input_dim * output_dim); + + /* Scale factor to normalize the projection. */ + const float scale = 1.0f / sqrt(input_dim); + + /* Fill the matrix using Hadamard pattern. */ + for (uint32_t i = 0; i < output_dim; i++) { + for (uint32_t j = 0; j < input_dim; j++) { + /* Calculate position in the flattened matrix. */ + uint32_t pos = i * input_dim + j; + + /* Hadamard pattern: use bit operations to determine sign + * If the count of 1-bits in the bitwise AND of i and j is even, + * the value is 1, otherwise -1. */ + int value = (bit_count(i & j) % 2 == 0) ? 1 : -1; + + /* Store the scaled value. */ + matrix[pos] = value * scale; + } + } + return matrix; +} + +/* Apply random projection to input vector. Returns new allocated vector. */ +float *applyProjection(const float *input, const float *proj_matrix, + uint32_t input_dim, uint32_t output_dim) +{ + float *output = RedisModule_Alloc(sizeof(float) * output_dim); + + for (uint32_t i = 0; i < output_dim; i++) { + const float *row = &proj_matrix[i * input_dim]; + float sum = 0.0f; + for (uint32_t j = 0; j < input_dim; j++) { + sum += row[j] * input[j]; + } + output[i] = sum; + } + return output; +} + +/* Create the vector as HNSW+Dictionary combined data structure. */ +struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type, uint32_t hnsw_M) { + struct vsetObject *o; + o = RedisModule_Alloc(sizeof(*o)); + + o->id = VectorSetTypeNextId++; + o->hnsw = hnsw_new(dim,quant_type,hnsw_M); + if (!o->hnsw) { // May fail because of mutex creation. + RedisModule_Free(o); + return NULL; + } + + o->dict = RedisModule_CreateDict(NULL); + o->proj_matrix = NULL; + o->proj_input_size = 0; + o->numattribs = 0; + o->thread_creation_pending = 0; + RedisModule_Assert(pthread_rwlock_init(&o->in_use_lock,NULL) == 0); + return o; +} + +void vectorSetReleaseNodeValue(void *v) { + struct vsetNodeVal *nv = v; + RedisModule_FreeString(NULL,nv->item); + if (nv->attrib) RedisModule_FreeString(NULL,nv->attrib); + RedisModule_Free(nv); +} + +/* Free the vector set object. */ +void vectorSetReleaseObject(struct vsetObject *o) { + if (!o) return; + if (o->hnsw) hnsw_free(o->hnsw,vectorSetReleaseNodeValue); + if (o->dict) RedisModule_FreeDict(NULL,o->dict); + if (o->proj_matrix) RedisModule_Free(o->proj_matrix); + pthread_rwlock_destroy(&o->in_use_lock); + RedisModule_Free(o); +} + +/* Wait for all the threads performing operations on this + * index to terminate their work (locking for write will + * wait for all the other threads). + * + * if 'for_del' is set to 1, we also wait for all the pending threads + * that still didn't acquire the lock to finish their work. This + * is useful only if we are going to call this function to delete + * the object, and not if we want to just to modify it. */ +void vectorSetWaitAllBackgroundClients(struct vsetObject *vset, int for_del) { + if (for_del) { + // If we are going to destroy the object, after this call, let's + // wait for threads that are being created and still didn't had + // a chance to acquire the lock. + while (vset->thread_creation_pending > 0); + } + RedisModule_Assert(pthread_rwlock_wrlock(&vset->in_use_lock) == 0); + pthread_rwlock_unlock(&vset->in_use_lock); +} + +/* Return a string representing the quantization type name of a vector set. */ +const char *vectorSetGetQuantName(struct vsetObject *o) { + switch(o->hnsw->quant_type) { + case HNSW_QUANT_NONE: return "f32"; + case HNSW_QUANT_Q8: return "int8"; + case HNSW_QUANT_BIN: return "bin"; + default: return "unknown"; + } +} + +/* Insert the specified element into the Vector Set. + * If update is '1', the existing node will be updated. + * + * Returns 1 if the element was added, or 0 if the element was already there + * and was just updated. */ +int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, RedisModuleString *attrib, int update, int ef) +{ + hnswNode *node = RedisModule_DictGet(o->dict,val,NULL); + if (node != NULL) { + if (update) { + /* Wait for clients in the background: background VSIM + * operations touch the nodes attributes we are going + * to touch. */ + vectorSetWaitAllBackgroundClients(o,0); + + struct vsetNodeVal *nv = node->value; + /* Pass NULL as value-free function. We want to reuse + * the old value. */ + hnsw_delete_node(o->hnsw, node, NULL); + node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef); + RedisModule_Assert(node != NULL); + RedisModule_DictReplace(o->dict,val,node); + + /* If attrib != NULL, the user wants that in case of an update we + * update the attribute as well (otherwise it remains as it was). + * Note that the order of operations is conceinved so that it + * works in case the old attrib and the new attrib pointer is the + * same. */ + if (attrib) { + // Empty attribute string means: unset the attribute during + // the update. + size_t attrlen; + RedisModule_StringPtrLen(attrib,&attrlen); + if (attrlen != 0) { + RedisModule_RetainString(NULL,attrib); + o->numattribs++; + } else { + attrib = NULL; + } + + if (nv->attrib) { + o->numattribs--; + RedisModule_FreeString(NULL,nv->attrib); + } + nv->attrib = attrib; + } + } + return 0; + } + + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = val; + nv->attrib = attrib; + node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef); + if (node == NULL) { + // XXX Technically in Redis-land we don't have out of memory, as we + // crash on OOM. However the HNSW library may fail for error in the + // locking libc call. Probably impossible in practical terms. + RedisModule_Free(nv); + return 0; + } + if (attrib != NULL) o->numattribs++; + RedisModule_DictSet(o->dict,val,node); + RedisModule_RetainString(NULL,val); + if (attrib) RedisModule_RetainString(NULL,attrib); + return 1; +} + +/* Parse vector from FP32 blob or VALUES format, with optional REDUCE. + * Format: [REDUCE dim] FP32|VALUES ... + * Returns allocated vector and sets dimension in *dim. + * If reduce_dim is not NULL, sets it to the requested reduction dimension. + * Returns NULL on parsing error. + * + * The function sets as a reference *consumed_args, so that the caller + * knows how many arguments we consumed in order to parse the input + * vector. Remaining arguments are often command options. */ +float *parseVector(RedisModuleString **argv, int argc, int start_idx, + size_t *dim, uint32_t *reduce_dim, int *consumed_args) +{ + int consumed = 0; // Arguments consumed + + /* Check for REDUCE option first. */ + if (reduce_dim) *reduce_dim = 0; + if (reduce_dim && argc > start_idx + 2 && + !strcasecmp(RedisModule_StringPtrLen(argv[start_idx],NULL),"REDUCE")) + { + long long rdim; + if (RedisModule_StringToLongLong(argv[start_idx+1],&rdim) + != REDISMODULE_OK || rdim <= 0) + { + return NULL; + } + if (reduce_dim) *reduce_dim = rdim; + start_idx += 2; // Skip REDUCE and its argument. + consumed += 2; + } + + /* Now parse the vector format as before. */ + float *vec = NULL; + const char *vec_format = RedisModule_StringPtrLen(argv[start_idx],NULL); + + if (!strcasecmp(vec_format,"FP32")) { + if (argc < start_idx + 2) return NULL; // Need FP32 + vector + value. + size_t vec_raw_len; + const char *blob = + RedisModule_StringPtrLen(argv[start_idx+1],&vec_raw_len); + + // Must be 4 bytes per component. + if (vec_raw_len % 4 || vec_raw_len < 4) return NULL; + *dim = vec_raw_len/4; + + vec = RedisModule_Alloc(vec_raw_len); + if (!vec) return NULL; + memcpy(vec,blob,vec_raw_len); + consumed += 2; + } else if (!strcasecmp(vec_format,"VALUES")) { + if (argc < start_idx + 2) return NULL; // Need at least the dimension. + long long vdim; // Vector dimension passed by the user. + if (RedisModule_StringToLongLong(argv[start_idx+1],&vdim) + != REDISMODULE_OK || vdim < 1) return NULL; + + // Check that all the arguments are available. + if (argc < start_idx + 2 + vdim) return NULL; + + *dim = vdim; + vec = RedisModule_Alloc(sizeof(float) * vdim); + if (!vec) return NULL; + + for (int j = 0; j < vdim; j++) { + double val; + if (RedisModule_StringToDouble(argv[start_idx+2+j],&val) + != REDISMODULE_OK) + { + RedisModule_Free(vec); + return NULL; + } + vec[j] = val; + } + consumed += vdim + 2; + } else { + return NULL; // Unknown format. + } + + if (consumed_args) *consumed_args = consumed; + return vec; +} + +/* ========================== Commands implementation ======================= */ + +/* VADD thread handling the "CAS" version of the command, that is + * performed blocking the client, accumulating here, in the thread, the + * set of potential candidates, and later inserting the element in the + * key (if it still exists, and if it is still the *same* vector set) + * in the Reply callback. */ +void *VADD_thread(void *arg) { + pthread_detach(pthread_self()); + + void **targ = (void**)arg; + RedisModuleBlockedClient *bc = targ[0]; + struct vsetObject *vset = targ[1]; + float *vec = targ[3]; + int ef = (uint64_t)targ[6]; + + /* Lock the object and signal that we are no longer pending + * the lock acquisition. */ + RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0); + vset->thread_creation_pending--; + + /* Look for candidates... */ + InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, ef); + targ[5] = ic; // Pass the context to the reply callback. + + /* Unblock the client so that our read reply will be invoked. */ + pthread_rwlock_unlock(&vset->in_use_lock); + RedisModule_BlockedClientMeasureTimeEnd(bc); + RedisModule_UnblockClient(bc,targ); // Use targ as privdata. + return NULL; +} + +/* Reply callback for CAS variant of VADD. + * Note: this is called in the main thread, in the background thread + * we just do the read operation of gathering the neighbors. */ +int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + (void)argc; + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + int retval = REDISMODULE_OK; + void **targ = (void**)RedisModule_GetBlockedClientPrivateData(ctx); + uint64_t vset_id = (unsigned long) targ[2]; + float *vec = targ[3]; + RedisModuleString *val = targ[4]; + InsertContext *ic = targ[5]; + int ef = (uint64_t)targ[6]; + RedisModuleString *attrib = targ[7]; + RedisModule_Free(targ); + + /* Open the key: there are no guarantees it still exists, or contains + * a vector set, or even the SAME vector set. */ + RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + struct vsetObject *vset = NULL; + + if (type != REDISMODULE_KEYTYPE_EMPTY && + RedisModule_ModuleTypeGetType(key) == VectorSetType) + { + vset = RedisModule_ModuleTypeGetValue(key); + // Same vector set? + if (vset->id != vset_id) vset = NULL; + + /* Also, if the element was already inserted, we just pretend + * the other insert won. We don't even start a threaded VADD + * if this was an update, since the deletion of the element itself + * in order to perform the update would invalidate the CAS state. */ + if (vset && RedisModule_DictGet(vset->dict,val,NULL) != NULL) + vset = NULL; + } + + if (vset == NULL) { + /* If the object does not match the start of the operation, we + * just pretend the VADD was performed BEFORE the key was deleted + * or replaced. We return success but don't do anything. */ + hnsw_free_insert_context(ic); + } else { + /* Otherwise try to insert the new element with the neighbors + * collected in background. If we fail, do it synchronously again + * from scratch. */ + + // First: allocate the dual-ported value for the node. + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = val; + nv->attrib = attrib; + + /* Then: insert the node in the HNSW data structure. Note that + * 'ic' could be NULL in case hnsw_prepare_insert() failed because of + * locking failure (likely impossible in practical terms). */ + hnswNode *newnode; + if (ic == NULL || + (newnode = hnsw_try_commit_insert(vset->hnsw, ic, nv)) == NULL) + { + /* If we are here, the CAS insert failed. We need to insert + * again with full locking for neighbors selection and + * actual insertion. This time we can't fail: */ + newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, nv, ef); + RedisModule_Assert(newnode != NULL); + } + RedisModule_DictSet(vset->dict,val,newnode); + val = NULL; // Don't free it later. + attrib = NULL; // Don't free it later. + + RedisModule_ReplicateVerbatim(ctx); + } + + // Whatever happens is a success... :D + RedisModule_ReplyWithBool(ctx,1); + if (val) RedisModule_FreeString(ctx,val); // Not added? Free it. + if (attrib) RedisModule_FreeString(ctx,attrib); // Not added? Free it. + RedisModule_Free(vec); + return retval; +} + +/* VADD key [REDUCE dim] FP32|VALUES vector value [CAS] [NOQUANT] [BIN] [Q8] + * [M count] */ +int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + if (argc < 5) return RedisModule_WrongArity(ctx); + + /* Parse vector with optional REDUCE */ + size_t dim = 0; + uint32_t reduce_dim = 0; + int consumed_args; + int cas = 0; // Threaded check-and-set style insert. + long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes. + long long hnsw_create_M = HNSW_DEFAULT_M; // HNSW creation default M value. + float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args); + RedisModuleString *attrib = NULL; // Attributes if passed via ATTRIB. + if (!vec) + return RedisModule_ReplyWithError(ctx,"ERR invalid vector specification"); + + /* Missing element string at the end? */ + if (argc-2-consumed_args < 1) { + RedisModule_Free(vec); + return RedisModule_WrongArity(ctx); + } + + /* Parse options after the element string. */ + uint32_t quant_type = HNSW_QUANT_Q8; // Default quantization type. + + for (int j = 2 + consumed_args + 1; j < argc; j++) { + const char *opt = RedisModule_StringPtrLen(argv[j], NULL); + if (!strcasecmp(opt, "CAS")) { + cas = 1; + } else if (!strcasecmp(opt, "EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &ef) + != REDISMODULE_OK || ef <= 0 || ef > 1000000) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); + } + j++; // skip argument. + } else if (!strcasecmp(opt, "M") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &hnsw_create_M) + != REDISMODULE_OK || hnsw_create_M < HNSW_MIN_M || + hnsw_create_M > HNSW_MAX_M) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid M"); + } + j++; // skip argument. + } else if (!strcasecmp(opt, "SETATTR") && j+1 < argc) { + attrib = argv[j+1]; + j++; // skip argument. + } else if (!strcasecmp(opt, "NOQUANT")) { + quant_type = HNSW_QUANT_NONE; + } else if (!strcasecmp(opt, "BIN")) { + quant_type = HNSW_QUANT_BIN; + } else if (!strcasecmp(opt, "Q8")) { + quant_type = HNSW_QUANT_Q8; + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx,"ERR invalid option after element"); + } + } + + /* Drop CAS if this is a replica and we are getting the command from the + * replication link: we want to add/delete items in the same order as + * the master, while with CAS the timing would be different. + * + * Also for Lua scripts and MULTI/EXEC, we want to run the command + * on the main thread. */ + if (RedisModule_GetContextFlags(ctx) & + (REDISMODULE_CTX_FLAGS_REPLICATED| + REDISMODULE_CTX_FLAGS_LUA| + REDISMODULE_CTX_FLAGS_MULTI)) + { + cas = 0; + } + + if (VSGlobalConfig.forceSingleThreadExec) { + cas = 0; + } + + /* Open/create key */ + RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + if (type != REDISMODULE_KEYTYPE_EMPTY && + RedisModule_ModuleTypeGetType(key) != VectorSetType) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx,REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get the correct value argument based on format and REDUCE */ + RedisModuleString *val = argv[2 + consumed_args]; + + /* Create or get existing vector set */ + struct vsetObject *vset; + if (type == REDISMODULE_KEYTYPE_EMPTY) { + cas = 0; /* Do synchronous insert at creation, otherwise the + * key would be left empty until the threaded part + * does not return. It's also pointless to try try + * doing threaded first element insertion. */ + vset = createVectorSetObject(reduce_dim ? reduce_dim : dim, quant_type, hnsw_create_M); + if (vset == NULL) { + // We can't fail for OOM in Redis, but the mutex initialization + // at least theoretically COULD fail. Likely this code path + // is not reachable in practical terms. + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR unable to create a Vector Set: system resources issue?"); + } + + /* Initialize projection if requested */ + if (reduce_dim) { + vset->proj_matrix = createProjectionMatrix(dim, reduce_dim); + vset->proj_input_size = dim; + + /* Project the vector */ + float *projected = applyProjection(vec, vset->proj_matrix, + dim, reduce_dim); + RedisModule_Free(vec); + vec = projected; + } + RedisModule_ModuleTypeSetValue(key,VectorSetType,vset); + } else { + vset = RedisModule_ModuleTypeGetValue(key); + + if (vset->hnsw->quant_type != quant_type) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR asked quantization mismatch with existing vector set"); + } + + if (vset->hnsw->M != hnsw_create_M) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR asked M value mismatch with existing vector set"); + } + + if ((vset->proj_matrix == NULL && vset->hnsw->vector_dim != dim) || + (vset->proj_matrix && vset->hnsw->vector_dim != reduce_dim)) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Vector dimension mismatch - got %d but set has %d", + (int)dim, (int)vset->hnsw->vector_dim); + } + + /* Check REDUCE compatibility */ + if (reduce_dim) { + if (!vset->proj_matrix) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR cannot add projection to existing set without projection"); + } + if (reduce_dim != vset->hnsw->vector_dim) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR projection dimension mismatch with existing set"); + } + } + + /* Apply projection if needed */ + if (vset->proj_matrix) { + /* Ensure input dimension matches the projection matrix's expected input dimension */ + if (dim != vset->proj_input_size) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Input dimension mismatch for projection - got %d but projection expects %d", + (int)dim, (int)vset->proj_input_size); + } + + float *projected = applyProjection(vec, vset->proj_matrix, + vset->proj_input_size, + vset->hnsw->vector_dim); + RedisModule_Free(vec); + vec = projected; + dim = vset->hnsw->vector_dim; + } + } + + /* For existing keys don't do CAS updates. For how things work now, the + * CAS state would be invalidated by the deletion before adding back. */ + if (cas && RedisModule_DictGet(vset->dict,val,NULL) != NULL) + cas = 0; + + /* Here depending on the CAS option we directly insert in a blocking + * way, or use a thread to do candidate neighbors selection and only + * later, in the reply callback, actually add the element. */ + if (cas) { + RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,VADD_CASReply,NULL,NULL,0); + pthread_t tid; + void **targ = RedisModule_Alloc(sizeof(void*)*8); + targ[0] = bc; + targ[1] = vset; + targ[2] = (void*)(unsigned long)vset->id; + targ[3] = vec; + targ[4] = val; + targ[5] = NULL; // Used later for insertion context. + targ[6] = (void*)(unsigned long)ef; + targ[7] = attrib; + RedisModule_RetainString(ctx,val); + if (attrib) RedisModule_RetainString(ctx,attrib); + RedisModule_BlockedClientMeasureTimeStart(bc); + vset->thread_creation_pending++; + if (pthread_create(&tid,NULL,VADD_thread,targ) != 0) { + vset->thread_creation_pending--; + RedisModule_AbortBlock(bc); + RedisModule_Free(targ); + RedisModule_FreeString(ctx,val); + if (attrib) RedisModule_FreeString(ctx,attrib); + + // Fall back to synchronous insert, see later in the code. + } else { + return REDISMODULE_OK; + } + } + + /* Insert vector synchronously: we reach this place even + * if cas was true but thread creation failed. */ + int added = vectorSetInsert(vset,vec,NULL,0,val,attrib,1,ef); + RedisModule_Free(vec); + + RedisModule_ReplyWithBool(ctx,added); + if (added) RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +/* HNSW callback to filter items according to a predicate function + * (our FILTER expression in this case). */ +int vectorSetFilterCallback(void *value, void *privdata) { + exprstate *expr = privdata; + struct vsetNodeVal *nv = value; + if (nv->attrib == NULL) return 0; // No attributes? No match. + size_t json_len; + char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len); + return exprRun(expr,json,json_len); +} + +/* Common path for the execution of the VSIM command both threaded and + * not threaded. Note that 'ctx' may be normal context of a thread safe + * context obtained from a blocked client. The locking that is specific + * to the vset object is handled by the caller, however the function + * handles the HNSW locking explicitly. */ +void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, + float *vec, unsigned long count, float epsilon, unsigned long withscores, + unsigned long withattribs, unsigned long ef, exprstate *filter_expr, + unsigned long filter_ef, int ground_truth) +{ + /* In our scan, we can't just collect 'count' elements as + * if count is small we would explore the graph in an insufficient + * way to provide enough recall. + * + * If the user didn't asked for a specific exploration, we use + * VSET_DEFAULT_SEARCH_EF as minimum, or we match count if count + * is greater than that. Otherwise the minumim will be the specified + * EF argument. */ + if (ef == 0) ef = VSET_DEFAULT_SEARCH_EF; + if (count > ef) ef = count; + + int slot = hnsw_acquire_read_slot(vset->hnsw); + if (ef > vset->hnsw->node_count) ef = vset->hnsw->node_count; + + /* Perform search */ + hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef); + float *distances = RedisModule_Alloc(sizeof(float)*ef); + unsigned int found; + if (ground_truth) { + found = hnsw_ground_truth_with_filter(vset->hnsw, vec, ef, neighbors, + distances, slot, 0, + filter_expr ? vectorSetFilterCallback : NULL, + filter_expr); + } else { + if (filter_expr == NULL) { + found = hnsw_search(vset->hnsw, vec, ef, neighbors, + distances, slot, 0); + } else { + found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, + distances, slot, 0, vectorSetFilterCallback, + filter_expr, filter_ef); + } + } + + /* Return results */ + int resp3 = RedisModule_GetContextFlags(ctx) & REDISMODULE_CTX_FLAGS_RESP3; + int reply_with_map = resp3 && (withscores || withattribs); + + if (reply_with_map) + RedisModule_ReplyWithMap(ctx, REDISMODULE_POSTPONED_LEN); + else + RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN); + + long long arraylen = 0; + for (unsigned int i = 0; i < found && i < count; i++) { + if (distances[i]/2 > epsilon) break; + struct vsetNodeVal *nv = neighbors[i]->value; + RedisModule_ReplyWithString(ctx, nv->item); + arraylen++; + + /* If the user asked for multiple properties at the same time using + * the RESP3 protocol, we wrap the value of the map into an N-items + * array. Two for now, since we have just two properties that can be + * requested. + * + * So in the case of RESP2 we will just have the flat reply: + * item, score, attribute. For RESP3 instead item -> [score, attribute] + */ + if (resp3 && withscores && withattribs) + RedisModule_ReplyWithArray(ctx,2); + + if (withscores) { + /* The similarity score is provided in a 0-1 range. */ + RedisModule_ReplyWithDouble(ctx, 1.0 - distances[i]/2.0); + } + if (withattribs) { + /* Return the attributes as well, if any. */ + if (nv->attrib) + RedisModule_ReplyWithString(ctx, nv->attrib); + else + RedisModule_ReplyWithNull(ctx); + } + } + hnsw_release_read_slot(vset->hnsw,slot); + + if (reply_with_map) { + RedisModule_ReplySetMapLength(ctx, arraylen); + } else { + int items_per_ele = 1+withattribs+withscores; + RedisModule_ReplySetArrayLength(ctx, arraylen * items_per_ele); + } + + RedisModule_Free(vec); + RedisModule_Free(neighbors); + RedisModule_Free(distances); + if (filter_expr) exprFree(filter_expr); +} + +/* VSIM thread handling the blocked client request. */ +void *VSIM_thread(void *arg) { + pthread_detach(pthread_self()); + + // Extract arguments. + void **targ = (void**)arg; + RedisModuleBlockedClient *bc = targ[0]; + struct vsetObject *vset = targ[1]; + float *vec = targ[2]; + unsigned long count = (unsigned long)targ[3]; + float epsilon = *((float*)targ[4]); + unsigned long withscores = (unsigned long)targ[5]; + unsigned long withattribs = (unsigned long)targ[6]; + unsigned long ef = (unsigned long)targ[7]; + exprstate *filter_expr = targ[8]; + unsigned long filter_ef = (unsigned long)targ[9]; + unsigned long ground_truth = (unsigned long)targ[10]; + RedisModule_Free(targ[4]); + RedisModule_Free(targ); + + /* Lock the object and signal that we are no longer pending + * the lock acquisition. */ + RedisModule_Assert(pthread_rwlock_rdlock(&vset->in_use_lock) == 0); + vset->thread_creation_pending--; + + // Accumulate reply in a thread safe context: no contention. + RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc); + + // Run the query. + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth); + pthread_rwlock_unlock(&vset->in_use_lock); + + // Cleanup. + RedisModule_FreeThreadSafeContext(ctx); + RedisModule_BlockedClientMeasureTimeEnd(bc); + RedisModule_UnblockClient(bc,NULL); + return NULL; +} + +/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */ +int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + /* Basic argument check: need at least key and vector specification + * method. */ + if (argc < 4) return RedisModule_WrongArity(ctx); + + /* Defaults */ + int withscores = 0; + int withattribs = 0; + long long count = VSET_DEFAULT_COUNT; /* New default value */ + long long ef = 0; /* Exploration factor (see HNSW paper) */ + double epsilon = 2.0; /* Max cosine distance */ + long long ground_truth = 0; /* Linear scan instead of HNSW search? */ + int no_thread = 0; /* NOTHREAD option: exec on main thread. */ + + /* Things computed later. */ + long long filter_ef = 0; + exprstate *filter_expr = NULL; + + /* Get key and vector type */ + RedisModuleString *key = argv[1]; + const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL); + + /* Get vector set */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithEmptyArray(ctx); + + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + + /* Vector parsing stage */ + float *vec = NULL; + size_t dim = 0; + int vector_args = 0; /* Number of args consumed by vector specification */ + + if (!strcasecmp(vectorType, "ELE")) { + /* Get vector from existing element */ + RedisModuleString *ele = argv[3]; + hnswNode *node = RedisModule_DictGet(vset->dict, ele, NULL); + if (!node) { + return RedisModule_ReplyWithError(ctx, "ERR element not found in set"); + } + vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim); + hnsw_get_node_vector(vset->hnsw,node,vec); + dim = vset->hnsw->vector_dim; + vector_args = 2; /* ELE + element name */ + } else { + /* Parse vector. */ + int consumed_args; + + vec = parseVector(argv, argc, 2, &dim, NULL, &consumed_args); + if (!vec) { + return RedisModule_ReplyWithError(ctx, + "ERR invalid vector specification"); + } + vector_args = consumed_args; + + /* Apply projection if the set uses it, with the exception + * of ELE type, that will already have the right dimension. */ + if (vset->proj_matrix && dim != vset->hnsw->vector_dim) { + /* Ensure input dimension matches the projection matrix's expected input dimension */ + if (dim != vset->proj_input_size) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Input dimension mismatch for projection - got %d but projection expects %d", + (int)dim, (int)vset->proj_input_size); + } + + float *projected = applyProjection(vec, vset->proj_matrix, + vset->proj_input_size, + vset->hnsw->vector_dim); + RedisModule_Free(vec); + vec = projected; + dim = vset->hnsw->vector_dim; + } + + /* Count consumed arguments */ + if (!strcasecmp(vectorType, "FP32")) { + vector_args = 2; /* FP32 + vector blob */ + } else if (!strcasecmp(vectorType, "VALUES")) { + long long vdim; + if (RedisModule_StringToLongLong(argv[3], &vdim) != REDISMODULE_OK) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid vector dimension"); + } + vector_args = 2 + vdim; /* VALUES + dim + values */ + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR vector type must be ELE, FP32 or VALUES"); + } + } + + /* Check vector dimension matches set */ + if (dim != vset->hnsw->vector_dim) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Vector dimension mismatch - got %d but set has %d", + (int)dim, (int)vset->hnsw->vector_dim); + } + + /* Parse optional arguments - start after vector specification */ + int j = 2 + vector_args; + while (j < argc) { + const char *opt = RedisModule_StringPtrLen(argv[j], NULL); + if (!strcasecmp(opt, "WITHSCORES")) { + withscores = 1; + j++; + } else if (!strcasecmp(opt, "WITHATTRIBS")) { + withattribs = 1; + j++; + } else if (!strcasecmp(opt, "TRUTH")) { + ground_truth = 1; + j++; + } else if (!strcasecmp(opt, "NOTHREAD")) { + no_thread = 1; + j++; + } else if (!strcasecmp(opt, "COUNT") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &count) + != REDISMODULE_OK || count <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT"); + } + j += 2; + } else if (!strcasecmp(opt, "EPSILON") && j+1 < argc) { + if (RedisModule_StringToDouble(argv[j+1], &epsilon) != + REDISMODULE_OK || epsilon <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EPSILON"); + } + j += 2; + } else if (!strcasecmp(opt, "EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &ef) != + REDISMODULE_OK || ef <= 0 || ef > 1000000) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); + } + j += 2; + } else if (!strcasecmp(opt, "FILTER-EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &filter_ef) != + REDISMODULE_OK || filter_ef <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid FILTER-EF"); + } + j += 2; + } else if (!strcasecmp(opt, "FILTER") && j+1 < argc) { + RedisModuleString *exprarg = argv[j+1]; + size_t exprlen; + char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen); + int errpos; + filter_expr = exprCompile(exprstr,&errpos); + if (filter_expr == NULL) { + if ((size_t)errpos >= exprlen) errpos = 0; + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR syntax error in FILTER expression near: %s", + exprstr+errpos); + } + j += 2; + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR syntax error in VSIM command"); + } + } + + int threaded_request = 1; // Run on a thread, by default. + if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes. + + /* Disable threaded for MULTI/EXEC and Lua, or if explicitly + * requested by the user via the NOTHREAD option. */ + if (no_thread || VSGlobalConfig.forceSingleThreadExec || + (RedisModule_GetContextFlags(ctx) & + (REDISMODULE_CTX_FLAGS_LUA | REDISMODULE_CTX_FLAGS_MULTI))) + { + threaded_request = 0; + } + + if (threaded_request) { + /* Note: even if we create one thread per request, the underlying + * HNSW library has a fixed number of slots for the threads, as it's + * defined in HNSW_MAX_THREADS (beware that if you increase it, + * every node will use more memory). This means that while this request + * is threaded, and will NOT block Redis, it may end waiting for a + * free slot if all the HNSW_MAX_THREADS slots are used. */ + RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0); + pthread_t tid; + void **targ = RedisModule_Alloc(sizeof(void*)*11); + targ[0] = bc; + targ[1] = vset; + targ[2] = vec; + targ[3] = (void*)count; + targ[4] = RedisModule_Alloc(sizeof(float)); + *((float*)targ[4]) = epsilon; + targ[5] = (void*)(unsigned long)withscores; + targ[6] = (void*)(unsigned long)withattribs; + targ[7] = (void*)(unsigned long)ef; + targ[8] = (void*)filter_expr; + targ[9] = (void*)(unsigned long)filter_ef; + targ[10] = (void*)(unsigned long)ground_truth; + RedisModule_BlockedClientMeasureTimeStart(bc); + vset->thread_creation_pending++; + if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) { + vset->thread_creation_pending--; + RedisModule_AbortBlock(bc); + RedisModule_Free(targ[4]); + RedisModule_Free(targ); + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth); + } + } else { + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, withattribs, ef, filter_expr, filter_ef, ground_truth); + } + + return REDISMODULE_OK; +} + +/* VDIM <key>: return the dimension of vectors in the vector set. */ +int VDIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithError(ctx, "ERR key does not exist"); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); +} + +/* VCARD <key>: return cardinality (num of elements) of the vector set. */ +int VCARD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithLongLong(ctx, 0); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count); +} + +/* VREM key element + * Remove an element from a vector set. + * Returns 1 if the element was found and removed, 0 if not found. */ +int VREM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + if (argc != 3) return RedisModule_WrongArity(ctx); + + /* Get key and value */ + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Open key */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key or wrong type */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + return RedisModule_ReplyWithBool(ctx, 0); + } + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get vector set from key */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + + /* Find the node for this element */ + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) { + return RedisModule_ReplyWithBool(ctx, 0); + } + + /* Remove from dictionary */ + RedisModule_DictDel(vset->dict, element, NULL); + + /* Remove from HNSW graph using the high-level API that handles + * locking and cleanup. We pass RedisModule_FreeString as the value + * free function since the strings were retained at insertion time. */ + struct vsetNodeVal *nv = node->value; + if (nv->attrib != NULL) vset->numattribs--; + RedisModule_Assert(hnsw_delete_node(vset->hnsw, node, vectorSetReleaseNodeValue) == 1); + + /* Destroy empty vector set. */ + if (RedisModule_DictSize(vset->dict) == 0) { + RedisModule_DeleteKey(keyptr); + } + + /* Reply and propagate the command */ + RedisModule_ReplyWithBool(ctx, 1); + RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +/* VEMB key element + * Returns the embedding vector associated with an element, or NIL if not + * found. The vector is returned in the same format it was added, but the + * return value will have some lack of precision due to quantization and + * normalization of vectors. Also, if items were added using REDUCE, the + * reduced vector is returned instead. */ +int VEMB_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + int raw_output = 0; // RAW option. + + if (argc < 3) return RedisModule_WrongArity(ctx); + + /* Parse arguments. */ + for (int j = 3; j < argc; j++) { + const char *opt = RedisModule_StringPtrLen(argv[j], NULL); + if (!strcasecmp(opt,"raw")) { + raw_output = 1; + } else { + return RedisModule_ReplyWithError(ctx,"ERR invalid option"); + } + } + + /* Get key and element. */ + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Open key. */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key and key of wrong type. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + return RedisModule_ReplyWithNull(ctx); + } else if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Lookup the node about the specified element. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) { + return RedisModule_ReplyWithNull(ctx); + } + + if (raw_output) { + int output_qrange = vset->hnsw->quant_type == HNSW_QUANT_Q8; + RedisModule_ReplyWithArray(ctx, 3+output_qrange); + RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset)); + RedisModule_ReplyWithStringBuffer(ctx, node->vector, hnsw_quants_bytes(vset->hnsw)); + RedisModule_ReplyWithDouble(ctx, node->l2); + if (output_qrange) RedisModule_ReplyWithDouble(ctx, node->quants_range); + } else { + /* Get the vector associated with the node. */ + float *vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim); + hnsw_get_node_vector(vset->hnsw, node, vec); // May dequantize/denorm. + + /* Return as array of doubles. */ + RedisModule_ReplyWithArray(ctx, vset->hnsw->vector_dim); + for (uint32_t i = 0; i < vset->hnsw->vector_dim; i++) + RedisModule_ReplyWithDouble(ctx, vec[i]); + RedisModule_Free(vec); + } + return REDISMODULE_OK; +} + +/* VSETATTR key element json + * Set or remove the JSON attribute associated with an element. + * Setting an empty string removes the attribute. + * The command returns one if the attribute was actually updated or + * zero if there is no key or element. */ +int VSETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 4) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithBool(ctx, 0); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL); + if (!node) + return RedisModule_ReplyWithBool(ctx, 0); + + struct vsetNodeVal *nv = node->value; + RedisModuleString *new_attr = argv[3]; + + /* Background VSIM operations use the node attributes, so + * wait for background operations before messing with them. */ + vectorSetWaitAllBackgroundClients(vset,0); + + /* Set or delete the attribute based on the fact it's an empty + * string or not. */ + size_t attrlen; + RedisModule_StringPtrLen(new_attr, &attrlen); + if (attrlen == 0) { + // If we had an attribute before, decrease the count and free it. + if (nv->attrib) { + vset->numattribs--; + RedisModule_FreeString(NULL, nv->attrib); + nv->attrib = NULL; + } + } else { + // If we didn't have an attribute before, increase the count. + // Otherwise free the old one. + if (nv->attrib) { + RedisModule_FreeString(NULL, nv->attrib); + } else { + vset->numattribs++; + } + // Set new attribute. + RedisModule_RetainString(NULL, new_attr); + nv->attrib = new_attr; + } + + RedisModule_ReplyWithBool(ctx, 1); + RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +/* VGETATTR key element + * Get the JSON attribute associated with an element. + * Returns NIL if the element has no attribute or doesn't exist. */ +int VGETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 3) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNull(ctx); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL); + if (!node) + return RedisModule_ReplyWithNull(ctx); + + struct vsetNodeVal *nv = node->value; + if (!nv->attrib) + return RedisModule_ReplyWithNull(ctx); + + return RedisModule_ReplyWithString(ctx, nv->attrib); +} + +/* ============================== Reflection ================================ */ + +/* VLINKS key element [WITHSCORES] + * Returns the neighbors of an element at each layer in the HNSW graph. + * Reply is an array of arrays, where each nested array represents one level + * of neighbors, from highest level to level 0. If WITHSCORES is specified, + * each neighbor is followed by its distance from the element. */ +int VLINKS_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc < 3 || argc > 4) return RedisModule_WrongArity(ctx); + + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Parse WITHSCORES option. */ + int withscores = 0; + if (argc == 4) { + const char *opt = RedisModule_StringPtrLen(argv[3], NULL); + if (strcasecmp(opt, "WITHSCORES") != 0) { + return RedisModule_WrongArity(ctx); + } + withscores = 1; + } + + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key or wrong type. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNull(ctx); + + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + /* Find the node for this element. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) + return RedisModule_ReplyWithNull(ctx); + + /* Reply with array of arrays, one per level. */ + RedisModule_ReplyWithArray(ctx, node->level + 1); + + /* For each level, from highest to lowest: */ + for (int i = node->level; i >= 0; i--) { + /* Reply with array of neighbors at this level. */ + if (withscores) + RedisModule_ReplyWithMap(ctx,node->layers[i].num_links); + else + RedisModule_ReplyWithArray(ctx,node->layers[i].num_links); + + /* Add each neighbor's element value to the array. */ + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + struct vsetNodeVal *nv = node->layers[i].links[j]->value; + RedisModule_ReplyWithString(ctx, nv->item); + if (withscores) { + float distance = hnsw_distance(vset->hnsw, node, node->layers[i].links[j]); + /* Convert distance to similarity score to match + * VSIM behavior.*/ + float similarity = 1.0 - distance/2.0; + RedisModule_ReplyWithDouble(ctx, similarity); + } + } + } + return REDISMODULE_OK; +} + +/* VINFO key + * Returns information about a vector set, both visible and hidden + * features of the HNSW data structure. */ +int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNullArray(ctx); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + + /* Reply with hash */ + RedisModule_ReplyWithMap(ctx, 9); + + /* Quantization type */ + RedisModule_ReplyWithSimpleString(ctx, "quant-type"); + RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset)); + + /* HNSW M value */ + RedisModule_ReplyWithSimpleString(ctx, "hnsw-m"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->M); + + /* Vector dimensionality. */ + RedisModule_ReplyWithSimpleString(ctx, "vector-dim"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); + + /* Original input dimension before projection. + * This is zero for vector sets without a random projection matrix. */ + RedisModule_ReplyWithSimpleString(ctx, "projection-input-dim"); + RedisModule_ReplyWithLongLong(ctx, vset->proj_input_size); + + /* Number of elements. */ + RedisModule_ReplyWithSimpleString(ctx, "size"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count); + + /* Max level of HNSW. */ + RedisModule_ReplyWithSimpleString(ctx, "max-level"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->max_level); + + /* Number of nodes with attributes. */ + RedisModule_ReplyWithSimpleString(ctx, "attributes-count"); + RedisModule_ReplyWithLongLong(ctx, vset->numattribs); + + /* Vector set ID. */ + RedisModule_ReplyWithSimpleString(ctx, "vset-uid"); + RedisModule_ReplyWithLongLong(ctx, vset->id); + + /* HNSW max node ID. */ + RedisModule_ReplyWithSimpleString(ctx, "hnsw-max-node-uid"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->last_id); + + return REDISMODULE_OK; +} + +/* VRANDMEMBER key [count] + * Return random members from a vector set. + * + * Without count: returns a single random member. + * With positive count: N unique random members (no duplicates). + * With negative count: N random members (with possible duplicates). + * + * If the key doesn't exist, returns NULL if count is not given, or + * an empty array if a count was given. */ +int VRANDMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + /* Check arguments. */ + if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx); + + /* Parse optional count argument. */ + long long count = 1; /* Default is to return a single element. */ + int with_count = (argc == 3); + + if (with_count) { + if (RedisModule_StringToLongLong(argv[2], &count) != REDISMODULE_OK) { + return RedisModule_ReplyWithError(ctx, + "ERR COUNT value is not an integer"); + } + /* Count = 0 is a special case, return empty array */ + if (count == 0) { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Open key. */ + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + /* Handle non-existing key. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + if (!with_count) { + return RedisModule_ReplyWithNull(ctx); + } else { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Check key type. */ + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get vector set from key. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + uint64_t set_size = vset->hnsw->node_count; + + /* No elements in the set? */ + if (set_size == 0) { + if (!with_count) { + return RedisModule_ReplyWithNull(ctx); + } else { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Case 1: No count specified: return a single element. */ + if (!with_count) { + hnswNode *random_node = hnsw_random_node(vset->hnsw, 0); + if (random_node) { + struct vsetNodeVal *nv = random_node->value; + return RedisModule_ReplyWithString(ctx, nv->item); + } else { + return RedisModule_ReplyWithNull(ctx); + } + } + + /* Case 2: COUNT option given, return an array of elements. */ + int allow_duplicates = (count < 0); + long long abs_count = (count < 0) ? -count : count; + + /* Cap the count to the set size if we are not allowing duplicates. */ + if (!allow_duplicates && abs_count > (long long)set_size) + abs_count = set_size; + + /* Prepare reply. */ + RedisModule_ReplyWithArray(ctx, abs_count); + + if (allow_duplicates) { + /* Simple case: With duplicates, just pick random nodes + * abs_count times. */ + for (long long i = 0; i < abs_count; i++) { + hnswNode *random_node = hnsw_random_node(vset->hnsw,0); + struct vsetNodeVal *nv = random_node->value; + RedisModule_ReplyWithString(ctx, nv->item); + } + } else { + /* Case where count is positive: we need unique elements. + * But, if the user asked for many elements, selecting so + * many (> 20%) random nodes may be too expansive: we just start + * from a random element and follow the next link. + * + * Otherwisem for the <= 20% case, a dictionary is used to + * reject duplicates. */ + int use_dict = (abs_count <= set_size * 0.2); + + if (use_dict) { + RedisModuleDict *returned = RedisModule_CreateDict(ctx); + + long long returned_count = 0; + while (returned_count < abs_count) { + hnswNode *random_node = hnsw_random_node(vset->hnsw, 0); + struct vsetNodeVal *nv = random_node->value; + + /* Check if we've already returned this element. */ + if (RedisModule_DictGet(returned, nv->item, NULL) == NULL) { + /* Mark as returned and add to results. */ + RedisModule_DictSet(returned, nv->item, (void*)1); + RedisModule_ReplyWithString(ctx, nv->item); + returned_count++; + } + } + RedisModule_FreeDict(ctx, returned); + } else { + /* For large samples, get a random starting node and walk + * the list. + * + * IMPORTANT: doing so does not really generate random + * elements: it's just a linear scan, but we have no choices. + * If we generate too many random elements, more and more would + * fail the check of being novel (not yet collected in the set + * to return) if the % of elements to emit is too large, we would + * spend too much CPU. */ + hnswNode *start_node = hnsw_random_node(vset->hnsw, 0); + hnswNode *current = start_node; + + long long returned_count = 0; + while (returned_count < abs_count) { + if (current == NULL) { + /* Restart from head if we hit the end. */ + current = vset->hnsw->head; + } + struct vsetNodeVal *nv = current->value; + RedisModule_ReplyWithString(ctx, nv->item); + returned_count++; + current = current->next; + } + } + } + return REDISMODULE_OK; +} + +/* VISMEMBER key element + * Check if an element exists in a vector set. + * Returns 1 if the element exists, 0 if not. */ +int VISMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + if (argc != 3) return RedisModule_WrongArity(ctx); + + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Open key. */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key or wrong type. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + /* An element of a non existing key does not exist, like + * SISMEMBER & similar. */ + return RedisModule_ReplyWithBool(ctx, 0); + } + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get the object and test membership via the dictionary in constant + * time (assuming a member of average size). */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + return RedisModule_ReplyWithBool(ctx, node != NULL); +} + +/* Structure to represent a range boundary. */ +struct vsetRangeOp { + int incl; /* 1 if inclusive ([), 0 if exclusive ((). */ + int min; /* 1 if this is "-" (minimum). */ + int max; /* 1 if this is "+" (maximum). */ + unsigned char *ele; /* The actual element, NULL if min/max. */ + size_t ele_len; /* Length of the element. */ +}; + +/* Parse a range specification like "[foo" or "(bar" or "-" or "+". + * Returns 1 on success, 0 on error. */ +int vsetParseRangeOp(RedisModuleString *arg, struct vsetRangeOp *op) { + size_t len; + const char *str = RedisModule_StringPtrLen(arg, &len); + + if (len == 0) return 0; + + /* Initialize the structure. */ + op->incl = 0; + op->min = 0; + op->max = 0; + op->ele = NULL; + op->ele_len = 0; + + /* Check for special cases "-" and "+". */ + if (len == 1 && str[0] == '-') { + op->min = 1; + return 1; + } + if (len == 1 && str[0] == '+') { + op->max = 1; + return 1; + } + + /* Otherwise, must start with ( or [. */ + if (str[0] == '[') { + op->incl = 1; + } else if (str[0] == '(') { + op->incl = 0; + } else { + return 0; /* Invalid format. */ + } + + /* Extract the string part after the bracket. */ + if (len > 1) { + op->ele = (unsigned char *)(str + 1); + op->ele_len = len - 1; + } else { + return 0; /* Just a bracket with no string. */ + } + + return 1; +} + +/* Check if the current element is within the range defined by the end operator. + * Returns 1 if the element is within range, 0 if it has passed the end. */ +int vsetIsElementInRange(const void *ele, size_t ele_len, struct vsetRangeOp *end_op) { + /* If end is "+", element is always in range. */ + if (end_op->max) return 1; + + /* Compare current element with end boundary. */ + size_t minlen = ele_len < end_op->ele_len ? ele_len : end_op->ele_len; + int cmp = memcmp(ele, end_op->ele, minlen); + + if (cmp == 0) { + /* If equal up to minlen, shorter string is smaller. */ + if (ele_len < end_op->ele_len) { + cmp = -1; + } else if (ele_len > end_op->ele_len) { + cmp = 1; + } + } + + /* Check based on inclusive/exclusive. */ + if (end_op->incl) { + return cmp <= 0; /* Inclusive: element <= end. */ + } else { + return cmp < 0; /* Exclusive: element < end. */ + } +} + +/* VRANGE key start end [count] + * Returns elements in the lexicographical range [start, end] + * + * Elements must be specified in one of the following forms: + * + * [myelement + * (myelement + * + + * - + * + * Elements starting with [ are inclusive, so "myelement" would be + * returned if present in the set. Elements starting with ( are exclusive + * ranges instead. The special - and + elements mean the minimum and maximum + * possible element (inclusive), so "VRANGE key - +" will return everything + * (depending on COUNT of course). The special - element can be used only + * as starting element, the special + element only as ending element. */ +int VRANGE_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + /* Check arguments. */ + if (argc < 4 || argc > 5) return RedisModule_WrongArity(ctx); + + /* Parse COUNT if provided. */ + long long count = -1; /* Default: return all elements. */ + if (argc == 5) { + if (RedisModule_StringToLongLong(argv[4], &count) != REDISMODULE_OK) { + return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT value"); + } + } + + /* Parse range operators. */ + struct vsetRangeOp start_op, end_op; + if (!vsetParseRangeOp(argv[2], &start_op)) { + return RedisModule_ReplyWithError(ctx, "ERR invalid start range format"); + } + if (!vsetParseRangeOp(argv[3], &end_op)) { + return RedisModule_ReplyWithError(ctx, "ERR invalid end range format"); + } + + /* Validate: "-" can only be first arg, "+" can only be second. */ + if (start_op.max || end_op.min) { + return RedisModule_ReplyWithError(ctx, + "ERR '-' can only be used as first argument, '+' only as second"); + } + + /* Open the key. */ + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) { + return RedisModule_ReplyWithEmptyArray(ctx); + } + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + + /* Start the iterator. */ + RedisModuleDictIter *iter; + if (start_op.min) { + /* Start from the beginning. */ + iter = RedisModule_DictIteratorStartC(vset->dict, "^", NULL, 0); + } else { + /* Start from the specified element. */ + const char *op = start_op.incl ? ">=" : ">"; + iter = RedisModule_DictIteratorStartC(vset->dict, op, start_op.ele, start_op.ele_len); + } + + /* Collect results. */ + RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN); + long long returned = 0; + + void *key_data; + size_t key_len; + while ((key_data = RedisModule_DictNextC(iter, &key_len, NULL)) != NULL) { + /* Check if we've collected enough elements. */ + if (count >= 0 && returned >= count) break; + + /* Check if we've passed the end range. */ + if (!vsetIsElementInRange(key_data, key_len, &end_op)) break; + + /* Add this element to the result. */ + RedisModule_ReplyWithStringBuffer(ctx, key_data, key_len); + returned++; + } + + RedisModule_ReplySetArrayLength(ctx, returned); + + /* Cleanup. */ + RedisModule_DictIteratorStop(iter); + + return REDISMODULE_OK; +} + +/* ============================== vset type methods ========================= */ + +#define SAVE_FLAG_HAS_PROJMATRIX (1<<0) +#define SAVE_FLAG_HAS_ATTRIBS (1<<1) + +/* Save object to RDB */ +void VectorSetRdbSave(RedisModuleIO *rdb, void *value) { + struct vsetObject *vset = value; + RedisModule_SaveUnsigned(rdb, vset->hnsw->vector_dim); + RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count); + + uint32_t hnsw_config = (vset->hnsw->quant_type & 0xff) | + ((vset->hnsw->M & 0xffff) << 8); + RedisModule_SaveUnsigned(rdb, hnsw_config); + + uint32_t save_flags = 0; + if (vset->proj_matrix) save_flags |= SAVE_FLAG_HAS_PROJMATRIX; + if (vset->numattribs != 0) save_flags |= SAVE_FLAG_HAS_ATTRIBS; + RedisModule_SaveUnsigned(rdb, save_flags); + + /* Save projection matrix if present */ + if (vset->proj_matrix) { + uint32_t input_dim = vset->proj_input_size; + uint32_t output_dim = vset->hnsw->vector_dim; + RedisModule_SaveUnsigned(rdb, input_dim); + // Output dim is the same as the first value saved + // above, so we don't save it. + + // Save projection matrix as binary blob + size_t matrix_size = sizeof(float) * input_dim * output_dim; + RedisModule_SaveStringBuffer(rdb, (const char *)vset->proj_matrix, matrix_size); + } + + hnswNode *node = vset->hnsw->head; + while(node) { + struct vsetNodeVal *nv = node->value; + RedisModule_SaveString(rdb, nv->item); + if (vset->numattribs) { + if (nv->attrib) + RedisModule_SaveString(rdb, nv->attrib); + else + RedisModule_SaveStringBuffer(rdb, "", 0); + } + hnswSerNode *sn = hnsw_serialize_node(vset->hnsw,node); + RedisModule_SaveStringBuffer(rdb, (const char *)sn->vector, sn->vector_size); + RedisModule_SaveUnsigned(rdb, sn->params_count); + for (uint32_t j = 0; j < sn->params_count; j++) + RedisModule_SaveUnsigned(rdb, sn->params[j]); + hnsw_free_serialized_node(sn); + node = node->next; + } +} + +/* Load object from RDB. Recover from recoverable errors (read errors) + * by performing cleanup. */ +void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { + if (encver != 0) return NULL; // Invalid version + + uint32_t dim = RedisModule_LoadUnsigned(rdb); + uint64_t elements = RedisModule_LoadUnsigned(rdb); + uint32_t hnsw_config = RedisModule_LoadUnsigned(rdb); + if (RedisModule_IsIOError(rdb)) return NULL; + uint32_t quant_type = hnsw_config & 0xff; + uint32_t hnsw_m = (hnsw_config >> 8) & 0xffff; + + /* Check that the quantization type is correct. Otherwise + * return ASAP signaling the error. */ + if (quant_type != HNSW_QUANT_NONE && + quant_type != HNSW_QUANT_Q8 && + quant_type != HNSW_QUANT_BIN) return NULL; + + if (hnsw_m == 0) hnsw_m = 16; // Default, useful for RDB files predating + // this configuration parameter: it was fixed + // to 16. + struct vsetObject *vset = createVectorSetObject(dim,quant_type,hnsw_m); + RedisModule_Assert(vset != NULL); + + /* Load projection matrix if present */ + uint32_t save_flags = RedisModule_LoadUnsigned(rdb); + if (RedisModule_IsIOError(rdb)) goto ioerr; + int has_projection = save_flags & SAVE_FLAG_HAS_PROJMATRIX; + int has_attribs = save_flags & SAVE_FLAG_HAS_ATTRIBS; + if (has_projection) { + uint32_t input_dim = RedisModule_LoadUnsigned(rdb); + if (RedisModule_IsIOError(rdb)) goto ioerr; + uint32_t output_dim = dim; + size_t matrix_size = sizeof(float) * input_dim * output_dim; + + vset->proj_matrix = RedisModule_Alloc(matrix_size); + vset->proj_input_size = input_dim; + + // Load projection matrix as a binary blob + char *matrix_blob = RedisModule_LoadStringBuffer(rdb, NULL); + if (matrix_blob == NULL) goto ioerr; + memcpy(vset->proj_matrix, matrix_blob, matrix_size); + RedisModule_Free(matrix_blob); + } + + while(elements--) { + // Load associated string element. + RedisModuleString *ele = RedisModule_LoadString(rdb); + if (RedisModule_IsIOError(rdb)) goto ioerr; + RedisModuleString *attrib = NULL; + if (has_attribs) { + attrib = RedisModule_LoadString(rdb); + if (RedisModule_IsIOError(rdb)) { + RedisModule_FreeString(NULL,ele); + goto ioerr; + } + size_t attrlen; + RedisModule_StringPtrLen(attrib,&attrlen); + if (attrlen == 0) { + RedisModule_FreeString(NULL,attrib); + attrib = NULL; + } + } + size_t vector_len; + void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len); + if (RedisModule_IsIOError(rdb)) { + RedisModule_FreeString(NULL,ele); + if (attrib) RedisModule_FreeString(NULL,attrib); + goto ioerr; + } + uint32_t vector_bytes = hnsw_quants_bytes(vset->hnsw); + if (vector_len != vector_bytes) { + RedisModule_LogIOError(rdb,"warning", + "Mismatching vector dimension"); + RedisModule_FreeString(NULL,ele); + if (attrib) RedisModule_FreeString(NULL,attrib); + RedisModule_Free(vector); + goto ioerr; + } + + // Load node parameters back. + uint32_t params_count = RedisModule_LoadUnsigned(rdb); + if (RedisModule_IsIOError(rdb)) { + RedisModule_FreeString(NULL,ele); + if (attrib) RedisModule_FreeString(NULL,attrib); + RedisModule_Free(vector); + goto ioerr; + } + + uint64_t *params = RedisModule_Alloc(params_count*sizeof(uint64_t)); + for (uint32_t j = 0; j < params_count; j++) { + // Ignore loading errors here: handled at the end of the loop. + params[j] = RedisModule_LoadUnsigned(rdb); + } + if (RedisModule_IsIOError(rdb)) { + RedisModule_FreeString(NULL,ele); + if (attrib) RedisModule_FreeString(NULL,attrib); + RedisModule_Free(vector); + RedisModule_Free(params); + goto ioerr; + } + + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = ele; + nv->attrib = attrib; + hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, nv); + if (node == NULL) { + RedisModule_LogIOError(rdb,"warning", + "Vector set node index loading error"); + vectorSetReleaseNodeValue(nv); + RedisModule_Free(vector); + RedisModule_Free(params); + goto ioerr; + } + if (nv->attrib) vset->numattribs++; + RedisModule_DictSet(vset->dict,ele,node); + RedisModule_Free(vector); + RedisModule_Free(params); + } + + uint64_t salt[2]; + RedisModule_GetRandomBytes((unsigned char*)salt,sizeof(salt)); + if (!hnsw_deserialize_index(vset->hnsw, salt[0], salt[1])) goto ioerr; + + return vset; + +ioerr: + /* We want to recover from I/O errors and free the partially allocated + * data structure to support diskless replication. */ + vectorSetReleaseObject(vset); + return NULL; +} + +/* Calculate memory usage */ +size_t VectorSetMemUsage(const void *value) { + const struct vsetObject *vset = value; + size_t size = sizeof(*vset); + + /* Account for HNSW index base structure */ + size += sizeof(HNSW); + + /* Account for projection matrix if present */ + if (vset->proj_matrix) { + /* For the matrix size, we need the input dimension. We can get it + * from the first node if the set is not empty. */ + uint32_t input_dim = vset->proj_input_size; + uint32_t output_dim = vset->hnsw->vector_dim; + size += sizeof(float) * input_dim * output_dim; + } + + /* Account for each node's memory usage. */ + hnswNode *node = vset->hnsw->head; + if (node == NULL) return size; + + /* Base node structure. */ + size += sizeof(*node) * vset->hnsw->node_count; + + /* Vector storage. */ + uint64_t vec_storage = hnsw_quants_bytes(vset->hnsw); + size += vec_storage * vset->hnsw->node_count; + + /* Layers array. We use 1.33 as average nodes layers count. */ + uint64_t layers_storage = sizeof(hnswNodeLayer) * vset->hnsw->node_count; + layers_storage = layers_storage * 4 / 3; // 1.33 times. + size += layers_storage; + + /* All the nodes have layer 0 links. */ + uint64_t level0_links = node->layers[0].max_links; + uint64_t other_levels_links = level0_links/2; + size += sizeof(hnswNode*) * level0_links * vset->hnsw->node_count; + + /* Add the 0.33 remaining part, but upper layers have less links. */ + size += (sizeof(hnswNode*) * other_levels_links * vset->hnsw->node_count)/3; + + /* Associated string value and attributres. + * Use Redis Module API to get string size, and guess that all the + * elements have similar size as the first few. */ + size_t items_scanned = 0, items_size = 0; + size_t attribs_scanned = 0, attribs_size = 0; + int scan_effort = 20; + while(scan_effort > 0 && node) { + struct vsetNodeVal *nv = node->value; + items_size += RedisModule_MallocSizeString(nv->item); + items_scanned++; + if (nv->attrib) { + attribs_size += RedisModule_MallocSizeString(nv->attrib); + attribs_scanned++; + } + scan_effort--; + node = node->next; + } + + /* Add the memory usage due to items. */ + if (items_scanned) + size += items_size / items_scanned * vset->hnsw->node_count; + + /* Add memory usage due to attributres. */ + if (attribs_scanned == 0) { + /* We were not lucky enough to find a single attribute in the + * first few items? Let's use a fixed arbitrary value. */ + attribs_scanned = 1; + attribs_size = 64; + } + size += attribs_size / attribs_scanned * vset->numattribs; + + /* Account for dictionary overhead - this is an approximation. */ + size += RedisModule_DictSize(vset->dict) * (sizeof(void*) * 2); + + return size; +} + +/* Free the entire data structure */ +void VectorSetFree(void *value) { + struct vsetObject *vset = value; + + vectorSetWaitAllBackgroundClients(vset,1); + vectorSetReleaseObject(value); +} + +/* Add object digest to the digest context */ +void VectorSetDigest(RedisModuleDigest *md, void *value) { + struct vsetObject *vset = value; + + /* Add consistent order-independent hash of all vectors */ + hnswNode *node = vset->hnsw->head; + + /* Hash the vector dimension and number of nodes. */ + RedisModule_DigestAddLongLong(md, vset->hnsw->node_count); + RedisModule_DigestAddLongLong(md, vset->hnsw->vector_dim); + RedisModule_DigestEndSequence(md); + + while(node) { + struct vsetNodeVal *nv = node->value; + /* Hash each vector component */ + RedisModule_DigestAddStringBuffer(md, node->vector, hnsw_quants_bytes(vset->hnsw)); + /* Hash the associated value */ + size_t len; + const char *str = RedisModule_StringPtrLen(nv->item, &len); + RedisModule_DigestAddStringBuffer(md, (char*)str, len); + if (nv->attrib) { + str = RedisModule_StringPtrLen(nv->attrib, &len); + RedisModule_DigestAddStringBuffer(md, (char*)str, len); + } + node = node->next; + RedisModule_DigestEndSequence(md); + } +} + +// int VectorSets_InitModuleConfig(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { +int VectorSets_InitModuleConfig(RedisModuleCtx *ctx) { + if (RegisterModuleConfig(ctx) == REDISMODULE_ERR) { + RedisModule_Log(ctx, "warning", "Error registering module configuration"); + return REDISMODULE_ERR; + } + // Load default values + if (RedisModule_LoadDefaultConfigs(ctx) == REDISMODULE_ERR) { + RedisModule_Log(ctx, "warning", "Error loading default module configuration"); + return REDISMODULE_ERR; + } else { + RedisModule_Log(ctx, "verbose", "Successfully loaded default module configuration"); + } + if (RedisModule_LoadConfigs(ctx) == REDISMODULE_ERR) { + RedisModule_Log(ctx, "warning", "Error loading user module configuration"); + return REDISMODULE_ERR; + } else { + RedisModule_Log(ctx, "verbose", "Successfully loaded user module configuration"); + } + return REDISMODULE_OK; +} + +/* This function must be present on each Redis module. It is used in order to + * register the commands into the Redis server. */ +int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + REDISMODULE_NOT_USED(argv); + REDISMODULE_NOT_USED(argc); + + if (RedisModule_Init(ctx,"vectorset",1,REDISMODULE_APIVER_1) + == REDISMODULE_ERR) return REDISMODULE_ERR; + + if (VectorSets_InitModuleConfig(ctx) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + + RedisModule_SetModuleOptions(ctx, REDISMODULE_OPTIONS_HANDLE_IO_ERRORS|REDISMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD); + + RedisModuleTypeMethods tm = { + .version = REDISMODULE_TYPE_METHOD_VERSION, + .rdb_load = VectorSetRdbLoad, + .rdb_save = VectorSetRdbSave, + .aof_rewrite = NULL, + .mem_usage = VectorSetMemUsage, + .free = VectorSetFree, + .digest = VectorSetDigest + }; + + VectorSetType = RedisModule_CreateDataType(ctx,"vectorset",0,&tm); + if (VectorSetType == NULL) return REDISMODULE_ERR; + + // Register command VADD + if (RedisModule_CreateCommand(ctx,"VADD", + VADD_RedisCommand,"write deny-oom",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vadd_cmd = RedisModule_GetCommand(ctx, "VADD"); + if (vadd_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vadd_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "reduce", .type = REDISMODULE_ARG_TYPE_BLOCK, .token = "REDUCE", .flags = REDISMODULE_CMD_ARG_OPTIONAL, + .subargs = (RedisModuleCommandArg[]) { + { .name = "dim", .type = REDISMODULE_ARG_TYPE_INTEGER }, + { .name = NULL } + } + }, + { .name = "format", .type = REDISMODULE_ARG_TYPE_ONEOF, .subargs = (RedisModuleCommandArg[]) { + { .name = "fp32", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "FP32" }, + { .name = "values", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "VALUES" }, + { .name = NULL } + } + }, + { .name = "vector", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = "cas", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "CAS", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "quant_type", .type = REDISMODULE_ARG_TYPE_ONEOF, .flags = REDISMODULE_CMD_ARG_OPTIONAL, .subargs = (RedisModuleCommandArg[]) { + { .name = "noquant", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "NOQUANT" }, + { .name = "bin", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "BIN" }, + { .name = "q8", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "Q8" }, + { .name = NULL } + } + }, + { .name = "build-exploration-factor", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "attributes", .type = REDISMODULE_ARG_TYPE_STRING, .token = "SETATTR", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "numlinks", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "M", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = NULL } + }; + RedisModuleCommandInfo vadd_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Add one or more elements to a vector set, or update its vector if it already exists", + .since = "8.0.0", + .arity = -5, + .args = vadd_args, + }; + if (RedisModule_SetCommandInfo(vadd_cmd, &vadd_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VREM + if (RedisModule_CreateCommand(ctx,"VREM", + VREM_RedisCommand,"write",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vrem_cmd = RedisModule_GetCommand(ctx, "VREM"); + if (vrem_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vrem_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = NULL } + }; + RedisModuleCommandInfo vrem_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Remove an element from a vector set", + .since = "8.0.0", + .arity = 3, + .args = vrem_args, + }; + if (RedisModule_SetCommandInfo(vrem_cmd, &vrem_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VSIM + if (RedisModule_CreateCommand(ctx,"VSIM", + VSIM_RedisCommand,"readonly",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vsim_cmd = RedisModule_GetCommand(ctx, "VSIM"); + if (vsim_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vsim_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "format", .type = REDISMODULE_ARG_TYPE_ONEOF, .subargs = (RedisModuleCommandArg[]) { + { .name = "ele", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "ELE" }, + { .name = "fp32", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "FP32" }, + { .name = "values", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "VALUES" }, + { .name = NULL } + } + }, + { .name = "vector_or_element", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = "withscores", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHSCORES", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "withattribs", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHATTRIBS", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "COUNT", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "max_distance", .type = REDISMODULE_ARG_TYPE_DOUBLE, .token = "EPSILON", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "search-exploration-factor", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "expression", .type = REDISMODULE_ARG_TYPE_STRING, .token = "FILTER", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "max-filtering-effort", .type = REDISMODULE_ARG_TYPE_INTEGER, .token = "FILTER-EF", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "truth", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "TRUTH", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = "nothread", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "NOTHREAD", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = NULL } + }; + RedisModuleCommandInfo vsim_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Return elements by vector similarity", + .since = "8.0.0", + .arity = -4, + .args = vsim_args, + }; + if (RedisModule_SetCommandInfo(vsim_cmd, &vsim_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VDIM + if (RedisModule_CreateCommand(ctx, "VDIM", + VDIM_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vdim_cmd = RedisModule_GetCommand(ctx, "VDIM"); + if (vdim_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vdim_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = NULL } + }; + RedisModuleCommandInfo vdim_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Return the dimension of vectors in the vector set", + .since = "8.0.0", + .arity = 2, + .args = vdim_args, + }; + if (RedisModule_SetCommandInfo(vdim_cmd, &vdim_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VCARD + if (RedisModule_CreateCommand(ctx, "VCARD", + VCARD_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vcard_cmd = RedisModule_GetCommand(ctx, "VCARD"); + if (vcard_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vcard_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = NULL } + }; + RedisModuleCommandInfo vcard_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Return the number of elements in a vector set", + .since = "8.0.0", + .arity = 2, + .args = vcard_args, + }; + if (RedisModule_SetCommandInfo(vcard_cmd, &vcard_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VEMB + if (RedisModule_CreateCommand(ctx, "VEMB", + VEMB_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vemb_cmd = RedisModule_GetCommand(ctx, "VEMB"); + if (vemb_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vemb_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = "raw", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "RAW", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = NULL } + }; + RedisModuleCommandInfo vemb_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Return the vector associated with an element", + .since = "8.0.0", + .arity = -3, + .args = vemb_args, + }; + if (RedisModule_SetCommandInfo(vemb_cmd, &vemb_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VLINKS + if (RedisModule_CreateCommand(ctx, "VLINKS", + VLINKS_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vlinks_cmd = RedisModule_GetCommand(ctx, "VLINKS"); + if (vlinks_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vlinks_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = "withscores", .type = REDISMODULE_ARG_TYPE_PURE_TOKEN, .token = "WITHSCORES", .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = NULL } + }; + RedisModuleCommandInfo vlinks_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Return the neighbors of an element at each layer in the HNSW graph", + .since = "8.0.0", + .arity = -3, + .args = vlinks_args, + }; + if (RedisModule_SetCommandInfo(vlinks_cmd, &vlinks_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VINFO + if (RedisModule_CreateCommand(ctx, "VINFO", + VINFO_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vinfo_cmd = RedisModule_GetCommand(ctx, "VINFO"); + if (vinfo_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vinfo_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = NULL } + }; + RedisModuleCommandInfo vinfo_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Return information about a vector set", + .since = "8.0.0", + .arity = 2, + .args = vinfo_args, + }; + if (RedisModule_SetCommandInfo(vinfo_cmd, &vinfo_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VSETATTR + if (RedisModule_CreateCommand(ctx, "VSETATTR", + VSETATTR_RedisCommand, "write fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vsetattr_cmd = RedisModule_GetCommand(ctx, "VSETATTR"); + if (vsetattr_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vsetattr_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = "json", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = NULL } + }; + RedisModuleCommandInfo vsetattr_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Associate or remove the JSON attributes of elements", + .since = "8.0.0", + .arity = 4, + .args = vsetattr_args, + }; + if (RedisModule_SetCommandInfo(vsetattr_cmd, &vsetattr_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VGETATTR + if (RedisModule_CreateCommand(ctx, "VGETATTR", + VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vgetattr_cmd = RedisModule_GetCommand(ctx, "VGETATTR"); + if (vgetattr_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vgetattr_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = NULL } + }; + RedisModuleCommandInfo vgetattr_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Retrieve the JSON attributes of elements", + .since = "8.0.0", + .arity = 3, + .args = vgetattr_args, + }; + if (RedisModule_SetCommandInfo(vgetattr_cmd, &vgetattr_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VRANDMEMBER + if (RedisModule_CreateCommand(ctx, "VRANDMEMBER", + VRANDMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vrandmember_cmd = RedisModule_GetCommand(ctx, "VRANDMEMBER"); + if (vrandmember_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vrandmember_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = NULL } + }; + RedisModuleCommandInfo vrandmember_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Return one or multiple random members from a vector set", + .since = "8.0.0", + .arity = -2, + .args = vrandmember_args, + }; + if (RedisModule_SetCommandInfo(vrandmember_cmd, &vrandmember_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VISMEMBER + if (RedisModule_CreateCommand(ctx, "VISMEMBER", + VISMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vismember_cmd = RedisModule_GetCommand(ctx, "VISMEMBER"); + if (vismember_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vismember_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "element", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = NULL } + }; + RedisModuleCommandInfo vismember_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Check if an element exists in a vector set", + .since = "8.2.0", + .arity = 3, + .args = vismember_args, + }; + if (RedisModule_SetCommandInfo(vismember_cmd, &vismember_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Register command VRANGE + if (RedisModule_CreateCommand(ctx, "VRANGE", + VRANGE_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + RedisModuleCommand *vrange_cmd = RedisModule_GetCommand(ctx, "VRANGE"); + if (vrange_cmd == NULL) return REDISMODULE_ERR; + + RedisModuleCommandArg vrange_args[] = { + { .name = "key", .type = REDISMODULE_ARG_TYPE_KEY, .key_spec_index = 0 }, + { .name = "start", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = "end", .type = REDISMODULE_ARG_TYPE_STRING }, + { .name = "count", .type = REDISMODULE_ARG_TYPE_INTEGER, .flags = REDISMODULE_CMD_ARG_OPTIONAL }, + { .name = NULL } + }; + RedisModuleCommandInfo vrange_info = { + .version = REDISMODULE_COMMAND_INFO_VERSION, + .summary = "Return vector set elements in a lex range", + .since = "8.4.0", + .arity = -4, + .args = vrange_args, + }; + if (RedisModule_SetCommandInfo(vrange_cmd, &vrange_info) == REDISMODULE_ERR) return REDISMODULE_ERR; + + // Set the allocator for the HNSW library, so that memory tracking + // is correct in Redis. + hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc, + RedisModule_Realloc); + + return REDISMODULE_OK; +} + +int VectorSets_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + return RedisModule_OnLoad(ctx, argv, argc); +} diff --git a/examples/redis-unstable/modules/vector-sets/vset_config.c b/examples/redis-unstable/modules/vector-sets/vset_config.c new file mode 100644 index 0000000..79dc8a3 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/vset_config.c @@ -0,0 +1,51 @@ +/* vector set module configuration. + * + * Copyright (c) 2009-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of (a) the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). +*/ + +#include "vset_config.h" + +/* Define __STRING macro for portability (not available in all environments) */ +#ifndef __STRING +#define __STRING(x) #x +#endif + +#define RM_TRY(expr) \ + if (expr == REDISMODULE_ERR) { \ + RedisModule_Log(ctx, "warning", "Could not run " __STRING(expr)); \ + return REDISMODULE_ERR; \ + } + +VSConfig VSGlobalConfig; + +int set_bool_config(const char *name, int val, void *privdata, + RedisModuleString **err) { + REDISMODULE_NOT_USED(name); + REDISMODULE_NOT_USED(err); + *(int *)privdata = val; + return REDISMODULE_OK; +} + +int get_bool_config(const char *name, void *privdata) { + REDISMODULE_NOT_USED(name); + return *(int *)privdata; +} + +int RegisterModuleConfig(RedisModuleCtx *ctx) { + // Numeric parameters + RM_TRY( + RedisModule_RegisterBoolConfig( + ctx, "vset-force-single-threaded-execution", 0, + REDISMODULE_CONFIG_UNPREFIXED, + get_bool_config, set_bool_config, NULL, + (void *)&(VSGlobalConfig.forceSingleThreadExec) + ) + ) + + return REDISMODULE_OK; +} diff --git a/examples/redis-unstable/modules/vector-sets/vset_config.h b/examples/redis-unstable/modules/vector-sets/vset_config.h new file mode 100644 index 0000000..8da94fa --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/vset_config.h @@ -0,0 +1,24 @@ +/* vector set module configuration. + * + * Copyright (c) 2009-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of (a) the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). +*/ + +#ifndef VSET_CONFIG_H +#define VSET_CONFIG_H + +#include "../../src/redismodule.h" + +typedef struct { + int forceSingleThreadExec; +} VSConfig; + +extern VSConfig VSGlobalConfig; + +int RegisterModuleConfig(RedisModuleCtx *ctx); + +#endif diff --git a/examples/redis-unstable/modules/vector-sets/w2v.c b/examples/redis-unstable/modules/vector-sets/w2v.c new file mode 100644 index 0000000..bcf6338 --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/w2v.c @@ -0,0 +1,539 @@ +/* + * HNSW (Hierarchical Navigable Small World) Implementation + * Based on the paper by Yu. A. Malkov, D. A. Yashunin + * + * Copyright (c) 2009-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of (a) the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + * Originally authored by: Salvatore Sanfilippo + */ + +#define _DEFAULT_SOURCE +#define _USE_MATH_DEFINES +#define _POSIX_C_SOURCE 200809L + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <strings.h> +#include <sys/time.h> +#include <time.h> +#include <stdint.h> +#include <pthread.h> +#include <stdatomic.h> +#include <math.h> + +#include "hnsw.h" + +/* Get current time in milliseconds */ +uint64_t ms_time(void) { + struct timeval tv; + gettimeofday(&tv, NULL); + return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000); +} + +/* Implementation of the recall test with random vectors. */ +void test_recall(HNSW *index, int ef) { + const int num_test_vectors = 10000; + const int k = 100; // Number of nearest neighbors to find. + if (ef < k) ef = k; + + // Add recall distribution counters (2% bins from 0-100%). + int recall_bins[50] = {0}; + + // Create array to store vectors for mixing. + int num_source_vectors = 1000; // Enough, since we mix them. + float **source_vectors = malloc(sizeof(float*) * num_source_vectors); + if (!source_vectors) { + printf("Failed to allocate memory for source vectors\n"); + return; + } + + // Allocate memory for each source vector. + for (int i = 0; i < num_source_vectors; i++) { + source_vectors[i] = malloc(sizeof(float) * 300); + if (!source_vectors[i]) { + printf("Failed to allocate memory for source vector %d\n", i); + // Clean up already allocated vectors. + for (int j = 0; j < i; j++) free(source_vectors[j]); + free(source_vectors); + return; + } + } + + /* Populate source vectors from the index, we just scan the + * first N items. */ + int source_count = 0; + hnswNode *current = index->head; + while (current && source_count < num_source_vectors) { + hnsw_get_node_vector(index, current, source_vectors[source_count]); + source_count++; + current = current->next; + } + + if (source_count < num_source_vectors) { + printf("Warning: Only found %d nodes for source vectors\n", + source_count); + num_source_vectors = source_count; + } + + // Allocate memory for test vector. + float *test_vector = malloc(sizeof(float) * 300); + if (!test_vector) { + printf("Failed to allocate memory for test vector\n"); + for (int i = 0; i < num_source_vectors; i++) { + free(source_vectors[i]); + } + free(source_vectors); + return; + } + + // Allocate memory for results. + hnswNode **hnsw_results = malloc(sizeof(hnswNode*) * ef); + hnswNode **linear_results = malloc(sizeof(hnswNode*) * ef); + float *hnsw_distances = malloc(sizeof(float) * ef); + float *linear_distances = malloc(sizeof(float) * ef); + + if (!hnsw_results || !linear_results || !hnsw_distances || !linear_distances) { + printf("Failed to allocate memory for results\n"); + if (hnsw_results) free(hnsw_results); + if (linear_results) free(linear_results); + if (hnsw_distances) free(hnsw_distances); + if (linear_distances) free(linear_distances); + for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]); + free(source_vectors); + free(test_vector); + return; + } + + // Initialize random seed. + srand(time(NULL)); + + // Perform recall test. + printf("\nPerforming recall test with EF=%d on %d random vectors...\n", + ef, num_test_vectors); + double total_recall = 0.0; + + for (int t = 0; t < num_test_vectors; t++) { + // Create a random vector by mixing 3 existing vectors. + float weights[3] = {0.0}; + int src_indices[3] = {0}; + + // Generate random weights. + float weight_sum = 0.0; + for (int i = 0; i < 3; i++) { + weights[i] = (float)rand() / RAND_MAX; + weight_sum += weights[i]; + src_indices[i] = rand() % num_source_vectors; + } + + // Normalize weights. + for (int i = 0; i < 3; i++) weights[i] /= weight_sum; + + // Mix vectors. + memset(test_vector, 0, sizeof(float) * 300); + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 300; j++) { + test_vector[j] += + weights[i] * source_vectors[src_indices[i]][j]; + } + } + + // Perform HNSW search with the specified EF parameter. + int slot = hnsw_acquire_read_slot(index); + int hnsw_found = hnsw_search(index, test_vector, ef, hnsw_results, hnsw_distances, slot, 0); + + // Perform linear search (ground truth). + int linear_found = hnsw_ground_truth_with_filter(index, test_vector, ef, linear_results, linear_distances, slot, 0, NULL, NULL); + hnsw_release_read_slot(index, slot); + + // Calculate recall for this query (intersection size / k). + if (hnsw_found > k) hnsw_found = k; + if (linear_found > k) linear_found = k; + int intersection_count = 0; + for (int i = 0; i < linear_found; i++) { + for (int j = 0; j < hnsw_found; j++) { + if (linear_results[i] == hnsw_results[j]) { + intersection_count++; + break; + } + } + } + + double recall = (double)intersection_count / linear_found; + total_recall += recall; + + // Add to distribution bins (2% steps) + int bin_index = (int)(recall * 50); + if (bin_index >= 50) bin_index = 49; // Handle 100% recall case + recall_bins[bin_index]++; + + // Show progress. + if ((t+1) % 1000 == 0 || t == num_test_vectors-1) { + printf("Processed %d/%d queries, current avg recall: %.2f%%\n", + t+1, num_test_vectors, (total_recall / (t+1)) * 100); + } + } + + // Calculate and print final average recall. + double avg_recall = (total_recall / num_test_vectors) * 100; + printf("\nRecall Test Results:\n"); + printf("Average recall@%d (EF=%d): %.2f%%\n", k, ef, avg_recall); + + // Print recall distribution histogram. + printf("\nRecall Distribution (2%% bins):\n"); + printf("================================\n"); + + // Find the maximum bin count for scaling. + int max_count = 0; + for (int i = 0; i < 50; i++) { + if (recall_bins[i] > max_count) max_count = recall_bins[i]; + } + + // Scale factor for histogram (max 50 chars wide) + const int max_bars = 50; + double scale = (max_count > max_bars) ? (double)max_bars / max_count : 1.0; + + // Print the histogram. + for (int i = 0; i < 50; i++) { + int bar_len = (int)(recall_bins[i] * scale); + printf("%3d%%-%-3d%% | %-6d |", i*2, (i+1)*2, recall_bins[i]); + for (int j = 0; j < bar_len; j++) printf("#"); + printf("\n"); + } + + // Cleanup. + free(hnsw_results); + free(linear_results); + free(hnsw_distances); + free(linear_distances); + free(test_vector); + for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]); + free(source_vectors); +} + +/* Example usage in main() */ +int w2v_single_thread(int m_param, int quantization, uint64_t numele, int massdel, int self_recall, int recall_ef) { + /* Create index */ + HNSW *index = hnsw_new(300, quantization, m_param); + float v[300]; + uint16_t wlen; + + FILE *fp = fopen("word2vec.bin","rb"); + if (fp == NULL) { + perror("word2vec.bin file missing"); + exit(1); + } + unsigned char header[8]; + if (fread(header,8,1,fp) <= 0) { // Skip header + perror("Unexpected EOF"); + exit(1); + } + + uint64_t id = 0; + uint64_t start_time = ms_time(); + char *word = NULL; + hnswNode *search_node = NULL; + + while(id < numele) { + if (fread(&wlen,2,1,fp) == 0) break; + word = malloc(wlen+1); + if (fread(word,wlen,1,fp) <= 0) { + perror("unexpected EOF"); + exit(1); + } + word[wlen] = 0; + if (fread(v,300*sizeof(float),1,fp) <= 0) { + perror("unexpected EOF"); + exit(1); + } + + // Plain API that acquires a write lock for the whole time. + hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200); + + if (!strcmp(word,"banana")) search_node = added; + if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id); + } + uint64_t elapsed = ms_time() - start_time; + fclose(fp); + + printf("%llu words added (%llu words/sec), last word: %s\n", + (unsigned long long)index->node_count, + (unsigned long long)id*1000/elapsed, word); + + /* Search query */ + if (search_node == NULL) search_node = index->head; + hnsw_get_node_vector(index,search_node,v); + hnswNode *neighbors[10]; + float distances[10]; + + int found, j; + start_time = ms_time(); + for (j = 0; j < 20000; j++) + found = hnsw_search(index, v, 10, neighbors, distances, 0, 0); + elapsed = ms_time() - start_time; + printf("%d searches performed (%llu searches/sec), nodes found: %d\n", + j, (unsigned long long)j*1000/elapsed, found); + + if (found > 0) { + printf("Found %d neighbors:\n", found); + for (int i = 0; i < found; i++) { + printf("Node ID: %llu, distance: %f, word: %s\n", + (unsigned long long)neighbors[i]->id, + distances[i], (char*)neighbors[i]->value); + } + } + + // Self-recall test (ability to find the node by its own vector). + if (self_recall) { + hnsw_print_stats(index); + hnsw_test_graph_recall(index,200,0); + } + + // Recall test with random vectors. + if (recall_ef > 0) { + test_recall(index, recall_ef); + } + + uint64_t connected_nodes; + int reciprocal_links; + hnsw_validate_graph(index, &connected_nodes, &reciprocal_links); + + if (massdel) { + int remove_perc = 95; + printf("\nRemoving %d%% of nodes...\n", remove_perc); + uint64_t initial_nodes = index->node_count; + + hnswNode *current = index->head; + while (current && index->node_count > initial_nodes*(100-remove_perc)/100) { + hnswNode *next = current->next; + hnsw_delete_node(index,current,free); + current = next; + // In order to don't remove only contiguous nodes, from time + // skip a node. + if (current && !(random() % remove_perc)) current = current->next; + } + printf("%llu nodes left\n", (unsigned long long)index->node_count); + + // Test again. + hnsw_validate_graph(index, &connected_nodes, &reciprocal_links); + hnsw_test_graph_recall(index,200,0); + } + + hnsw_free(index,free); + return 0; +} + +struct threadContext { + pthread_mutex_t FileAccessMutex; + uint64_t numele; + _Atomic uint64_t SearchesDone; + _Atomic uint64_t id; + FILE *fp; + HNSW *index; + float *search_vector; +}; + +// Note that in practical terms inserting with many concurrent threads +// may be *slower* and not faster, because there is a lot of +// contention. So this is more a robustness test than anything else. +// +// The optimistic commit API goal is actually to exploit the ability to +// add faster when there are many concurrent reads. +void *threaded_insert(void *ctxptr) { + struct threadContext *ctx = ctxptr; + char *word; + float v[300]; + uint16_t wlen; + + while(1) { + pthread_mutex_lock(&ctx->FileAccessMutex); + if (fread(&wlen,2,1,ctx->fp) == 0) break; + pthread_mutex_unlock(&ctx->FileAccessMutex); + word = malloc(wlen+1); + if (fread(word,wlen,1,ctx->fp) <= 0) { + perror("Unexpected EOF"); + exit(1); + } + + word[wlen] = 0; + if (fread(v,300*sizeof(float),1,ctx->fp) <= 0) { + perror("Unexpected EOF"); + exit(1); + } + + // Check-and-set API that performs the costly scan for similar + // nodes concurrently with other read threads, and finally + // applies the check if the graph wasn't modified. + InsertContext *ic; + uint64_t next_id = ctx->id++; + ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, 200); + if (hnsw_try_commit_insert(ctx->index, ic, word) == NULL) { + // This time try locking since the start. + hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200); + } + + if (next_id >= ctx->numele) break; + if (!((next_id+1) % 10000)) + printf("%llu added\n", (unsigned long long)next_id+1); + } + return NULL; +} + +void *threaded_search(void *ctxptr) { + struct threadContext *ctx = ctxptr; + + /* Search query */ + hnswNode *neighbors[10]; + float distances[10]; + int found = 0; + uint64_t last_id = 0; + + while(ctx->id < 1000000) { + int slot = hnsw_acquire_read_slot(ctx->index); + found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0); + hnsw_release_read_slot(ctx->index,slot); + last_id = ++ctx->id; + } + + if (found > 0 && last_id == 1000000) { + printf("Found %d neighbors:\n", found); + for (int i = 0; i < found; i++) { + printf("Node ID: %llu, distance: %f, word: %s\n", + (unsigned long long)neighbors[i]->id, + distances[i], (char*)neighbors[i]->value); + } + } + return NULL; +} + +int w2v_multi_thread(int m_param, int numthreads, int quantization, uint64_t numele) { + /* Create index */ + struct threadContext ctx; + + ctx.index = hnsw_new(300, quantization, m_param); + + ctx.fp = fopen("word2vec.bin","rb"); + if (ctx.fp == NULL) { + perror("word2vec.bin file missing"); + exit(1); + } + + unsigned char header[8]; + if (fread(header,8,1,ctx.fp) <= 0) { // Skip header + perror("Unexpected EOF"); + exit(1); + } + pthread_mutex_init(&ctx.FileAccessMutex,NULL); + + uint64_t start_time = ms_time(); + ctx.id = 0; + ctx.numele = numele; + pthread_t threads[numthreads]; + for (int j = 0; j < numthreads; j++) + pthread_create(&threads[j], NULL, threaded_insert, &ctx); + + // Wait for all the threads to terminate adding items. + for (int j = 0; j < numthreads; j++) + pthread_join(threads[j],NULL); + + uint64_t elapsed = ms_time() - start_time; + fclose(ctx.fp); + + // Obtain the last word. + hnswNode *node = ctx.index->head; + char *word = node->value; + + // We will search this last inserted word in the next test. + // Let's save its embedding. + ctx.search_vector = malloc(sizeof(float)*300); + hnsw_get_node_vector(ctx.index,node,ctx.search_vector); + + printf("%llu words added (%llu words/sec), last word: %s\n", + (unsigned long long)ctx.index->node_count, + (unsigned long long)ctx.id*1000/elapsed, word); + + /* Search query */ + start_time = ms_time(); + ctx.id = 0; // We will use this atomic field to stop at N queries done. + + for (int j = 0; j < numthreads; j++) + pthread_create(&threads[j], NULL, threaded_search, &ctx); + + // Wait for all the threads to terminate searching. + for (int j = 0; j < numthreads; j++) + pthread_join(threads[j],NULL); + + elapsed = ms_time() - start_time; + printf("%llu searches performed (%llu searches/sec)\n", + (unsigned long long)ctx.id, + (unsigned long long)ctx.id*1000/elapsed); + + hnsw_print_stats(ctx.index); + uint64_t connected_nodes; + int reciprocal_links; + hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links); + printf("%llu connected nodes. Links all reciprocal: %d\n", + (unsigned long long)connected_nodes, reciprocal_links); + hnsw_free(ctx.index,free); + return 0; +} + +int main(int argc, char **argv) { + int quantization = HNSW_QUANT_NONE; + int numthreads = 0; + uint64_t numele = 20000; + int m_param = 0; // Default value (0 means use HNSW_DEFAULT_M) + + /* This you can enable in single thread mode for testing: */ + int massdel = 0; // If true, does the mass deletion test. + int self_recall = 0; // If true, does the self-recall test. + int recall_ef = 0; // If not 0, does the recall test with this EF value. + + for (int j = 1; j < argc; j++) { + int moreargs = argc-j-1; + + if (!strcasecmp(argv[j],"--quant")) { + quantization = HNSW_QUANT_Q8; + } else if (!strcasecmp(argv[j],"--bin")) { + quantization = HNSW_QUANT_BIN; + } else if (!strcasecmp(argv[j],"--mass-del")) { + massdel = 1; + } else if (!strcasecmp(argv[j],"--self-recall")) { + self_recall = 1; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--recall")) { + recall_ef = atoi(argv[j+1]); + j++; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) { + numthreads = atoi(argv[j+1]); + j++; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) { + numele = strtoll(argv[j+1],NULL,0); + j++; + if (numele < 1) numele = 1; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--m")) { + m_param = atoi(argv[j+1]); + j++; + } else if (!strcasecmp(argv[j],"--help")) { + printf("%s [--quant] [--bin] [--thread <count>] [--numele <count>] [--m <count>] [--mass-del] [--self-recall] [--recall <ef>]\n", argv[0]); + exit(0); + } else { + printf("Unrecognized option or wrong number of arguments: %s\n", argv[j]); + exit(1); + } + } + + if (quantization == HNSW_QUANT_NONE) { + printf("You can enable quantization with --quant\n"); + } + + if (numthreads > 0) { + w2v_multi_thread(m_param, numthreads, quantization, numele); + } else { + printf("Single thread execution. Use --threads 4 for concurrent API\n"); + w2v_single_thread(m_param, quantization, numele, massdel, self_recall, recall_ef); + } +} |
