1#!/usr/bin/env python
  2'''
  3    This script fetches all the models used in the server tests.
  4
  5    This is useful for slow tests that use larger models, to avoid them timing out on the model downloads.
  6
  7    It is meant to be run from the root of the repository.
  8
  9    Example:
 10        python scripts/fetch_server_test_models.py
 11        ( cd tools/server/tests && ./tests.sh -v -x -m slow )
 12'''
 13import ast
 14import glob
 15import logging
 16import os
 17from typing import Generator
 18from pydantic import BaseModel
 19from typing import Optional
 20import subprocess
 21
 22
 23class HuggingFaceModel(BaseModel):
 24    hf_repo: str
 25    hf_file: Optional[str] = None
 26
 27    class Config:
 28        frozen = True
 29
 30
 31def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
 32    try:
 33        with open(test_file) as f:
 34            tree = ast.parse(f.read())
 35    except Exception as e:
 36        logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
 37        return
 38
 39    for node in ast.walk(tree):
 40        if isinstance(node, ast.FunctionDef):
 41            for dec in node.decorator_list:
 42                if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
 43                    param_names = ast.literal_eval(dec.args[0]).split(",")
 44                    if "hf_repo" not in param_names:
 45                        continue
 46
 47                    raw_param_values = dec.args[1]
 48                    if not isinstance(raw_param_values, ast.List):
 49                        logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
 50                        continue
 51
 52                    hf_repo_idx = param_names.index("hf_repo")
 53                    hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None
 54
 55                    for t in raw_param_values.elts:
 56                        if not isinstance(t, ast.Tuple):
 57                            logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
 58                            continue
 59                        yield HuggingFaceModel(
 60                            hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
 61                            hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None)
 62
 63
 64if __name__ == '__main__':
 65    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
 66
 67    models = sorted(list(set([
 68        model
 69        for test_file in glob.glob('tools/server/tests/unit/test_*.py')
 70        for model in collect_hf_model_test_parameters(test_file)
 71    ])), key=lambda m: (m.hf_repo, m.hf_file))
 72
 73    logging.info(f'Found {len(models)} models in parameterized tests:')
 74    for m in models:
 75        logging.info(f'  - {m.hf_repo} / {m.hf_file}')
 76
 77    cli_path = os.environ.get(
 78        'LLAMA_CLI_BIN_PATH',
 79        os.path.join(
 80            os.path.dirname(__file__),
 81            '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
 82
 83    for m in models:
 84        if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file):
 85            continue
 86        if m.hf_file is not None and '-of-' in m.hf_file:
 87            logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
 88            continue
 89        logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
 90        cmd = [
 91            cli_path,
 92            '-hfr', m.hf_repo,
 93            *([] if m.hf_file is None else ['-hff', m.hf_file]),
 94            '-n', '1',
 95            '-p', 'Hey',
 96            '--no-warmup',
 97            '--log-disable',
 98            '-no-cnv']
 99        if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
100            cmd.append('-fa')
101        try:
102            subprocess.check_call(cmd)
103        except subprocess.CalledProcessError:
104            logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n  {" ".join(cmd)}')
105            exit(1)