1import pytest
2from utils import *
3
4server = ServerPreset.tinyllama_infill()
5
6@pytest.fixture(autouse=True)
7def create_server():
8 global server
9 server = ServerPreset.tinyllama_infill()
10
11
12def test_infill_without_input_extra():
13 global server
14 server.start()
15 res = server.make_request("POST", "/infill", data={
16 "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
17 "prompt": " int n_threads = llama_",
18 "input_suffix": "}\n",
19 })
20 assert res.status_code == 200
21 assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
22
23
24def test_infill_with_input_extra():
25 global server
26 server.start()
27 res = server.make_request("POST", "/infill", data={
28 "input_extra": [{
29 "filename": "llama.h",
30 "text": "LLAMA_API int32_t llama_n_threads();\n"
31 }],
32 "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
33 "prompt": " int n_threads = llama_",
34 "input_suffix": "}\n",
35 })
36 assert res.status_code == 200
37 assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
38
39
40@pytest.mark.parametrize("input_extra", [
41 {},
42 {"filename": "ok"},
43 {"filename": 123},
44 {"filename": 123, "text": "abc"},
45 {"filename": 123, "text": 456},
46])
47def test_invalid_input_extra_req(input_extra):
48 global server
49 server.start()
50 res = server.make_request("POST", "/infill", data={
51 "input_extra": [input_extra],
52 "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
53 "prompt": " int n_threads = llama_",
54 "input_suffix": "}\n",
55 })
56 assert res.status_code == 400
57 assert "error" in res.body
58
59
60@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
61def test_with_qwen_model():
62 global server
63 server.model_file = None
64 server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
65 server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
66 server.start(timeout_seconds=600)
67 res = server.make_request("POST", "/infill", data={
68 "input_extra": [{
69 "filename": "llama.h",
70 "text": "LLAMA_API int32_t llama_n_threads();\n"
71 }],
72 "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
73 "prompt": " int n_threads = llama_",
74 "input_suffix": "}\n",
75 })
76 assert res.status_code == 200
77 assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"