1#!/usr/bin/env python3
2
3import argparse
4import csv
5import heapq
6import json
7import logging
8import os
9import sqlite3
10import sys
11from collections.abc import Iterator, Sequence
12from glob import glob
13from typing import Any, Optional, Union
14
15try:
16 import git
17 from tabulate import tabulate
18except ImportError as e:
19 print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
20 raise e
21
22
23logger = logging.getLogger("compare-llama-bench")
24
25# All llama-bench SQL fields
26LLAMA_BENCH_DB_FIELDS = [
27 "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
28 "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
29 "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
30 "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
31 "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
32 "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "n_cpu_moe"
33]
34
35LLAMA_BENCH_DB_TYPES = [
36 "TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT",
37 "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
38 "TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
39 "TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
40 "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
41 "TEXT", "INTEGER", "INTEGER", "REAL", "REAL", "INTEGER",
42]
43
44# All test-backend-ops SQL fields
45TEST_BACKEND_OPS_DB_FIELDS = [
46 "test_time", "build_commit", "backend_name", "op_name", "op_params", "test_mode",
47 "supported", "passed", "error_message", "time_us", "flops", "bandwidth_gb_s",
48 "memory_kb", "n_runs"
49]
50
51TEST_BACKEND_OPS_DB_TYPES = [
52 "TEXT", "TEXT", "TEXT", "TEXT", "TEXT", "TEXT",
53 "INTEGER", "INTEGER", "TEXT", "REAL", "REAL", "REAL",
54 "INTEGER", "INTEGER"
55]
56
57assert len(LLAMA_BENCH_DB_FIELDS) == len(LLAMA_BENCH_DB_TYPES)
58assert len(TEST_BACKEND_OPS_DB_FIELDS) == len(TEST_BACKEND_OPS_DB_TYPES)
59
60# Properties by which to differentiate results per commit for llama-bench:
61LLAMA_BENCH_KEY_PROPERTIES = [
62 "cpu_info", "gpu_info", "backends", "n_gpu_layers", "n_cpu_moe", "tensor_buft_overrides", "model_filename", "model_type",
63 "n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v",
64 "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth"
65]
66
67# Properties by which to differentiate results per commit for test-backend-ops:
68TEST_BACKEND_OPS_KEY_PROPERTIES = [
69 "backend_name", "op_name", "op_params", "test_mode"
70]
71
72# Properties that are boolean and are converted to Yes/No for the table:
73LLAMA_BENCH_BOOL_PROPERTIES = ["embeddings", "cpu_strict", "use_mmap", "no_kv_offload", "flash_attn"]
74TEST_BACKEND_OPS_BOOL_PROPERTIES = ["supported", "passed"]
75
76# Header names for the table (llama-bench):
77LLAMA_BENCH_PRETTY_NAMES = {
78 "cpu_info": "CPU", "gpu_info": "GPU", "backends": "Backends", "n_gpu_layers": "GPU layers",
79 "tensor_buft_overrides": "Tensor overrides", "model_filename": "File", "model_type": "Model", "model_size": "Model size [GiB]",
80 "model_n_params": "Num. of par.", "n_batch": "Batch size", "n_ubatch": "Microbatch size", "embeddings": "Embeddings",
81 "cpu_mask": "CPU mask", "cpu_strict": "CPU strict", "poll": "Poll", "n_threads": "Threads", "type_k": "K type", "type_v": "V type",
82 "use_mmap": "Use mmap", "no_kv_offload": "NKVO", "split_mode": "Split mode", "main_gpu": "Main GPU", "tensor_split": "Tensor split",
83 "flash_attn": "FlashAttention",
84}
85
86# Header names for the table (test-backend-ops):
87TEST_BACKEND_OPS_PRETTY_NAMES = {
88 "backend_name": "Backend", "op_name": "GGML op", "op_params": "Op parameters", "test_mode": "Mode",
89 "supported": "Supported", "passed": "Passed", "error_message": "Error",
90 "flops": "FLOPS", "bandwidth_gb_s": "Bandwidth (GB/s)", "memory_kb": "Memory (KB)", "n_runs": "Runs"
91}
92
93DEFAULT_SHOW_LLAMA_BENCH = ["model_type"] # Always show these properties by default.
94DEFAULT_HIDE_LLAMA_BENCH = ["model_filename"] # Always hide these properties by default.
95
96DEFAULT_SHOW_TEST_BACKEND_OPS = ["backend_name", "op_name"] # Always show these properties by default.
97DEFAULT_HIDE_TEST_BACKEND_OPS = ["error_message"] # Always hide these properties by default.
98
99GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon ", "AMD Instinct "] # Strip prefixes for smaller tables.
100MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
101
102DESCRIPTION = """Creates tables from llama-bench or test-backend-ops data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
103
104For llama-bench:
105$ git checkout master
106$ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t llama-bench -j $(nproc)
107$ ./llama-bench -o sql | sqlite3 llama-bench.sqlite
108$ git checkout some_branch
109$ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t llama-bench -j $(nproc)
110$ ./llama-bench -o sql | sqlite3 llama-bench.sqlite
111$ ./scripts/compare-llama-bench.py
112
113For test-backend-ops:
114$ git checkout master
115$ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t test-backend-ops -j $(nproc)
116$ ./test-backend-ops perf --output sql | sqlite3 test-backend-ops.sqlite
117$ git checkout some_branch
118$ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t test-backend-ops -j $(nproc)
119$ ./test-backend-ops perf --output sql | sqlite3 test-backend-ops.sqlite
120$ ./scripts/compare-llama-bench.py --tool test-backend-ops -i test-backend-ops.sqlite
121
122Performance numbers from multiple runs per commit are averaged WITHOUT being weighted by the --repetitions parameter of llama-bench.
123"""
124
125parser = argparse.ArgumentParser(
126 description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter)
127help_b = (
128 "The baseline commit to compare performance to. "
129 "Accepts either a branch name, tag name, or commit hash. "
130 "Defaults to latest master commit with data."
131)
132parser.add_argument("-b", "--baseline", help=help_b)
133help_c = (
134 "The commit whose performance is to be compared to the baseline. "
135 "Accepts either a branch name, tag name, or commit hash. "
136 "Defaults to the non-master commit for which llama-bench was run most recently."
137)
138parser.add_argument("-c", "--compare", help=help_c)
139help_t = (
140 "The tool whose data is being compared. "
141 "Either 'llama-bench' or 'test-backend-ops'. "
142 "This determines the database schema and comparison logic used. "
143 "If left unspecified, try to determine from the input file."
144)
145parser.add_argument("-t", "--tool", help=help_t, default=None, choices=[None, "llama-bench", "test-backend-ops"])
146help_i = (
147 "JSON/JSONL/SQLite/CSV files for comparing commits. "
148 "Specify multiple times to use multiple input files (JSON/CSV only). "
149 "Defaults to 'llama-bench.sqlite' in the current working directory. "
150 "If no such file is found and there is exactly one .sqlite file in the current directory, "
151 "that file is instead used as input."
152)
153parser.add_argument("-i", "--input", action="append", help=help_i)
154help_o = (
155 "Output format for the table. "
156 "Defaults to 'pipe' (GitHub compatible). "
157 "Also supports e.g. 'latex' or 'mediawiki'. "
158 "See tabulate documentation for full list."
159)
160parser.add_argument("-o", "--output", help=help_o, default="pipe")
161help_s = (
162 "Columns to add to the table. "
163 "Accepts a comma-separated list of values. "
164 f"Legal values for test-backend-ops: {', '.join(TEST_BACKEND_OPS_KEY_PROPERTIES)}. "
165 f"Legal values for llama-bench: {', '.join(LLAMA_BENCH_KEY_PROPERTIES[:-3])}. "
166 "Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) "
167 "plus any column where not all data points are the same. "
168 "If the columns are manually specified, then the results for each unique combination of the "
169 "specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench."
170)
171parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
172parser.add_argument("-s", "--show", help=help_s)
173parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
174parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
175parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
176parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)")
177
178known_args, unknown_args = parser.parse_known_args()
179
180logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
181
182
183if known_args.check:
184 # Check if all required Python libraries are installed. Would have failed earlier if not.
185 sys.exit(0)
186
187if unknown_args:
188 logger.error(f"Received unknown args: {unknown_args}.\n")
189 parser.print_help()
190 sys.exit(1)
191
192input_file = known_args.input
193tool = known_args.tool
194
195if not input_file:
196 if tool == "llama-bench" and os.path.exists("./llama-bench.sqlite"):
197 input_file = ["llama-bench.sqlite"]
198 elif tool == "test-backend-ops" and os.path.exists("./test-backend-ops.sqlite"):
199 input_file = ["test-backend-ops.sqlite"]
200
201if not input_file:
202 sqlite_files = glob("*.sqlite")
203 if len(sqlite_files) == 1:
204 input_file = sqlite_files
205
206if not input_file:
207 logger.error("Cannot find a suitable input file, please provide one.\n")
208 parser.print_help()
209 sys.exit(1)
210
211
212class LlamaBenchData:
213 repo: Optional[git.Repo]
214 build_len_min: int
215 build_len_max: int
216 build_len: int = 8
217 builds: list[str] = []
218 tool: str = "llama-bench" # Tool type: "llama-bench" or "test-backend-ops"
219
220 def __init__(self, tool: str = "llama-bench"):
221 self.tool = tool
222 try:
223 self.repo = git.Repo(".", search_parent_directories=True)
224 except git.InvalidGitRepositoryError:
225 self.repo = None
226
227 # Set schema-specific properties based on tool
228 if self.tool == "llama-bench":
229 self.check_keys = set(LLAMA_BENCH_KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
230 elif self.tool == "test-backend-ops":
231 self.check_keys = set(TEST_BACKEND_OPS_KEY_PROPERTIES + ["build_commit", "test_time"])
232 else:
233 assert False
234
235 def _builds_init(self):
236 self.build_len = self.build_len_min
237
238 def _check_keys(self, keys: set) -> Optional[set]:
239 """Private helper method that checks against required data keys and returns missing ones."""
240 if not keys >= self.check_keys:
241 return self.check_keys - keys
242 return None
243
244 def find_parent_in_data(self, commit: git.Commit) -> Optional[str]:
245 """Helper method to find the most recent parent measured in number of commits for which there is data."""
246 heap: list[tuple[int, git.Commit]] = [(0, commit)]
247 seen_hexsha8 = set()
248 while heap:
249 depth, current_commit = heapq.heappop(heap)
250 current_hexsha8 = commit.hexsha[:self.build_len]
251 if current_hexsha8 in self.builds:
252 return current_hexsha8
253 for parent in commit.parents:
254 parent_hexsha8 = parent.hexsha[:self.build_len]
255 if parent_hexsha8 not in seen_hexsha8:
256 seen_hexsha8.add(parent_hexsha8)
257 heapq.heappush(heap, (depth + 1, parent))
258 return None
259
260 def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]:
261 """Helper method to recursively get hexsha8 values for all parents of a commit."""
262 unvisited = [commit]
263 visited = []
264
265 while unvisited:
266 current_commit = unvisited.pop(0)
267 visited.append(current_commit.hexsha[:self.build_len])
268 for parent in current_commit.parents:
269 if parent.hexsha[:self.build_len] not in visited:
270 unvisited.append(parent)
271
272 return visited
273
274 def get_commit_name(self, hexsha8: str) -> str:
275 """Helper method to find a human-readable name for a commit if possible."""
276 if self.repo is None:
277 return hexsha8
278 for h in self.repo.heads:
279 if h.commit.hexsha[:self.build_len] == hexsha8:
280 return h.name
281 for t in self.repo.tags:
282 if t.commit.hexsha[:self.build_len] == hexsha8:
283 return t.name
284 return hexsha8
285
286 def get_commit_hexsha8(self, name: str) -> Optional[str]:
287 """Helper method to search for a commit given a human-readable name."""
288 if self.repo is None:
289 return None
290 for h in self.repo.heads:
291 if h.name == name:
292 return h.commit.hexsha[:self.build_len]
293 for t in self.repo.tags:
294 if t.name == name:
295 return t.commit.hexsha[:self.build_len]
296 for c in self.repo.iter_commits("--all"):
297 if c.hexsha[:self.build_len] == name[:self.build_len]:
298 return c.hexsha[:self.build_len]
299 return None
300
301 def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
302 """Helper method that gets rows of (build_commit, test_time) sorted by the latter."""
303 return []
304
305 def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
306 """
307 Helper method that gets table rows for some list of properties.
308 Rows are created by combining those where all provided properties are equal.
309 The resulting rows are then grouped by the provided properties and the t/s values are averaged.
310 The returned rows are unique in terms of property combinations.
311 """
312 return []
313
314
315class LlamaBenchDataSQLite3(LlamaBenchData):
316 connection: Optional[sqlite3.Connection] = None
317 cursor: sqlite3.Cursor
318 table_name: str
319
320 def __init__(self, tool: str = "llama-bench"):
321 super().__init__(tool)
322 if self.connection is None:
323 self.connection = sqlite3.connect(":memory:")
324 self.cursor = self.connection.cursor()
325
326 # Set table name and schema based on tool
327 if self.tool == "llama-bench":
328 self.table_name = "llama_bench"
329 db_fields = LLAMA_BENCH_DB_FIELDS
330 db_types = LLAMA_BENCH_DB_TYPES
331 elif self.tool == "test-backend-ops":
332 self.table_name = "test_backend_ops"
333 db_fields = TEST_BACKEND_OPS_DB_FIELDS
334 db_types = TEST_BACKEND_OPS_DB_TYPES
335 else:
336 assert False
337
338 self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});")
339
340 def _builds_init(self):
341 if self.connection:
342 self.build_len_min = self.cursor.execute(f"SELECT MIN(LENGTH(build_commit)) from {self.table_name};").fetchone()[0]
343 self.build_len_max = self.cursor.execute(f"SELECT MAX(LENGTH(build_commit)) from {self.table_name};").fetchone()[0]
344
345 if self.build_len_min != self.build_len_max:
346 logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
347 "Try purging the the database of old commits.")
348 self.cursor.execute(f"UPDATE {self.table_name} SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
349
350 builds = self.cursor.execute(f"SELECT DISTINCT build_commit FROM {self.table_name};").fetchall()
351 self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
352 super()._builds_init()
353
354 def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
355 data = self.cursor.execute(
356 f"SELECT build_commit, test_time FROM {self.table_name} ORDER BY test_time;").fetchall()
357 return reversed(data) if reverse else data
358
359 def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
360 if self.tool == "llama-bench":
361 return self._get_rows_llama_bench(properties, hexsha8_baseline, hexsha8_compare)
362 elif self.tool == "test-backend-ops":
363 return self._get_rows_test_backend_ops(properties, hexsha8_baseline, hexsha8_compare)
364 else:
365 assert False
366
367 def _get_rows_llama_bench(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
368 select_string = ", ".join(
369 [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
370 equal_string = " AND ".join(
371 [f"tb.{p} = tc.{p}" for p in LLAMA_BENCH_KEY_PROPERTIES] + [
372 f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
373 )
374 group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
375 query = (f"SELECT {select_string} FROM {self.table_name} tb JOIN {self.table_name} tc ON {equal_string} "
376 f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
377 return self.cursor.execute(query).fetchall()
378
379 def _get_rows_test_backend_ops(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
380 # For test-backend-ops, we compare FLOPS and bandwidth metrics (prioritizing FLOPS over bandwidth)
381 select_string = ", ".join(
382 [f"tb.{p}" for p in properties] + [
383 "AVG(tb.flops)", "AVG(tc.flops)",
384 "AVG(tb.bandwidth_gb_s)", "AVG(tc.bandwidth_gb_s)"
385 ])
386 equal_string = " AND ".join(
387 [f"tb.{p} = tc.{p}" for p in TEST_BACKEND_OPS_KEY_PROPERTIES] + [
388 f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'",
389 "tb.supported = 1", "tc.supported = 1", "tb.passed = 1", "tc.passed = 1"] # Only compare successful tests
390 )
391 group_order_string = ", ".join([f"tb.{p}" for p in properties])
392 query = (f"SELECT {select_string} FROM {self.table_name} tb JOIN {self.table_name} tc ON {equal_string} "
393 f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
394 return self.cursor.execute(query).fetchall()
395
396
397class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
398 def __init__(self, data_file: str, tool: Any):
399 self.connection = sqlite3.connect(data_file)
400 self.cursor = self.connection.cursor()
401
402 # Check which table exists in the database
403 tables = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
404 table_names = [table[0] for table in tables]
405
406 # Tool selection logic
407 if tool is None:
408 if "llama_bench" in table_names:
409 self.table_name = "llama_bench"
410 tool = "llama-bench"
411 elif "test_backend_ops" in table_names:
412 self.table_name = "test_backend_ops"
413 tool = "test-backend-ops"
414 else:
415 raise RuntimeError(f"No suitable table found in database. Available tables: {table_names}")
416 elif tool == "llama-bench":
417 if "llama_bench" in table_names:
418 self.table_name = "llama_bench"
419 tool = "llama-bench"
420 else:
421 raise RuntimeError(f"Table 'test' not found for tool 'llama-bench'. Available tables: {table_names}")
422 elif tool == "test-backend-ops":
423 if "test_backend_ops" in table_names:
424 self.table_name = "test_backend_ops"
425 tool = "test-backend-ops"
426 else:
427 raise RuntimeError(f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: {table_names}")
428 else:
429 raise RuntimeError(f"Unknown tool: {tool}")
430
431 super().__init__(tool)
432 self._builds_init()
433
434 @staticmethod
435 def valid_format(data_file: str) -> bool:
436 connection = sqlite3.connect(data_file)
437 cursor = connection.cursor()
438
439 try:
440 if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
441 raise sqlite3.DatabaseError("The provided input file does not exist or is empty.")
442 except sqlite3.DatabaseError as e:
443 logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e)
444 cursor = None
445
446 connection.close()
447 return True if cursor else False
448
449
450class LlamaBenchDataJSONL(LlamaBenchDataSQLite3):
451 def __init__(self, data_file: str, tool: str = "llama-bench"):
452 super().__init__(tool)
453
454 # Get the appropriate field list based on tool
455 db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
456
457 with open(data_file, "r", encoding="utf-8") as fp:
458 for i, line in enumerate(fp):
459 parsed = json.loads(line)
460
461 for k in parsed.keys() - set(db_fields):
462 del parsed[k]
463
464 if (missing_keys := self._check_keys(parsed.keys())):
465 raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
466
467 self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
468
469 self._builds_init()
470
471 @staticmethod
472 def valid_format(data_file: str) -> bool:
473 try:
474 with open(data_file, "r", encoding="utf-8") as fp:
475 for line in fp:
476 json.loads(line)
477 break
478 except Exception as e:
479 logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e)
480 return False
481
482 return True
483
484
485class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
486 def __init__(self, data_files: list[str], tool: str = "llama-bench"):
487 super().__init__(tool)
488
489 # Get the appropriate field list based on tool
490 db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
491
492 for data_file in data_files:
493 with open(data_file, "r", encoding="utf-8") as fp:
494 parsed = json.load(fp)
495
496 for i, entry in enumerate(parsed):
497 for k in entry.keys() - set(db_fields):
498 del entry[k]
499
500 if (missing_keys := self._check_keys(entry.keys())):
501 raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
502
503 self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values()))
504
505 self._builds_init()
506
507 @staticmethod
508 def valid_format(data_files: list[str]) -> bool:
509 if not data_files:
510 return False
511
512 for data_file in data_files:
513 try:
514 with open(data_file, "r", encoding="utf-8") as fp:
515 json.load(fp)
516 except Exception as e:
517 logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e)
518 return False
519
520 return True
521
522
523class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
524 def __init__(self, data_files: list[str], tool: str = "llama-bench"):
525 super().__init__(tool)
526
527 # Get the appropriate field list based on tool
528 db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
529
530 for data_file in data_files:
531 with open(data_file, "r", encoding="utf-8") as fp:
532 for i, parsed in enumerate(csv.DictReader(fp)):
533 keys = set(parsed.keys())
534
535 for k in keys - set(db_fields):
536 del parsed[k]
537
538 if (missing_keys := self._check_keys(keys)):
539 raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
540
541 self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
542
543 self._builds_init()
544
545 @staticmethod
546 def valid_format(data_files: list[str]) -> bool:
547 if not data_files:
548 return False
549
550 for data_file in data_files:
551 try:
552 with open(data_file, "r", encoding="utf-8") as fp:
553 for parsed in csv.DictReader(fp):
554 break
555 except Exception as e:
556 logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e)
557 return False
558
559 return True
560
561
562def format_flops(flops_value: float) -> str:
563 """Format FLOPS values with appropriate units for better readability."""
564 if flops_value == 0:
565 return "0.00"
566
567 # Define unit thresholds and names
568 units = [
569 (1e12, "T"), # TeraFLOPS
570 (1e9, "G"), # GigaFLOPS
571 (1e6, "M"), # MegaFLOPS
572 (1e3, "k"), # kiloFLOPS
573 (1, "") # FLOPS
574 ]
575
576 for threshold, unit in units:
577 if abs(flops_value) >= threshold:
578 formatted_value = flops_value / threshold
579 if formatted_value >= 100:
580 return f"{formatted_value:.1f}{unit}"
581 else:
582 return f"{formatted_value:.2f}{unit}"
583
584 # Fallback for very small values
585 return f"{flops_value:.2f}"
586
587
588def format_flops_for_table(flops_value: float, target_unit: str) -> str:
589 """Format FLOPS values for table display without unit suffix (since unit is in header)."""
590 if flops_value == 0:
591 return "0.00"
592
593 # Define unit thresholds based on target unit
594 unit_divisors = {
595 "TFLOPS": 1e12,
596 "GFLOPS": 1e9,
597 "MFLOPS": 1e6,
598 "kFLOPS": 1e3,
599 "FLOPS": 1
600 }
601
602 divisor = unit_divisors.get(target_unit, 1)
603 formatted_value = flops_value / divisor
604
605 if formatted_value >= 100:
606 return f"{formatted_value:.1f}"
607 else:
608 return f"{formatted_value:.2f}"
609
610
611def get_flops_unit_name(flops_values: list) -> str:
612 """Determine the best FLOPS unit name based on the magnitude of values."""
613 if not flops_values or all(v == 0 for v in flops_values):
614 return "FLOPS"
615
616 # Find the maximum absolute value to determine appropriate unit
617 max_flops = max(abs(v) for v in flops_values if v != 0)
618
619 if max_flops >= 1e12:
620 return "TFLOPS"
621 elif max_flops >= 1e9:
622 return "GFLOPS"
623 elif max_flops >= 1e6:
624 return "MFLOPS"
625 elif max_flops >= 1e3:
626 return "kFLOPS"
627 else:
628 return "FLOPS"
629
630
631bench_data = None
632if len(input_file) == 1:
633 if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
634 bench_data = LlamaBenchDataSQLite3File(input_file[0], tool)
635 elif LlamaBenchDataJSON.valid_format(input_file):
636 bench_data = LlamaBenchDataJSON(input_file, tool)
637 elif LlamaBenchDataJSONL.valid_format(input_file[0]):
638 bench_data = LlamaBenchDataJSONL(input_file[0], tool)
639 elif LlamaBenchDataCSV.valid_format(input_file):
640 bench_data = LlamaBenchDataCSV(input_file, tool)
641else:
642 if LlamaBenchDataJSON.valid_format(input_file):
643 bench_data = LlamaBenchDataJSON(input_file, tool)
644 elif LlamaBenchDataCSV.valid_format(input_file):
645 bench_data = LlamaBenchDataCSV(input_file, tool)
646
647if not bench_data:
648 raise RuntimeError("No valid (or some invalid) input files found.")
649
650if not bench_data.builds:
651 raise RuntimeError(f"{input_file} does not contain any builds.")
652
653tool = bench_data.tool # May have chosen a default if tool was None.
654
655
656hexsha8_baseline = name_baseline = None
657
658# If the user specified a baseline, try to find a commit for it:
659if known_args.baseline is not None:
660 if known_args.baseline in bench_data.builds:
661 hexsha8_baseline = known_args.baseline
662 if hexsha8_baseline is None:
663 hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline)
664 name_baseline = known_args.baseline
665 if hexsha8_baseline is None:
666 logger.error(f"cannot find data for baseline={known_args.baseline}.")
667 sys.exit(1)
668# Otherwise, search for the most recent parent of master for which there is data:
669elif bench_data.repo is not None:
670 hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit)
671
672 if hexsha8_baseline is None:
673 logger.error("No baseline was provided and did not find data for any master branch commits.\n")
674 parser.print_help()
675 sys.exit(1)
676else:
677 logger.error("No baseline was provided and the current working directory "
678 "is not part of a git repository from which a baseline could be inferred.\n")
679 parser.print_help()
680 sys.exit(1)
681
682
683name_baseline = bench_data.get_commit_name(hexsha8_baseline)
684
685hexsha8_compare = name_compare = None
686
687# If the user has specified a compare value, try to find a corresponding commit:
688if known_args.compare is not None:
689 if known_args.compare in bench_data.builds:
690 hexsha8_compare = known_args.compare
691 if hexsha8_compare is None:
692 hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare)
693 name_compare = known_args.compare
694 if hexsha8_compare is None:
695 logger.error(f"cannot find data for compare={known_args.compare}.")
696 sys.exit(1)
697# Otherwise, search for the commit for llama-bench was most recently run
698# and that is not a parent of master:
699elif bench_data.repo is not None:
700 hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit)
701 for (hexsha8, _) in bench_data.builds_timestamp(reverse=True):
702 if hexsha8 not in hexsha8s_master:
703 hexsha8_compare = hexsha8
704 break
705
706 if hexsha8_compare is None:
707 logger.error("No compare target was provided and did not find data for any non-master commits.\n")
708 parser.print_help()
709 sys.exit(1)
710else:
711 logger.error("No compare target was provided and the current working directory "
712 "is not part of a git repository from which a compare target could be inferred.\n")
713 parser.print_help()
714 sys.exit(1)
715
716name_compare = bench_data.get_commit_name(hexsha8_compare)
717
718# Get tool-specific configuration
719if tool == "llama-bench":
720 key_properties = LLAMA_BENCH_KEY_PROPERTIES
721 bool_properties = LLAMA_BENCH_BOOL_PROPERTIES
722 pretty_names = LLAMA_BENCH_PRETTY_NAMES
723 default_show = DEFAULT_SHOW_LLAMA_BENCH
724 default_hide = DEFAULT_HIDE_LLAMA_BENCH
725elif tool == "test-backend-ops":
726 key_properties = TEST_BACKEND_OPS_KEY_PROPERTIES
727 bool_properties = TEST_BACKEND_OPS_BOOL_PROPERTIES
728 pretty_names = TEST_BACKEND_OPS_PRETTY_NAMES
729 default_show = DEFAULT_SHOW_TEST_BACKEND_OPS
730 default_hide = DEFAULT_HIDE_TEST_BACKEND_OPS
731else:
732 assert False
733
734# If the user provided columns to group the results by, use them:
735if known_args.show is not None:
736 show = known_args.show.split(",")
737 unknown_cols = []
738 for prop in show:
739 valid_props = key_properties if tool == "test-backend-ops" else key_properties[:-3] # Exclude n_prompt, n_gen, n_depth for llama-bench
740 if prop not in valid_props:
741 unknown_cols.append(prop)
742 if unknown_cols:
743 logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
744 parser.print_usage()
745 sys.exit(1)
746 rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
747# Otherwise, select those columns where the values are not all the same:
748else:
749 rows_full = bench_data.get_rows(key_properties, hexsha8_baseline, hexsha8_compare)
750 properties_different = []
751
752 if tool == "llama-bench":
753 # For llama-bench, skip n_prompt, n_gen, n_depth from differentiation logic
754 check_properties = [kp for kp in key_properties if kp not in ["n_prompt", "n_gen", "n_depth"]]
755 for i, kp_i in enumerate(key_properties):
756 if kp_i in default_show or kp_i in ["n_prompt", "n_gen", "n_depth"]:
757 continue
758 for row_full in rows_full:
759 if row_full[i] != rows_full[0][i]:
760 properties_different.append(kp_i)
761 break
762 elif tool == "test-backend-ops":
763 # For test-backend-ops, check all key properties
764 for i, kp_i in enumerate(key_properties):
765 if kp_i in default_show:
766 continue
767 for row_full in rows_full:
768 if row_full[i] != rows_full[0][i]:
769 properties_different.append(kp_i)
770 break
771 else:
772 assert False
773
774 show = []
775
776 if tool == "llama-bench":
777 # Show CPU and/or GPU by default even if the hardware for all results is the same:
778 if rows_full and "n_gpu_layers" not in properties_different:
779 ngl = int(rows_full[0][key_properties.index("n_gpu_layers")])
780
781 if ngl != 99 and "cpu_info" not in properties_different:
782 show.append("cpu_info")
783
784 show += properties_different
785
786 index_default = 0
787 for prop in ["cpu_info", "gpu_info", "n_gpu_layers", "main_gpu"]:
788 if prop in show:
789 index_default += 1
790 show = show[:index_default] + default_show + show[index_default:]
791 elif tool == "test-backend-ops":
792 show = default_show + properties_different
793 else:
794 assert False
795
796 for prop in default_hide:
797 try:
798 show.remove(prop)
799 except ValueError:
800 pass
801
802 # Add plot_x parameter to parameters to show if it's not already present:
803 if known_args.plot:
804 for k, v in pretty_names.items():
805 if v == known_args.plot_x and k not in show:
806 show.append(k)
807 break
808
809 rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
810
811if not rows_show:
812 logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")
813 sys.exit(1)
814
815table = []
816primary_metric = "FLOPS" # Default to FLOPS for test-backend-ops
817
818if tool == "llama-bench":
819 # For llama-bench, create test names and compare avg_ts values
820 for row in rows_show:
821 n_prompt = int(row[-5])
822 n_gen = int(row[-4])
823 n_depth = int(row[-3])
824 if n_prompt != 0 and n_gen == 0:
825 test_name = f"pp{n_prompt}"
826 elif n_prompt == 0 and n_gen != 0:
827 test_name = f"tg{n_gen}"
828 else:
829 test_name = f"pp{n_prompt}+tg{n_gen}"
830 if n_depth != 0:
831 test_name = f"{test_name}@d{n_depth}"
832 # Regular columns test name avg t/s values Speedup
833 # VVVVVVVVVVVVV VVVVVVVVV VVVVVVVVVVVVVV VVVVVVV
834 table.append(list(row[:-5]) + [test_name] + list(row[-2:]) + [float(row[-1]) / float(row[-2])])
835elif tool == "test-backend-ops":
836 # Determine the primary metric by checking rows until we find one with valid data
837 if rows_show:
838 primary_metric = "FLOPS" # Default to FLOPS
839 flops_values = []
840
841 # Collect all FLOPS values to determine the best unit
842 for sample_row in rows_show:
843 baseline_flops = float(sample_row[-4])
844 compare_flops = float(sample_row[-3])
845 baseline_bandwidth = float(sample_row[-2])
846
847 if baseline_flops > 0:
848 flops_values.extend([baseline_flops, compare_flops])
849 elif baseline_bandwidth > 0 and not flops_values:
850 primary_metric = "Bandwidth (GB/s)"
851
852 # If we have FLOPS data, determine the appropriate unit
853 if flops_values:
854 primary_metric = get_flops_unit_name(flops_values)
855
856 # For test-backend-ops, prioritize FLOPS > bandwidth for comparison
857 for row in rows_show:
858 # Extract metrics: flops, bandwidth_gb_s (baseline and compare)
859 baseline_flops = float(row[-4])
860 compare_flops = float(row[-3])
861 baseline_bandwidth = float(row[-2])
862 compare_bandwidth = float(row[-1])
863
864 # Determine which metric to use for comparison (prioritize FLOPS > bandwidth)
865 if baseline_flops > 0 and compare_flops > 0:
866 # Use FLOPS comparison (higher is better)
867 speedup = compare_flops / baseline_flops
868 baseline_str = format_flops_for_table(baseline_flops, primary_metric)
869 compare_str = format_flops_for_table(compare_flops, primary_metric)
870 elif baseline_bandwidth > 0 and compare_bandwidth > 0:
871 # Use bandwidth comparison (higher is better)
872 speedup = compare_bandwidth / baseline_bandwidth
873 baseline_str = f"{baseline_bandwidth:.2f}"
874 compare_str = f"{compare_bandwidth:.2f}"
875 else:
876 # Fallback if no valid data is available
877 baseline_str = "N/A"
878 compare_str = "N/A"
879 from math import nan
880 speedup = nan
881
882 table.append(list(row[:-4]) + [baseline_str, compare_str, speedup])
883else:
884 assert False
885
886# Some a-posteriori fixes to make the table contents prettier:
887for bool_property in bool_properties:
888 if bool_property in show:
889 ip = show.index(bool_property)
890 for row_table in table:
891 row_table[ip] = "Yes" if int(row_table[ip]) == 1 else "No"
892
893if tool == "llama-bench":
894 if "model_type" in show:
895 ip = show.index("model_type")
896 for (old, new) in MODEL_SUFFIX_REPLACE.items():
897 for row_table in table:
898 row_table[ip] = row_table[ip].replace(old, new)
899
900 if "model_size" in show:
901 ip = show.index("model_size")
902 for row_table in table:
903 row_table[ip] = float(row_table[ip]) / 1024 ** 3
904
905 if "gpu_info" in show:
906 ip = show.index("gpu_info")
907 for row_table in table:
908 for gns in GPU_NAME_STRIP:
909 row_table[ip] = row_table[ip].replace(gns, "")
910
911 gpu_names = row_table[ip].split(", ")
912 num_gpus = len(gpu_names)
913 all_names_the_same = len(set(gpu_names)) == 1
914 if len(gpu_names) >= 2 and all_names_the_same:
915 row_table[ip] = f"{num_gpus}x {gpu_names[0]}"
916
917headers = [pretty_names.get(p, p) for p in show]
918if tool == "llama-bench":
919 headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
920elif tool == "test-backend-ops":
921 headers += [f"{primary_metric} {name_baseline}", f"{primary_metric} {name_compare}", "Speedup"]
922else:
923 assert False
924
925if known_args.plot:
926 def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False, tool_type: str = "llama-bench", metric_name: str = "t/s"):
927 try:
928 import matplotlib
929 import matplotlib.pyplot as plt
930 matplotlib.use('Agg')
931 except ImportError as e:
932 logger.error("matplotlib is required for --plot.")
933 raise e
934
935 data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
936 plot_x_index = None
937 plot_x_label = plot_x_param
938
939 if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
940 pretty_name = LLAMA_BENCH_PRETTY_NAMES.get(plot_x_param, plot_x_param)
941 if pretty_name in data_headers:
942 plot_x_index = data_headers.index(pretty_name)
943 plot_x_label = pretty_name
944 elif plot_x_param in data_headers:
945 plot_x_index = data_headers.index(plot_x_param)
946 plot_x_label = plot_x_param
947 else:
948 logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
949 return
950
951 grouped_data = {}
952
953 for i, row in enumerate(table_data):
954 group_key_parts = []
955 test_name = row[-4]
956
957 base_test = ""
958 x_value = None
959
960 if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
961 for j, val in enumerate(row[:-4]):
962 header_name = data_headers[j]
963 if val is not None and str(val).strip():
964 group_key_parts.append(f"{header_name}={val}")
965
966 if plot_x_param == "n_prompt" and "pp" in test_name:
967 base_test = test_name.split("@")[0]
968 x_value = base_test
969 elif plot_x_param == "n_gen" and "tg" in test_name:
970 x_value = test_name.split("@")[0]
971 elif plot_x_param == "n_depth" and "@d" in test_name:
972 base_test = test_name.split("@d")[0]
973 x_value = int(test_name.split("@d")[1])
974 else:
975 base_test = test_name
976
977 if base_test.strip():
978 group_key_parts.append(f"Test={base_test}")
979 else:
980 for j, val in enumerate(row[:-4]):
981 if j != plot_x_index:
982 header_name = data_headers[j]
983 if val is not None and str(val).strip():
984 group_key_parts.append(f"{header_name}={val}")
985 else:
986 x_value = val
987
988 group_key_parts.append(f"Test={test_name}")
989
990 group_key = tuple(group_key_parts)
991
992 if group_key not in grouped_data:
993 grouped_data[group_key] = []
994
995 grouped_data[group_key].append({
996 'x_value': x_value,
997 'baseline': float(row[-3]),
998 'compare': float(row[-2]),
999 'speedup': float(row[-1])
1000 })
1001
1002 if not grouped_data:
1003 logger.error("No data available for plotting")
1004 return
1005
1006 def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
1007 from math import ceil
1008 cols = 1 if num_groups == 1 else min(max_cols, num_groups)
1009 rows = ceil(num_groups / cols)
1010
1011 # Scale figure size by grid dimensions
1012 w, h = base_size
1013 fig, ax_arr = plt.subplots(rows, cols,
1014 figsize=(w * cols, h * rows),
1015 squeeze=False)
1016
1017 axes = ax_arr.flatten()[:num_groups]
1018 return fig, axes
1019
1020 num_groups = len(grouped_data)
1021 fig, axes = make_axes(num_groups)
1022
1023 plot_idx = 0
1024
1025 for group_key, points in grouped_data.items():
1026 if plot_idx >= len(axes):
1027 break
1028 ax = axes[plot_idx]
1029
1030 try:
1031 points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
1032 x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
1033 except ValueError:
1034 points_sorted = sorted(points, key=lambda p: group_key)
1035 x_values = [p['x_value'] for p in points_sorted]
1036
1037 baseline_vals = [p['baseline'] for p in points_sorted]
1038 compare_vals = [p['compare'] for p in points_sorted]
1039
1040 ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
1041 label=f'{baseline_name}', linewidth=2, markersize=6)
1042 ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
1043 label=f'{compare_name}', linewidth=2, markersize=6)
1044
1045 if log_scale:
1046 ax.set_xscale('log', base=2)
1047 unique_x = sorted(set(x_values))
1048 ax.set_xticks(unique_x)
1049 ax.set_xticklabels([str(int(x)) for x in unique_x])
1050
1051 title_parts = []
1052 for part in group_key:
1053 if '=' in part:
1054 key, value = part.split('=', 1)
1055 title_parts.append(f"{key}: {value}")
1056
1057 title = ', '.join(title_parts) if title_parts else "Performance comparison"
1058
1059 # Determine y-axis label based on tool type
1060 if tool_type == "llama-bench":
1061 y_label = "Tokens per second (t/s)"
1062 elif tool_type == "test-backend-ops":
1063 y_label = metric_name
1064 else:
1065 assert False
1066
1067 ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
1068 ax.set_ylabel(y_label, fontsize=12, fontweight='bold')
1069 ax.set_title(title, fontsize=12, fontweight='bold')
1070 ax.legend(loc='best', fontsize=10)
1071 ax.grid(True, alpha=0.3)
1072
1073 plot_idx += 1
1074
1075 for i in range(plot_idx, len(axes)):
1076 axes[i].set_visible(False)
1077
1078 fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
1079 fontsize=14, fontweight='bold')
1080 fig.subplots_adjust(top=1)
1081
1082 plt.tight_layout()
1083 plt.savefig(output_file, dpi=300, bbox_inches='tight')
1084 plt.close()
1085
1086 create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale, tool, primary_metric)
1087
1088print(tabulate( # noqa: NP100
1089 table,
1090 headers=headers,
1091 floatfmt=".2f",
1092 tablefmt=known_args.output
1093))