1from __future__ import annotations
2
3import argparse
4import json
5import os
6import re
7import signal
8import socket
9import subprocess
10import sys
11import threading
12import time
13import traceback
14from contextlib import closing
15from datetime import datetime
16
17import matplotlib
18import matplotlib.dates
19import matplotlib.pyplot as plt
20import requests
21from statistics import mean
22
23
24def main(args_in: list[str] | None = None) -> None:
25 parser = argparse.ArgumentParser(description="Start server benchmark scenario")
26 parser.add_argument("--name", type=str, help="Bench name", required=True)
27 parser.add_argument("--runner-label", type=str, help="Runner label", required=True)
28 parser.add_argument("--branch", type=str, help="Branch name", default="detached")
29 parser.add_argument("--commit", type=str, help="Commit name", default="dirty")
30 parser.add_argument("--host", type=str, help="Server listen host", default="0.0.0.0")
31 parser.add_argument("--port", type=int, help="Server listen host", default="8080")
32 parser.add_argument("--model-path-prefix", type=str, help="Prefix where to store the model files", default="models")
33 parser.add_argument("--n-prompts", type=int,
34 help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", required=True)
35 parser.add_argument("--max-prompt-tokens", type=int,
36 help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset",
37 required=True)
38 parser.add_argument("--max-tokens", type=int,
39 help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens",
40 required=True)
41 parser.add_argument("--hf-repo", type=str, help="Hugging Face model repository", required=True)
42 parser.add_argument("--hf-file", type=str, help="Hugging Face model file", required=True)
43 parser.add_argument("-ngl", "--n-gpu-layers", type=int, help="layers to the GPU for computation", required=True)
44 parser.add_argument("--ctx-size", type=int, help="Set the size of the prompt context", required=True)
45 parser.add_argument("--parallel", type=int, help="Set the number of slots for process requests", required=True)
46 parser.add_argument("--batch-size", type=int, help="Set the batch size for prompt processing", required=True)
47 parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
48 parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
49 parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
50
51 args = parser.parse_args(args_in)
52
53 start_time = time.time()
54
55 # Start the server and performance scenario
56 try:
57 server_process = start_server(args)
58 except Exception:
59 print("bench: server start error :")
60 traceback.print_exc(file=sys.stdout)
61 sys.exit(1)
62
63 # start the benchmark
64 iterations = 0
65 data = {}
66 try:
67 start_benchmark(args)
68
69 with open("results.github.env", 'w') as github_env:
70 # parse output
71 with open('k6-results.json', 'r') as bench_results:
72 # Load JSON data from file
73 data = json.load(bench_results)
74 for metric_name in data['metrics']:
75 for metric_metric in data['metrics'][metric_name]:
76 value = data['metrics'][metric_name][metric_metric]
77 if isinstance(value, float) or isinstance(value, int):
78 value = round(value, 2)
79 data['metrics'][metric_name][metric_metric]=value
80 github_env.write(
81 f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n")
82 iterations = data['root_group']['checks']['success completion']['passes']
83
84 except Exception:
85 print("bench: error :")
86 traceback.print_exc(file=sys.stdout)
87
88 # Stop the server
89 if server_process:
90 try:
91 print(f"bench: shutting down server pid={server_process.pid} ...")
92 if os.name == 'nt':
93 interrupt = signal.CTRL_C_EVENT
94 else:
95 interrupt = signal.SIGINT
96 server_process.send_signal(interrupt)
97 server_process.wait(0.5)
98
99 except subprocess.TimeoutExpired:
100 print(f"server still alive after 500ms, force-killing pid={server_process.pid} ...")
101 server_process.kill() # SIGKILL
102 server_process.wait()
103
104 while is_server_listening(args.host, args.port):
105 time.sleep(0.1)
106
107 title = (f"llama.cpp {args.name} on {args.runner_label}\n "
108 f"duration={args.duration} {iterations} iterations")
109 xlabel = (f"{args.hf_repo}/{args.hf_file}\n"
110 f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n"
111 f"branch={args.branch} commit={args.commit}")
112
113 # Prometheus
114 end_time = time.time()
115 prometheus_metrics = {}
116 if is_server_listening("0.0.0.0", 9090):
117 metrics = ['prompt_tokens_seconds', 'predicted_tokens_seconds',
118 'kv_cache_usage_ratio', 'requests_processing', 'requests_deferred']
119
120 for metric in metrics:
121 resp = requests.get(f"http://localhost:9090/api/v1/query_range",
122 params={'query': 'llamacpp:' + metric, 'start': start_time, 'end': end_time, 'step': 2})
123
124 with open(f"{metric}.json", 'w') as metric_json:
125 metric_json.write(resp.text)
126
127 if resp.status_code != 200:
128 print(f"bench: unable to extract prometheus metric {metric}: {resp.text}")
129 else:
130 metric_data = resp.json()
131 values = metric_data['data']['result'][0]['values']
132 timestamps, metric_values = zip(*values)
133 metric_values = [float(value) for value in metric_values]
134 prometheus_metrics[metric] = metric_values
135 timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
136 plt.figure(figsize=(16, 10), dpi=80)
137 plt.plot(timestamps_dt, metric_values, label=metric)
138 plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
139 plt.yticks(fontsize=12, alpha=.7)
140
141 ylabel = f"llamacpp:{metric}"
142 plt.title(title,
143 fontsize=14, wrap=True)
144 plt.grid(axis='both', alpha=.3)
145 plt.ylabel(ylabel, fontsize=22)
146 plt.xlabel(xlabel, fontsize=14, wrap=True)
147 plt.gca().xaxis.set_major_locator(matplotlib.dates.MinuteLocator())
148 plt.gca().xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S"))
149 plt.gcf().autofmt_xdate()
150
151 # Remove borders
152 plt.gca().spines["top"].set_alpha(0.0)
153 plt.gca().spines["bottom"].set_alpha(0.3)
154 plt.gca().spines["right"].set_alpha(0.0)
155 plt.gca().spines["left"].set_alpha(0.3)
156
157 # Save the plot as a jpg image
158 plt.savefig(f'{metric}.jpg', dpi=60)
159 plt.close()
160
161 # Mermaid format in case images upload failed
162 with open(f"{metric}.mermaid", 'w') as mermaid_f:
163 mermaid = (
164 f"""---
165config:
166 xyChart:
167 titleFontSize: 12
168 width: 900
169 height: 600
170 themeVariables:
171 xyChart:
172 titleColor: "#000000"
173---
174xychart-beta
175 title "{title}"
176 y-axis "llamacpp:{metric}"
177 x-axis "llamacpp:{metric}" {int(min(timestamps))} --> {int(max(timestamps))}
178 line [{', '.join([str(round(float(value), 2)) for value in metric_values])}]
179 """)
180 mermaid_f.write(mermaid)
181
182 # 140 chars max for commit status description
183 bench_results = {
184 "i": iterations,
185 "req": {
186 "p95": round(data['metrics']["http_req_duration"]["p(95)"], 2),
187 "avg": round(data['metrics']["http_req_duration"]["avg"], 2),
188 },
189 "pp": {
190 "p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
191 "avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
192 "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0,
193 },
194 "tg": {
195 "p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
196 "avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
197 "0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0,
198 },
199 }
200 with open("results.github.env", 'a') as github_env:
201 github_env.write(f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n")
202 github_env.write(f"BENCH_ITERATIONS={iterations}\n")
203
204 title = title.replace('\n', ' ')
205 xlabel = xlabel.replace('\n', ' ')
206 github_env.write(f"BENCH_GRAPH_TITLE={title}\n")
207 github_env.write(f"BENCH_GRAPH_XLABEL={xlabel}\n")
208
209
210def start_benchmark(args):
211 k6_path = './k6'
212 if 'BENCH_K6_BIN_PATH' in os.environ:
213 k6_path = os.environ['BENCH_K6_BIN_PATH']
214 k6_args = [
215 'run', args.scenario,
216 '--no-color',
217 '--no-connection-reuse',
218 '--no-vu-connection-reuse',
219 ]
220 k6_args.extend(['--duration', args.duration])
221 k6_args.extend(['--iterations', args.n_prompts])
222 k6_args.extend(['--vus', args.parallel])
223 k6_args.extend(['--summary-export', 'k6-results.json'])
224 k6_args.extend(['--out', 'csv=k6-results.csv'])
225 args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
226 args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
227 print(f"bench: starting k6 with: {args}")
228 k6_completed = subprocess.run(args, shell=True, stdout=sys.stdout, stderr=sys.stderr)
229 if k6_completed.returncode != 0:
230 raise Exception("bench: unable to run k6")
231
232
233def start_server(args):
234 server_process = start_server_background(args)
235
236 attempts = 0
237 max_attempts = 600
238 if 'GITHUB_ACTIONS' in os.environ:
239 max_attempts *= 2
240
241 while not is_server_listening(args.host, args.port):
242 attempts += 1
243 if attempts > max_attempts:
244 assert False, "server not started"
245 print(f"bench: waiting for server to start ...")
246 time.sleep(0.5)
247
248 attempts = 0
249 while not is_server_ready(args.host, args.port):
250 attempts += 1
251 if attempts > max_attempts:
252 assert False, "server not ready"
253 print(f"bench: waiting for server to be ready ...")
254 time.sleep(0.5)
255
256 print("bench: server started and ready.")
257 return server_process
258
259
260def start_server_background(args):
261 # Start the server
262 server_path = '../../../build/bin/llama-server'
263 if 'LLAMA_SERVER_BIN_PATH' in os.environ:
264 server_path = os.environ['LLAMA_SERVER_BIN_PATH']
265 server_args = [
266 '--host', args.host,
267 '--port', args.port,
268 ]
269 server_args.extend(['--hf-repo', args.hf_repo])
270 server_args.extend(['--hf-file', args.hf_file])
271 server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
272 server_args.extend(['--ctx-size', args.ctx_size])
273 server_args.extend(['--parallel', args.parallel])
274 server_args.extend(['--batch-size', args.batch_size])
275 server_args.extend(['--ubatch-size', args.ubatch_size])
276 server_args.extend(['--n-predict', args.max_tokens * 2])
277 server_args.append('--cont-batching')
278 server_args.append('--metrics')
279 server_args.append('--flash-attn')
280 args = [str(arg) for arg in [server_path, *server_args]]
281 print(f"bench: starting server with: {' '.join(args)}")
282 pkwargs = {
283 'stdout': subprocess.PIPE,
284 'stderr': subprocess.PIPE
285 }
286 server_process = subprocess.Popen(
287 args,
288 **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
289
290 def server_log(in_stream, out_stream):
291 for line in iter(in_stream.readline, b''):
292 print(line.decode('utf-8'), end='', file=out_stream)
293
294 thread_stdout = threading.Thread(target=server_log, args=(server_process.stdout, sys.stdout))
295 thread_stdout.start()
296 thread_stderr = threading.Thread(target=server_log, args=(server_process.stderr, sys.stderr))
297 thread_stderr.start()
298
299 return server_process
300
301
302def is_server_listening(server_fqdn, server_port):
303 with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
304 result = sock.connect_ex((server_fqdn, server_port))
305 _is_server_listening = result == 0
306 if _is_server_listening:
307 print(f"server is listening on {server_fqdn}:{server_port}...")
308 return _is_server_listening
309
310
311def is_server_ready(server_fqdn, server_port):
312 url = f"http://{server_fqdn}:{server_port}/health"
313 response = requests.get(url)
314 return response.status_code == 200
315
316
317def escape_metric_name(metric_name):
318 return re.sub('[^A-Z0-9]', '_', metric_name.upper())
319
320
321if __name__ == '__main__':
322 main()