1import pytest
2from utils import *
3
4server = ServerPreset.tinyllama2()
5
6
7SHORT_TEXT = """
8Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
9Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
10Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
11""".strip()
12
13LONG_TEXT = """
14Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
15Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
16Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
17Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
18""".strip()
19
20@pytest.fixture(autouse=True)
21def create_server():
22 global server
23 server = ServerPreset.tinyllama2()
24 server.n_ctx = 512
25 server.n_slots = 2
26 server.n_predict = 128
27
28
29def test_ctx_shift_enabled():
30 # the prompt is 226 tokens
31 # the slot context is 512/2 = 256 tokens
32 # 96 tokens are generated thanks to shifting the context when it gets full
33 global server
34 server.enable_ctx_shift = True
35 server.start()
36 res = server.make_request("POST", "/completion", data={
37 "n_predict": 96,
38 "prompt": SHORT_TEXT,
39 })
40 assert res.status_code == 200
41 assert res.body["timings"]["prompt_n"] == 226
42 assert res.body["timings"]["predicted_n"] == 96
43 assert res.body["truncated"] is True
44
45
46@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
47 (64, 64, False),
48 (-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
49])
50def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
51 global server
52 server.n_predict = -1
53 server.start()
54 res = server.make_request("POST", "/completion", data={
55 "n_predict": n_predict,
56 "prompt": "Hi how are you",
57 })
58 assert res.status_code == 200
59 assert res.body["timings"]["predicted_n"] == n_token_output
60 assert res.body["truncated"] == truncated
61
62
63def test_ctx_shift_disabled_long_prompt():
64 global server
65 server.start()
66 res = server.make_request("POST", "/completion", data={
67 "n_predict": 64,
68 "prompt": LONG_TEXT,
69 })
70 assert res.status_code != 200
71 assert "error" in res.body
72 assert "exceeds the available context size" in res.body["error"]["message"]
73
74def test_ctx_shift_disabled_stream():
75 global server
76 server.start()
77 res = server.make_stream_request("POST", "/v1/completions", data={
78 "n_predict": 256,
79 "prompt": "Once",
80 "stream": True,
81 })
82 content = ""
83 for data in res:
84 choice = data["choices"][0]
85 if choice["finish_reason"] == "length":
86 assert len(content) > 0
87 else:
88 assert choice["finish_reason"] is None
89 content += choice["text"]