1import pytest
  2from utils import *
  3
  4# We use a F16 MOE gguf as main model, and q4_0 as draft model
  5
  6server = ServerPreset.stories15m_moe()
  7
  8MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
  9
 10def create_server():
 11    global server
 12    server = ServerPreset.stories15m_moe()
 13    # set default values
 14    server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
 15    server.draft_min = 4
 16    server.draft_max = 8
 17    server.fa = "off"
 18
 19
 20@pytest.fixture(autouse=True)
 21def fixture_create_server():
 22    return create_server()
 23
 24
 25def test_with_and_without_draft():
 26    global server
 27    server.model_draft = None  # disable draft model
 28    server.start()
 29    res = server.make_request("POST", "/completion", data={
 30        "prompt": "I believe the meaning of life is",
 31        "temperature": 0.0,
 32        "top_k": 1,
 33        "n_predict": 16,
 34    })
 35    assert res.status_code == 200
 36    content_no_draft = res.body["content"]
 37    server.stop()
 38
 39    # create new server with draft model
 40    create_server()
 41    server.start()
 42    res = server.make_request("POST", "/completion", data={
 43        "prompt": "I believe the meaning of life is",
 44        "temperature": 0.0,
 45        "top_k": 1,
 46        "n_predict": 16,
 47    })
 48    assert res.status_code == 200
 49    content_draft = res.body["content"]
 50
 51    assert content_no_draft == content_draft
 52
 53
 54def test_different_draft_min_draft_max():
 55    global server
 56    test_values = [
 57        (1, 2),
 58        (1, 4),
 59        (4, 8),
 60        (4, 12),
 61        (8, 16),
 62    ]
 63    last_content = None
 64    for draft_min, draft_max in test_values:
 65        server.stop()
 66        server.draft_min = draft_min
 67        server.draft_max = draft_max
 68        server.start()
 69        res = server.make_request("POST", "/completion", data={
 70            "prompt": "I believe the meaning of life is",
 71            "temperature": 0.0,
 72            "top_k": 1,
 73            "n_predict": 16,
 74        })
 75        assert res.status_code == 200
 76        if last_content is not None:
 77            assert last_content == res.body["content"]
 78        last_content = res.body["content"]
 79
 80
 81def test_slot_ctx_not_exceeded():
 82    global server
 83    server.n_ctx = 256
 84    server.start()
 85    res = server.make_request("POST", "/completion", data={
 86        "prompt": "Hello " * 248,
 87        "temperature": 0.0,
 88        "top_k": 1,
 89        "speculative.p_min": 0.0,
 90    })
 91    assert res.status_code == 200
 92    assert len(res.body["content"]) > 0
 93
 94
 95def test_with_ctx_shift():
 96    global server
 97    server.n_ctx = 256
 98    server.enable_ctx_shift = True
 99    server.start()
100    res = server.make_request("POST", "/completion", data={
101        "prompt": "Hello " * 248,
102        "temperature": 0.0,
103        "top_k": 1,
104        "n_predict": 256,
105        "speculative.p_min": 0.0,
106    })
107    assert res.status_code == 200
108    assert len(res.body["content"]) > 0
109    assert res.body["tokens_predicted"] == 256
110    assert res.body["truncated"] == True
111
112
113@pytest.mark.parametrize("n_slots,n_requests", [
114    (1, 2),
115    (2, 2),
116])
117def test_multi_requests_parallel(n_slots: int, n_requests: int):
118    global server
119    server.n_slots = n_slots
120    server.start()
121    tasks = []
122    for _ in range(n_requests):
123        tasks.append((server.make_request, ("POST", "/completion", {
124            "prompt": "I believe the meaning of life is",
125            "temperature": 0.0,
126            "top_k": 1,
127        })))
128    results = parallel_function_calls(tasks)
129    for res in results:
130        assert res.status_code == 200
131        assert match_regex("(wise|kind|owl|answer)+", res.body["content"])