diff options
Diffstat (limited to 'llama.cpp/scripts/fetch_server_test_models.py')
| -rwxr-xr-x | llama.cpp/scripts/fetch_server_test_models.py | 105 |
1 files changed, 105 insertions, 0 deletions
diff --git a/llama.cpp/scripts/fetch_server_test_models.py b/llama.cpp/scripts/fetch_server_test_models.py new file mode 100755 index 0000000..ac483ef --- /dev/null +++ b/llama.cpp/scripts/fetch_server_test_models.py | |||
| @@ -0,0 +1,105 @@ | |||
| 1 | #!/usr/bin/env python | ||
| 2 | ''' | ||
| 3 | This script fetches all the models used in the server tests. | ||
| 4 | |||
| 5 | This is useful for slow tests that use larger models, to avoid them timing out on the model downloads. | ||
| 6 | |||
| 7 | It is meant to be run from the root of the repository. | ||
| 8 | |||
| 9 | Example: | ||
| 10 | python scripts/fetch_server_test_models.py | ||
| 11 | ( cd tools/server/tests && ./tests.sh -v -x -m slow ) | ||
| 12 | ''' | ||
| 13 | import ast | ||
| 14 | import glob | ||
| 15 | import logging | ||
| 16 | import os | ||
| 17 | from typing import Generator | ||
| 18 | from pydantic import BaseModel | ||
| 19 | from typing import Optional | ||
| 20 | import subprocess | ||
| 21 | |||
| 22 | |||
| 23 | class HuggingFaceModel(BaseModel): | ||
| 24 | hf_repo: str | ||
| 25 | hf_file: Optional[str] = None | ||
| 26 | |||
| 27 | class Config: | ||
| 28 | frozen = True | ||
| 29 | |||
| 30 | |||
| 31 | def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]: | ||
| 32 | try: | ||
| 33 | with open(test_file) as f: | ||
| 34 | tree = ast.parse(f.read()) | ||
| 35 | except Exception as e: | ||
| 36 | logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}') | ||
| 37 | return | ||
| 38 | |||
| 39 | for node in ast.walk(tree): | ||
| 40 | if isinstance(node, ast.FunctionDef): | ||
| 41 | for dec in node.decorator_list: | ||
| 42 | if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize': | ||
| 43 | param_names = ast.literal_eval(dec.args[0]).split(",") | ||
| 44 | if "hf_repo" not in param_names: | ||
| 45 | continue | ||
| 46 | |||
| 47 | raw_param_values = dec.args[1] | ||
| 48 | if not isinstance(raw_param_values, ast.List): | ||
| 49 | logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}') | ||
| 50 | continue | ||
| 51 | |||
| 52 | hf_repo_idx = param_names.index("hf_repo") | ||
| 53 | hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None | ||
| 54 | |||
| 55 | for t in raw_param_values.elts: | ||
| 56 | if not isinstance(t, ast.Tuple): | ||
| 57 | logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}') | ||
| 58 | continue | ||
| 59 | yield HuggingFaceModel( | ||
| 60 | hf_repo=ast.literal_eval(t.elts[hf_repo_idx]), | ||
| 61 | hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None) | ||
| 62 | |||
| 63 | |||
| 64 | if __name__ == '__main__': | ||
| 65 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') | ||
| 66 | |||
| 67 | models = sorted(list(set([ | ||
| 68 | model | ||
| 69 | for test_file in glob.glob('tools/server/tests/unit/test_*.py') | ||
| 70 | for model in collect_hf_model_test_parameters(test_file) | ||
| 71 | ])), key=lambda m: (m.hf_repo, m.hf_file)) | ||
| 72 | |||
| 73 | logging.info(f'Found {len(models)} models in parameterized tests:') | ||
| 74 | for m in models: | ||
| 75 | logging.info(f' - {m.hf_repo} / {m.hf_file}') | ||
| 76 | |||
| 77 | cli_path = os.environ.get( | ||
| 78 | 'LLAMA_CLI_BIN_PATH', | ||
| 79 | os.path.join( | ||
| 80 | os.path.dirname(__file__), | ||
| 81 | '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli')) | ||
| 82 | |||
| 83 | for m in models: | ||
| 84 | if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file): | ||
| 85 | continue | ||
| 86 | if m.hf_file is not None and '-of-' in m.hf_file: | ||
| 87 | logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file') | ||
| 88 | continue | ||
| 89 | logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched') | ||
| 90 | cmd = [ | ||
| 91 | cli_path, | ||
| 92 | '-hfr', m.hf_repo, | ||
| 93 | *([] if m.hf_file is None else ['-hff', m.hf_file]), | ||
| 94 | '-n', '1', | ||
| 95 | '-p', 'Hey', | ||
| 96 | '--no-warmup', | ||
| 97 | '--log-disable', | ||
| 98 | '-no-cnv'] | ||
| 99 | if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo: | ||
| 100 | cmd.append('-fa') | ||
| 101 | try: | ||
| 102 | subprocess.check_call(cmd) | ||
| 103 | except subprocess.CalledProcessError: | ||
| 104 | logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}') | ||
| 105 | exit(1) | ||
