diff options
Diffstat (limited to 'llama.cpp/scripts/server-bench.py')
| -rwxr-xr-x | llama.cpp/scripts/server-bench.py | 297 |
1 files changed, 297 insertions, 0 deletions
diff --git a/llama.cpp/scripts/server-bench.py b/llama.cpp/scripts/server-bench.py new file mode 100755 index 0000000..dbbb093 --- /dev/null +++ b/llama.cpp/scripts/server-bench.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import random +import sqlite3 +import subprocess +from time import sleep, time +from typing import Optional, Union + +import datasets +import logging +import matplotlib.pyplot as plt +import numpy as np +import requests +from tqdm.contrib.concurrent import thread_map + + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger("server-bench") + + +def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]: + ret = [] + if dataset_name.lower() == "mmlu": + logger.info("Loading MMLU dataset...") + ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore + else: + return None + if n_prompts >= 0: + ret = ret[:n_prompts] + return ret + + +def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]: + assert n_prompts >= 0 + ret: list[int] = [] + for i in range(n_prompts): + if seed_offset >= 0: + random.seed(3 * (seed_offset + 1000 * i) + 0) + ret.append(random.randint(prompt_length_min, prompt_length_max)) + return ret + + +def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]: + return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths] + + +def get_server(path_server: str, path_log: Optional[str]) -> dict: + if path_server.startswith("http://") or path_server.startswith("https://"): + return {"process": None, "address": path_server, "fout": None} + if os.environ.get("LLAMA_ARG_HOST") is None: + logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") + os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" + if os.environ.get("LLAMA_ARG_PORT") is None: + logger.info("LLAMA_ARG_PORT not explicitly set, using 8080") + os.environ["LLAMA_ARG_PORT"] = "8080" + hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST") + port: Optional[str] = os.environ.get("LLAMA_ARG_PORT") + assert hostname is not None + assert port is not None + address: str = f"http://{hostname}:{port}" + logger.info(f"Starting the llama.cpp server under {address}...") + + fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL + process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT) + + n_failures: int = 0 + while True: + try: + sleep(1.0) + exit_code = process.poll() + if exit_code is not None: + raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}") + response = requests.get(f"{address}/health") + if response.status_code == 200: + break + except requests.ConnectionError: + n_failures += 1 + if n_failures >= 10: + raise RuntimeError("llama.cpp server is not healthy after 10 seconds") + + return {"process": process, "address": address, "fout": fout} + + +def get_prompt_length(data: dict) -> int: + session = data["session"] + server_address: str = data["server_address"] + + response = session.post( + f"{server_address}/apply-template", + json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} + ) + response.raise_for_status() + prompt: str = json.loads(response.text)["prompt"] + response = session.post( + f"{server_address}/tokenize", + json={"content": prompt, "add_special": True} + ) + response.raise_for_status() + tokens: list[str] = json.loads(response.text)["tokens"] + return len(tokens) + + +def send_prompt(data: dict) -> tuple[float, list[float]]: + session = data["session"] + server_address: str = data["server_address"] + + t_submit = time() + if data["external_server"]: + json_data: dict = { + "prompt": data["prompt"], "ignore_eos": True, + "seed": data["seed"], "max_tokens": data["n_predict"], "stream": True} + response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True) + elif data["synthetic_prompt"]: + json_data: dict = { + "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False, + "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} + response = session.post(f"{server_address}/completion", json=json_data, stream=True) + else: + response = session.post( + f"{server_address}/apply-template", + json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} + ) + response.raise_for_status() + prompt: str = json.loads(response.text)["prompt"] + + json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} + response = session.post(f"{server_address}/completion", json=json_data, stream=True) + response.raise_for_status() + + lines = [] + token_arrival_times: list[float] = [] + for line in response.iter_lines(decode_unicode=False): + if not line.startswith(b"data: "): + continue + lines.append(line) + token_arrival_times.append(time()) + token_arrival_times = token_arrival_times[:-1] + if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]): + token_arrival_times = token_arrival_times[:-1] + + return (t_submit, token_arrival_times) + + +def benchmark( + path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int, + n_predict: int, n_predict_min: int, seed_offset: int): + external_server: bool = path_server.startswith("http://") or path_server.startswith("https://") + if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: + logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") + os.environ["LLAMA_ARG_N_PARALLEL"] = "32" + + parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore + prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts) + synthetic_prompts: bool = prompts is None + prompt_n = [] + + if synthetic_prompts: + prompt_source_split: list[str] = prompt_source.split("-") + assert len(prompt_source_split) == 3 + assert prompt_source_split[0].lower() == "rng" + prompt_length_min: int = int(prompt_source_split[1]) + prompt_length_max: int = int(prompt_source_split[2]) + logger.info("Generating random prompts...") + prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset) + prompts = get_prompts_rng(prompt_n) + else: + n_predict_min = n_predict + + if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None: + context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048))) + context_total: int = context_per_slot * parallel + os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total) + logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).") + + server: Optional[dict] = None + session = None + try: + server = get_server(path_server, path_log) + server_address: str = server["address"] + assert external_server == (server["process"] is None) + + adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + + data: list[dict] = [] + + for i, p in enumerate(prompts): + if seed_offset >= 0: + random.seed(3 * (seed_offset + 1000 * i) + 1) + data.append({ + "session": session, "server_address": server_address, "external_server": external_server, "prompt": p, + "synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict), + "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1}) + + if not synthetic_prompts: + logger.info("Getting the prompt lengths...") + prompt_n = [get_prompt_length(d) for d in data] + + logger.info("Starting the benchmark...\n") + t0 = time() + results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1) + finally: + if server is not None and server["process"] is not None: + server["process"].terminate() + server["process"].wait() + if session is not None: + session.close() + + prompt_t = [] + token_t = [] + depth_sum: int = 0 + for pn, (t_submit, tat) in zip(prompt_n, results): + prompt_t.append(tat[0] - t_submit) + token_t += tat + n_tokens: int = len(tat) + depth_sum += n_tokens * pn + depth_sum += n_tokens * (n_tokens + 1) // 2 + assert len(token_t) > 0 + prompt_n = np.array(prompt_n, dtype=np.int64) + prompt_t = np.array(prompt_t, dtype=np.float64) + token_t = np.array(token_t, dtype=np.float64) + + token_t -= t0 + token_t_last = np.max(token_t) + + logger.info("") + logger.info(f"Benchmark duration: {token_t_last:.2f} s") + logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min") + logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens") + logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens") + logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms") + logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s") + logger.info(f"Total generated tokens: {token_t.shape[0]}") + logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens") + logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s") + logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot") + + if path_db is not None: + con = sqlite3.connect(path_db) + cursor = con.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS server_bench" + "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, " + "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);") + cursor.execute( + "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);", + [name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last]) + con.commit() + + plt.figure() + plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25) + plt.xlim(0, 1.05e0 * np.max(prompt_n)) + plt.ylim(0, 1.05e3 * np.max(prompt_t)) + plt.title(name or "") + plt.xlabel("Prompt length [tokens]") + plt.ylabel("Time to first token [ms]") + plt.savefig("prompt_time.png", dpi=240) + + bin_max = np.ceil(token_t_last) + 1 + plt.figure() + plt.hist(token_t, np.arange(0, bin_max)) + plt.xlim(0, bin_max + 1) + plt.title(name or "") + plt.xlabel("Time [s]") + plt.ylabel("Num. tokens generated per second") + plt.savefig("gen_rate.png", dpi=240) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " + "Results are printed to console and visualized as plots (saved to current working directory). " + "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). " + "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " + "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).") + parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary") + parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark") + parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in") + parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with") + parser.add_argument( + "--prompt_source", type=str, default="rng-1024-2048", + help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or " + "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]") + parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate") + parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt") + parser.add_argument( + "--n_predict_min", type=int, default=1024, + help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)") + parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. " + "Corelations between seeds can occur when set >= 1000. Negative values mean no seed.") + args = parser.parse_args() + benchmark(**vars(args)) |
