aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/scripts/fetch_server_test_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/scripts/fetch_server_test_models.py')
-rwxr-xr-xllama.cpp/scripts/fetch_server_test_models.py105
1 files changed, 105 insertions, 0 deletions
diff --git a/llama.cpp/scripts/fetch_server_test_models.py b/llama.cpp/scripts/fetch_server_test_models.py
new file mode 100755
index 0000000..ac483ef
--- /dev/null
+++ b/llama.cpp/scripts/fetch_server_test_models.py
@@ -0,0 +1,105 @@
1#!/usr/bin/env python
2'''
3 This script fetches all the models used in the server tests.
4
5 This is useful for slow tests that use larger models, to avoid them timing out on the model downloads.
6
7 It is meant to be run from the root of the repository.
8
9 Example:
10 python scripts/fetch_server_test_models.py
11 ( cd tools/server/tests && ./tests.sh -v -x -m slow )
12'''
13import ast
14import glob
15import logging
16import os
17from typing import Generator
18from pydantic import BaseModel
19from typing import Optional
20import subprocess
21
22
23class HuggingFaceModel(BaseModel):
24 hf_repo: str
25 hf_file: Optional[str] = None
26
27 class Config:
28 frozen = True
29
30
31def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
32 try:
33 with open(test_file) as f:
34 tree = ast.parse(f.read())
35 except Exception as e:
36 logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
37 return
38
39 for node in ast.walk(tree):
40 if isinstance(node, ast.FunctionDef):
41 for dec in node.decorator_list:
42 if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
43 param_names = ast.literal_eval(dec.args[0]).split(",")
44 if "hf_repo" not in param_names:
45 continue
46
47 raw_param_values = dec.args[1]
48 if not isinstance(raw_param_values, ast.List):
49 logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
50 continue
51
52 hf_repo_idx = param_names.index("hf_repo")
53 hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None
54
55 for t in raw_param_values.elts:
56 if not isinstance(t, ast.Tuple):
57 logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
58 continue
59 yield HuggingFaceModel(
60 hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
61 hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None)
62
63
64if __name__ == '__main__':
65 logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
66
67 models = sorted(list(set([
68 model
69 for test_file in glob.glob('tools/server/tests/unit/test_*.py')
70 for model in collect_hf_model_test_parameters(test_file)
71 ])), key=lambda m: (m.hf_repo, m.hf_file))
72
73 logging.info(f'Found {len(models)} models in parameterized tests:')
74 for m in models:
75 logging.info(f' - {m.hf_repo} / {m.hf_file}')
76
77 cli_path = os.environ.get(
78 'LLAMA_CLI_BIN_PATH',
79 os.path.join(
80 os.path.dirname(__file__),
81 '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
82
83 for m in models:
84 if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file):
85 continue
86 if m.hf_file is not None and '-of-' in m.hf_file:
87 logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
88 continue
89 logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
90 cmd = [
91 cli_path,
92 '-hfr', m.hf_repo,
93 *([] if m.hf_file is None else ['-hff', m.hf_file]),
94 '-n', '1',
95 '-p', 'Hey',
96 '--no-warmup',
97 '--log-disable',
98 '-no-cnv']
99 if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
100 cmd.append('-fa')
101 try:
102 subprocess.check_call(cmd)
103 except subprocess.CalledProcessError:
104 logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}')
105 exit(1)