1import pytest
2from openai import OpenAI
3from utils import *
4
5server: ServerProcess
6
7@pytest.fixture(autouse=True)
8def create_server():
9 global server
10 server = ServerPreset.tinyllama2()
11
12
13@pytest.mark.parametrize(
14 "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
15 [
16 (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
17 (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
18 (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
19 (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
20 (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
21 (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
22 ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
23 ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
24 (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
25 (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
26 ]
27)
28def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
29 global server
30 server.jinja = jinja
31 server.chat_template = chat_template
32 server.start()
33 res = server.make_request("POST", "/chat/completions", data={
34 "model": model,
35 "max_tokens": max_tokens,
36 "messages": [
37 {"role": "system", "content": system_prompt},
38 {"role": "user", "content": user_prompt},
39 ],
40 })
41 assert res.status_code == 200
42 assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
43 assert res.body["system_fingerprint"].startswith("b")
44 # we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
45 # assert res.body["model"] == model if model is not None else server.model_alias
46 assert res.body["usage"]["prompt_tokens"] == n_prompt
47 assert res.body["usage"]["completion_tokens"] == n_predicted
48 choice = res.body["choices"][0]
49 assert "assistant" == choice["message"]["role"]
50 assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
51 assert choice["finish_reason"] == finish_reason
52
53
54@pytest.mark.parametrize(
55 "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
56 [
57 ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
58 ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
59 ]
60)
61def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
62 global server
63 server.model_alias = "llama-test-model"
64 server.start()
65 res = server.make_stream_request("POST", "/chat/completions", data={
66 "max_tokens": max_tokens,
67 "messages": [
68 {"role": "system", "content": system_prompt},
69 {"role": "user", "content": user_prompt},
70 ],
71 "stream": True,
72 })
73 content = ""
74 last_cmpl_id = None
75 for i, data in enumerate(res):
76 if data["choices"]:
77 choice = data["choices"][0]
78 if i == 0:
79 # Check first role message for stream=True
80 assert choice["delta"]["content"] is None
81 assert choice["delta"]["role"] == "assistant"
82 else:
83 assert "role" not in choice["delta"]
84 assert data["system_fingerprint"].startswith("b")
85 assert data["model"] == "llama-test-model"
86 if last_cmpl_id is None:
87 last_cmpl_id = data["id"]
88 assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
89 if choice["finish_reason"] in ["stop", "length"]:
90 assert "content" not in choice["delta"]
91 assert match_regex(re_content, content)
92 assert choice["finish_reason"] == finish_reason
93 else:
94 assert choice["finish_reason"] is None
95 content += choice["delta"]["content"] or ''
96 else:
97 assert data["usage"]["prompt_tokens"] == n_prompt
98 assert data["usage"]["completion_tokens"] == n_predicted
99
100
101def test_chat_completion_with_openai_library():
102 global server
103 server.start()
104 client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
105 res = client.chat.completions.create(
106 model="gpt-3.5-turbo-instruct",
107 messages=[
108 {"role": "system", "content": "Book"},
109 {"role": "user", "content": "What is the best book"},
110 ],
111 max_tokens=8,
112 seed=42,
113 temperature=0.8,
114 )
115 assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
116 assert res.choices[0].finish_reason == "length"
117 assert res.choices[0].message.content is not None
118 assert match_regex("(Suddenly)+", res.choices[0].message.content)
119
120
121def test_chat_template():
122 global server
123 server.chat_template = "llama3"
124 server.debug = True # to get the "__verbose" object in the response
125 server.start()
126 res = server.make_request("POST", "/chat/completions", data={
127 "max_tokens": 8,
128 "messages": [
129 {"role": "system", "content": "Book"},
130 {"role": "user", "content": "What is the best book"},
131 ]
132 })
133 assert res.status_code == 200
134 assert "__verbose" in res.body
135 assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
136
137
138@pytest.mark.parametrize("prefill,re_prefill", [
139 ("Whill", "Whill"),
140 ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
141])
142def test_chat_template_assistant_prefill(prefill, re_prefill):
143 global server
144 server.chat_template = "llama3"
145 server.debug = True # to get the "__verbose" object in the response
146 server.start()
147 res = server.make_request("POST", "/chat/completions", data={
148 "max_tokens": 8,
149 "messages": [
150 {"role": "system", "content": "Book"},
151 {"role": "user", "content": "What is the best book"},
152 {"role": "assistant", "content": prefill},
153 ]
154 })
155 assert res.status_code == 200
156 assert "__verbose" in res.body
157 assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
158
159
160def test_apply_chat_template():
161 global server
162 server.chat_template = "command-r"
163 server.start()
164 res = server.make_request("POST", "/apply-template", data={
165 "messages": [
166 {"role": "system", "content": "You are a test."},
167 {"role": "user", "content":"Hi there"},
168 ]
169 })
170 assert res.status_code == 200
171 assert "prompt" in res.body
172 assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
173
174
175@pytest.mark.parametrize("response_format,n_predicted,re_content", [
176 ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
177 ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
178 ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
179 ({"type": "json_object"}, 10, "(\\{|John)+"),
180 ({"type": "sound"}, 0, None),
181 # invalid response format (expected to fail)
182 ({"type": "json_object", "schema": 123}, 0, None),
183 ({"type": "json_object", "schema": {"type": 123}}, 0, None),
184 ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
185])
186def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
187 global server
188 server.start()
189 res = server.make_request("POST", "/chat/completions", data={
190 "max_tokens": n_predicted,
191 "messages": [
192 {"role": "system", "content": "You are a coding assistant."},
193 {"role": "user", "content": "Write an example"},
194 ],
195 "response_format": response_format,
196 })
197 if re_content is not None:
198 assert res.status_code == 200
199 choice = res.body["choices"][0]
200 assert match_regex(re_content, choice["message"]["content"])
201 else:
202 assert res.status_code == 400
203 assert "error" in res.body
204
205
206@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
207 (False, {"const": "42"}, 6, "\"42\""),
208 (True, {"const": "42"}, 6, "\"42\""),
209])
210def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
211 global server
212 server.jinja = jinja
213 server.start()
214 res = server.make_request("POST", "/chat/completions", data={
215 "max_tokens": n_predicted,
216 "messages": [
217 {"role": "system", "content": "You are a coding assistant."},
218 {"role": "user", "content": "Write an example"},
219 ],
220 "json_schema": json_schema,
221 })
222 assert res.status_code == 200, f'Expected 200, got {res.status_code}'
223 choice = res.body["choices"][0]
224 assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
225
226
227@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
228 (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
229 (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
230])
231def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
232 global server
233 server.jinja = jinja
234 server.start()
235 res = server.make_request("POST", "/chat/completions", data={
236 "max_tokens": n_predicted,
237 "messages": [
238 {"role": "user", "content": "Does not matter what I say, does it?"},
239 ],
240 "grammar": grammar,
241 })
242 assert res.status_code == 200, res.body
243 choice = res.body["choices"][0]
244 assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
245
246
247@pytest.mark.parametrize("messages", [
248 None,
249 "string",
250 [123],
251 [{}],
252 [{"role": 123}],
253 [{"role": "system", "content": 123}],
254 # [{"content": "hello"}], # TODO: should not be a valid case
255 [{"role": "system", "content": "test"}, {}],
256 [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
257])
258def test_invalid_chat_completion_req(messages):
259 global server
260 server.start()
261 res = server.make_request("POST", "/chat/completions", data={
262 "messages": messages,
263 })
264 assert res.status_code == 400 or res.status_code == 500
265 assert "error" in res.body
266
267
268def test_chat_completion_with_timings_per_token():
269 global server
270 server.start()
271 res = server.make_stream_request("POST", "/chat/completions", data={
272 "max_tokens": 10,
273 "messages": [{"role": "user", "content": "test"}],
274 "stream": True,
275 "stream_options": {"include_usage": True},
276 "timings_per_token": True,
277 })
278 stats_received = False
279 for i, data in enumerate(res):
280 if i == 0:
281 # Check first role message for stream=True
282 assert data["choices"][0]["delta"]["content"] is None
283 assert data["choices"][0]["delta"]["role"] == "assistant"
284 assert "timings" not in data, f'First event should not have timings: {data}'
285 else:
286 if data["choices"]:
287 assert "role" not in data["choices"][0]["delta"]
288 else:
289 assert "timings" in data
290 assert "prompt_per_second" in data["timings"]
291 assert "predicted_per_second" in data["timings"]
292 assert "predicted_n" in data["timings"]
293 assert data["timings"]["predicted_n"] <= 10
294 stats_received = True
295 assert stats_received
296
297
298def test_logprobs():
299 global server
300 server.start()
301 client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
302 res = client.chat.completions.create(
303 model="gpt-3.5-turbo-instruct",
304 temperature=0.0,
305 messages=[
306 {"role": "system", "content": "Book"},
307 {"role": "user", "content": "What is the best book"},
308 ],
309 max_tokens=5,
310 logprobs=True,
311 top_logprobs=10,
312 )
313 output_text = res.choices[0].message.content
314 aggregated_text = ''
315 assert res.choices[0].logprobs is not None
316 assert res.choices[0].logprobs.content is not None
317 for token in res.choices[0].logprobs.content:
318 aggregated_text += token.token
319 assert token.logprob <= 0.0
320 assert token.bytes is not None
321 assert len(token.top_logprobs) > 0
322 assert aggregated_text == output_text
323
324
325def test_logprobs_stream():
326 global server
327 server.start()
328 client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
329 res = client.chat.completions.create(
330 model="gpt-3.5-turbo-instruct",
331 temperature=0.0,
332 messages=[
333 {"role": "system", "content": "Book"},
334 {"role": "user", "content": "What is the best book"},
335 ],
336 max_tokens=5,
337 logprobs=True,
338 top_logprobs=10,
339 stream=True,
340 )
341 output_text = ''
342 aggregated_text = ''
343 for i, data in enumerate(res):
344 if data.choices:
345 choice = data.choices[0]
346 if i == 0:
347 # Check first role message for stream=True
348 assert choice.delta.content is None
349 assert choice.delta.role == "assistant"
350 else:
351 assert choice.delta.role is None
352 if choice.finish_reason is None:
353 if choice.delta.content:
354 output_text += choice.delta.content
355 assert choice.logprobs is not None
356 assert choice.logprobs.content is not None
357 for token in choice.logprobs.content:
358 aggregated_text += token.token
359 assert token.logprob <= 0.0
360 assert token.bytes is not None
361 assert token.top_logprobs is not None
362 assert len(token.top_logprobs) > 0
363 assert aggregated_text == output_text
364
365
366def test_logit_bias():
367 global server
368 server.start()
369
370 exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
371
372 res = server.make_request("POST", "/tokenize", data={
373 "content": " " + " ".join(exclude) + " ",
374 })
375 assert res.status_code == 200
376 tokens = res.body["tokens"]
377 logit_bias = {tok: -100 for tok in tokens}
378
379 client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
380 res = client.chat.completions.create(
381 model="gpt-3.5-turbo-instruct",
382 temperature=0.0,
383 messages=[
384 {"role": "system", "content": "Book"},
385 {"role": "user", "content": "What is the best book"},
386 ],
387 max_tokens=64,
388 logit_bias=logit_bias
389 )
390 output_text = res.choices[0].message.content
391 assert output_text
392 assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
393
394def test_context_size_exceeded():
395 global server
396 server.start()
397 res = server.make_request("POST", "/chat/completions", data={
398 "messages": [
399 {"role": "system", "content": "Book"},
400 {"role": "user", "content": "What is the best book"},
401 ] * 100, # make the prompt too long
402 })
403 assert res.status_code == 400
404 assert "error" in res.body
405 assert res.body["error"]["type"] == "exceed_context_size_error"
406 assert res.body["error"]["n_prompt_tokens"] > 0
407 assert server.n_ctx is not None
408 assert server.n_slots is not None
409 assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
410
411
412def test_context_size_exceeded_stream():
413 global server
414 server.start()
415 try:
416 for _ in server.make_stream_request("POST", "/chat/completions", data={
417 "messages": [
418 {"role": "system", "content": "Book"},
419 {"role": "user", "content": "What is the best book"},
420 ] * 100, # make the prompt too long
421 "stream": True}):
422 pass
423 assert False, "Should have failed"
424 except ServerError as e:
425 assert e.code == 400
426 assert "error" in e.body
427 assert e.body["error"]["type"] == "exceed_context_size_error"
428 assert e.body["error"]["n_prompt_tokens"] > 0
429 assert server.n_ctx is not None
430 assert server.n_slots is not None
431 assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
432
433
434@pytest.mark.parametrize(
435 "n_batch,batch_count,reuse_cache",
436 [
437 (64, 4, False),
438 (64, 2, True),
439 ]
440)
441def test_return_progress(n_batch, batch_count, reuse_cache):
442 global server
443 server.n_batch = n_batch
444 server.n_ctx = 256
445 server.n_slots = 1
446 server.start()
447 def make_cmpl_request():
448 return server.make_stream_request("POST", "/chat/completions", data={
449 "max_tokens": 10,
450 "messages": [
451 {"role": "user", "content": "This is a test" * 10},
452 ],
453 "stream": True,
454 "return_progress": True,
455 })
456 if reuse_cache:
457 # make a first request to populate the cache
458 res0 = make_cmpl_request()
459 for _ in res0:
460 pass # discard the output
461
462 res = make_cmpl_request()
463 last_progress = None
464 total_batch_count = 0
465
466 for data in res:
467 cur_progress = data.get("prompt_progress", None)
468 if cur_progress is None:
469 continue
470 if total_batch_count == 0:
471 # first progress report must have n_cache == n_processed
472 assert cur_progress["total"] > 0
473 assert cur_progress["cache"] == cur_progress["processed"]
474 if reuse_cache:
475 # when reusing cache, we expect some cached tokens
476 assert cur_progress["cache"] > 0
477 if last_progress is not None:
478 assert cur_progress["total"] == last_progress["total"]
479 assert cur_progress["cache"] == last_progress["cache"]
480 assert cur_progress["processed"] > last_progress["processed"]
481 total_batch_count += 1
482 last_progress = cur_progress
483
484 # last progress should indicate completion (all tokens processed)
485 assert last_progress is not None
486 assert last_progress["total"] > 0
487 assert last_progress["processed"] == last_progress["total"]
488 assert total_batch_count == batch_count
489
490
491def test_chat_completions_multiple_choices():
492 global server
493 server.start()
494 # make sure cache can be reused across multiple choices and multiple requests
495 # ref: https://github.com/ggml-org/llama.cpp/pull/18663
496 for _ in range(2):
497 res = server.make_request("POST", "/chat/completions", data={
498 "max_tokens": 8,
499 "n": 2,
500 "messages": [
501 {"role": "system", "content": "Book"},
502 {"role": "user", "content": "What is the best book"},
503 ],
504 # test forcing the same slot to be used
505 # the scheduler should not be locked up in this case
506 "id_slot": 0,
507 })
508 assert res.status_code == 200
509 assert len(res.body["choices"]) == 2
510 for choice in res.body["choices"]:
511 assert "assistant" == choice["message"]["role"]
512 assert choice["finish_reason"] == "length"