1#!/usr/bin/env python
2import pytest
3
4# ensure grandparent path is in sys.path
5from pathlib import Path
6import sys
7
8from unit.test_tool_call import TEST_TOOL
9path = Path(__file__).resolve().parents[1]
10sys.path.insert(0, str(path))
11
12import datetime
13from utils import *
14
15server: ServerProcess
16
17@pytest.fixture(autouse=True)
18def create_server():
19 global server
20 server = ServerPreset.tinyllama2()
21 server.model_alias = "tinyllama-2"
22 server.n_slots = 1
23
24
25@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
26@pytest.mark.parametrize("template_name,reasoning_budget,expected_end", [
27 ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", None, "<think>\n"),
28 ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", -1, "<think>\n"),
29 ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", 0, "<think>\n</think>"),
30
31 ("Qwen-Qwen3-0.6B", -1, "<|im_start|>assistant\n"),
32 ("Qwen-Qwen3-0.6B", 0, "<|im_start|>assistant\n<think>\n\n</think>\n\n"),
33
34 ("Qwen-QwQ-32B", -1, "<|im_start|>assistant\n<think>\n"),
35 ("Qwen-QwQ-32B", 0, "<|im_start|>assistant\n<think>\n</think>"),
36
37 ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", -1, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"),
38 ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", 0, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"),
39])
40def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expected_end: str, tools: list[dict]):
41 global server
42 server.jinja = True
43 server.reasoning_budget = reasoning_budget
44 server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
45 server.start()
46
47 res = server.make_request("POST", "/apply-template", data={
48 "messages": [
49 {"role": "user", "content": "What is today?"},
50 ],
51 "tools": tools,
52 })
53 assert res.status_code == 200
54 prompt = res.body["prompt"]
55
56 assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'"
57
58
59@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
60@pytest.mark.parametrize("template_name,format", [
61 ("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
62 ("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
63])
64def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
65 global server
66 server.jinja = True
67 server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
68 server.start()
69
70 res = server.make_request("POST", "/apply-template", data={
71 "messages": [
72 {"role": "user", "content": "What is today?"},
73 ],
74 "tools": tools,
75 })
76 assert res.status_code == 200
77 prompt = res.body["prompt"]
78
79 today_str = datetime.date.today().strftime(format)
80 assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
81
82
83@pytest.mark.parametrize("add_generation_prompt", [False, True])
84@pytest.mark.parametrize("template_name,expected_generation_prompt", [
85 ("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
86])
87def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
88 global server
89 server.jinja = True
90 server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
91 server.start()
92
93 res = server.make_request("POST", "/apply-template", data={
94 "messages": [
95 {"role": "user", "content": "What is today?"},
96 ],
97 "add_generation_prompt": add_generation_prompt,
98 })
99 assert res.status_code == 200
100 prompt = res.body["prompt"]
101
102 if add_generation_prompt:
103 assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
104 else:
105 assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"