1#!/usr/bin/env python3
  2
  3import argparse
  4import json
  5import os
  6import random
  7import sqlite3
  8import subprocess
  9from time import sleep, time
 10from typing import Optional, Union
 11
 12import datasets
 13import logging
 14import matplotlib.pyplot as plt
 15import numpy as np
 16import requests
 17from tqdm.contrib.concurrent import thread_map
 18
 19
 20logging.basicConfig(level=logging.INFO, format='%(message)s')
 21logger = logging.getLogger("server-bench")
 22
 23
 24def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
 25    ret = []
 26    if dataset_name.lower() == "mmlu":
 27        logger.info("Loading MMLU dataset...")
 28        ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]  # type: ignore
 29    else:
 30        return None
 31    if n_prompts >= 0:
 32        ret = ret[:n_prompts]
 33    return ret
 34
 35
 36def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]:
 37    assert n_prompts >= 0
 38    ret: list[int] = []
 39    for i in range(n_prompts):
 40        if seed_offset >= 0:
 41            random.seed(3 * (seed_offset + 1000 * i) + 0)
 42        ret.append(random.randint(prompt_length_min, prompt_length_max))
 43    return ret
 44
 45
 46def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
 47    return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
 48
 49
 50def get_server(path_server: str, path_log: Optional[str]) -> dict:
 51    if path_server.startswith("http://") or path_server.startswith("https://"):
 52        return {"process": None, "address": path_server, "fout": None}
 53    if os.environ.get("LLAMA_ARG_HOST") is None:
 54        logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
 55        os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
 56    if os.environ.get("LLAMA_ARG_PORT") is None:
 57        logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
 58        os.environ["LLAMA_ARG_PORT"] = "8080"
 59    hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
 60    port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
 61    assert hostname is not None
 62    assert port is not None
 63    address: str = f"http://{hostname}:{port}"
 64    logger.info(f"Starting the llama.cpp server under {address}...")
 65
 66    fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
 67    process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
 68
 69    n_failures: int = 0
 70    while True:
 71        try:
 72            sleep(1.0)
 73            exit_code = process.poll()
 74            if exit_code is not None:
 75                raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
 76            response = requests.get(f"{address}/health")
 77            if response.status_code == 200:
 78                break
 79        except requests.ConnectionError:
 80            n_failures += 1
 81            if n_failures >= 10:
 82                raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
 83
 84    return {"process": process, "address": address, "fout": fout}
 85
 86
 87def get_prompt_length(data: dict) -> int:
 88    session = data["session"]
 89    server_address: str = data["server_address"]
 90
 91    response = session.post(
 92        f"{server_address}/apply-template",
 93        json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
 94    )
 95    response.raise_for_status()
 96    prompt: str = json.loads(response.text)["prompt"]
 97    response = session.post(
 98        f"{server_address}/tokenize",
 99        json={"content": prompt, "add_special": True}
100    )
101    response.raise_for_status()
102    tokens: list[str] = json.loads(response.text)["tokens"]
103    return len(tokens)
104
105
106def send_prompt(data: dict) -> tuple[float, list[float]]:
107    session = data["session"]
108    server_address: str = data["server_address"]
109
110    t_submit = time()
111    if data["external_server"]:
112        json_data: dict = {
113            "prompt": data["prompt"], "ignore_eos": True,
114            "seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
115        response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
116    elif data["synthetic_prompt"]:
117        json_data: dict = {
118            "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
119            "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
120        response = session.post(f"{server_address}/completion", json=json_data, stream=True)
121    else:
122        response = session.post(
123            f"{server_address}/apply-template",
124            json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
125        )
126        response.raise_for_status()
127        prompt: str = json.loads(response.text)["prompt"]
128
129        json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
130        response = session.post(f"{server_address}/completion", json=json_data, stream=True)
131    response.raise_for_status()
132
133    lines = []
134    token_arrival_times: list[float] = []
135    for line in response.iter_lines(decode_unicode=False):
136        if not line.startswith(b"data: "):
137            continue
138        lines.append(line)
139        token_arrival_times.append(time())
140    token_arrival_times = token_arrival_times[:-1]
141    if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
142        token_arrival_times = token_arrival_times[:-1]
143
144    return (t_submit, token_arrival_times)
145
146
147def benchmark(
148        path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
149        n_predict: int, n_predict_min: int, seed_offset: int):
150    external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
151    if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
152        logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
153        os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
154
155    parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
156    prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
157    synthetic_prompts: bool = prompts is None
158    prompt_n = []
159
160    if synthetic_prompts:
161        prompt_source_split: list[str] = prompt_source.split("-")
162        assert len(prompt_source_split) == 3
163        assert prompt_source_split[0].lower() == "rng"
164        prompt_length_min: int = int(prompt_source_split[1])
165        prompt_length_max: int = int(prompt_source_split[2])
166        logger.info("Generating random prompts...")
167        prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset)
168        prompts = get_prompts_rng(prompt_n)
169    else:
170        n_predict_min = n_predict
171
172    if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
173        context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
174        context_total: int = context_per_slot * parallel
175        os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
176        logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
177
178    server: Optional[dict] = None
179    session = None
180    try:
181        server = get_server(path_server, path_log)
182        server_address: str = server["address"]
183        assert external_server == (server["process"] is None)
184
185        adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel)  # type: ignore
186        session = requests.Session()
187        session.mount("http://", adapter)
188        session.mount("https://", adapter)
189
190        data: list[dict] = []
191
192        for i, p in enumerate(prompts):
193            if seed_offset >= 0:
194                random.seed(3 * (seed_offset + 1000 * i) + 1)
195            data.append({
196                "session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
197                "synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
198                "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
199
200        if not synthetic_prompts:
201            logger.info("Getting the prompt lengths...")
202            prompt_n = [get_prompt_length(d) for d in data]
203
204        logger.info("Starting the benchmark...\n")
205        t0 = time()
206        results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
207    finally:
208        if server is not None and server["process"] is not None:
209            server["process"].terminate()
210            server["process"].wait()
211        if session is not None:
212            session.close()
213
214    prompt_t = []
215    token_t = []
216    depth_sum: int = 0
217    for pn, (t_submit, tat) in zip(prompt_n, results):
218        prompt_t.append(tat[0] - t_submit)
219        token_t += tat
220        n_tokens: int = len(tat)
221        depth_sum += n_tokens * pn
222        depth_sum += n_tokens * (n_tokens + 1) // 2
223    assert len(token_t) > 0
224    prompt_n = np.array(prompt_n, dtype=np.int64)
225    prompt_t = np.array(prompt_t, dtype=np.float64)
226    token_t = np.array(token_t, dtype=np.float64)
227
228    token_t -= t0
229    token_t_last = np.max(token_t)
230
231    logger.info("")
232    logger.info(f"Benchmark duration:                {token_t_last:.2f} s")
233    logger.info(f"Request throughput:                {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
234    logger.info(f"Total prompt length:               {np.sum(prompt_n)} tokens")
235    logger.info(f"Average prompt length:             {np.mean(prompt_n):.2f} tokens")
236    logger.info(f"Average prompt latency:            {1e3 * np.mean(prompt_t):.2f} ms")
237    logger.info(f"Average prompt speed:              {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
238    logger.info(f"Total generated tokens:            {token_t.shape[0]}")
239    logger.info(f"Average generation depth:          {depth_sum / token_t.shape[0]:.2f} tokens")
240    logger.info(f"Average total generation speed:    {token_t.shape[0] / token_t_last:.2f} tokens/s")
241    logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
242
243    if path_db is not None:
244        con = sqlite3.connect(path_db)
245        cursor = con.cursor()
246        cursor.execute(
247            "CREATE TABLE IF NOT EXISTS server_bench"
248            "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
249            "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
250        cursor.execute(
251            "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
252            [name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
253        con.commit()
254
255    plt.figure()
256    plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
257    plt.xlim(0, 1.05e0 * np.max(prompt_n))
258    plt.ylim(0, 1.05e3 * np.max(prompt_t))
259    plt.title(name or "")
260    plt.xlabel("Prompt length [tokens]")
261    plt.ylabel("Time to first token [ms]")
262    plt.savefig("prompt_time.png", dpi=240)
263
264    bin_max = np.ceil(token_t_last) + 1
265    plt.figure()
266    plt.hist(token_t, np.arange(0, bin_max))
267    plt.xlim(0, bin_max + 1)
268    plt.title(name or "")
269    plt.xlabel("Time [s]")
270    plt.ylabel("Num. tokens generated per second")
271    plt.savefig("gen_rate.png", dpi=240)
272
273
274if __name__ == "__main__":
275    parser = argparse.ArgumentParser(
276        description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
277        "Results are printed to console and visualized as plots (saved to current working directory). "
278        "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
279        "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
280        "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
281    parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
282    parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
283    parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
284    parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
285    parser.add_argument(
286        "--prompt_source", type=str, default="rng-1024-2048",
287        help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
288        "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
289    parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
290    parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
291    parser.add_argument(
292        "--n_predict_min", type=int, default=1024,
293        help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
294    parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. "
295                        "Corelations between seeds can occur when set >= 1000. Negative values mean no seed.")
296    args = parser.parse_args()
297    benchmark(**vars(args))