diff options
Diffstat (limited to 'llama.cpp/tools/server/tests/utils.py')
| -rw-r--r-- | llama.cpp/tools/server/tests/utils.py | 643 |
1 files changed, 643 insertions, 0 deletions
diff --git a/llama.cpp/tools/server/tests/utils.py b/llama.cpp/tools/server/tests/utils.py new file mode 100644 index 0000000..f76bb1a --- /dev/null +++ b/llama.cpp/tools/server/tests/utils.py @@ -0,0 +1,643 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# type: ignore[reportUnusedImport] + +import subprocess +import os +import re +import json +from json import JSONDecodeError +import sys +import requests +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import ( + Any, + Callable, + ContextManager, + Iterable, + Iterator, + List, + Literal, + Tuple, + Set, +) +from re import RegexFlag +import wget + + +DEFAULT_HTTP_TIMEOUT = 60 + + +class ServerResponse: + headers: dict + status_code: int + body: dict | Any + + +class ServerError(Exception): + def __init__(self, code, body): + self.code = code + self.body = body + + +class ServerProcess: + # default options + debug: bool = False + server_port: int = 8080 + server_host: str = "127.0.0.1" + model_hf_repo: str | None = "ggml-org/models" + model_hf_file: str | None = "tinyllamas/stories260K.gguf" + model_alias: str = "tinyllama-2" + temperature: float = 0.8 + seed: int = 42 + offline: bool = False + + # custom options + model_alias: str | None = None + model_url: str | None = None + model_file: str | None = None + model_draft: str | None = None + n_threads: int | None = None + n_gpu_layer: int | None = None + n_batch: int | None = None + n_ubatch: int | None = None + n_ctx: int | None = None + n_ga: int | None = None + n_ga_w: int | None = None + n_predict: int | None = None + n_prompts: int | None = 0 + slot_save_path: str | None = None + id_slot: int | None = None + cache_prompt: bool | None = None + n_slots: int | None = None + ctk: str | None = None + ctv: str | None = None + fa: str | None = None + server_continuous_batching: bool | None = False + server_embeddings: bool | None = False + server_reranking: bool | None = False + server_metrics: bool | None = False + kv_unified: bool | None = False + server_slots: bool | None = False + pooling: str | None = None + draft: int | None = None + api_key: str | None = None + models_dir: str | None = None + models_max: int | None = None + no_models_autoload: bool | None = None + lora_files: List[str] | None = None + enable_ctx_shift: int | None = False + draft_min: int | None = None + draft_max: int | None = None + no_webui: bool | None = None + jinja: bool | None = None + reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None + reasoning_budget: int | None = None + chat_template: str | None = None + chat_template_file: str | None = None + server_path: str | None = None + mmproj_url: str | None = None + media_path: str | None = None + sleep_idle_seconds: int | None = None + + # session variables + process: subprocess.Popen | None = None + + def __init__(self): + if "N_GPU_LAYERS" in os.environ: + self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"]) + if "DEBUG" in os.environ: + self.debug = True + if "PORT" in os.environ: + self.server_port = int(os.environ["PORT"]) + self.external_server = "DEBUG_EXTERNAL" in os.environ + + def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: + if self.external_server: + print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}") + return + if self.server_path is not None: + server_path = self.server_path + elif "LLAMA_SERVER_BIN_PATH" in os.environ: + server_path = os.environ["LLAMA_SERVER_BIN_PATH"] + elif os.name == "nt": + server_path = "../../../build/bin/Release/llama-server.exe" + else: + server_path = "../../../build/bin/llama-server" + server_args = [ + "--host", + self.server_host, + "--port", + self.server_port, + "--temp", + self.temperature, + "--seed", + self.seed, + ] + if self.offline: + server_args.append("--offline") + if self.model_file: + server_args.extend(["--model", self.model_file]) + if self.model_url: + server_args.extend(["--model-url", self.model_url]) + if self.model_draft: + server_args.extend(["--model-draft", self.model_draft]) + if self.model_hf_repo: + server_args.extend(["--hf-repo", self.model_hf_repo]) + if self.model_hf_file: + server_args.extend(["--hf-file", self.model_hf_file]) + if self.models_dir: + server_args.extend(["--models-dir", self.models_dir]) + if self.models_max is not None: + server_args.extend(["--models-max", self.models_max]) + if self.n_batch: + server_args.extend(["--batch-size", self.n_batch]) + if self.n_ubatch: + server_args.extend(["--ubatch-size", self.n_ubatch]) + if self.n_threads: + server_args.extend(["--threads", self.n_threads]) + if self.n_gpu_layer: + server_args.extend(["--n-gpu-layers", self.n_gpu_layer]) + if self.draft is not None: + server_args.extend(["--draft", self.draft]) + if self.server_continuous_batching: + server_args.append("--cont-batching") + if self.server_embeddings: + server_args.append("--embedding") + if self.server_reranking: + server_args.append("--reranking") + if self.server_metrics: + server_args.append("--metrics") + if self.kv_unified: + server_args.append("--kv-unified") + if self.server_slots: + server_args.append("--slots") + else: + server_args.append("--no-slots") + if self.pooling: + server_args.extend(["--pooling", self.pooling]) + if self.model_alias: + server_args.extend(["--alias", self.model_alias]) + if self.n_ctx: + server_args.extend(["--ctx-size", self.n_ctx]) + if self.n_slots: + server_args.extend(["--parallel", self.n_slots]) + if self.ctk: + server_args.extend(["-ctk", self.ctk]) + if self.ctv: + server_args.extend(["-ctv", self.ctv]) + if self.fa is not None: + server_args.extend(["-fa", self.fa]) + if self.n_predict: + server_args.extend(["--n-predict", self.n_predict]) + if self.slot_save_path: + server_args.extend(["--slot-save-path", self.slot_save_path]) + if self.n_ga: + server_args.extend(["--grp-attn-n", self.n_ga]) + if self.n_ga_w: + server_args.extend(["--grp-attn-w", self.n_ga_w]) + if self.debug: + server_args.append("--verbose") + if self.lora_files: + for lora_file in self.lora_files: + server_args.extend(["--lora", lora_file]) + if self.enable_ctx_shift: + server_args.append("--context-shift") + if self.api_key: + server_args.extend(["--api-key", self.api_key]) + if self.draft_max: + server_args.extend(["--draft-max", self.draft_max]) + if self.draft_min: + server_args.extend(["--draft-min", self.draft_min]) + if self.no_webui: + server_args.append("--no-webui") + if self.no_models_autoload: + server_args.append("--no-models-autoload") + if self.jinja: + server_args.append("--jinja") + else: + server_args.append("--no-jinja") + if self.reasoning_format is not None: + server_args.extend(("--reasoning-format", self.reasoning_format)) + if self.reasoning_budget is not None: + server_args.extend(("--reasoning-budget", self.reasoning_budget)) + if self.chat_template: + server_args.extend(["--chat-template", self.chat_template]) + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.mmproj_url: + server_args.extend(["--mmproj-url", self.mmproj_url]) + if self.media_path: + server_args.extend(["--media-path", self.media_path]) + if self.sleep_idle_seconds is not None: + server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds]) + + args = [str(arg) for arg in [server_path, *server_args]] + print(f"tests: starting server with: {' '.join(args)}") + + flags = 0 + if "nt" == os.name: + flags |= subprocess.DETACHED_PROCESS + flags |= subprocess.CREATE_NEW_PROCESS_GROUP + flags |= subprocess.CREATE_NO_WINDOW + + self.process = subprocess.Popen( + [str(arg) for arg in [server_path, *server_args]], + creationflags=flags, + stdout=sys.stdout, + stderr=sys.stdout, + env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None, + ) + server_instances.add(self) + + print(f"server pid={self.process.pid}, pytest pid={os.getpid()}") + + # wait for server to start + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + response = self.make_request("GET", "/health", headers={ + "Authorization": f"Bearer {self.api_key}" if self.api_key else None + }) + if response.status_code == 200: + self.ready = True + return # server is ready + except Exception as e: + pass + # Check if process died + if self.process.poll() is not None: + raise RuntimeError(f"Server process died with return code {self.process.returncode}") + + print(f"Waiting for server to start...") + time.sleep(0.5) + raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") + + def stop(self) -> None: + if self.external_server: + print("[external_server]: Not stopping external server") + return + if self in server_instances: + server_instances.remove(self) + if self.process: + print(f"Stopping server with pid={self.process.pid}") + self.process.kill() + self.process = None + + def make_request( + self, + method: str, + path: str, + data: dict | Any | None = None, + headers: dict | None = None, + timeout: float | None = None, + ) -> ServerResponse: + url = f"http://{self.server_host}:{self.server_port}{path}" + parse_body = False + if method == "GET": + response = requests.get(url, headers=headers, timeout=timeout) + parse_body = True + elif method == "POST": + response = requests.post(url, headers=headers, json=data, timeout=timeout) + parse_body = True + elif method == "OPTIONS": + response = requests.options(url, headers=headers, timeout=timeout) + else: + raise ValueError(f"Unimplemented method: {method}") + result = ServerResponse() + result.headers = dict(response.headers) + result.status_code = response.status_code + if parse_body: + try: + result.body = response.json() + except JSONDecodeError: + result.body = response.text + else: + result.body = None + print("Response from server", json.dumps(result.body, indent=2)) + return result + + def make_stream_request( + self, + method: str, + path: str, + data: dict | None = None, + headers: dict | None = None, + ) -> Iterator[dict]: + url = f"http://{self.server_host}:{self.server_port}{path}" + if method == "POST": + response = requests.post(url, headers=headers, json=data, stream=True) + else: + raise ValueError(f"Unimplemented method: {method}") + if response.status_code != 200: + raise ServerError(response.status_code, response.json()) + for line_bytes in response.iter_lines(): + line = line_bytes.decode("utf-8") + if '[DONE]' in line: + break + elif line.startswith('data: '): + data = json.loads(line[6:]) + print("Partial response from server", json.dumps(data, indent=2)) + yield data + + def make_any_request( + self, + method: str, + path: str, + data: dict | None = None, + headers: dict | None = None, + timeout: float | None = None, + ) -> dict: + stream = data.get('stream', False) + if stream: + content: list[str] = [] + reasoning_content: list[str] = [] + tool_calls: list[dict] = [] + finish_reason: Optional[str] = None + + content_parts = 0 + reasoning_content_parts = 0 + tool_call_parts = 0 + arguments_parts = 0 + + for chunk in self.make_stream_request(method, path, data, headers): + if chunk['choices']: + assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}' + choice = chunk['choices'][0] + if choice['delta'].get('content') is not None: + assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' + content.append(choice['delta']['content']) + content_parts += 1 + if choice['delta'].get('reasoning_content') is not None: + assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' + reasoning_content.append(choice['delta']['reasoning_content']) + reasoning_content_parts += 1 + if choice['delta'].get('finish_reason') is not None: + finish_reason = choice['delta']['finish_reason'] + for tc in choice['delta'].get('tool_calls', []): + if 'function' not in tc: + raise ValueError(f"Expected function type, got {tc['type']}") + if tc['index'] >= len(tool_calls): + assert 'id' in tc + assert tc.get('type') == 'function' + assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \ + f"Expected function call with name, got {tc.get('function')}" + tool_calls.append(dict( + id="", + type="function", + function=dict( + name="", + arguments="", + ) + )) + tool_call = tool_calls[tc['index']] + if tc.get('id') is not None: + tool_call['id'] = tc['id'] + fct = tc['function'] + assert 'id' not in fct, f"Function call should not have id: {fct}" + if fct.get('name') is not None: + tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] + if fct.get('arguments') is not None: + tool_call['function']['arguments'] += fct['arguments'] + arguments_parts += 1 + tool_call_parts += 1 + else: + # When `include_usage` is True (the default), we expect the last chunk of the stream + # immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array + # and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in + # the last chunk) + assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}" + assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}" + print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') + result = dict( + choices=[ + dict( + index=0, + finish_reason=finish_reason, + message=dict( + role='assistant', + content=''.join(content) if content else None, + reasoning_content=''.join(reasoning_content) if reasoning_content else None, + tool_calls=tool_calls if tool_calls else None, + ), + ) + ], + ) + print("Final response from server", json.dumps(result, indent=2)) + return result + else: + response = self.make_request(method, path, data, headers, timeout=timeout) + assert response.status_code == 200, f"Server returned error: {response.status_code}" + return response.body + + + +server_instances: Set[ServerProcess] = set() + + +class ServerPreset: + @staticmethod + def load_all() -> None: + """ Load all server presets to ensure model files are cached. """ + servers: List[ServerProcess] = [ + method() + for name, method in ServerPreset.__dict__.items() + if callable(method) and name != "load_all" + ] + for server in servers: + server.offline = False + server.start() + server.stop() + + @staticmethod + def tinyllama2() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/test-model-stories260K" + server.model_hf_file = None + server.model_alias = "tinyllama-2" + server.n_ctx = 512 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 64 + server.seed = 42 + return server + + @staticmethod + def bert_bge_small() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 512 + server.n_batch = 128 + server.n_ubatch = 128 + server.n_slots = 2 + server.seed = 42 + server.server_embeddings = True + return server + + @staticmethod + def bert_bge_small_with_fa() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 1024 + server.n_batch = 300 + server.n_ubatch = 300 + server.n_slots = 2 + server.fa = "on" + server.seed = 42 + server.server_embeddings = True + return server + + @staticmethod + def tinyllama_infill() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/test-model-stories260K-infill" + server.model_hf_file = None + server.model_alias = "tinyllama-infill" + server.n_ctx = 2048 + server.n_batch = 1024 + server.n_slots = 1 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + return server + + @staticmethod + def stories15m_moe() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/stories15M_MOE" + server.model_hf_file = "stories15M_MOE-F16.gguf" + server.model_alias = "stories15m-moe" + server.n_ctx = 2048 + server.n_batch = 1024 + server.n_slots = 1 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + return server + + @staticmethod + def jina_reranker_tiny() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" + server.model_alias = "jina-reranker" + server.n_ctx = 512 + server.n_batch = 512 + server.n_slots = 1 + server.seed = 42 + server.server_reranking = True + return server + + @staticmethod + def tinygemma3() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + # mmproj is already provided by HF registry API + server.model_hf_file = None + server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0" + server.model_alias = "tinygemma3" + server.n_ctx = 1024 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 4 + server.seed = 42 + return server + + @staticmethod + def router() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + # router server has no models + server.model_file = None + server.model_alias = None + server.model_hf_repo = None + server.model_hf_file = None + server.n_ctx = 1024 + server.n_batch = 16 + server.n_slots = 1 + server.n_predict = 16 + server.seed = 42 + return server + + +def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: + """ + Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS. + + Example usage: + + results = parallel_function_calls([ + (func1, (arg1, arg2)), + (func2, (arg3, arg4)), + ]) + """ + results = [None] * len(function_list) + exceptions = [] + + def worker(index, func, args): + try: + result = func(*args) + results[index] = result + except Exception as e: + exceptions.append((index, str(e))) + + with ThreadPoolExecutor() as executor: + futures = [] + for i, (func, args) in enumerate(function_list): + future = executor.submit(worker, i, func, args) + futures.append(future) + + # Wait for all futures to complete + for future in as_completed(futures): + pass + + # Check if there were any exceptions + if exceptions: + print("Exceptions occurred:") + for index, error in exceptions: + print(f"Function at index {index}: {error}") + + return results + + +def match_regex(regex: str, text: str) -> bool: + return ( + re.compile( + regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL + ).search(text) + is not None + ) + + +def download_file(url: str, output_file_path: str | None = None) -> str: + """ + Download a file from a URL to a local path. If the file already exists, it will not be downloaded again. + + output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory. + + Returns the local path of the downloaded file. + """ + file_name = url.split('/').pop() + output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path + if not os.path.exists(output_file): + print(f"Downloading {url} to {output_file}") + wget.download(url, out=output_file) + print(f"Done downloading to {output_file}") + else: + print(f"File already exists at {output_file}") + return output_file + + +def is_slow_test_allowed(): + return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON" |
