1#!/usr/bin/env python3
   2
   3import argparse
   4import csv
   5import heapq
   6import json
   7import logging
   8import os
   9import sqlite3
  10import sys
  11from collections.abc import Iterator, Sequence
  12from glob import glob
  13from typing import Any, Optional, Union
  14
  15try:
  16    import git
  17    from tabulate import tabulate
  18except ImportError as e:
  19    print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
  20    raise e
  21
  22
  23logger = logging.getLogger("compare-llama-bench")
  24
  25# All llama-bench SQL fields
  26LLAMA_BENCH_DB_FIELDS = [
  27    "build_commit", "build_number", "cpu_info",       "gpu_info",   "backends",     "model_filename",
  28    "model_type",   "model_size",   "model_n_params", "n_batch",    "n_ubatch",     "n_threads",
  29    "cpu_mask",     "cpu_strict",   "poll",           "type_k",     "type_v",       "n_gpu_layers",
  30    "split_mode",   "main_gpu",     "no_kv_offload",  "flash_attn", "tensor_split", "tensor_buft_overrides",
  31    "use_mmap",     "embeddings",   "no_op_offload",  "n_prompt",   "n_gen",        "n_depth",
  32    "test_time",    "avg_ns",       "stddev_ns",      "avg_ts",     "stddev_ts",    "n_cpu_moe"
  33]
  34
  35LLAMA_BENCH_DB_TYPES = [
  36    "TEXT",    "INTEGER", "TEXT",    "TEXT",    "TEXT",    "TEXT",
  37    "TEXT",    "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
  38    "TEXT",    "INTEGER", "INTEGER", "TEXT",    "TEXT",    "INTEGER",
  39    "TEXT",    "INTEGER", "INTEGER", "INTEGER", "TEXT",    "TEXT",
  40    "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
  41    "TEXT",    "INTEGER", "INTEGER", "REAL",    "REAL",    "INTEGER",
  42]
  43
  44# All test-backend-ops SQL fields
  45TEST_BACKEND_OPS_DB_FIELDS = [
  46    "test_time", "build_commit", "backend_name",  "op_name", "op_params", "test_mode",
  47    "supported", "passed",       "error_message", "time_us", "flops",     "bandwidth_gb_s",
  48    "memory_kb", "n_runs"
  49]
  50
  51TEST_BACKEND_OPS_DB_TYPES = [
  52    "TEXT",    "TEXT",    "TEXT", "TEXT", "TEXT", "TEXT",
  53    "INTEGER", "INTEGER", "TEXT", "REAL", "REAL", "REAL",
  54    "INTEGER", "INTEGER"
  55]
  56
  57assert len(LLAMA_BENCH_DB_FIELDS) == len(LLAMA_BENCH_DB_TYPES)
  58assert len(TEST_BACKEND_OPS_DB_FIELDS) == len(TEST_BACKEND_OPS_DB_TYPES)
  59
  60# Properties by which to differentiate results per commit for llama-bench:
  61LLAMA_BENCH_KEY_PROPERTIES = [
  62    "cpu_info", "gpu_info", "backends", "n_gpu_layers", "n_cpu_moe", "tensor_buft_overrides", "model_filename", "model_type",
  63    "n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v",
  64    "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth"
  65]
  66
  67# Properties by which to differentiate results per commit for test-backend-ops:
  68TEST_BACKEND_OPS_KEY_PROPERTIES = [
  69    "backend_name", "op_name", "op_params", "test_mode"
  70]
  71
  72# Properties that are boolean and are converted to Yes/No for the table:
  73LLAMA_BENCH_BOOL_PROPERTIES = ["embeddings", "cpu_strict", "use_mmap", "no_kv_offload", "flash_attn"]
  74TEST_BACKEND_OPS_BOOL_PROPERTIES = ["supported", "passed"]
  75
  76# Header names for the table (llama-bench):
  77LLAMA_BENCH_PRETTY_NAMES = {
  78    "cpu_info": "CPU", "gpu_info": "GPU", "backends": "Backends", "n_gpu_layers": "GPU layers",
  79    "tensor_buft_overrides": "Tensor overrides", "model_filename": "File", "model_type": "Model", "model_size": "Model size [GiB]",
  80    "model_n_params": "Num. of par.", "n_batch": "Batch size", "n_ubatch": "Microbatch size", "embeddings": "Embeddings",
  81    "cpu_mask": "CPU mask", "cpu_strict": "CPU strict", "poll": "Poll", "n_threads": "Threads", "type_k": "K type", "type_v": "V type",
  82    "use_mmap": "Use mmap", "no_kv_offload": "NKVO", "split_mode": "Split mode", "main_gpu": "Main GPU", "tensor_split": "Tensor split",
  83    "flash_attn": "FlashAttention",
  84}
  85
  86# Header names for the table (test-backend-ops):
  87TEST_BACKEND_OPS_PRETTY_NAMES = {
  88    "backend_name": "Backend", "op_name": "GGML op", "op_params": "Op parameters", "test_mode": "Mode",
  89    "supported": "Supported", "passed": "Passed", "error_message": "Error",
  90    "flops": "FLOPS", "bandwidth_gb_s": "Bandwidth (GB/s)", "memory_kb": "Memory (KB)", "n_runs": "Runs"
  91}
  92
  93DEFAULT_SHOW_LLAMA_BENCH = ["model_type"]  # Always show these properties by default.
  94DEFAULT_HIDE_LLAMA_BENCH = ["model_filename"]  # Always hide these properties by default.
  95
  96DEFAULT_SHOW_TEST_BACKEND_OPS = ["backend_name", "op_name"]  # Always show these properties by default.
  97DEFAULT_HIDE_TEST_BACKEND_OPS = ["error_message"]  # Always hide these properties by default.
  98
  99GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon ", "AMD Instinct "]  # Strip prefixes for smaller tables.
 100MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
 101
 102DESCRIPTION = """Creates tables from llama-bench or test-backend-ops data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
 103
 104For llama-bench:
 105$ git checkout master
 106$ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t llama-bench -j $(nproc)
 107$ ./llama-bench -o sql | sqlite3 llama-bench.sqlite
 108$ git checkout some_branch
 109$ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t llama-bench -j $(nproc)
 110$ ./llama-bench -o sql | sqlite3 llama-bench.sqlite
 111$ ./scripts/compare-llama-bench.py
 112
 113For test-backend-ops:
 114$ git checkout master
 115$ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t test-backend-ops -j $(nproc)
 116$ ./test-backend-ops perf --output sql | sqlite3 test-backend-ops.sqlite
 117$ git checkout some_branch
 118$ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t test-backend-ops -j $(nproc)
 119$ ./test-backend-ops perf --output sql | sqlite3 test-backend-ops.sqlite
 120$ ./scripts/compare-llama-bench.py --tool test-backend-ops -i test-backend-ops.sqlite
 121
 122Performance numbers from multiple runs per commit are averaged WITHOUT being weighted by the --repetitions parameter of llama-bench.
 123"""
 124
 125parser = argparse.ArgumentParser(
 126    description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter)
 127help_b = (
 128    "The baseline commit to compare performance to. "
 129    "Accepts either a branch name, tag name, or commit hash. "
 130    "Defaults to latest master commit with data."
 131)
 132parser.add_argument("-b", "--baseline", help=help_b)
 133help_c = (
 134    "The commit whose performance is to be compared to the baseline. "
 135    "Accepts either a branch name, tag name, or commit hash. "
 136    "Defaults to the non-master commit for which llama-bench was run most recently."
 137)
 138parser.add_argument("-c", "--compare", help=help_c)
 139help_t = (
 140    "The tool whose data is being compared. "
 141    "Either 'llama-bench' or 'test-backend-ops'. "
 142    "This determines the database schema and comparison logic used. "
 143    "If left unspecified, try to determine from the input file."
 144)
 145parser.add_argument("-t", "--tool", help=help_t, default=None, choices=[None, "llama-bench", "test-backend-ops"])
 146help_i = (
 147    "JSON/JSONL/SQLite/CSV files for comparing commits. "
 148    "Specify multiple times to use multiple input files (JSON/CSV only). "
 149    "Defaults to 'llama-bench.sqlite' in the current working directory. "
 150    "If no such file is found and there is exactly one .sqlite file in the current directory, "
 151    "that file is instead used as input."
 152)
 153parser.add_argument("-i", "--input", action="append", help=help_i)
 154help_o = (
 155    "Output format for the table. "
 156    "Defaults to 'pipe' (GitHub compatible). "
 157    "Also supports e.g. 'latex' or 'mediawiki'. "
 158    "See tabulate documentation for full list."
 159)
 160parser.add_argument("-o", "--output", help=help_o, default="pipe")
 161help_s = (
 162    "Columns to add to the table. "
 163    "Accepts a comma-separated list of values. "
 164    f"Legal values for test-backend-ops: {', '.join(TEST_BACKEND_OPS_KEY_PROPERTIES)}. "
 165    f"Legal values for llama-bench: {', '.join(LLAMA_BENCH_KEY_PROPERTIES[:-3])}. "
 166    "Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) "
 167    "plus any column where not all data points are the same. "
 168    "If the columns are manually specified, then the results for each unique combination of the "
 169    "specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench."
 170)
 171parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
 172parser.add_argument("-s", "--show", help=help_s)
 173parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
 174parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
 175parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
 176parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)")
 177
 178known_args, unknown_args = parser.parse_known_args()
 179
 180logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
 181
 182
 183if known_args.check:
 184    # Check if all required Python libraries are installed. Would have failed earlier if not.
 185    sys.exit(0)
 186
 187if unknown_args:
 188    logger.error(f"Received unknown args: {unknown_args}.\n")
 189    parser.print_help()
 190    sys.exit(1)
 191
 192input_file = known_args.input
 193tool = known_args.tool
 194
 195if not input_file:
 196    if tool == "llama-bench" and os.path.exists("./llama-bench.sqlite"):
 197        input_file = ["llama-bench.sqlite"]
 198    elif tool == "test-backend-ops" and os.path.exists("./test-backend-ops.sqlite"):
 199        input_file = ["test-backend-ops.sqlite"]
 200
 201if not input_file:
 202    sqlite_files = glob("*.sqlite")
 203    if len(sqlite_files) == 1:
 204        input_file = sqlite_files
 205
 206if not input_file:
 207    logger.error("Cannot find a suitable input file, please provide one.\n")
 208    parser.print_help()
 209    sys.exit(1)
 210
 211
 212class LlamaBenchData:
 213    repo: Optional[git.Repo]
 214    build_len_min: int
 215    build_len_max: int
 216    build_len: int = 8
 217    builds: list[str] = []
 218    tool: str = "llama-bench"  # Tool type: "llama-bench" or "test-backend-ops"
 219
 220    def __init__(self, tool: str = "llama-bench"):
 221        self.tool = tool
 222        try:
 223            self.repo = git.Repo(".", search_parent_directories=True)
 224        except git.InvalidGitRepositoryError:
 225            self.repo = None
 226
 227        # Set schema-specific properties based on tool
 228        if self.tool == "llama-bench":
 229            self.check_keys = set(LLAMA_BENCH_KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
 230        elif self.tool == "test-backend-ops":
 231            self.check_keys = set(TEST_BACKEND_OPS_KEY_PROPERTIES + ["build_commit", "test_time"])
 232        else:
 233            assert False
 234
 235    def _builds_init(self):
 236        self.build_len = self.build_len_min
 237
 238    def _check_keys(self, keys: set) -> Optional[set]:
 239        """Private helper method that checks against required data keys and returns missing ones."""
 240        if not keys >= self.check_keys:
 241            return self.check_keys - keys
 242        return None
 243
 244    def find_parent_in_data(self, commit: git.Commit) -> Optional[str]:
 245        """Helper method to find the most recent parent measured in number of commits for which there is data."""
 246        heap: list[tuple[int, git.Commit]] = [(0, commit)]
 247        seen_hexsha8 = set()
 248        while heap:
 249            depth, current_commit = heapq.heappop(heap)
 250            current_hexsha8 = commit.hexsha[:self.build_len]
 251            if current_hexsha8 in self.builds:
 252                return current_hexsha8
 253            for parent in commit.parents:
 254                parent_hexsha8 = parent.hexsha[:self.build_len]
 255                if parent_hexsha8 not in seen_hexsha8:
 256                    seen_hexsha8.add(parent_hexsha8)
 257                    heapq.heappush(heap, (depth + 1, parent))
 258        return None
 259
 260    def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]:
 261        """Helper method to recursively get hexsha8 values for all parents of a commit."""
 262        unvisited = [commit]
 263        visited   = []
 264
 265        while unvisited:
 266            current_commit = unvisited.pop(0)
 267            visited.append(current_commit.hexsha[:self.build_len])
 268            for parent in current_commit.parents:
 269                if parent.hexsha[:self.build_len] not in visited:
 270                    unvisited.append(parent)
 271
 272        return visited
 273
 274    def get_commit_name(self, hexsha8: str) -> str:
 275        """Helper method to find a human-readable name for a commit if possible."""
 276        if self.repo is None:
 277            return hexsha8
 278        for h in self.repo.heads:
 279            if h.commit.hexsha[:self.build_len] == hexsha8:
 280                return h.name
 281        for t in self.repo.tags:
 282            if t.commit.hexsha[:self.build_len] == hexsha8:
 283                return t.name
 284        return hexsha8
 285
 286    def get_commit_hexsha8(self, name: str) -> Optional[str]:
 287        """Helper method to search for a commit given a human-readable name."""
 288        if self.repo is None:
 289            return None
 290        for h in self.repo.heads:
 291            if h.name == name:
 292                return h.commit.hexsha[:self.build_len]
 293        for t in self.repo.tags:
 294            if t.name == name:
 295                return t.commit.hexsha[:self.build_len]
 296        for c in self.repo.iter_commits("--all"):
 297            if c.hexsha[:self.build_len] == name[:self.build_len]:
 298                return c.hexsha[:self.build_len]
 299        return None
 300
 301    def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
 302        """Helper method that gets rows of (build_commit, test_time) sorted by the latter."""
 303        return []
 304
 305    def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
 306        """
 307        Helper method that gets table rows for some list of properties.
 308        Rows are created by combining those where all provided properties are equal.
 309        The resulting rows are then grouped by the provided properties and the t/s values are averaged.
 310        The returned rows are unique in terms of property combinations.
 311        """
 312        return []
 313
 314
 315class LlamaBenchDataSQLite3(LlamaBenchData):
 316    connection: Optional[sqlite3.Connection] = None
 317    cursor: sqlite3.Cursor
 318    table_name: str
 319
 320    def __init__(self, tool: str = "llama-bench"):
 321        super().__init__(tool)
 322        if self.connection is None:
 323            self.connection = sqlite3.connect(":memory:")
 324            self.cursor = self.connection.cursor()
 325
 326            # Set table name and schema based on tool
 327            if self.tool == "llama-bench":
 328                self.table_name = "llama_bench"
 329                db_fields = LLAMA_BENCH_DB_FIELDS
 330                db_types = LLAMA_BENCH_DB_TYPES
 331            elif self.tool == "test-backend-ops":
 332                self.table_name = "test_backend_ops"
 333                db_fields = TEST_BACKEND_OPS_DB_FIELDS
 334                db_types = TEST_BACKEND_OPS_DB_TYPES
 335            else:
 336                assert False
 337
 338            self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});")
 339
 340    def _builds_init(self):
 341        if self.connection:
 342            self.build_len_min = self.cursor.execute(f"SELECT MIN(LENGTH(build_commit)) from {self.table_name};").fetchone()[0]
 343            self.build_len_max = self.cursor.execute(f"SELECT MAX(LENGTH(build_commit)) from {self.table_name};").fetchone()[0]
 344
 345            if self.build_len_min != self.build_len_max:
 346                logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
 347                               "Try purging the the database of old commits.")
 348                self.cursor.execute(f"UPDATE {self.table_name} SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
 349
 350            builds = self.cursor.execute(f"SELECT DISTINCT build_commit FROM {self.table_name};").fetchall()
 351            self.builds = list(map(lambda b: b[0], builds))  # list[tuple[str]] -> list[str]
 352        super()._builds_init()
 353
 354    def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
 355        data = self.cursor.execute(
 356            f"SELECT build_commit, test_time FROM {self.table_name} ORDER BY test_time;").fetchall()
 357        return reversed(data) if reverse else data
 358
 359    def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
 360        if self.tool == "llama-bench":
 361            return self._get_rows_llama_bench(properties, hexsha8_baseline, hexsha8_compare)
 362        elif self.tool == "test-backend-ops":
 363            return self._get_rows_test_backend_ops(properties, hexsha8_baseline, hexsha8_compare)
 364        else:
 365            assert False
 366
 367    def _get_rows_llama_bench(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
 368        select_string = ", ".join(
 369            [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
 370        equal_string = " AND ".join(
 371            [f"tb.{p} = tc.{p}" for p in LLAMA_BENCH_KEY_PROPERTIES] + [
 372                f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
 373        )
 374        group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
 375        query = (f"SELECT {select_string} FROM {self.table_name} tb JOIN {self.table_name} tc ON {equal_string} "
 376                 f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
 377        return self.cursor.execute(query).fetchall()
 378
 379    def _get_rows_test_backend_ops(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
 380        # For test-backend-ops, we compare FLOPS and bandwidth metrics (prioritizing FLOPS over bandwidth)
 381        select_string = ", ".join(
 382            [f"tb.{p}" for p in properties] + [
 383                "AVG(tb.flops)", "AVG(tc.flops)",
 384                "AVG(tb.bandwidth_gb_s)", "AVG(tc.bandwidth_gb_s)"
 385            ])
 386        equal_string = " AND ".join(
 387            [f"tb.{p} = tc.{p}" for p in TEST_BACKEND_OPS_KEY_PROPERTIES] + [
 388                f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'",
 389                "tb.supported = 1", "tc.supported = 1", "tb.passed = 1", "tc.passed = 1"]  # Only compare successful tests
 390        )
 391        group_order_string = ", ".join([f"tb.{p}" for p in properties])
 392        query = (f"SELECT {select_string} FROM {self.table_name} tb JOIN {self.table_name} tc ON {equal_string} "
 393                 f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
 394        return self.cursor.execute(query).fetchall()
 395
 396
 397class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
 398    def __init__(self, data_file: str, tool: Any):
 399        self.connection = sqlite3.connect(data_file)
 400        self.cursor = self.connection.cursor()
 401
 402        # Check which table exists in the database
 403        tables = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
 404        table_names = [table[0] for table in tables]
 405
 406        # Tool selection logic
 407        if tool is None:
 408            if "llama_bench" in table_names:
 409                self.table_name = "llama_bench"
 410                tool = "llama-bench"
 411            elif "test_backend_ops" in table_names:
 412                self.table_name = "test_backend_ops"
 413                tool = "test-backend-ops"
 414            else:
 415                raise RuntimeError(f"No suitable table found in database. Available tables: {table_names}")
 416        elif tool == "llama-bench":
 417            if "llama_bench" in table_names:
 418                self.table_name = "llama_bench"
 419                tool = "llama-bench"
 420            else:
 421                raise RuntimeError(f"Table 'test' not found for tool 'llama-bench'. Available tables: {table_names}")
 422        elif tool == "test-backend-ops":
 423            if "test_backend_ops" in table_names:
 424                self.table_name = "test_backend_ops"
 425                tool = "test-backend-ops"
 426            else:
 427                raise RuntimeError(f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: {table_names}")
 428        else:
 429            raise RuntimeError(f"Unknown tool: {tool}")
 430
 431        super().__init__(tool)
 432        self._builds_init()
 433
 434    @staticmethod
 435    def valid_format(data_file: str) -> bool:
 436        connection = sqlite3.connect(data_file)
 437        cursor = connection.cursor()
 438
 439        try:
 440            if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
 441                raise sqlite3.DatabaseError("The provided input file does not exist or is empty.")
 442        except sqlite3.DatabaseError as e:
 443            logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e)
 444            cursor = None
 445
 446        connection.close()
 447        return True if cursor else False
 448
 449
 450class LlamaBenchDataJSONL(LlamaBenchDataSQLite3):
 451    def __init__(self, data_file: str, tool: str = "llama-bench"):
 452        super().__init__(tool)
 453
 454        # Get the appropriate field list based on tool
 455        db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
 456
 457        with open(data_file, "r", encoding="utf-8") as fp:
 458            for i, line in enumerate(fp):
 459                parsed = json.loads(line)
 460
 461                for k in parsed.keys() - set(db_fields):
 462                    del parsed[k]
 463
 464                if (missing_keys := self._check_keys(parsed.keys())):
 465                    raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
 466
 467                self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
 468
 469        self._builds_init()
 470
 471    @staticmethod
 472    def valid_format(data_file: str) -> bool:
 473        try:
 474            with open(data_file, "r", encoding="utf-8") as fp:
 475                for line in fp:
 476                    json.loads(line)
 477                    break
 478        except Exception as e:
 479            logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e)
 480            return False
 481
 482        return True
 483
 484
 485class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
 486    def __init__(self, data_files: list[str], tool: str = "llama-bench"):
 487        super().__init__(tool)
 488
 489        # Get the appropriate field list based on tool
 490        db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
 491
 492        for data_file in data_files:
 493            with open(data_file, "r", encoding="utf-8") as fp:
 494                parsed = json.load(fp)
 495
 496                for i, entry in enumerate(parsed):
 497                    for k in entry.keys() - set(db_fields):
 498                        del entry[k]
 499
 500                    if (missing_keys := self._check_keys(entry.keys())):
 501                        raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
 502
 503                    self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values()))
 504
 505        self._builds_init()
 506
 507    @staticmethod
 508    def valid_format(data_files: list[str]) -> bool:
 509        if not data_files:
 510            return False
 511
 512        for data_file in data_files:
 513            try:
 514                with open(data_file, "r", encoding="utf-8") as fp:
 515                    json.load(fp)
 516            except Exception as e:
 517                logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e)
 518                return False
 519
 520        return True
 521
 522
 523class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
 524    def __init__(self, data_files: list[str], tool: str = "llama-bench"):
 525        super().__init__(tool)
 526
 527        # Get the appropriate field list based on tool
 528        db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
 529
 530        for data_file in data_files:
 531            with open(data_file, "r", encoding="utf-8") as fp:
 532                for i, parsed in enumerate(csv.DictReader(fp)):
 533                    keys = set(parsed.keys())
 534
 535                    for k in keys - set(db_fields):
 536                        del parsed[k]
 537
 538                    if (missing_keys := self._check_keys(keys)):
 539                        raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
 540
 541                    self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
 542
 543        self._builds_init()
 544
 545    @staticmethod
 546    def valid_format(data_files: list[str]) -> bool:
 547        if not data_files:
 548            return False
 549
 550        for data_file in data_files:
 551            try:
 552                with open(data_file, "r", encoding="utf-8") as fp:
 553                    for parsed in csv.DictReader(fp):
 554                        break
 555            except Exception as e:
 556                logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e)
 557                return False
 558
 559        return True
 560
 561
 562def format_flops(flops_value: float) -> str:
 563    """Format FLOPS values with appropriate units for better readability."""
 564    if flops_value == 0:
 565        return "0.00"
 566
 567    # Define unit thresholds and names
 568    units = [
 569        (1e12, "T"),   # TeraFLOPS
 570        (1e9, "G"),    # GigaFLOPS
 571        (1e6, "M"),    # MegaFLOPS
 572        (1e3, "k"),    # kiloFLOPS
 573        (1, "")        # FLOPS
 574    ]
 575
 576    for threshold, unit in units:
 577        if abs(flops_value) >= threshold:
 578            formatted_value = flops_value / threshold
 579            if formatted_value >= 100:
 580                return f"{formatted_value:.1f}{unit}"
 581            else:
 582                return f"{formatted_value:.2f}{unit}"
 583
 584    # Fallback for very small values
 585    return f"{flops_value:.2f}"
 586
 587
 588def format_flops_for_table(flops_value: float, target_unit: str) -> str:
 589    """Format FLOPS values for table display without unit suffix (since unit is in header)."""
 590    if flops_value == 0:
 591        return "0.00"
 592
 593    # Define unit thresholds based on target unit
 594    unit_divisors = {
 595        "TFLOPS": 1e12,
 596        "GFLOPS": 1e9,
 597        "MFLOPS": 1e6,
 598        "kFLOPS": 1e3,
 599        "FLOPS": 1
 600    }
 601
 602    divisor = unit_divisors.get(target_unit, 1)
 603    formatted_value = flops_value / divisor
 604
 605    if formatted_value >= 100:
 606        return f"{formatted_value:.1f}"
 607    else:
 608        return f"{formatted_value:.2f}"
 609
 610
 611def get_flops_unit_name(flops_values: list) -> str:
 612    """Determine the best FLOPS unit name based on the magnitude of values."""
 613    if not flops_values or all(v == 0 for v in flops_values):
 614        return "FLOPS"
 615
 616    # Find the maximum absolute value to determine appropriate unit
 617    max_flops = max(abs(v) for v in flops_values if v != 0)
 618
 619    if max_flops >= 1e12:
 620        return "TFLOPS"
 621    elif max_flops >= 1e9:
 622        return "GFLOPS"
 623    elif max_flops >= 1e6:
 624        return "MFLOPS"
 625    elif max_flops >= 1e3:
 626        return "kFLOPS"
 627    else:
 628        return "FLOPS"
 629
 630
 631bench_data = None
 632if len(input_file) == 1:
 633    if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
 634        bench_data = LlamaBenchDataSQLite3File(input_file[0], tool)
 635    elif LlamaBenchDataJSON.valid_format(input_file):
 636        bench_data = LlamaBenchDataJSON(input_file, tool)
 637    elif LlamaBenchDataJSONL.valid_format(input_file[0]):
 638        bench_data = LlamaBenchDataJSONL(input_file[0], tool)
 639    elif LlamaBenchDataCSV.valid_format(input_file):
 640        bench_data = LlamaBenchDataCSV(input_file, tool)
 641else:
 642    if LlamaBenchDataJSON.valid_format(input_file):
 643        bench_data = LlamaBenchDataJSON(input_file, tool)
 644    elif LlamaBenchDataCSV.valid_format(input_file):
 645        bench_data = LlamaBenchDataCSV(input_file, tool)
 646
 647if not bench_data:
 648    raise RuntimeError("No valid (or some invalid) input files found.")
 649
 650if not bench_data.builds:
 651    raise RuntimeError(f"{input_file} does not contain any builds.")
 652
 653tool = bench_data.tool  # May have chosen a default if tool was None.
 654
 655
 656hexsha8_baseline = name_baseline = None
 657
 658# If the user specified a baseline, try to find a commit for it:
 659if known_args.baseline is not None:
 660    if known_args.baseline in bench_data.builds:
 661        hexsha8_baseline = known_args.baseline
 662    if hexsha8_baseline is None:
 663        hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline)
 664        name_baseline = known_args.baseline
 665    if hexsha8_baseline is None:
 666        logger.error(f"cannot find data for baseline={known_args.baseline}.")
 667        sys.exit(1)
 668# Otherwise, search for the most recent parent of master for which there is data:
 669elif bench_data.repo is not None:
 670    hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit)
 671
 672    if hexsha8_baseline is None:
 673        logger.error("No baseline was provided and did not find data for any master branch commits.\n")
 674        parser.print_help()
 675        sys.exit(1)
 676else:
 677    logger.error("No baseline was provided and the current working directory "
 678                 "is not part of a git repository from which a baseline could be inferred.\n")
 679    parser.print_help()
 680    sys.exit(1)
 681
 682
 683name_baseline = bench_data.get_commit_name(hexsha8_baseline)
 684
 685hexsha8_compare = name_compare = None
 686
 687# If the user has specified a compare value, try to find a corresponding commit:
 688if known_args.compare is not None:
 689    if known_args.compare in bench_data.builds:
 690        hexsha8_compare = known_args.compare
 691    if hexsha8_compare is None:
 692        hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare)
 693        name_compare = known_args.compare
 694    if hexsha8_compare is None:
 695        logger.error(f"cannot find data for compare={known_args.compare}.")
 696        sys.exit(1)
 697# Otherwise, search for the commit for llama-bench was most recently run
 698# and that is not a parent of master:
 699elif bench_data.repo is not None:
 700    hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit)
 701    for (hexsha8, _) in bench_data.builds_timestamp(reverse=True):
 702        if hexsha8 not in hexsha8s_master:
 703            hexsha8_compare = hexsha8
 704            break
 705
 706    if hexsha8_compare is None:
 707        logger.error("No compare target was provided and did not find data for any non-master commits.\n")
 708        parser.print_help()
 709        sys.exit(1)
 710else:
 711    logger.error("No compare target was provided and the current working directory "
 712                 "is not part of a git repository from which a compare target could be inferred.\n")
 713    parser.print_help()
 714    sys.exit(1)
 715
 716name_compare = bench_data.get_commit_name(hexsha8_compare)
 717
 718# Get tool-specific configuration
 719if tool == "llama-bench":
 720    key_properties = LLAMA_BENCH_KEY_PROPERTIES
 721    bool_properties = LLAMA_BENCH_BOOL_PROPERTIES
 722    pretty_names = LLAMA_BENCH_PRETTY_NAMES
 723    default_show = DEFAULT_SHOW_LLAMA_BENCH
 724    default_hide = DEFAULT_HIDE_LLAMA_BENCH
 725elif tool == "test-backend-ops":
 726    key_properties = TEST_BACKEND_OPS_KEY_PROPERTIES
 727    bool_properties = TEST_BACKEND_OPS_BOOL_PROPERTIES
 728    pretty_names = TEST_BACKEND_OPS_PRETTY_NAMES
 729    default_show = DEFAULT_SHOW_TEST_BACKEND_OPS
 730    default_hide = DEFAULT_HIDE_TEST_BACKEND_OPS
 731else:
 732    assert False
 733
 734# If the user provided columns to group the results by, use them:
 735if known_args.show is not None:
 736    show = known_args.show.split(",")
 737    unknown_cols = []
 738    for prop in show:
 739        valid_props = key_properties if tool == "test-backend-ops" else key_properties[:-3]  # Exclude n_prompt, n_gen, n_depth for llama-bench
 740        if prop not in valid_props:
 741            unknown_cols.append(prop)
 742    if unknown_cols:
 743        logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
 744        parser.print_usage()
 745        sys.exit(1)
 746    rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
 747# Otherwise, select those columns where the values are not all the same:
 748else:
 749    rows_full = bench_data.get_rows(key_properties, hexsha8_baseline, hexsha8_compare)
 750    properties_different = []
 751
 752    if tool == "llama-bench":
 753        # For llama-bench, skip n_prompt, n_gen, n_depth from differentiation logic
 754        check_properties = [kp for kp in key_properties if kp not in ["n_prompt", "n_gen", "n_depth"]]
 755        for i, kp_i in enumerate(key_properties):
 756            if kp_i in default_show or kp_i in ["n_prompt", "n_gen", "n_depth"]:
 757                continue
 758            for row_full in rows_full:
 759                if row_full[i] != rows_full[0][i]:
 760                    properties_different.append(kp_i)
 761                    break
 762    elif tool == "test-backend-ops":
 763        # For test-backend-ops, check all key properties
 764        for i, kp_i in enumerate(key_properties):
 765            if kp_i in default_show:
 766                continue
 767            for row_full in rows_full:
 768                if row_full[i] != rows_full[0][i]:
 769                    properties_different.append(kp_i)
 770                    break
 771    else:
 772        assert False
 773
 774    show = []
 775
 776    if tool == "llama-bench":
 777        # Show CPU and/or GPU by default even if the hardware for all results is the same:
 778        if rows_full and "n_gpu_layers" not in properties_different:
 779            ngl = int(rows_full[0][key_properties.index("n_gpu_layers")])
 780
 781            if ngl != 99 and "cpu_info" not in properties_different:
 782                show.append("cpu_info")
 783
 784        show += properties_different
 785
 786        index_default = 0
 787        for prop in ["cpu_info", "gpu_info", "n_gpu_layers", "main_gpu"]:
 788            if prop in show:
 789                index_default += 1
 790        show = show[:index_default] + default_show + show[index_default:]
 791    elif tool == "test-backend-ops":
 792        show = default_show + properties_different
 793    else:
 794        assert False
 795
 796    for prop in default_hide:
 797        try:
 798            show.remove(prop)
 799        except ValueError:
 800            pass
 801
 802    # Add plot_x parameter to parameters to show if it's not already present:
 803    if known_args.plot:
 804        for k, v in pretty_names.items():
 805            if v == known_args.plot_x and k not in show:
 806                show.append(k)
 807                break
 808
 809    rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
 810
 811if not rows_show:
 812    logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")
 813    sys.exit(1)
 814
 815table = []
 816primary_metric = "FLOPS"  # Default to FLOPS for test-backend-ops
 817
 818if tool == "llama-bench":
 819    # For llama-bench, create test names and compare avg_ts values
 820    for row in rows_show:
 821        n_prompt = int(row[-5])
 822        n_gen    = int(row[-4])
 823        n_depth  = int(row[-3])
 824        if n_prompt != 0 and n_gen == 0:
 825            test_name = f"pp{n_prompt}"
 826        elif n_prompt == 0 and n_gen != 0:
 827            test_name = f"tg{n_gen}"
 828        else:
 829            test_name = f"pp{n_prompt}+tg{n_gen}"
 830        if n_depth != 0:
 831            test_name = f"{test_name}@d{n_depth}"
 832        #           Regular columns    test name    avg t/s values              Speedup
 833        #            VVVVVVVVVVVVV     VVVVVVVVV    VVVVVVVVVVVVVV              VVVVVVV
 834        table.append(list(row[:-5]) + [test_name] + list(row[-2:]) + [float(row[-1]) / float(row[-2])])
 835elif tool == "test-backend-ops":
 836    # Determine the primary metric by checking rows until we find one with valid data
 837    if rows_show:
 838        primary_metric = "FLOPS"  # Default to FLOPS
 839        flops_values = []
 840
 841        # Collect all FLOPS values to determine the best unit
 842        for sample_row in rows_show:
 843            baseline_flops = float(sample_row[-4])
 844            compare_flops = float(sample_row[-3])
 845            baseline_bandwidth = float(sample_row[-2])
 846
 847            if baseline_flops > 0:
 848                flops_values.extend([baseline_flops, compare_flops])
 849            elif baseline_bandwidth > 0 and not flops_values:
 850                primary_metric = "Bandwidth (GB/s)"
 851
 852        # If we have FLOPS data, determine the appropriate unit
 853        if flops_values:
 854            primary_metric = get_flops_unit_name(flops_values)
 855
 856    # For test-backend-ops, prioritize FLOPS > bandwidth for comparison
 857    for row in rows_show:
 858        # Extract metrics: flops, bandwidth_gb_s (baseline and compare)
 859        baseline_flops = float(row[-4])
 860        compare_flops = float(row[-3])
 861        baseline_bandwidth = float(row[-2])
 862        compare_bandwidth = float(row[-1])
 863
 864        # Determine which metric to use for comparison (prioritize FLOPS > bandwidth)
 865        if baseline_flops > 0 and compare_flops > 0:
 866            # Use FLOPS comparison (higher is better)
 867            speedup = compare_flops / baseline_flops
 868            baseline_str = format_flops_for_table(baseline_flops, primary_metric)
 869            compare_str = format_flops_for_table(compare_flops, primary_metric)
 870        elif baseline_bandwidth > 0 and compare_bandwidth > 0:
 871            # Use bandwidth comparison (higher is better)
 872            speedup = compare_bandwidth / baseline_bandwidth
 873            baseline_str = f"{baseline_bandwidth:.2f}"
 874            compare_str = f"{compare_bandwidth:.2f}"
 875        else:
 876            # Fallback if no valid data is available
 877            baseline_str = "N/A"
 878            compare_str = "N/A"
 879            from math import nan
 880            speedup = nan
 881
 882        table.append(list(row[:-4]) + [baseline_str, compare_str, speedup])
 883else:
 884    assert False
 885
 886# Some a-posteriori fixes to make the table contents prettier:
 887for bool_property in bool_properties:
 888    if bool_property in show:
 889        ip = show.index(bool_property)
 890        for row_table in table:
 891            row_table[ip] = "Yes" if int(row_table[ip]) == 1 else "No"
 892
 893if tool == "llama-bench":
 894    if "model_type" in show:
 895        ip = show.index("model_type")
 896        for (old, new) in MODEL_SUFFIX_REPLACE.items():
 897            for row_table in table:
 898                row_table[ip] = row_table[ip].replace(old, new)
 899
 900    if "model_size" in show:
 901        ip = show.index("model_size")
 902        for row_table in table:
 903            row_table[ip] = float(row_table[ip]) / 1024 ** 3
 904
 905    if "gpu_info" in show:
 906        ip = show.index("gpu_info")
 907        for row_table in table:
 908            for gns in GPU_NAME_STRIP:
 909                row_table[ip] = row_table[ip].replace(gns, "")
 910
 911            gpu_names = row_table[ip].split(", ")
 912            num_gpus = len(gpu_names)
 913            all_names_the_same = len(set(gpu_names)) == 1
 914            if len(gpu_names) >= 2 and all_names_the_same:
 915                row_table[ip] = f"{num_gpus}x {gpu_names[0]}"
 916
 917headers  = [pretty_names.get(p, p) for p in show]
 918if tool == "llama-bench":
 919    headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
 920elif tool == "test-backend-ops":
 921    headers += [f"{primary_metric} {name_baseline}", f"{primary_metric} {name_compare}", "Speedup"]
 922else:
 923    assert False
 924
 925if known_args.plot:
 926    def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False, tool_type: str = "llama-bench", metric_name: str = "t/s"):
 927        try:
 928            import matplotlib
 929            import matplotlib.pyplot as plt
 930            matplotlib.use('Agg')
 931        except ImportError as e:
 932            logger.error("matplotlib is required for --plot.")
 933            raise e
 934
 935        data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
 936        plot_x_index = None
 937        plot_x_label = plot_x_param
 938
 939        if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
 940            pretty_name = LLAMA_BENCH_PRETTY_NAMES.get(plot_x_param, plot_x_param)
 941            if pretty_name in data_headers:
 942                plot_x_index = data_headers.index(pretty_name)
 943                plot_x_label = pretty_name
 944            elif plot_x_param in data_headers:
 945                plot_x_index = data_headers.index(plot_x_param)
 946                plot_x_label = plot_x_param
 947            else:
 948                logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
 949                return
 950
 951        grouped_data = {}
 952
 953        for i, row in enumerate(table_data):
 954            group_key_parts = []
 955            test_name = row[-4]
 956
 957            base_test = ""
 958            x_value = None
 959
 960            if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
 961                for j, val in enumerate(row[:-4]):
 962                    header_name = data_headers[j]
 963                    if val is not None and str(val).strip():
 964                        group_key_parts.append(f"{header_name}={val}")
 965
 966                if plot_x_param == "n_prompt" and "pp" in test_name:
 967                    base_test = test_name.split("@")[0]
 968                    x_value = base_test
 969                elif plot_x_param == "n_gen" and "tg" in test_name:
 970                    x_value = test_name.split("@")[0]
 971                elif plot_x_param == "n_depth" and "@d" in test_name:
 972                    base_test = test_name.split("@d")[0]
 973                    x_value = int(test_name.split("@d")[1])
 974                else:
 975                    base_test = test_name
 976
 977                if base_test.strip():
 978                    group_key_parts.append(f"Test={base_test}")
 979            else:
 980                for j, val in enumerate(row[:-4]):
 981                    if j != plot_x_index:
 982                        header_name = data_headers[j]
 983                        if val is not None and str(val).strip():
 984                            group_key_parts.append(f"{header_name}={val}")
 985                    else:
 986                        x_value = val
 987
 988                group_key_parts.append(f"Test={test_name}")
 989
 990            group_key = tuple(group_key_parts)
 991
 992            if group_key not in grouped_data:
 993                grouped_data[group_key] = []
 994
 995            grouped_data[group_key].append({
 996                'x_value': x_value,
 997                'baseline': float(row[-3]),
 998                'compare': float(row[-2]),
 999                'speedup': float(row[-1])
1000            })
1001
1002        if not grouped_data:
1003            logger.error("No data available for plotting")
1004            return
1005
1006        def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
1007            from math import ceil
1008            cols = 1 if num_groups == 1 else min(max_cols, num_groups)
1009            rows = ceil(num_groups / cols)
1010
1011            # Scale figure size by grid dimensions
1012            w, h = base_size
1013            fig, ax_arr = plt.subplots(rows, cols,
1014                                       figsize=(w * cols, h * rows),
1015                                       squeeze=False)
1016
1017            axes = ax_arr.flatten()[:num_groups]
1018            return fig, axes
1019
1020        num_groups = len(grouped_data)
1021        fig, axes = make_axes(num_groups)
1022
1023        plot_idx = 0
1024
1025        for group_key, points in grouped_data.items():
1026            if plot_idx >= len(axes):
1027                break
1028            ax = axes[plot_idx]
1029
1030            try:
1031                points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
1032                x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
1033            except ValueError:
1034                points_sorted = sorted(points, key=lambda p: group_key)
1035                x_values = [p['x_value'] for p in points_sorted]
1036
1037            baseline_vals = [p['baseline'] for p in points_sorted]
1038            compare_vals = [p['compare'] for p in points_sorted]
1039
1040            ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
1041                    label=f'{baseline_name}', linewidth=2, markersize=6)
1042            ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
1043                    label=f'{compare_name}', linewidth=2, markersize=6)
1044
1045            if log_scale:
1046                ax.set_xscale('log', base=2)
1047                unique_x = sorted(set(x_values))
1048                ax.set_xticks(unique_x)
1049                ax.set_xticklabels([str(int(x)) for x in unique_x])
1050
1051            title_parts = []
1052            for part in group_key:
1053                if '=' in part:
1054                    key, value = part.split('=', 1)
1055                    title_parts.append(f"{key}: {value}")
1056
1057            title = ', '.join(title_parts) if title_parts else "Performance comparison"
1058
1059            # Determine y-axis label based on tool type
1060            if tool_type == "llama-bench":
1061                y_label = "Tokens per second (t/s)"
1062            elif tool_type == "test-backend-ops":
1063                y_label = metric_name
1064            else:
1065                assert False
1066
1067            ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
1068            ax.set_ylabel(y_label, fontsize=12, fontweight='bold')
1069            ax.set_title(title, fontsize=12, fontweight='bold')
1070            ax.legend(loc='best', fontsize=10)
1071            ax.grid(True, alpha=0.3)
1072
1073            plot_idx += 1
1074
1075        for i in range(plot_idx, len(axes)):
1076            axes[i].set_visible(False)
1077
1078        fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
1079                     fontsize=14, fontweight='bold')
1080        fig.subplots_adjust(top=1)
1081
1082        plt.tight_layout()
1083        plt.savefig(output_file, dpi=300, bbox_inches='tight')
1084        plt.close()
1085
1086    create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale, tool, primary_metric)
1087
1088print(tabulate( # noqa: NP100
1089    table,
1090    headers=headers,
1091    floatfmt=".2f",
1092    tablefmt=known_args.output
1093))