1#!/usr/bin/env python3
2
3import argparse
4import json
5import os
6import random
7import sqlite3
8import subprocess
9from time import sleep, time
10from typing import Optional, Union
11
12import datasets
13import logging
14import matplotlib.pyplot as plt
15import numpy as np
16import requests
17from tqdm.contrib.concurrent import thread_map
18
19
20logging.basicConfig(level=logging.INFO, format='%(message)s')
21logger = logging.getLogger("server-bench")
22
23
24def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
25 ret = []
26 if dataset_name.lower() == "mmlu":
27 logger.info("Loading MMLU dataset...")
28 ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
29 else:
30 return None
31 if n_prompts >= 0:
32 ret = ret[:n_prompts]
33 return ret
34
35
36def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]:
37 assert n_prompts >= 0
38 ret: list[int] = []
39 for i in range(n_prompts):
40 if seed_offset >= 0:
41 random.seed(3 * (seed_offset + 1000 * i) + 0)
42 ret.append(random.randint(prompt_length_min, prompt_length_max))
43 return ret
44
45
46def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
47 return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
48
49
50def get_server(path_server: str, path_log: Optional[str]) -> dict:
51 if path_server.startswith("http://") or path_server.startswith("https://"):
52 return {"process": None, "address": path_server, "fout": None}
53 if os.environ.get("LLAMA_ARG_HOST") is None:
54 logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
55 os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
56 if os.environ.get("LLAMA_ARG_PORT") is None:
57 logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
58 os.environ["LLAMA_ARG_PORT"] = "8080"
59 hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
60 port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
61 assert hostname is not None
62 assert port is not None
63 address: str = f"http://{hostname}:{port}"
64 logger.info(f"Starting the llama.cpp server under {address}...")
65
66 fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
67 process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
68
69 n_failures: int = 0
70 while True:
71 try:
72 sleep(1.0)
73 exit_code = process.poll()
74 if exit_code is not None:
75 raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
76 response = requests.get(f"{address}/health")
77 if response.status_code == 200:
78 break
79 except requests.ConnectionError:
80 n_failures += 1
81 if n_failures >= 10:
82 raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
83
84 return {"process": process, "address": address, "fout": fout}
85
86
87def get_prompt_length(data: dict) -> int:
88 session = data["session"]
89 server_address: str = data["server_address"]
90
91 response = session.post(
92 f"{server_address}/apply-template",
93 json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
94 )
95 response.raise_for_status()
96 prompt: str = json.loads(response.text)["prompt"]
97 response = session.post(
98 f"{server_address}/tokenize",
99 json={"content": prompt, "add_special": True}
100 )
101 response.raise_for_status()
102 tokens: list[str] = json.loads(response.text)["tokens"]
103 return len(tokens)
104
105
106def send_prompt(data: dict) -> tuple[float, list[float]]:
107 session = data["session"]
108 server_address: str = data["server_address"]
109
110 t_submit = time()
111 if data["external_server"]:
112 json_data: dict = {
113 "prompt": data["prompt"], "ignore_eos": True,
114 "seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
115 response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
116 elif data["synthetic_prompt"]:
117 json_data: dict = {
118 "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
119 "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
120 response = session.post(f"{server_address}/completion", json=json_data, stream=True)
121 else:
122 response = session.post(
123 f"{server_address}/apply-template",
124 json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
125 )
126 response.raise_for_status()
127 prompt: str = json.loads(response.text)["prompt"]
128
129 json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
130 response = session.post(f"{server_address}/completion", json=json_data, stream=True)
131 response.raise_for_status()
132
133 lines = []
134 token_arrival_times: list[float] = []
135 for line in response.iter_lines(decode_unicode=False):
136 if not line.startswith(b"data: "):
137 continue
138 lines.append(line)
139 token_arrival_times.append(time())
140 token_arrival_times = token_arrival_times[:-1]
141 if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
142 token_arrival_times = token_arrival_times[:-1]
143
144 return (t_submit, token_arrival_times)
145
146
147def benchmark(
148 path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
149 n_predict: int, n_predict_min: int, seed_offset: int):
150 external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
151 if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
152 logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
153 os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
154
155 parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
156 prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
157 synthetic_prompts: bool = prompts is None
158 prompt_n = []
159
160 if synthetic_prompts:
161 prompt_source_split: list[str] = prompt_source.split("-")
162 assert len(prompt_source_split) == 3
163 assert prompt_source_split[0].lower() == "rng"
164 prompt_length_min: int = int(prompt_source_split[1])
165 prompt_length_max: int = int(prompt_source_split[2])
166 logger.info("Generating random prompts...")
167 prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset)
168 prompts = get_prompts_rng(prompt_n)
169 else:
170 n_predict_min = n_predict
171
172 if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
173 context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
174 context_total: int = context_per_slot * parallel
175 os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
176 logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
177
178 server: Optional[dict] = None
179 session = None
180 try:
181 server = get_server(path_server, path_log)
182 server_address: str = server["address"]
183 assert external_server == (server["process"] is None)
184
185 adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
186 session = requests.Session()
187 session.mount("http://", adapter)
188 session.mount("https://", adapter)
189
190 data: list[dict] = []
191
192 for i, p in enumerate(prompts):
193 if seed_offset >= 0:
194 random.seed(3 * (seed_offset + 1000 * i) + 1)
195 data.append({
196 "session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
197 "synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
198 "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
199
200 if not synthetic_prompts:
201 logger.info("Getting the prompt lengths...")
202 prompt_n = [get_prompt_length(d) for d in data]
203
204 logger.info("Starting the benchmark...\n")
205 t0 = time()
206 results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
207 finally:
208 if server is not None and server["process"] is not None:
209 server["process"].terminate()
210 server["process"].wait()
211 if session is not None:
212 session.close()
213
214 prompt_t = []
215 token_t = []
216 depth_sum: int = 0
217 for pn, (t_submit, tat) in zip(prompt_n, results):
218 prompt_t.append(tat[0] - t_submit)
219 token_t += tat
220 n_tokens: int = len(tat)
221 depth_sum += n_tokens * pn
222 depth_sum += n_tokens * (n_tokens + 1) // 2
223 assert len(token_t) > 0
224 prompt_n = np.array(prompt_n, dtype=np.int64)
225 prompt_t = np.array(prompt_t, dtype=np.float64)
226 token_t = np.array(token_t, dtype=np.float64)
227
228 token_t -= t0
229 token_t_last = np.max(token_t)
230
231 logger.info("")
232 logger.info(f"Benchmark duration: {token_t_last:.2f} s")
233 logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
234 logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
235 logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
236 logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
237 logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
238 logger.info(f"Total generated tokens: {token_t.shape[0]}")
239 logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
240 logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
241 logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
242
243 if path_db is not None:
244 con = sqlite3.connect(path_db)
245 cursor = con.cursor()
246 cursor.execute(
247 "CREATE TABLE IF NOT EXISTS server_bench"
248 "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
249 "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
250 cursor.execute(
251 "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
252 [name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
253 con.commit()
254
255 plt.figure()
256 plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
257 plt.xlim(0, 1.05e0 * np.max(prompt_n))
258 plt.ylim(0, 1.05e3 * np.max(prompt_t))
259 plt.title(name or "")
260 plt.xlabel("Prompt length [tokens]")
261 plt.ylabel("Time to first token [ms]")
262 plt.savefig("prompt_time.png", dpi=240)
263
264 bin_max = np.ceil(token_t_last) + 1
265 plt.figure()
266 plt.hist(token_t, np.arange(0, bin_max))
267 plt.xlim(0, bin_max + 1)
268 plt.title(name or "")
269 plt.xlabel("Time [s]")
270 plt.ylabel("Num. tokens generated per second")
271 plt.savefig("gen_rate.png", dpi=240)
272
273
274if __name__ == "__main__":
275 parser = argparse.ArgumentParser(
276 description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
277 "Results are printed to console and visualized as plots (saved to current working directory). "
278 "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
279 "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
280 "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
281 parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
282 parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
283 parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
284 parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
285 parser.add_argument(
286 "--prompt_source", type=str, default="rng-1024-2048",
287 help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
288 "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
289 parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
290 parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
291 parser.add_argument(
292 "--n_predict_min", type=int, default=1024,
293 help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
294 parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. "
295 "Corelations between seeds can occur when set >= 1000. Negative values mean no seed.")
296 args = parser.parse_args()
297 benchmark(**vars(args))