diff options
Diffstat (limited to 'examples/redis-unstable/modules/vector-sets/tests/node_update.py')
| -rw-r--r-- | examples/redis-unstable/modules/vector-sets/tests/node_update.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/examples/redis-unstable/modules/vector-sets/tests/node_update.py b/examples/redis-unstable/modules/vector-sets/tests/node_update.py new file mode 100644 index 0000000..53aa2dd --- /dev/null +++ b/examples/redis-unstable/modules/vector-sets/tests/node_update.py | |||
| @@ -0,0 +1,85 @@ | |||
| 1 | from test import TestCase, generate_random_vector | ||
| 2 | import struct | ||
| 3 | import math | ||
| 4 | import random | ||
| 5 | |||
| 6 | class VectorUpdateAndClusters(TestCase): | ||
| 7 | def getname(self): | ||
| 8 | return "VADD vector update with cluster relocation" | ||
| 9 | |||
| 10 | def estimated_runtime(self): | ||
| 11 | return 2.0 # Should take around 2 seconds | ||
| 12 | |||
| 13 | def generate_cluster_vector(self, base_vec, noise=0.1): | ||
| 14 | """Generate a vector that's similar to base_vec with some noise.""" | ||
| 15 | vec = [x + random.gauss(0, noise) for x in base_vec] | ||
| 16 | # Normalize | ||
| 17 | norm = math.sqrt(sum(x*x for x in vec)) | ||
| 18 | return [x/norm for x in vec] | ||
| 19 | |||
| 20 | def test(self): | ||
| 21 | dim = 128 | ||
| 22 | vectors_per_cluster = 5000 | ||
| 23 | |||
| 24 | # Create two very different base vectors for our clusters | ||
| 25 | cluster1_base = generate_random_vector(dim) | ||
| 26 | cluster2_base = [-x for x in cluster1_base] # Opposite direction | ||
| 27 | |||
| 28 | # Add vectors from first cluster | ||
| 29 | for i in range(vectors_per_cluster): | ||
| 30 | vec = self.generate_cluster_vector(cluster1_base) | ||
| 31 | vec_bytes = struct.pack(f'{dim}f', *vec) | ||
| 32 | self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, | ||
| 33 | f'{self.test_key}:cluster1:{i}') | ||
| 34 | |||
| 35 | # Add vectors from second cluster | ||
| 36 | for i in range(vectors_per_cluster): | ||
| 37 | vec = self.generate_cluster_vector(cluster2_base) | ||
| 38 | vec_bytes = struct.pack(f'{dim}f', *vec) | ||
| 39 | self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, | ||
| 40 | f'{self.test_key}:cluster2:{i}') | ||
| 41 | |||
| 42 | # Pick a test vector from cluster1 | ||
| 43 | test_key = f'{self.test_key}:cluster1:0' | ||
| 44 | |||
| 45 | # Verify it's in cluster1 using VSIM | ||
| 46 | initial_vec = self.generate_cluster_vector(cluster1_base) | ||
| 47 | results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, | ||
| 48 | *[str(x) for x in initial_vec], | ||
| 49 | 'COUNT', 100, 'WITHSCORES') | ||
| 50 | |||
| 51 | # Count how many cluster1 items are in top results | ||
| 52 | cluster1_count = sum(1 for i in range(0, len(results), 2) | ||
| 53 | if b'cluster1' in results[i]) | ||
| 54 | assert cluster1_count > 80, "Initial clustering check failed" | ||
| 55 | |||
| 56 | # Now update the test vector to be in cluster2 | ||
| 57 | new_vec = self.generate_cluster_vector(cluster2_base, noise=0.05) | ||
| 58 | vec_bytes = struct.pack(f'{dim}f', *new_vec) | ||
| 59 | self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, test_key) | ||
| 60 | |||
| 61 | # Verify the embedding was actually updated using VEMB | ||
| 62 | emb_result = self.redis.execute_command('VEMB', self.test_key, test_key) | ||
| 63 | updated_vec = [float(x) for x in emb_result] | ||
| 64 | |||
| 65 | # Verify updated vector matches what we inserted | ||
| 66 | dot_product = sum(a*b for a,b in zip(updated_vec, new_vec)) | ||
| 67 | similarity = dot_product / (math.sqrt(sum(x*x for x in updated_vec)) * | ||
| 68 | math.sqrt(sum(x*x for x in new_vec))) | ||
| 69 | assert similarity > 0.9, "Vector was not properly updated" | ||
| 70 | |||
| 71 | # Verify it's now in cluster2 using VSIM | ||
| 72 | results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, | ||
| 73 | *[str(x) for x in cluster2_base], | ||
| 74 | 'COUNT', 100, 'WITHSCORES') | ||
| 75 | |||
| 76 | # Verify our updated vector is among top results | ||
| 77 | found = False | ||
| 78 | for i in range(0, len(results), 2): | ||
| 79 | if results[i].decode() == test_key: | ||
| 80 | found = True | ||
| 81 | similarity = float(results[i+1]) | ||
| 82 | assert similarity > 0.80, f"Updated vector has low similarity: {similarity}" | ||
| 83 | break | ||
| 84 | |||
| 85 | assert found, "Updated vector not found in cluster2 proximity" | ||
