1import pytest
 2from utils import *
 3
 4server = ServerPreset.tinyllama2()
 5
 6@pytest.fixture(autouse=True)
 7def create_server():
 8    global server
 9    server = ServerPreset.tinyllama2()
10    server.slot_save_path = "./tmp"
11    server.temperature = 0.0
12
13
14def test_slot_save_restore():
15    global server
16    server.start()
17
18    # First prompt in slot 1 should be fully processed
19    res = server.make_request("POST", "/completion", data={
20        "prompt": "What is the capital of France?",
21        "id_slot": 1,
22        "cache_prompt": True,
23    })
24    assert res.status_code == 200
25    assert match_regex("(Whiskers|Flana)+", res.body["content"])
26    assert res.body["timings"]["prompt_n"] == 21  # all tokens are processed
27
28    # Save state of slot 1
29    res = server.make_request("POST", "/slots/1?action=save", data={
30        "filename": "slot1.bin",
31    })
32    assert res.status_code == 200
33    assert res.body["n_saved"] == 84
34
35    # Since we have cache, this should only process the last tokens
36    res = server.make_request("POST", "/completion", data={
37        "prompt": "What is the capital of Germany?",
38        "id_slot": 1,
39        "cache_prompt": True,
40    })
41    assert res.status_code == 200
42    assert match_regex("(Jack|said)+", res.body["content"])
43    assert res.body["timings"]["prompt_n"] == 6  # only different part is processed
44
45    # Loading the saved cache into slot 0
46    res = server.make_request("POST", "/slots/0?action=restore", data={
47        "filename": "slot1.bin",
48    })
49    assert res.status_code == 200
50    assert res.body["n_restored"] == 84
51
52    # Since we have cache, slot 0 should only process the last tokens
53    res = server.make_request("POST", "/completion", data={
54        "prompt": "What is the capital of Germany?",
55        "id_slot": 0,
56        "cache_prompt": True,
57    })
58    assert res.status_code == 200
59    assert match_regex("(Jack|said)+", res.body["content"])
60    assert res.body["timings"]["prompt_n"] == 6  # only different part is processed
61
62    # For verification that slot 1 was not corrupted during slot 0 load, same thing should work
63    res = server.make_request("POST", "/completion", data={
64        "prompt": "What is the capital of Germany?",
65        "id_slot": 1,
66        "cache_prompt": True,
67    })
68    assert res.status_code == 200
69    assert match_regex("(Jack|said)+", res.body["content"])
70    assert res.body["timings"]["prompt_n"] == 1
71
72
73def test_slot_erase():
74    global server
75    server.start()
76
77    res = server.make_request("POST", "/completion", data={
78        "prompt": "What is the capital of France?",
79        "id_slot": 1,
80        "cache_prompt": True,
81    })
82    assert res.status_code == 200
83    assert match_regex("(Whiskers|Flana)+", res.body["content"])
84    assert res.body["timings"]["prompt_n"] == 21  # all tokens are processed
85
86    # erase slot 1
87    res = server.make_request("POST", "/slots/1?action=erase")
88    assert res.status_code == 200
89
90    # re-run the same prompt, it should process all tokens again
91    res = server.make_request("POST", "/completion", data={
92        "prompt": "What is the capital of France?",
93        "id_slot": 1,
94        "cache_prompt": True,
95    })
96    assert res.status_code == 200
97    assert match_regex("(Whiskers|Flana)+", res.body["content"])
98    assert res.body["timings"]["prompt_n"] == 21  # all tokens are processed