1#!/usr/bin/env uv run
2'''
3 Simplistic tool call benchmarks for llama-server and ollama.
4
5 Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama),
6 and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap.
7
8 Simple usage example:
9
10 cmake -B build && cmake --build build --config Release -j -t llama-server
11
12 export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
13 export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
14
15 ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
16 ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
17 ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
18
19 ./scripts/tool_bench.py plot *.jsonl # Opens window w/ heatmap
20 ./scripts/tool_bench.py plot qwen*.jsonl --output qwen.png # Saves heatmap to qwen.png
21
22 (please see ./scripts/tool_bench.sh for a more complete example)
23'''
24# /// script
25# requires-python = ">=3.10"
26# dependencies = [
27# "pytest",
28# "pandas",
29# "matplotlib",
30# "seaborn",
31# "requests",
32# "wget",
33# "typer",
34# ]
35# ///
36from contextlib import contextmanager
37from pathlib import Path
38import re
39from statistics import mean, median
40from typing import Annotated, Dict, List, Optional, Tuple
41import atexit
42import json
43import logging
44import matplotlib.pyplot as plt
45import numpy as np
46import pandas as pd
47import seaborn as sns
48import subprocess
49import sys
50import time
51import typer
52
53sys.path.insert(0, Path(__file__).parent.parent.as_posix())
54if True:
55 from tools.server.tests.utils import ServerProcess
56 from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather
57
58
59@contextmanager
60def scoped_server(sp: ServerProcess):
61 def stop():
62 nonlocal sp
63 if sp is not None:
64 sp.stop()
65 sp = None # type: ignore
66 atexit.register(stop)
67 yield sp
68 stop()
69
70
71logging.basicConfig(
72 level=logging.INFO,
73 format='%(asctime)s - %(levelname)s - %(message)s'
74)
75logger = logging.getLogger(__name__)
76
77app = typer.Typer()
78
79
80@app.command()
81def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None):
82
83 lines: List[Dict] = []
84 for file in files:
85 if not file.exists():
86 logger.error(f"File not found: {file}")
87 continue
88
89 try:
90 with file.open() as f:
91 raw_data = f.read()
92 logger.info(f"Reading {file} ({len(raw_data)} bytes)")
93
94 for line_num, line in enumerate(raw_data.split('\n'), 1):
95 line = line.strip()
96 if not line:
97 continue
98 try:
99 record = json.loads(line)
100 lines.append(record)
101 except json.JSONDecodeError as e:
102 logger.warning(f"Invalid JSON at {file}:{line_num} - {e}")
103 except Exception as e:
104 logger.error(f"Error processing {file}: {e}")
105
106 if not lines:
107 raise Exception("No valid data was loaded")
108
109 data_dict: Dict[Tuple, float] = {}
110 models: List[str] = []
111 temps = set()
112 tests = set()
113 server_names = set()
114 total_counts = set()
115 for rec in lines:
116 try:
117 model = rec["model"]
118 temp = rec["temp"]
119 server_name = rec["server_name"]
120 test = rec["test"]
121 success = rec["success_ratio"]
122 success_count = rec["success_count"]
123 failure_count = rec["failure_count"]
124 total_count = success_count + failure_count
125 total_counts.add(total_count)
126
127 if test_regex and not re.search(test_regex, test):
128 continue
129
130 if server_regex and not re.search(server_regex, server_name):
131 continue
132
133 data_dict[(model, temp, server_name, test)] = success
134
135 if model not in models:
136 models.append(model)
137 temps.add(temp)
138 tests.add(test)
139 server_names.add(server_name)
140
141 except KeyError as e:
142 logger.warning(f"Missing required field in record: {e}")
143
144 if len(total_counts) > 1:
145 logger.warning(f"Total counts are not consistent: {total_counts}")
146
147 # Sort the collected values
148 temps = list(sorted(temps, key=lambda x: x if x is not None else -1))
149 tests = list(sorted(tests))
150 server_names = list(sorted(server_names))
151
152 logger.info(f"Processed {len(lines)} lines")
153 logger.info(f"Found {len(data_dict)} valid data points")
154 logger.info(f"Models: {models}")
155 logger.info(f"Temperatures: {temps}")
156 logger.info(f"Tests: {tests}")
157 logger.info(f"Servers: {server_names}")
158
159 matrix: list[list[float]] = []
160 index: list[str] = []
161
162 all_cols = [
163 (server_name, test)
164 for server_name in server_names
165 for test in tests
166 ]
167 for model in models:
168 for temp in temps:
169 index.append(f"{model} @ {temp}")
170 row_vals = [
171 data_dict.get((model, temp, server_name, test), np.nan)
172 for server_name, test in all_cols
173 ]
174 matrix.append(row_vals)
175
176 columns: list[str] = [f"{server_name}\n{test}" for server_name, test in all_cols]
177
178 df = pd.DataFrame(matrix, index=np.array(index), columns=np.array(columns))
179
180 plt.figure(figsize=(12, 6))
181
182 sns.heatmap(
183 df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5,
184 cbar_kws={"label": "Success Ratio"},
185 )
186
187 plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20)
188 plt.xlabel("Server & Test", labelpad=10)
189 plt.ylabel("Model @ Temperature", labelpad=10)
190
191 plt.xticks(rotation=45, ha='right')
192 plt.yticks(rotation=0)
193
194 plt.tight_layout()
195
196 if output:
197 plt.savefig(output, dpi=300, bbox_inches='tight')
198 logger.info(f"Plot saved to {output}")
199 else:
200 plt.show()
201
202
203@app.command()
204def run(
205 output: Annotated[Path, typer.Option(help="Output JSON file")],
206 model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
207 hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
208 chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
209 chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
210 ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
211 llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
212 n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
213 temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None,
214 top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None,
215 top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None,
216 ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None,
217 ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None,
218 fa: Annotated[Optional[bool], typer.Option(help="fa")] = None,
219 seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None,
220 port: Annotated[int, typer.Option(help="llama-server port")] = 8084,
221 force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False,
222 append: Annotated[bool, typer.Option(help="Append to output file")] = False,
223
224 test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True,
225 test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True,
226 test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False,
227):
228 # Check only one of output and append
229
230 n_predict = 512 # High because of DeepSeek R1
231 # n_ctx = 8192
232 n_ctx = 2048
233
234 if model is None:
235 if hf is not None:
236 model = hf.split("/")[-1]
237 elif ollama is not None:
238 model = ollama
239
240 assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
241
242 with output.open('a' if append else 'w') as output_file:
243
244 def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}):
245 request_kwargs = {**request_kwargs}
246 if temp is not None:
247 request_kwargs['temperature'] = temp
248 if top_p is not None:
249 request_kwargs['top_p'] = top_p
250 if top_k is not None:
251 request_kwargs['top_k'] = top_k
252 if seed is not None:
253 request_kwargs['seed'] = seed
254
255 request_kwargs['cache_prompt'] = False
256
257 tests = {}
258 if test_hello_world:
259 tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs)
260 if test_weather:
261 tests["weather"] = lambda server: do_test_weather(server, **request_kwargs)
262 if test_calc_result:
263 tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs)
264
265 for test_name, test in tests.items():
266 success_count = 0
267 failure_count = 0
268 failures = []
269 success_times = []
270 failure_times = []
271 logger.info(f"Running {test_name} ({server_name}, {model}): ")
272 for i in range(n):
273 start_time = time.time()
274
275 def elapsed():
276 return time.time() - start_time
277
278 try:
279 test(server)
280 success_times.append(elapsed())
281 success_count += 1
282 logger.info('success')
283 except Exception as e:
284 logger.error(f'failure: {e}')
285 failure_count += 1
286 failure_times.append(elapsed())
287 failures.append(str(e))
288 # import traceback
289 # traceback.print_exc()
290 output_file.write(json.dumps({**output_kwargs, **dict(
291 model=model,
292 server_name=server_name,
293 model_id=model_id,
294 test=test_name,
295 temp=t,
296 top_p=top_p,
297 top_k=top_k,
298 ctk=ctk,
299 ctv=ctv,
300 seed=seed,
301 success_ratio=float(success_count) / n,
302 avg_time=mean(success_times + failure_times),
303 median_time=median(success_times + failure_times),
304 success_count=success_count,
305 success_times=success_times,
306 failure_count=failure_count,
307 failure_times=failure_times,
308 failures=list(set(failures)),
309 )}) + '\n')
310 output_file.flush()
311
312 for t in [None] if temp is None else [t if t >= 0 else None for t in temp]:
313 if hf is not None:
314
315 servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)]
316 if llama_baseline is not None:
317 servers.append(('llama-server (baseline)', llama_baseline))
318
319 for server_name, server_path in servers:
320 server = ServerProcess()
321 server.n_ctx = n_ctx
322 server.n_slots = 1
323 server.jinja = True
324 server.ctk = ctk
325 server.ctv = ctv
326 server.fa = "on" if fa else "off"
327 server.n_predict = n_predict
328 server.model_hf_repo = hf
329 server.model_hf_file = None
330 server.chat_template = chat_template
331 server.chat_template_file = chat_template_file
332 server.server_path = server_path
333 if port is not None:
334 server.server_port = port
335 # server.debug = True
336
337 with scoped_server(server):
338 server.start(timeout_seconds=15 * 60)
339 for ignore_chat_grammar in [False]:
340 run(
341 server,
342 server_name=server_name,
343 model_id=hf,
344 temp=t,
345 output_kwargs=dict(
346 chat_template=chat_template,
347 chat_template_file=chat_template_file,
348 ),
349 request_kwargs=dict(
350 ignore_chat_grammar=ignore_chat_grammar,
351 ),
352 )
353
354 if ollama is not None:
355 server = ServerProcess()
356 server.server_port = 11434
357 server.server_host = "localhost"
358 subprocess.check_call(["ollama", "pull", ollama])
359
360 with scoped_server(server):
361 run(
362 server,
363 server_name="ollama",
364 model_id=ollama,
365 temp=t,
366 output_kwargs=dict(
367 chat_template=None,
368 chat_template_file=None,
369 ),
370 request_kwargs=dict(
371 model=ollama,
372 max_tokens=n_predict,
373 num_ctx = n_ctx,
374 ),
375 )
376
377
378if __name__ == "__main__":
379 app()