1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3
  4# type: ignore[reportUnusedImport]
  5
  6import subprocess
  7import os
  8import re
  9import json
 10from json import JSONDecodeError
 11import sys
 12import requests
 13import time
 14from concurrent.futures import ThreadPoolExecutor, as_completed
 15from typing import (
 16    Any,
 17    Callable,
 18    ContextManager,
 19    Iterable,
 20    Iterator,
 21    List,
 22    Literal,
 23    Tuple,
 24    Set,
 25)
 26from re import RegexFlag
 27import wget
 28
 29
 30DEFAULT_HTTP_TIMEOUT = 60
 31
 32
 33class ServerResponse:
 34    headers: dict
 35    status_code: int
 36    body: dict | Any
 37
 38
 39class ServerError(Exception):
 40    def __init__(self, code, body):
 41        self.code = code
 42        self.body = body
 43
 44
 45class ServerProcess:
 46    # default options
 47    debug: bool = False
 48    server_port: int = 8080
 49    server_host: str = "127.0.0.1"
 50    model_hf_repo: str | None = "ggml-org/models"
 51    model_hf_file: str | None = "tinyllamas/stories260K.gguf"
 52    model_alias: str = "tinyllama-2"
 53    temperature: float = 0.8
 54    seed: int = 42
 55    offline: bool = False
 56
 57    # custom options
 58    model_alias: str | None = None
 59    model_url: str | None = None
 60    model_file: str | None = None
 61    model_draft: str | None = None
 62    n_threads: int | None = None
 63    n_gpu_layer: int | None = None
 64    n_batch: int | None = None
 65    n_ubatch: int | None = None
 66    n_ctx: int | None = None
 67    n_ga: int | None = None
 68    n_ga_w: int | None = None
 69    n_predict: int | None = None
 70    n_prompts: int | None = 0
 71    slot_save_path: str | None = None
 72    id_slot: int | None = None
 73    cache_prompt: bool | None = None
 74    n_slots: int | None = None
 75    ctk: str | None = None
 76    ctv: str | None = None
 77    fa: str | None = None
 78    server_continuous_batching: bool | None = False
 79    server_embeddings: bool | None = False
 80    server_reranking: bool | None = False
 81    server_metrics: bool | None = False
 82    kv_unified: bool | None = False
 83    server_slots: bool | None = False
 84    pooling: str | None = None
 85    draft: int | None = None
 86    api_key: str | None = None
 87    models_dir: str | None = None
 88    models_max: int | None = None
 89    no_models_autoload: bool | None = None
 90    lora_files: List[str] | None = None
 91    enable_ctx_shift: int | None = False
 92    draft_min: int | None = None
 93    draft_max: int | None = None
 94    no_webui: bool | None = None
 95    jinja: bool | None = None
 96    reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
 97    reasoning_budget: int | None = None
 98    chat_template: str | None = None
 99    chat_template_file: str | None = None
100    server_path: str | None = None
101    mmproj_url: str | None = None
102    media_path: str | None = None
103    sleep_idle_seconds: int | None = None
104
105    # session variables
106    process: subprocess.Popen | None = None
107
108    def __init__(self):
109        if "N_GPU_LAYERS" in os.environ:
110            self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
111        if "DEBUG" in os.environ:
112            self.debug = True
113        if "PORT" in os.environ:
114            self.server_port = int(os.environ["PORT"])
115        self.external_server = "DEBUG_EXTERNAL" in os.environ
116
117    def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
118        if self.external_server:
119            print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}")
120            return
121        if self.server_path is not None:
122            server_path = self.server_path
123        elif "LLAMA_SERVER_BIN_PATH" in os.environ:
124            server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
125        elif os.name == "nt":
126            server_path = "../../../build/bin/Release/llama-server.exe"
127        else:
128            server_path = "../../../build/bin/llama-server"
129        server_args = [
130            "--host",
131            self.server_host,
132            "--port",
133            self.server_port,
134            "--temp",
135            self.temperature,
136            "--seed",
137            self.seed,
138        ]
139        if self.offline:
140            server_args.append("--offline")
141        if self.model_file:
142            server_args.extend(["--model", self.model_file])
143        if self.model_url:
144            server_args.extend(["--model-url", self.model_url])
145        if self.model_draft:
146            server_args.extend(["--model-draft", self.model_draft])
147        if self.model_hf_repo:
148            server_args.extend(["--hf-repo", self.model_hf_repo])
149        if self.model_hf_file:
150            server_args.extend(["--hf-file", self.model_hf_file])
151        if self.models_dir:
152            server_args.extend(["--models-dir", self.models_dir])
153        if self.models_max is not None:
154            server_args.extend(["--models-max", self.models_max])
155        if self.n_batch:
156            server_args.extend(["--batch-size", self.n_batch])
157        if self.n_ubatch:
158            server_args.extend(["--ubatch-size", self.n_ubatch])
159        if self.n_threads:
160            server_args.extend(["--threads", self.n_threads])
161        if self.n_gpu_layer:
162            server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
163        if self.draft is not None:
164            server_args.extend(["--draft", self.draft])
165        if self.server_continuous_batching:
166            server_args.append("--cont-batching")
167        if self.server_embeddings:
168            server_args.append("--embedding")
169        if self.server_reranking:
170            server_args.append("--reranking")
171        if self.server_metrics:
172            server_args.append("--metrics")
173        if self.kv_unified:
174            server_args.append("--kv-unified")
175        if self.server_slots:
176            server_args.append("--slots")
177        else:
178            server_args.append("--no-slots")
179        if self.pooling:
180            server_args.extend(["--pooling", self.pooling])
181        if self.model_alias:
182            server_args.extend(["--alias", self.model_alias])
183        if self.n_ctx:
184            server_args.extend(["--ctx-size", self.n_ctx])
185        if self.n_slots:
186            server_args.extend(["--parallel", self.n_slots])
187        if self.ctk:
188            server_args.extend(["-ctk", self.ctk])
189        if self.ctv:
190            server_args.extend(["-ctv", self.ctv])
191        if self.fa is not None:
192            server_args.extend(["-fa", self.fa])
193        if self.n_predict:
194            server_args.extend(["--n-predict", self.n_predict])
195        if self.slot_save_path:
196            server_args.extend(["--slot-save-path", self.slot_save_path])
197        if self.n_ga:
198            server_args.extend(["--grp-attn-n", self.n_ga])
199        if self.n_ga_w:
200            server_args.extend(["--grp-attn-w", self.n_ga_w])
201        if self.debug:
202            server_args.append("--verbose")
203        if self.lora_files:
204            for lora_file in self.lora_files:
205                server_args.extend(["--lora", lora_file])
206        if self.enable_ctx_shift:
207            server_args.append("--context-shift")
208        if self.api_key:
209            server_args.extend(["--api-key", self.api_key])
210        if self.draft_max:
211            server_args.extend(["--draft-max", self.draft_max])
212        if self.draft_min:
213            server_args.extend(["--draft-min", self.draft_min])
214        if self.no_webui:
215            server_args.append("--no-webui")
216        if self.no_models_autoload:
217            server_args.append("--no-models-autoload")
218        if self.jinja:
219            server_args.append("--jinja")
220        else:
221            server_args.append("--no-jinja")
222        if self.reasoning_format is not None:
223            server_args.extend(("--reasoning-format", self.reasoning_format))
224        if self.reasoning_budget is not None:
225            server_args.extend(("--reasoning-budget", self.reasoning_budget))
226        if self.chat_template:
227            server_args.extend(["--chat-template", self.chat_template])
228        if self.chat_template_file:
229            server_args.extend(["--chat-template-file", self.chat_template_file])
230        if self.mmproj_url:
231            server_args.extend(["--mmproj-url", self.mmproj_url])
232        if self.media_path:
233            server_args.extend(["--media-path", self.media_path])
234        if self.sleep_idle_seconds is not None:
235            server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds])
236
237        args = [str(arg) for arg in [server_path, *server_args]]
238        print(f"tests: starting server with: {' '.join(args)}")
239
240        flags = 0
241        if "nt" == os.name:
242            flags |= subprocess.DETACHED_PROCESS
243            flags |= subprocess.CREATE_NEW_PROCESS_GROUP
244            flags |= subprocess.CREATE_NO_WINDOW
245
246        self.process = subprocess.Popen(
247            [str(arg) for arg in [server_path, *server_args]],
248            creationflags=flags,
249            stdout=sys.stdout,
250            stderr=sys.stdout,
251            env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
252        )
253        server_instances.add(self)
254
255        print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
256
257        # wait for server to start
258        start_time = time.time()
259        while time.time() - start_time < timeout_seconds:
260            try:
261                response = self.make_request("GET", "/health", headers={
262                    "Authorization": f"Bearer {self.api_key}" if self.api_key else None
263                })
264                if response.status_code == 200:
265                    self.ready = True
266                    return  # server is ready
267            except Exception as e:
268                pass
269            # Check if process died
270            if self.process.poll() is not None:
271                raise RuntimeError(f"Server process died with return code {self.process.returncode}")
272
273            print(f"Waiting for server to start...")
274            time.sleep(0.5)
275        raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
276
277    def stop(self) -> None:
278        if self.external_server:
279            print("[external_server]: Not stopping external server")
280            return
281        if self in server_instances:
282            server_instances.remove(self)
283        if self.process:
284            print(f"Stopping server with pid={self.process.pid}")
285            self.process.kill()
286            self.process = None
287
288    def make_request(
289        self,
290        method: str,
291        path: str,
292        data: dict | Any | None = None,
293        headers: dict | None = None,
294        timeout: float | None = None,
295    ) -> ServerResponse:
296        url = f"http://{self.server_host}:{self.server_port}{path}"
297        parse_body = False
298        if method == "GET":
299            response = requests.get(url, headers=headers, timeout=timeout)
300            parse_body = True
301        elif method == "POST":
302            response = requests.post(url, headers=headers, json=data, timeout=timeout)
303            parse_body = True
304        elif method == "OPTIONS":
305            response = requests.options(url, headers=headers, timeout=timeout)
306        else:
307            raise ValueError(f"Unimplemented method: {method}")
308        result = ServerResponse()
309        result.headers = dict(response.headers)
310        result.status_code = response.status_code
311        if parse_body:
312            try:
313                result.body = response.json()
314            except JSONDecodeError:
315                result.body = response.text
316        else:
317            result.body = None
318        print("Response from server", json.dumps(result.body, indent=2))
319        return result
320
321    def make_stream_request(
322        self,
323        method: str,
324        path: str,
325        data: dict | None = None,
326        headers: dict | None = None,
327    ) -> Iterator[dict]:
328        url = f"http://{self.server_host}:{self.server_port}{path}"
329        if method == "POST":
330            response = requests.post(url, headers=headers, json=data, stream=True)
331        else:
332            raise ValueError(f"Unimplemented method: {method}")
333        if response.status_code != 200:
334            raise ServerError(response.status_code, response.json())
335        for line_bytes in response.iter_lines():
336            line = line_bytes.decode("utf-8")
337            if '[DONE]' in line:
338                break
339            elif line.startswith('data: '):
340                data = json.loads(line[6:])
341                print("Partial response from server", json.dumps(data, indent=2))
342                yield data
343
344    def make_any_request(
345        self,
346        method: str,
347        path: str,
348        data: dict | None = None,
349        headers: dict | None = None,
350        timeout: float | None = None,
351    ) -> dict:
352        stream = data.get('stream', False)
353        if stream:
354            content: list[str] = []
355            reasoning_content: list[str] = []
356            tool_calls: list[dict] = []
357            finish_reason: Optional[str] = None
358
359            content_parts = 0
360            reasoning_content_parts = 0
361            tool_call_parts = 0
362            arguments_parts = 0
363
364            for chunk in self.make_stream_request(method, path, data, headers):
365                if chunk['choices']:
366                    assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
367                    choice = chunk['choices'][0]
368                    if choice['delta'].get('content') is not None:
369                        assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
370                        content.append(choice['delta']['content'])
371                        content_parts += 1
372                    if choice['delta'].get('reasoning_content') is not None:
373                        assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
374                        reasoning_content.append(choice['delta']['reasoning_content'])
375                        reasoning_content_parts += 1
376                    if choice['delta'].get('finish_reason') is not None:
377                        finish_reason = choice['delta']['finish_reason']
378                    for tc in choice['delta'].get('tool_calls', []):
379                        if 'function' not in tc:
380                            raise ValueError(f"Expected function type, got {tc['type']}")
381                        if tc['index'] >= len(tool_calls):
382                            assert 'id' in tc
383                            assert tc.get('type') == 'function'
384                            assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
385                                f"Expected function call with name, got {tc.get('function')}"
386                            tool_calls.append(dict(
387                                id="",
388                                type="function",
389                                function=dict(
390                                    name="",
391                                    arguments="",
392                                )
393                            ))
394                        tool_call = tool_calls[tc['index']]
395                        if tc.get('id') is not None:
396                            tool_call['id'] = tc['id']
397                        fct = tc['function']
398                        assert 'id' not in fct, f"Function call should not have id: {fct}"
399                        if fct.get('name') is not None:
400                            tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
401                        if fct.get('arguments') is not None:
402                            tool_call['function']['arguments'] += fct['arguments']
403                            arguments_parts += 1
404                        tool_call_parts += 1
405                else:
406                    # When `include_usage` is True (the default), we expect the last chunk of the stream
407                    # immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
408                    # and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
409                    # the last chunk)
410                    assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
411                    assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
412            print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
413            result = dict(
414                choices=[
415                    dict(
416                        index=0,
417                        finish_reason=finish_reason,
418                        message=dict(
419                            role='assistant',
420                            content=''.join(content) if content else None,
421                            reasoning_content=''.join(reasoning_content) if reasoning_content else None,
422                            tool_calls=tool_calls if tool_calls else None,
423                        ),
424                    )
425                ],
426            )
427            print("Final response from server", json.dumps(result, indent=2))
428            return result
429        else:
430            response = self.make_request(method, path, data, headers, timeout=timeout)
431            assert response.status_code == 200, f"Server returned error: {response.status_code}"
432            return response.body
433
434
435
436server_instances: Set[ServerProcess] = set()
437
438
439class ServerPreset:
440    @staticmethod
441    def load_all() -> None:
442        """ Load all server presets to ensure model files are cached. """
443        servers: List[ServerProcess] = [
444            method()
445            for name, method in ServerPreset.__dict__.items()
446            if callable(method) and name != "load_all"
447        ]
448        for server in servers:
449            server.offline = False
450            server.start()
451            server.stop()
452
453    @staticmethod
454    def tinyllama2() -> ServerProcess:
455        server = ServerProcess()
456        server.offline = True # will be downloaded by load_all()
457        server.model_hf_repo = "ggml-org/test-model-stories260K"
458        server.model_hf_file = None
459        server.model_alias = "tinyllama-2"
460        server.n_ctx = 512
461        server.n_batch = 32
462        server.n_slots = 2
463        server.n_predict = 64
464        server.seed = 42
465        return server
466
467    @staticmethod
468    def bert_bge_small() -> ServerProcess:
469        server = ServerProcess()
470        server.offline = True # will be downloaded by load_all()
471        server.model_hf_repo = "ggml-org/models"
472        server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
473        server.model_alias = "bert-bge-small"
474        server.n_ctx = 512
475        server.n_batch = 128
476        server.n_ubatch = 128
477        server.n_slots = 2
478        server.seed = 42
479        server.server_embeddings = True
480        return server
481
482    @staticmethod
483    def bert_bge_small_with_fa() -> ServerProcess:
484        server = ServerProcess()
485        server.offline = True # will be downloaded by load_all()
486        server.model_hf_repo = "ggml-org/models"
487        server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
488        server.model_alias = "bert-bge-small"
489        server.n_ctx = 1024
490        server.n_batch = 300
491        server.n_ubatch = 300
492        server.n_slots = 2
493        server.fa = "on"
494        server.seed = 42
495        server.server_embeddings = True
496        return server
497
498    @staticmethod
499    def tinyllama_infill() -> ServerProcess:
500        server = ServerProcess()
501        server.offline = True # will be downloaded by load_all()
502        server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
503        server.model_hf_file = None
504        server.model_alias = "tinyllama-infill"
505        server.n_ctx = 2048
506        server.n_batch = 1024
507        server.n_slots = 1
508        server.n_predict = 64
509        server.temperature = 0.0
510        server.seed = 42
511        return server
512
513    @staticmethod
514    def stories15m_moe() -> ServerProcess:
515        server = ServerProcess()
516        server.offline = True # will be downloaded by load_all()
517        server.model_hf_repo = "ggml-org/stories15M_MOE"
518        server.model_hf_file = "stories15M_MOE-F16.gguf"
519        server.model_alias = "stories15m-moe"
520        server.n_ctx = 2048
521        server.n_batch = 1024
522        server.n_slots = 1
523        server.n_predict = 64
524        server.temperature = 0.0
525        server.seed = 42
526        return server
527
528    @staticmethod
529    def jina_reranker_tiny() -> ServerProcess:
530        server = ServerProcess()
531        server.offline = True # will be downloaded by load_all()
532        server.model_hf_repo = "ggml-org/models"
533        server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
534        server.model_alias = "jina-reranker"
535        server.n_ctx = 512
536        server.n_batch = 512
537        server.n_slots = 1
538        server.seed = 42
539        server.server_reranking = True
540        return server
541
542    @staticmethod
543    def tinygemma3() -> ServerProcess:
544        server = ServerProcess()
545        server.offline = True # will be downloaded by load_all()
546        # mmproj is already provided by HF registry API
547        server.model_hf_file = None
548        server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0"
549        server.model_alias = "tinygemma3"
550        server.n_ctx = 1024
551        server.n_batch = 32
552        server.n_slots = 2
553        server.n_predict = 4
554        server.seed = 42
555        return server
556
557    @staticmethod
558    def router() -> ServerProcess:
559        server = ServerProcess()
560        server.offline = True # will be downloaded by load_all()
561        # router server has no models
562        server.model_file = None
563        server.model_alias = None
564        server.model_hf_repo = None
565        server.model_hf_file = None
566        server.n_ctx = 1024
567        server.n_batch = 16
568        server.n_slots = 1
569        server.n_predict = 16
570        server.seed = 42
571        return server
572
573
574def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
575    """
576    Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
577
578    Example usage:
579
580    results = parallel_function_calls([
581        (func1, (arg1, arg2)),
582        (func2, (arg3, arg4)),
583    ])
584    """
585    results = [None] * len(function_list)
586    exceptions = []
587
588    def worker(index, func, args):
589        try:
590            result = func(*args)
591            results[index] = result
592        except Exception as e:
593            exceptions.append((index, str(e)))
594
595    with ThreadPoolExecutor() as executor:
596        futures = []
597        for i, (func, args) in enumerate(function_list):
598            future = executor.submit(worker, i, func, args)
599            futures.append(future)
600
601        # Wait for all futures to complete
602        for future in as_completed(futures):
603            pass
604
605    # Check if there were any exceptions
606    if exceptions:
607        print("Exceptions occurred:")
608        for index, error in exceptions:
609            print(f"Function at index {index}: {error}")
610
611    return results
612
613
614def match_regex(regex: str, text: str) -> bool:
615    return (
616        re.compile(
617            regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
618        ).search(text)
619        is not None
620    )
621
622
623def download_file(url: str, output_file_path: str | None = None) -> str:
624    """
625    Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
626
627    output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
628
629    Returns the local path of the downloaded file.
630    """
631    file_name = url.split('/').pop()
632    output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
633    if not os.path.exists(output_file):
634        print(f"Downloading {url} to {output_file}")
635        wget.download(url, out=output_file)
636        print(f"Done downloading to {output_file}")
637    else:
638        print(f"File already exists at {output_file}")
639    return output_file
640
641
642def is_slow_test_allowed():
643    return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"