aboutsummaryrefslogtreecommitdiff
path: root/examples/redis-unstable/modules/vector-sets/tests/persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/redis-unstable/modules/vector-sets/tests/persistence.py')
-rw-r--r--examples/redis-unstable/modules/vector-sets/tests/persistence.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/examples/redis-unstable/modules/vector-sets/tests/persistence.py b/examples/redis-unstable/modules/vector-sets/tests/persistence.py
new file mode 100644
index 0000000..79730f4
--- /dev/null
+++ b/examples/redis-unstable/modules/vector-sets/tests/persistence.py
@@ -0,0 +1,86 @@
1from test import TestCase, fill_redis_with_vectors, generate_random_vector
2import random
3
4class HNSWPersistence(TestCase):
5 def getname(self):
6 return "HNSW Persistence"
7
8 def estimated_runtime(self):
9 return 30
10
11 def _verify_results(self, key, dim, query_vec, reduced_dim=None):
12 """Run a query and return results dict"""
13 k = 10
14 args = ['VSIM', key]
15
16 if reduced_dim:
17 args.extend(['VALUES', dim])
18 args.extend([str(x) for x in query_vec])
19 else:
20 args.extend(['VALUES', dim])
21 args.extend([str(x) for x in query_vec])
22
23 args.extend(['COUNT', k, 'WITHSCORES'])
24 results = self.redis.execute_command(*args)
25
26 results_dict = {}
27 for i in range(0, len(results), 2):
28 key = results[i].decode()
29 score = float(results[i+1])
30 results_dict[key] = score
31 return results_dict
32
33 def test(self):
34 # Setup dimensions
35 dim = 128
36 reduced_dim = 32
37 count = 5000
38 random.seed(42)
39
40 # Create two datasets - one normal and one with dimension reduction
41 normal_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:normal", count, dim)
42 projected_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:projected",
43 count, dim, reduced_dim)
44
45 # Generate query vectors we'll use before and after reload
46 query_vec_normal = generate_random_vector(dim)
47 query_vec_projected = generate_random_vector(dim)
48
49 # Get initial results for both sets
50 initial_normal = self._verify_results(f"{self.test_key}:normal",
51 dim, query_vec_normal)
52 initial_projected = self._verify_results(f"{self.test_key}:projected",
53 dim, query_vec_projected, reduced_dim)
54
55 # Force Redis to save and reload the dataset
56 self.redis.execute_command('DEBUG', 'RELOAD')
57
58 # Verify results after reload
59 reloaded_normal = self._verify_results(f"{self.test_key}:normal",
60 dim, query_vec_normal)
61 reloaded_projected = self._verify_results(f"{self.test_key}:projected",
62 dim, query_vec_projected, reduced_dim)
63
64 # Verify normal vectors results
65 assert len(initial_normal) == len(reloaded_normal), \
66 "Normal vectors: Result count mismatch before/after reload"
67
68 for key in initial_normal:
69 assert key in reloaded_normal, f"Normal vectors: Missing item after reload: {key}"
70 assert abs(initial_normal[key] - reloaded_normal[key]) < 0.0001, \
71 f"Normal vectors: Score mismatch for {key}: " + \
72 f"before={initial_normal[key]:.6f}, after={reloaded_normal[key]:.6f}"
73
74 # Verify projected vectors results
75 assert len(initial_projected) == len(reloaded_projected), \
76 "Projected vectors: Result count mismatch before/after reload"
77
78 for key in initial_projected:
79 assert key in reloaded_projected, \
80 f"Projected vectors: Missing item after reload: {key}"
81 assert abs(initial_projected[key] - reloaded_projected[key]) < 0.0001, \
82 f"Projected vectors: Score mismatch for {key}: " + \
83 f"before={initial_projected[key]:.6f}, after={reloaded_projected[key]:.6f}"
84
85 self.redis.delete(f"{self.test_key}:normal")
86 self.redis.delete(f"{self.test_key}:projected")