1import pytest
  2from openai import OpenAI
  3from utils import *
  4
  5server: ServerProcess
  6
  7@pytest.fixture(autouse=True)
  8def create_server():
  9    global server
 10    server = ServerPreset.tinyllama2()
 11
 12
 13@pytest.mark.parametrize(
 14    "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
 15    [
 16        (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
 17        (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
 18        (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
 19        (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True,  None),
 20        (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
 21        (None, "Book", "What is the best book", 8, "^ blue",                    23, 8, "length", True, "This is not a chat template, it is"),
 22        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
 23        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
 24        (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
 25        (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
 26    ]
 27)
 28def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
 29    global server
 30    server.jinja = jinja
 31    server.chat_template = chat_template
 32    server.start()
 33    res = server.make_request("POST", "/chat/completions", data={
 34        "model": model,
 35        "max_tokens": max_tokens,
 36        "messages": [
 37            {"role": "system", "content": system_prompt},
 38            {"role": "user", "content": user_prompt},
 39        ],
 40    })
 41    assert res.status_code == 200
 42    assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
 43    assert res.body["system_fingerprint"].startswith("b")
 44    # we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
 45    # assert res.body["model"] == model if model is not None else server.model_alias
 46    assert res.body["usage"]["prompt_tokens"] == n_prompt
 47    assert res.body["usage"]["completion_tokens"] == n_predicted
 48    choice = res.body["choices"][0]
 49    assert "assistant" == choice["message"]["role"]
 50    assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
 51    assert choice["finish_reason"] == finish_reason
 52
 53
 54@pytest.mark.parametrize(
 55    "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
 56    [
 57        ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
 58        ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
 59    ]
 60)
 61def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
 62    global server
 63    server.model_alias = "llama-test-model"
 64    server.start()
 65    res = server.make_stream_request("POST", "/chat/completions", data={
 66        "max_tokens": max_tokens,
 67        "messages": [
 68            {"role": "system", "content": system_prompt},
 69            {"role": "user", "content": user_prompt},
 70        ],
 71        "stream": True,
 72    })
 73    content = ""
 74    last_cmpl_id = None
 75    for i, data in enumerate(res):
 76        if data["choices"]:
 77            choice = data["choices"][0]
 78            if i == 0:
 79                # Check first role message for stream=True
 80                assert choice["delta"]["content"] is None
 81                assert choice["delta"]["role"] == "assistant"
 82            else:
 83                assert "role" not in choice["delta"]
 84            assert data["system_fingerprint"].startswith("b")
 85            assert data["model"] == "llama-test-model"
 86            if last_cmpl_id is None:
 87                last_cmpl_id = data["id"]
 88            assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
 89            if choice["finish_reason"] in ["stop", "length"]:
 90                assert "content" not in choice["delta"]
 91                assert match_regex(re_content, content)
 92                assert choice["finish_reason"] == finish_reason
 93            else:
 94                assert choice["finish_reason"] is None
 95                content += choice["delta"]["content"] or ''
 96        else:
 97            assert data["usage"]["prompt_tokens"] == n_prompt
 98            assert data["usage"]["completion_tokens"] == n_predicted
 99
100
101def test_chat_completion_with_openai_library():
102    global server
103    server.start()
104    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
105    res = client.chat.completions.create(
106        model="gpt-3.5-turbo-instruct",
107        messages=[
108            {"role": "system", "content": "Book"},
109            {"role": "user", "content": "What is the best book"},
110        ],
111        max_tokens=8,
112        seed=42,
113        temperature=0.8,
114    )
115    assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
116    assert res.choices[0].finish_reason == "length"
117    assert res.choices[0].message.content is not None
118    assert match_regex("(Suddenly)+", res.choices[0].message.content)
119
120
121def test_chat_template():
122    global server
123    server.chat_template = "llama3"
124    server.debug = True  # to get the "__verbose" object in the response
125    server.start()
126    res = server.make_request("POST", "/chat/completions", data={
127        "max_tokens": 8,
128        "messages": [
129            {"role": "system", "content": "Book"},
130            {"role": "user", "content": "What is the best book"},
131        ]
132    })
133    assert res.status_code == 200
134    assert "__verbose" in res.body
135    assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
136
137
138@pytest.mark.parametrize("prefill,re_prefill", [
139    ("Whill", "Whill"),
140    ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
141])
142def test_chat_template_assistant_prefill(prefill, re_prefill):
143    global server
144    server.chat_template = "llama3"
145    server.debug = True  # to get the "__verbose" object in the response
146    server.start()
147    res = server.make_request("POST", "/chat/completions", data={
148        "max_tokens": 8,
149        "messages": [
150            {"role": "system", "content": "Book"},
151            {"role": "user", "content": "What is the best book"},
152            {"role": "assistant", "content": prefill},
153        ]
154    })
155    assert res.status_code == 200
156    assert "__verbose" in res.body
157    assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
158
159
160def test_apply_chat_template():
161    global server
162    server.chat_template = "command-r"
163    server.start()
164    res = server.make_request("POST", "/apply-template", data={
165        "messages": [
166            {"role": "system", "content": "You are a test."},
167            {"role": "user", "content":"Hi there"},
168        ]
169    })
170    assert res.status_code == 200
171    assert "prompt" in res.body
172    assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
173
174
175@pytest.mark.parametrize("response_format,n_predicted,re_content", [
176    ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
177    ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
178    ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
179    ({"type": "json_object"}, 10, "(\\{|John)+"),
180    ({"type": "sound"}, 0, None),
181    # invalid response format (expected to fail)
182    ({"type": "json_object", "schema": 123}, 0, None),
183    ({"type": "json_object", "schema": {"type": 123}}, 0, None),
184    ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
185])
186def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
187    global server
188    server.start()
189    res = server.make_request("POST", "/chat/completions", data={
190        "max_tokens": n_predicted,
191        "messages": [
192            {"role": "system", "content": "You are a coding assistant."},
193            {"role": "user", "content": "Write an example"},
194        ],
195        "response_format": response_format,
196    })
197    if re_content is not None:
198        assert res.status_code == 200
199        choice = res.body["choices"][0]
200        assert match_regex(re_content, choice["message"]["content"])
201    else:
202        assert res.status_code == 400
203        assert "error" in res.body
204
205
206@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
207    (False, {"const": "42"}, 6, "\"42\""),
208    (True, {"const": "42"}, 6, "\"42\""),
209])
210def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
211    global server
212    server.jinja = jinja
213    server.start()
214    res = server.make_request("POST", "/chat/completions", data={
215        "max_tokens": n_predicted,
216        "messages": [
217            {"role": "system", "content": "You are a coding assistant."},
218            {"role": "user", "content": "Write an example"},
219        ],
220        "json_schema": json_schema,
221    })
222    assert res.status_code == 200, f'Expected 200, got {res.status_code}'
223    choice = res.body["choices"][0]
224    assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
225
226
227@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
228    (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
229    (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
230])
231def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
232    global server
233    server.jinja = jinja
234    server.start()
235    res = server.make_request("POST", "/chat/completions", data={
236        "max_tokens": n_predicted,
237        "messages": [
238            {"role": "user", "content": "Does not matter what I say, does it?"},
239        ],
240        "grammar": grammar,
241    })
242    assert res.status_code == 200, res.body
243    choice = res.body["choices"][0]
244    assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
245
246
247@pytest.mark.parametrize("messages", [
248    None,
249    "string",
250    [123],
251    [{}],
252    [{"role": 123}],
253    [{"role": "system", "content": 123}],
254    # [{"content": "hello"}], # TODO: should not be a valid case
255    [{"role": "system", "content": "test"}, {}],
256    [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
257])
258def test_invalid_chat_completion_req(messages):
259    global server
260    server.start()
261    res = server.make_request("POST", "/chat/completions", data={
262        "messages": messages,
263    })
264    assert res.status_code == 400 or res.status_code == 500
265    assert "error" in res.body
266
267
268def test_chat_completion_with_timings_per_token():
269    global server
270    server.start()
271    res = server.make_stream_request("POST", "/chat/completions", data={
272        "max_tokens": 10,
273        "messages": [{"role": "user", "content": "test"}],
274        "stream": True,
275        "stream_options": {"include_usage": True},
276        "timings_per_token": True,
277    })
278    stats_received = False
279    for i, data in enumerate(res):
280        if i == 0:
281            # Check first role message for stream=True
282            assert data["choices"][0]["delta"]["content"] is None
283            assert data["choices"][0]["delta"]["role"] == "assistant"
284            assert "timings" not in data, f'First event should not have timings: {data}'
285        else:
286            if data["choices"]:
287                assert "role" not in data["choices"][0]["delta"]
288            else:
289                assert "timings" in data
290                assert "prompt_per_second" in data["timings"]
291                assert "predicted_per_second" in data["timings"]
292                assert "predicted_n" in data["timings"]
293                assert data["timings"]["predicted_n"] <= 10
294                stats_received = True
295    assert stats_received
296
297
298def test_logprobs():
299    global server
300    server.start()
301    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
302    res = client.chat.completions.create(
303        model="gpt-3.5-turbo-instruct",
304        temperature=0.0,
305        messages=[
306            {"role": "system", "content": "Book"},
307            {"role": "user", "content": "What is the best book"},
308        ],
309        max_tokens=5,
310        logprobs=True,
311        top_logprobs=10,
312    )
313    output_text = res.choices[0].message.content
314    aggregated_text = ''
315    assert res.choices[0].logprobs is not None
316    assert res.choices[0].logprobs.content is not None
317    for token in res.choices[0].logprobs.content:
318        aggregated_text += token.token
319        assert token.logprob <= 0.0
320        assert token.bytes is not None
321        assert len(token.top_logprobs) > 0
322    assert aggregated_text == output_text
323
324
325def test_logprobs_stream():
326    global server
327    server.start()
328    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
329    res = client.chat.completions.create(
330        model="gpt-3.5-turbo-instruct",
331        temperature=0.0,
332        messages=[
333            {"role": "system", "content": "Book"},
334            {"role": "user", "content": "What is the best book"},
335        ],
336        max_tokens=5,
337        logprobs=True,
338        top_logprobs=10,
339        stream=True,
340    )
341    output_text = ''
342    aggregated_text = ''
343    for i, data in enumerate(res):
344        if data.choices:
345            choice = data.choices[0]
346            if i == 0:
347                # Check first role message for stream=True
348                assert choice.delta.content is None
349                assert choice.delta.role == "assistant"
350            else:
351                assert choice.delta.role is None
352                if choice.finish_reason is None:
353                    if choice.delta.content:
354                        output_text += choice.delta.content
355                    assert choice.logprobs is not None
356                    assert choice.logprobs.content is not None
357                    for token in choice.logprobs.content:
358                        aggregated_text += token.token
359                        assert token.logprob <= 0.0
360                        assert token.bytes is not None
361                        assert token.top_logprobs is not None
362                        assert len(token.top_logprobs) > 0
363    assert aggregated_text == output_text
364
365
366def test_logit_bias():
367    global server
368    server.start()
369
370    exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
371
372    res = server.make_request("POST", "/tokenize", data={
373        "content": " " + " ".join(exclude) + " ",
374    })
375    assert res.status_code == 200
376    tokens = res.body["tokens"]
377    logit_bias = {tok: -100 for tok in tokens}
378
379    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
380    res = client.chat.completions.create(
381        model="gpt-3.5-turbo-instruct",
382        temperature=0.0,
383        messages=[
384            {"role": "system", "content": "Book"},
385            {"role": "user", "content": "What is the best book"},
386        ],
387        max_tokens=64,
388        logit_bias=logit_bias
389    )
390    output_text = res.choices[0].message.content
391    assert output_text
392    assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
393
394def test_context_size_exceeded():
395    global server
396    server.start()
397    res = server.make_request("POST", "/chat/completions", data={
398        "messages": [
399            {"role": "system", "content": "Book"},
400            {"role": "user", "content": "What is the best book"},
401        ] * 100, # make the prompt too long
402    })
403    assert res.status_code == 400
404    assert "error" in res.body
405    assert res.body["error"]["type"] == "exceed_context_size_error"
406    assert res.body["error"]["n_prompt_tokens"] > 0
407    assert server.n_ctx is not None
408    assert server.n_slots is not None
409    assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
410
411
412def test_context_size_exceeded_stream():
413    global server
414    server.start()
415    try:
416        for _ in server.make_stream_request("POST", "/chat/completions", data={
417            "messages": [
418                {"role": "system", "content": "Book"},
419                {"role": "user", "content": "What is the best book"},
420            ] * 100, # make the prompt too long
421            "stream": True}):
422                pass
423        assert False, "Should have failed"
424    except ServerError as e:
425        assert e.code == 400
426        assert "error" in e.body
427        assert e.body["error"]["type"] == "exceed_context_size_error"
428        assert e.body["error"]["n_prompt_tokens"] > 0
429        assert server.n_ctx is not None
430        assert server.n_slots is not None
431        assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
432
433
434@pytest.mark.parametrize(
435    "n_batch,batch_count,reuse_cache",
436    [
437        (64, 4, False),
438        (64, 2, True),
439    ]
440)
441def test_return_progress(n_batch, batch_count, reuse_cache):
442    global server
443    server.n_batch = n_batch
444    server.n_ctx = 256
445    server.n_slots = 1
446    server.start()
447    def make_cmpl_request():
448        return server.make_stream_request("POST", "/chat/completions", data={
449            "max_tokens": 10,
450            "messages": [
451                {"role": "user", "content": "This is a test" * 10},
452            ],
453            "stream": True,
454            "return_progress": True,
455        })
456    if reuse_cache:
457        # make a first request to populate the cache
458        res0 = make_cmpl_request()
459        for _ in res0:
460            pass # discard the output
461
462    res = make_cmpl_request()
463    last_progress = None
464    total_batch_count = 0
465
466    for data in res:
467        cur_progress = data.get("prompt_progress", None)
468        if cur_progress is None:
469            continue
470        if total_batch_count == 0:
471            # first progress report must have n_cache == n_processed
472            assert cur_progress["total"] > 0
473            assert cur_progress["cache"] == cur_progress["processed"]
474            if reuse_cache:
475                # when reusing cache, we expect some cached tokens
476                assert cur_progress["cache"] > 0
477        if last_progress is not None:
478            assert cur_progress["total"] == last_progress["total"]
479            assert cur_progress["cache"] == last_progress["cache"]
480            assert cur_progress["processed"] > last_progress["processed"]
481        total_batch_count += 1
482        last_progress = cur_progress
483
484    # last progress should indicate completion (all tokens processed)
485    assert last_progress is not None
486    assert last_progress["total"] > 0
487    assert last_progress["processed"] == last_progress["total"]
488    assert total_batch_count == batch_count
489
490
491def test_chat_completions_multiple_choices():
492    global server
493    server.start()
494    # make sure cache can be reused across multiple choices and multiple requests
495    # ref: https://github.com/ggml-org/llama.cpp/pull/18663
496    for _ in range(2):
497        res = server.make_request("POST", "/chat/completions", data={
498            "max_tokens": 8,
499            "n": 2,
500            "messages": [
501                {"role": "system", "content": "Book"},
502                {"role": "user", "content": "What is the best book"},
503            ],
504            # test forcing the same slot to be used
505            # the scheduler should not be locked up in this case
506            "id_slot": 0,
507        })
508        assert res.status_code == 200
509        assert len(res.body["choices"]) == 2
510        for choice in res.body["choices"]:
511            assert "assistant" == choice["message"]["role"]
512            assert choice["finish_reason"] == "length"