1from __future__ import annotations
  2
  3import argparse
  4import json
  5import os
  6import re
  7import signal
  8import socket
  9import subprocess
 10import sys
 11import threading
 12import time
 13import traceback
 14from contextlib import closing
 15from datetime import datetime
 16
 17import matplotlib
 18import matplotlib.dates
 19import matplotlib.pyplot as plt
 20import requests
 21from statistics import mean
 22
 23
 24def main(args_in: list[str] | None = None) -> None:
 25    parser = argparse.ArgumentParser(description="Start server benchmark scenario")
 26    parser.add_argument("--name", type=str, help="Bench name", required=True)
 27    parser.add_argument("--runner-label", type=str, help="Runner label", required=True)
 28    parser.add_argument("--branch", type=str, help="Branch name", default="detached")
 29    parser.add_argument("--commit", type=str, help="Commit name", default="dirty")
 30    parser.add_argument("--host", type=str, help="Server listen host", default="0.0.0.0")
 31    parser.add_argument("--port", type=int, help="Server listen host", default="8080")
 32    parser.add_argument("--model-path-prefix", type=str, help="Prefix where to store the model files", default="models")
 33    parser.add_argument("--n-prompts", type=int,
 34                        help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", required=True)
 35    parser.add_argument("--max-prompt-tokens", type=int,
 36                        help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset",
 37                        required=True)
 38    parser.add_argument("--max-tokens", type=int,
 39                        help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens",
 40                        required=True)
 41    parser.add_argument("--hf-repo", type=str, help="Hugging Face model repository", required=True)
 42    parser.add_argument("--hf-file", type=str, help="Hugging Face model file", required=True)
 43    parser.add_argument("-ngl", "--n-gpu-layers", type=int, help="layers to the GPU for computation", required=True)
 44    parser.add_argument("--ctx-size", type=int, help="Set the size of the prompt context", required=True)
 45    parser.add_argument("--parallel", type=int, help="Set the number of slots for process requests", required=True)
 46    parser.add_argument("--batch-size", type=int, help="Set the batch size for prompt processing", required=True)
 47    parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
 48    parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
 49    parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
 50
 51    args = parser.parse_args(args_in)
 52
 53    start_time = time.time()
 54
 55    # Start the server and performance scenario
 56    try:
 57        server_process = start_server(args)
 58    except Exception:
 59        print("bench: server start error :")
 60        traceback.print_exc(file=sys.stdout)
 61        sys.exit(1)
 62
 63    # start the benchmark
 64    iterations = 0
 65    data = {}
 66    try:
 67        start_benchmark(args)
 68
 69        with open("results.github.env", 'w') as github_env:
 70            # parse output
 71            with open('k6-results.json', 'r') as bench_results:
 72                # Load JSON data from file
 73                data = json.load(bench_results)
 74                for metric_name in data['metrics']:
 75                    for metric_metric in data['metrics'][metric_name]:
 76                        value = data['metrics'][metric_name][metric_metric]
 77                        if isinstance(value, float) or isinstance(value, int):
 78                            value = round(value, 2)
 79                            data['metrics'][metric_name][metric_metric]=value
 80                            github_env.write(
 81                                f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n")
 82                iterations = data['root_group']['checks']['success completion']['passes']
 83
 84    except Exception:
 85        print("bench: error :")
 86        traceback.print_exc(file=sys.stdout)
 87
 88    # Stop the server
 89    if server_process:
 90        try:
 91            print(f"bench: shutting down server pid={server_process.pid} ...")
 92            if os.name == 'nt':
 93                interrupt = signal.CTRL_C_EVENT
 94            else:
 95                interrupt = signal.SIGINT
 96            server_process.send_signal(interrupt)
 97            server_process.wait(0.5)
 98
 99        except subprocess.TimeoutExpired:
100            print(f"server still alive after 500ms, force-killing pid={server_process.pid} ...")
101            server_process.kill()  # SIGKILL
102            server_process.wait()
103
104        while is_server_listening(args.host, args.port):
105            time.sleep(0.1)
106
107    title = (f"llama.cpp {args.name} on {args.runner_label}\n "
108             f"duration={args.duration} {iterations} iterations")
109    xlabel = (f"{args.hf_repo}/{args.hf_file}\n"
110              f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n"
111              f"branch={args.branch} commit={args.commit}")
112
113    # Prometheus
114    end_time = time.time()
115    prometheus_metrics = {}
116    if is_server_listening("0.0.0.0", 9090):
117        metrics = ['prompt_tokens_seconds', 'predicted_tokens_seconds',
118                   'kv_cache_usage_ratio', 'requests_processing', 'requests_deferred']
119
120        for metric in metrics:
121            resp = requests.get(f"http://localhost:9090/api/v1/query_range",
122                                params={'query': 'llamacpp:' + metric, 'start': start_time, 'end': end_time, 'step': 2})
123
124            with open(f"{metric}.json", 'w') as metric_json:
125                metric_json.write(resp.text)
126
127            if resp.status_code != 200:
128                print(f"bench: unable to extract prometheus metric {metric}: {resp.text}")
129            else:
130                metric_data = resp.json()
131                values = metric_data['data']['result'][0]['values']
132                timestamps, metric_values = zip(*values)
133                metric_values = [float(value) for value in metric_values]
134                prometheus_metrics[metric] = metric_values
135                timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
136                plt.figure(figsize=(16, 10), dpi=80)
137                plt.plot(timestamps_dt, metric_values, label=metric)
138                plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
139                plt.yticks(fontsize=12, alpha=.7)
140
141                ylabel = f"llamacpp:{metric}"
142                plt.title(title,
143                          fontsize=14, wrap=True)
144                plt.grid(axis='both', alpha=.3)
145                plt.ylabel(ylabel, fontsize=22)
146                plt.xlabel(xlabel, fontsize=14, wrap=True)
147                plt.gca().xaxis.set_major_locator(matplotlib.dates.MinuteLocator())
148                plt.gca().xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S"))
149                plt.gcf().autofmt_xdate()
150
151                # Remove borders
152                plt.gca().spines["top"].set_alpha(0.0)
153                plt.gca().spines["bottom"].set_alpha(0.3)
154                plt.gca().spines["right"].set_alpha(0.0)
155                plt.gca().spines["left"].set_alpha(0.3)
156
157                # Save the plot as a jpg image
158                plt.savefig(f'{metric}.jpg', dpi=60)
159                plt.close()
160
161                # Mermaid format in case images upload failed
162                with open(f"{metric}.mermaid", 'w') as mermaid_f:
163                    mermaid = (
164                    f"""---
165config:
166    xyChart:
167        titleFontSize: 12
168        width: 900
169        height: 600
170    themeVariables:
171        xyChart:
172            titleColor: "#000000"
173---
174xychart-beta
175    title "{title}"
176    y-axis "llamacpp:{metric}"
177    x-axis "llamacpp:{metric}" {int(min(timestamps))} --> {int(max(timestamps))}
178    line [{', '.join([str(round(float(value), 2)) for value in metric_values])}]
179                    """)
180                    mermaid_f.write(mermaid)
181
182    # 140 chars max for commit status description
183    bench_results = {
184        "i": iterations,
185        "req": {
186            "p95": round(data['metrics']["http_req_duration"]["p(95)"], 2),
187            "avg": round(data['metrics']["http_req_duration"]["avg"], 2),
188        },
189        "pp": {
190            "p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
191            "avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
192            "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0,
193        },
194        "tg": {
195            "p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
196            "avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
197            "0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0,
198        },
199    }
200    with open("results.github.env", 'a') as github_env:
201        github_env.write(f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n")
202        github_env.write(f"BENCH_ITERATIONS={iterations}\n")
203
204        title = title.replace('\n', ' ')
205        xlabel = xlabel.replace('\n', ' ')
206        github_env.write(f"BENCH_GRAPH_TITLE={title}\n")
207        github_env.write(f"BENCH_GRAPH_XLABEL={xlabel}\n")
208
209
210def start_benchmark(args):
211    k6_path = './k6'
212    if 'BENCH_K6_BIN_PATH' in os.environ:
213        k6_path = os.environ['BENCH_K6_BIN_PATH']
214    k6_args = [
215        'run', args.scenario,
216        '--no-color',
217        '--no-connection-reuse',
218        '--no-vu-connection-reuse',
219    ]
220    k6_args.extend(['--duration', args.duration])
221    k6_args.extend(['--iterations', args.n_prompts])
222    k6_args.extend(['--vus', args.parallel])
223    k6_args.extend(['--summary-export', 'k6-results.json'])
224    k6_args.extend(['--out', 'csv=k6-results.csv'])
225    args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
226    args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
227    print(f"bench: starting k6 with: {args}")
228    k6_completed = subprocess.run(args, shell=True, stdout=sys.stdout, stderr=sys.stderr)
229    if k6_completed.returncode != 0:
230        raise Exception("bench: unable to run k6")
231
232
233def start_server(args):
234    server_process = start_server_background(args)
235
236    attempts = 0
237    max_attempts = 600
238    if 'GITHUB_ACTIONS' in os.environ:
239        max_attempts *= 2
240
241    while not is_server_listening(args.host, args.port):
242        attempts += 1
243        if attempts > max_attempts:
244            assert False, "server not started"
245        print(f"bench:     waiting for server to start ...")
246        time.sleep(0.5)
247
248    attempts = 0
249    while not is_server_ready(args.host, args.port):
250        attempts += 1
251        if attempts > max_attempts:
252            assert False, "server not ready"
253        print(f"bench:     waiting for server to be ready ...")
254        time.sleep(0.5)
255
256    print("bench: server started and ready.")
257    return server_process
258
259
260def start_server_background(args):
261    # Start the server
262    server_path = '../../../build/bin/llama-server'
263    if 'LLAMA_SERVER_BIN_PATH' in os.environ:
264        server_path = os.environ['LLAMA_SERVER_BIN_PATH']
265    server_args = [
266        '--host', args.host,
267        '--port', args.port,
268    ]
269    server_args.extend(['--hf-repo', args.hf_repo])
270    server_args.extend(['--hf-file', args.hf_file])
271    server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
272    server_args.extend(['--ctx-size', args.ctx_size])
273    server_args.extend(['--parallel', args.parallel])
274    server_args.extend(['--batch-size', args.batch_size])
275    server_args.extend(['--ubatch-size', args.ubatch_size])
276    server_args.extend(['--n-predict', args.max_tokens * 2])
277    server_args.append('--cont-batching')
278    server_args.append('--metrics')
279    server_args.append('--flash-attn')
280    args = [str(arg) for arg in [server_path, *server_args]]
281    print(f"bench: starting server with: {' '.join(args)}")
282    pkwargs = {
283        'stdout': subprocess.PIPE,
284        'stderr': subprocess.PIPE
285    }
286    server_process = subprocess.Popen(
287        args,
288        **pkwargs)  # pyright: ignore[reportArgumentType, reportCallIssue]
289
290    def server_log(in_stream, out_stream):
291        for line in iter(in_stream.readline, b''):
292            print(line.decode('utf-8'), end='', file=out_stream)
293
294    thread_stdout = threading.Thread(target=server_log, args=(server_process.stdout, sys.stdout))
295    thread_stdout.start()
296    thread_stderr = threading.Thread(target=server_log, args=(server_process.stderr, sys.stderr))
297    thread_stderr.start()
298
299    return server_process
300
301
302def is_server_listening(server_fqdn, server_port):
303    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
304        result = sock.connect_ex((server_fqdn, server_port))
305        _is_server_listening = result == 0
306        if _is_server_listening:
307            print(f"server is listening on {server_fqdn}:{server_port}...")
308        return _is_server_listening
309
310
311def is_server_ready(server_fqdn, server_port):
312    url = f"http://{server_fqdn}:{server_port}/health"
313    response = requests.get(url)
314    return response.status_code == 200
315
316
317def escape_metric_name(metric_name):
318    return re.sub('[^A-Z0-9]', '_', metric_name.upper())
319
320
321if __name__ == '__main__':
322    main()