1import pytest
 2from utils import *
 3
 4server = ServerPreset.tinyllama2()
 5
 6
 7SHORT_TEXT = """
 8Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
 9Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
10Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
11""".strip()
12
13LONG_TEXT = """
14Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
15Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
16Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
17Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
18""".strip()
19
20@pytest.fixture(autouse=True)
21def create_server():
22    global server
23    server = ServerPreset.tinyllama2()
24    server.n_ctx = 512
25    server.n_slots = 2
26    server.n_predict = 128
27
28
29def test_ctx_shift_enabled():
30    # the prompt is 226 tokens
31    # the slot context is 512/2 = 256 tokens
32    # 96 tokens are generated thanks to shifting the context when it gets full
33    global server
34    server.enable_ctx_shift = True
35    server.start()
36    res = server.make_request("POST", "/completion", data={
37        "n_predict": 96,
38        "prompt": SHORT_TEXT,
39    })
40    assert res.status_code == 200
41    assert res.body["timings"]["prompt_n"] == 226
42    assert res.body["timings"]["predicted_n"] == 96
43    assert res.body["truncated"] is True
44
45
46@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
47    (64, 64, False),
48    (-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
49])
50def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
51    global server
52    server.n_predict = -1
53    server.start()
54    res = server.make_request("POST", "/completion", data={
55        "n_predict": n_predict,
56        "prompt": "Hi how are you",
57    })
58    assert res.status_code == 200
59    assert res.body["timings"]["predicted_n"] == n_token_output
60    assert res.body["truncated"] == truncated
61
62
63def test_ctx_shift_disabled_long_prompt():
64    global server
65    server.start()
66    res = server.make_request("POST", "/completion", data={
67        "n_predict": 64,
68        "prompt": LONG_TEXT,
69    })
70    assert res.status_code != 200
71    assert "error" in res.body
72    assert "exceeds the available context size" in res.body["error"]["message"]
73
74def test_ctx_shift_disabled_stream():
75    global server
76    server.start()
77    res = server.make_stream_request("POST", "/v1/completions", data={
78        "n_predict": 256,
79        "prompt": "Once",
80        "stream": True,
81    })
82    content = ""
83    for data in res:
84        choice = data["choices"][0]
85        if choice["finish_reason"] == "length":
86            assert len(content) > 0
87        else:
88            assert choice["finish_reason"] is None
89            content += choice["text"]