1#!/usr/bin/env uv run
  2'''
  3    Simplistic tool call benchmarks for llama-server and ollama.
  4
  5    Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama),
  6    and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap.
  7
  8    Simple usage example:
  9
 10        cmake -B build && cmake --build build --config Release -j -t llama-server
 11
 12        export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
 13        export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
 14
 15        ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
 16        ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M"      --output qwen1.5b.jsonl  --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF      --ollama qwen2.5:1.5b-instruct-q4_K_M
 17        ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M"  --output qwenc7b.jsonl   --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF  --ollama qwen2.5-coder:7b
 18
 19        ./scripts/tool_bench.py plot *.jsonl                         # Opens window w/ heatmap
 20        ./scripts/tool_bench.py plot qwen*.jsonl  --output qwen.png  # Saves heatmap to qwen.png
 21
 22    (please see ./scripts/tool_bench.sh for a more complete example)
 23'''
 24# /// script
 25# requires-python = ">=3.10"
 26# dependencies = [
 27#     "pytest",
 28#     "pandas",
 29#     "matplotlib",
 30#     "seaborn",
 31#     "requests",
 32#     "wget",
 33#     "typer",
 34# ]
 35# ///
 36from contextlib import contextmanager
 37from pathlib import Path
 38import re
 39from statistics import mean, median
 40from typing import Annotated, Dict, List, Optional, Tuple
 41import atexit
 42import json
 43import logging
 44import matplotlib.pyplot as plt
 45import numpy as np
 46import pandas as pd
 47import seaborn as sns
 48import subprocess
 49import sys
 50import time
 51import typer
 52
 53sys.path.insert(0, Path(__file__).parent.parent.as_posix())
 54if True:
 55    from tools.server.tests.utils import ServerProcess
 56    from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather
 57
 58
 59@contextmanager
 60def scoped_server(sp: ServerProcess):
 61    def stop():
 62        nonlocal sp
 63        if sp is not None:
 64            sp.stop()
 65            sp = None # type: ignore
 66    atexit.register(stop)
 67    yield sp
 68    stop()
 69
 70
 71logging.basicConfig(
 72    level=logging.INFO,
 73    format='%(asctime)s - %(levelname)s - %(message)s'
 74)
 75logger = logging.getLogger(__name__)
 76
 77app = typer.Typer()
 78
 79
 80@app.command()
 81def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None):
 82
 83    lines: List[Dict] = []
 84    for file in files:
 85        if not file.exists():
 86            logger.error(f"File not found: {file}")
 87            continue
 88
 89        try:
 90            with file.open() as f:
 91                raw_data = f.read()
 92            logger.info(f"Reading {file} ({len(raw_data)} bytes)")
 93
 94            for line_num, line in enumerate(raw_data.split('\n'), 1):
 95                line = line.strip()
 96                if not line:
 97                    continue
 98                try:
 99                    record = json.loads(line)
100                    lines.append(record)
101                except json.JSONDecodeError as e:
102                    logger.warning(f"Invalid JSON at {file}:{line_num} - {e}")
103        except Exception as e:
104            logger.error(f"Error processing {file}: {e}")
105
106    if not lines:
107        raise Exception("No valid data was loaded")
108
109    data_dict: Dict[Tuple, float] = {}
110    models: List[str] = []
111    temps = set()
112    tests = set()
113    server_names = set()
114    total_counts = set()
115    for rec in lines:
116        try:
117            model = rec["model"]
118            temp = rec["temp"]
119            server_name = rec["server_name"]
120            test = rec["test"]
121            success = rec["success_ratio"]
122            success_count = rec["success_count"]
123            failure_count = rec["failure_count"]
124            total_count = success_count + failure_count
125            total_counts.add(total_count)
126
127            if test_regex and not re.search(test_regex, test):
128                continue
129
130            if server_regex and not re.search(server_regex, server_name):
131                continue
132
133            data_dict[(model, temp, server_name, test)] = success
134
135            if model not in models:
136                models.append(model)
137            temps.add(temp)
138            tests.add(test)
139            server_names.add(server_name)
140
141        except KeyError as e:
142            logger.warning(f"Missing required field in record: {e}")
143
144    if len(total_counts) > 1:
145        logger.warning(f"Total counts are not consistent: {total_counts}")
146
147    # Sort the collected values
148    temps = list(sorted(temps, key=lambda x: x if x is not None else -1))
149    tests = list(sorted(tests))
150    server_names = list(sorted(server_names))
151
152    logger.info(f"Processed {len(lines)} lines")
153    logger.info(f"Found {len(data_dict)} valid data points")
154    logger.info(f"Models: {models}")
155    logger.info(f"Temperatures: {temps}")
156    logger.info(f"Tests: {tests}")
157    logger.info(f"Servers: {server_names}")
158
159    matrix: list[list[float]] = []
160    index: list[str] = []
161
162    all_cols = [
163        (server_name, test)
164        for server_name in server_names
165        for test in tests
166    ]
167    for model in models:
168        for temp in temps:
169            index.append(f"{model} @ {temp}")
170            row_vals = [
171                data_dict.get((model, temp, server_name, test), np.nan)
172                for server_name, test in all_cols
173            ]
174            matrix.append(row_vals)
175
176    columns: list[str] = [f"{server_name}\n{test}" for server_name, test in all_cols]
177
178    df = pd.DataFrame(matrix, index=np.array(index), columns=np.array(columns))
179
180    plt.figure(figsize=(12, 6))
181
182    sns.heatmap(
183        df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5,
184        cbar_kws={"label": "Success Ratio"},
185    )
186
187    plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20)
188    plt.xlabel("Server & Test", labelpad=10)
189    plt.ylabel("Model @ Temperature", labelpad=10)
190
191    plt.xticks(rotation=45, ha='right')
192    plt.yticks(rotation=0)
193
194    plt.tight_layout()
195
196    if output:
197        plt.savefig(output, dpi=300, bbox_inches='tight')
198        logger.info(f"Plot saved to {output}")
199    else:
200        plt.show()
201
202
203@app.command()
204def run(
205    output: Annotated[Path, typer.Option(help="Output JSON file")],
206    model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
207    hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
208    chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
209    chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
210    ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
211    llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
212    n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
213    temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None,
214    top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None,
215    top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None,
216    ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None,
217    ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None,
218    fa: Annotated[Optional[bool], typer.Option(help="fa")] = None,
219    seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None,
220    port: Annotated[int, typer.Option(help="llama-server port")] = 8084,
221    force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False,
222    append: Annotated[bool, typer.Option(help="Append to output file")] = False,
223
224    test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True,
225    test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True,
226    test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False,
227):
228    # Check only one of output and append
229
230    n_predict = 512 # High because of DeepSeek R1
231    # n_ctx = 8192
232    n_ctx = 2048
233
234    if model is None:
235        if hf is not None:
236            model = hf.split("/")[-1]
237        elif ollama is not None:
238            model = ollama
239
240    assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
241
242    with output.open('a' if append else 'w') as output_file:
243
244        def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}):
245            request_kwargs = {**request_kwargs}
246            if temp is not None:
247                request_kwargs['temperature'] = temp
248            if top_p is not None:
249                request_kwargs['top_p'] = top_p
250            if top_k is not None:
251                request_kwargs['top_k'] = top_k
252            if seed is not None:
253                request_kwargs['seed'] = seed
254
255            request_kwargs['cache_prompt'] = False
256
257            tests = {}
258            if test_hello_world:
259                tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs)
260            if test_weather:
261                tests["weather"] = lambda server: do_test_weather(server, **request_kwargs)
262            if test_calc_result:
263                tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs)
264
265            for test_name, test in tests.items():
266                success_count = 0
267                failure_count = 0
268                failures = []
269                success_times = []
270                failure_times = []
271                logger.info(f"Running {test_name} ({server_name}, {model}): ")
272                for i in range(n):
273                    start_time = time.time()
274
275                    def elapsed():
276                        return time.time() - start_time
277
278                    try:
279                        test(server)
280                        success_times.append(elapsed())
281                        success_count += 1
282                        logger.info('success')
283                    except Exception as e:
284                        logger.error(f'failure: {e}')
285                        failure_count += 1
286                        failure_times.append(elapsed())
287                        failures.append(str(e))
288                        # import traceback
289                        # traceback.print_exc()
290                output_file.write(json.dumps({**output_kwargs, **dict(
291                    model=model,
292                    server_name=server_name,
293                    model_id=model_id,
294                    test=test_name,
295                    temp=t,
296                    top_p=top_p,
297                    top_k=top_k,
298                    ctk=ctk,
299                    ctv=ctv,
300                    seed=seed,
301                    success_ratio=float(success_count) / n,
302                    avg_time=mean(success_times + failure_times),
303                    median_time=median(success_times + failure_times),
304                    success_count=success_count,
305                    success_times=success_times,
306                    failure_count=failure_count,
307                    failure_times=failure_times,
308                    failures=list(set(failures)),
309                )}) + '\n')
310                output_file.flush()
311
312        for t in [None] if temp is None else [t if t >= 0 else None for t in temp]:
313            if hf is not None:
314
315                servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)]
316                if llama_baseline is not None:
317                    servers.append(('llama-server (baseline)', llama_baseline))
318
319                for server_name, server_path in servers:
320                    server = ServerProcess()
321                    server.n_ctx = n_ctx
322                    server.n_slots = 1
323                    server.jinja = True
324                    server.ctk = ctk
325                    server.ctv = ctv
326                    server.fa = "on" if fa else "off"
327                    server.n_predict = n_predict
328                    server.model_hf_repo = hf
329                    server.model_hf_file = None
330                    server.chat_template = chat_template
331                    server.chat_template_file = chat_template_file
332                    server.server_path = server_path
333                    if port is not None:
334                        server.server_port = port
335                    # server.debug = True
336
337                    with scoped_server(server):
338                        server.start(timeout_seconds=15 * 60)
339                        for ignore_chat_grammar in [False]:
340                            run(
341                                server,
342                                server_name=server_name,
343                                model_id=hf,
344                                temp=t,
345                                output_kwargs=dict(
346                                    chat_template=chat_template,
347                                    chat_template_file=chat_template_file,
348                                ),
349                                request_kwargs=dict(
350                                    ignore_chat_grammar=ignore_chat_grammar,
351                                ),
352                            )
353
354            if ollama is not None:
355                server = ServerProcess()
356                server.server_port = 11434
357                server.server_host = "localhost"
358                subprocess.check_call(["ollama", "pull", ollama])
359
360                with scoped_server(server):
361                    run(
362                        server,
363                        server_name="ollama",
364                        model_id=ollama,
365                        temp=t,
366                        output_kwargs=dict(
367                            chat_template=None,
368                            chat_template_file=None,
369                        ),
370                        request_kwargs=dict(
371                            model=ollama,
372                            max_tokens=n_predict,
373                            num_ctx = n_ctx,
374                        ),
375                    )
376
377
378if __name__ == "__main__":
379    app()