1import pytest
2from utils import *
3
4server = ServerPreset.tinyllama2()
5
6@pytest.fixture(autouse=True)
7def create_server():
8 global server
9 server = ServerPreset.tinyllama2()
10 server.slot_save_path = "./tmp"
11 server.temperature = 0.0
12
13
14def test_slot_save_restore():
15 global server
16 server.start()
17
18 # First prompt in slot 1 should be fully processed
19 res = server.make_request("POST", "/completion", data={
20 "prompt": "What is the capital of France?",
21 "id_slot": 1,
22 "cache_prompt": True,
23 })
24 assert res.status_code == 200
25 assert match_regex("(Whiskers|Flana)+", res.body["content"])
26 assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
27
28 # Save state of slot 1
29 res = server.make_request("POST", "/slots/1?action=save", data={
30 "filename": "slot1.bin",
31 })
32 assert res.status_code == 200
33 assert res.body["n_saved"] == 84
34
35 # Since we have cache, this should only process the last tokens
36 res = server.make_request("POST", "/completion", data={
37 "prompt": "What is the capital of Germany?",
38 "id_slot": 1,
39 "cache_prompt": True,
40 })
41 assert res.status_code == 200
42 assert match_regex("(Jack|said)+", res.body["content"])
43 assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
44
45 # Loading the saved cache into slot 0
46 res = server.make_request("POST", "/slots/0?action=restore", data={
47 "filename": "slot1.bin",
48 })
49 assert res.status_code == 200
50 assert res.body["n_restored"] == 84
51
52 # Since we have cache, slot 0 should only process the last tokens
53 res = server.make_request("POST", "/completion", data={
54 "prompt": "What is the capital of Germany?",
55 "id_slot": 0,
56 "cache_prompt": True,
57 })
58 assert res.status_code == 200
59 assert match_regex("(Jack|said)+", res.body["content"])
60 assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
61
62 # For verification that slot 1 was not corrupted during slot 0 load, same thing should work
63 res = server.make_request("POST", "/completion", data={
64 "prompt": "What is the capital of Germany?",
65 "id_slot": 1,
66 "cache_prompt": True,
67 })
68 assert res.status_code == 200
69 assert match_regex("(Jack|said)+", res.body["content"])
70 assert res.body["timings"]["prompt_n"] == 1
71
72
73def test_slot_erase():
74 global server
75 server.start()
76
77 res = server.make_request("POST", "/completion", data={
78 "prompt": "What is the capital of France?",
79 "id_slot": 1,
80 "cache_prompt": True,
81 })
82 assert res.status_code == 200
83 assert match_regex("(Whiskers|Flana)+", res.body["content"])
84 assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
85
86 # erase slot 1
87 res = server.make_request("POST", "/slots/1?action=erase")
88 assert res.status_code == 200
89
90 # re-run the same prompt, it should process all tokens again
91 res = server.make_request("POST", "/completion", data={
92 "prompt": "What is the capital of France?",
93 "id_slot": 1,
94 "cache_prompt": True,
95 })
96 assert res.status_code == 200
97 assert match_regex("(Whiskers|Flana)+", res.body["content"])
98 assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed