1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# type: ignore[reportUnusedImport]
5
6import subprocess
7import os
8import re
9import json
10from json import JSONDecodeError
11import sys
12import requests
13import time
14from concurrent.futures import ThreadPoolExecutor, as_completed
15from typing import (
16 Any,
17 Callable,
18 ContextManager,
19 Iterable,
20 Iterator,
21 List,
22 Literal,
23 Tuple,
24 Set,
25)
26from re import RegexFlag
27import wget
28
29
30DEFAULT_HTTP_TIMEOUT = 60
31
32
33class ServerResponse:
34 headers: dict
35 status_code: int
36 body: dict | Any
37
38
39class ServerError(Exception):
40 def __init__(self, code, body):
41 self.code = code
42 self.body = body
43
44
45class ServerProcess:
46 # default options
47 debug: bool = False
48 server_port: int = 8080
49 server_host: str = "127.0.0.1"
50 model_hf_repo: str | None = "ggml-org/models"
51 model_hf_file: str | None = "tinyllamas/stories260K.gguf"
52 model_alias: str = "tinyllama-2"
53 temperature: float = 0.8
54 seed: int = 42
55 offline: bool = False
56
57 # custom options
58 model_alias: str | None = None
59 model_url: str | None = None
60 model_file: str | None = None
61 model_draft: str | None = None
62 n_threads: int | None = None
63 n_gpu_layer: int | None = None
64 n_batch: int | None = None
65 n_ubatch: int | None = None
66 n_ctx: int | None = None
67 n_ga: int | None = None
68 n_ga_w: int | None = None
69 n_predict: int | None = None
70 n_prompts: int | None = 0
71 slot_save_path: str | None = None
72 id_slot: int | None = None
73 cache_prompt: bool | None = None
74 n_slots: int | None = None
75 ctk: str | None = None
76 ctv: str | None = None
77 fa: str | None = None
78 server_continuous_batching: bool | None = False
79 server_embeddings: bool | None = False
80 server_reranking: bool | None = False
81 server_metrics: bool | None = False
82 kv_unified: bool | None = False
83 server_slots: bool | None = False
84 pooling: str | None = None
85 draft: int | None = None
86 api_key: str | None = None
87 models_dir: str | None = None
88 models_max: int | None = None
89 no_models_autoload: bool | None = None
90 lora_files: List[str] | None = None
91 enable_ctx_shift: int | None = False
92 draft_min: int | None = None
93 draft_max: int | None = None
94 no_webui: bool | None = None
95 jinja: bool | None = None
96 reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
97 reasoning_budget: int | None = None
98 chat_template: str | None = None
99 chat_template_file: str | None = None
100 server_path: str | None = None
101 mmproj_url: str | None = None
102 media_path: str | None = None
103 sleep_idle_seconds: int | None = None
104
105 # session variables
106 process: subprocess.Popen | None = None
107
108 def __init__(self):
109 if "N_GPU_LAYERS" in os.environ:
110 self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
111 if "DEBUG" in os.environ:
112 self.debug = True
113 if "PORT" in os.environ:
114 self.server_port = int(os.environ["PORT"])
115 self.external_server = "DEBUG_EXTERNAL" in os.environ
116
117 def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
118 if self.external_server:
119 print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}")
120 return
121 if self.server_path is not None:
122 server_path = self.server_path
123 elif "LLAMA_SERVER_BIN_PATH" in os.environ:
124 server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
125 elif os.name == "nt":
126 server_path = "../../../build/bin/Release/llama-server.exe"
127 else:
128 server_path = "../../../build/bin/llama-server"
129 server_args = [
130 "--host",
131 self.server_host,
132 "--port",
133 self.server_port,
134 "--temp",
135 self.temperature,
136 "--seed",
137 self.seed,
138 ]
139 if self.offline:
140 server_args.append("--offline")
141 if self.model_file:
142 server_args.extend(["--model", self.model_file])
143 if self.model_url:
144 server_args.extend(["--model-url", self.model_url])
145 if self.model_draft:
146 server_args.extend(["--model-draft", self.model_draft])
147 if self.model_hf_repo:
148 server_args.extend(["--hf-repo", self.model_hf_repo])
149 if self.model_hf_file:
150 server_args.extend(["--hf-file", self.model_hf_file])
151 if self.models_dir:
152 server_args.extend(["--models-dir", self.models_dir])
153 if self.models_max is not None:
154 server_args.extend(["--models-max", self.models_max])
155 if self.n_batch:
156 server_args.extend(["--batch-size", self.n_batch])
157 if self.n_ubatch:
158 server_args.extend(["--ubatch-size", self.n_ubatch])
159 if self.n_threads:
160 server_args.extend(["--threads", self.n_threads])
161 if self.n_gpu_layer:
162 server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
163 if self.draft is not None:
164 server_args.extend(["--draft", self.draft])
165 if self.server_continuous_batching:
166 server_args.append("--cont-batching")
167 if self.server_embeddings:
168 server_args.append("--embedding")
169 if self.server_reranking:
170 server_args.append("--reranking")
171 if self.server_metrics:
172 server_args.append("--metrics")
173 if self.kv_unified:
174 server_args.append("--kv-unified")
175 if self.server_slots:
176 server_args.append("--slots")
177 else:
178 server_args.append("--no-slots")
179 if self.pooling:
180 server_args.extend(["--pooling", self.pooling])
181 if self.model_alias:
182 server_args.extend(["--alias", self.model_alias])
183 if self.n_ctx:
184 server_args.extend(["--ctx-size", self.n_ctx])
185 if self.n_slots:
186 server_args.extend(["--parallel", self.n_slots])
187 if self.ctk:
188 server_args.extend(["-ctk", self.ctk])
189 if self.ctv:
190 server_args.extend(["-ctv", self.ctv])
191 if self.fa is not None:
192 server_args.extend(["-fa", self.fa])
193 if self.n_predict:
194 server_args.extend(["--n-predict", self.n_predict])
195 if self.slot_save_path:
196 server_args.extend(["--slot-save-path", self.slot_save_path])
197 if self.n_ga:
198 server_args.extend(["--grp-attn-n", self.n_ga])
199 if self.n_ga_w:
200 server_args.extend(["--grp-attn-w", self.n_ga_w])
201 if self.debug:
202 server_args.append("--verbose")
203 if self.lora_files:
204 for lora_file in self.lora_files:
205 server_args.extend(["--lora", lora_file])
206 if self.enable_ctx_shift:
207 server_args.append("--context-shift")
208 if self.api_key:
209 server_args.extend(["--api-key", self.api_key])
210 if self.draft_max:
211 server_args.extend(["--draft-max", self.draft_max])
212 if self.draft_min:
213 server_args.extend(["--draft-min", self.draft_min])
214 if self.no_webui:
215 server_args.append("--no-webui")
216 if self.no_models_autoload:
217 server_args.append("--no-models-autoload")
218 if self.jinja:
219 server_args.append("--jinja")
220 else:
221 server_args.append("--no-jinja")
222 if self.reasoning_format is not None:
223 server_args.extend(("--reasoning-format", self.reasoning_format))
224 if self.reasoning_budget is not None:
225 server_args.extend(("--reasoning-budget", self.reasoning_budget))
226 if self.chat_template:
227 server_args.extend(["--chat-template", self.chat_template])
228 if self.chat_template_file:
229 server_args.extend(["--chat-template-file", self.chat_template_file])
230 if self.mmproj_url:
231 server_args.extend(["--mmproj-url", self.mmproj_url])
232 if self.media_path:
233 server_args.extend(["--media-path", self.media_path])
234 if self.sleep_idle_seconds is not None:
235 server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds])
236
237 args = [str(arg) for arg in [server_path, *server_args]]
238 print(f"tests: starting server with: {' '.join(args)}")
239
240 flags = 0
241 if "nt" == os.name:
242 flags |= subprocess.DETACHED_PROCESS
243 flags |= subprocess.CREATE_NEW_PROCESS_GROUP
244 flags |= subprocess.CREATE_NO_WINDOW
245
246 self.process = subprocess.Popen(
247 [str(arg) for arg in [server_path, *server_args]],
248 creationflags=flags,
249 stdout=sys.stdout,
250 stderr=sys.stdout,
251 env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
252 )
253 server_instances.add(self)
254
255 print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
256
257 # wait for server to start
258 start_time = time.time()
259 while time.time() - start_time < timeout_seconds:
260 try:
261 response = self.make_request("GET", "/health", headers={
262 "Authorization": f"Bearer {self.api_key}" if self.api_key else None
263 })
264 if response.status_code == 200:
265 self.ready = True
266 return # server is ready
267 except Exception as e:
268 pass
269 # Check if process died
270 if self.process.poll() is not None:
271 raise RuntimeError(f"Server process died with return code {self.process.returncode}")
272
273 print(f"Waiting for server to start...")
274 time.sleep(0.5)
275 raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
276
277 def stop(self) -> None:
278 if self.external_server:
279 print("[external_server]: Not stopping external server")
280 return
281 if self in server_instances:
282 server_instances.remove(self)
283 if self.process:
284 print(f"Stopping server with pid={self.process.pid}")
285 self.process.kill()
286 self.process = None
287
288 def make_request(
289 self,
290 method: str,
291 path: str,
292 data: dict | Any | None = None,
293 headers: dict | None = None,
294 timeout: float | None = None,
295 ) -> ServerResponse:
296 url = f"http://{self.server_host}:{self.server_port}{path}"
297 parse_body = False
298 if method == "GET":
299 response = requests.get(url, headers=headers, timeout=timeout)
300 parse_body = True
301 elif method == "POST":
302 response = requests.post(url, headers=headers, json=data, timeout=timeout)
303 parse_body = True
304 elif method == "OPTIONS":
305 response = requests.options(url, headers=headers, timeout=timeout)
306 else:
307 raise ValueError(f"Unimplemented method: {method}")
308 result = ServerResponse()
309 result.headers = dict(response.headers)
310 result.status_code = response.status_code
311 if parse_body:
312 try:
313 result.body = response.json()
314 except JSONDecodeError:
315 result.body = response.text
316 else:
317 result.body = None
318 print("Response from server", json.dumps(result.body, indent=2))
319 return result
320
321 def make_stream_request(
322 self,
323 method: str,
324 path: str,
325 data: dict | None = None,
326 headers: dict | None = None,
327 ) -> Iterator[dict]:
328 url = f"http://{self.server_host}:{self.server_port}{path}"
329 if method == "POST":
330 response = requests.post(url, headers=headers, json=data, stream=True)
331 else:
332 raise ValueError(f"Unimplemented method: {method}")
333 if response.status_code != 200:
334 raise ServerError(response.status_code, response.json())
335 for line_bytes in response.iter_lines():
336 line = line_bytes.decode("utf-8")
337 if '[DONE]' in line:
338 break
339 elif line.startswith('data: '):
340 data = json.loads(line[6:])
341 print("Partial response from server", json.dumps(data, indent=2))
342 yield data
343
344 def make_any_request(
345 self,
346 method: str,
347 path: str,
348 data: dict | None = None,
349 headers: dict | None = None,
350 timeout: float | None = None,
351 ) -> dict:
352 stream = data.get('stream', False)
353 if stream:
354 content: list[str] = []
355 reasoning_content: list[str] = []
356 tool_calls: list[dict] = []
357 finish_reason: Optional[str] = None
358
359 content_parts = 0
360 reasoning_content_parts = 0
361 tool_call_parts = 0
362 arguments_parts = 0
363
364 for chunk in self.make_stream_request(method, path, data, headers):
365 if chunk['choices']:
366 assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
367 choice = chunk['choices'][0]
368 if choice['delta'].get('content') is not None:
369 assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
370 content.append(choice['delta']['content'])
371 content_parts += 1
372 if choice['delta'].get('reasoning_content') is not None:
373 assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
374 reasoning_content.append(choice['delta']['reasoning_content'])
375 reasoning_content_parts += 1
376 if choice['delta'].get('finish_reason') is not None:
377 finish_reason = choice['delta']['finish_reason']
378 for tc in choice['delta'].get('tool_calls', []):
379 if 'function' not in tc:
380 raise ValueError(f"Expected function type, got {tc['type']}")
381 if tc['index'] >= len(tool_calls):
382 assert 'id' in tc
383 assert tc.get('type') == 'function'
384 assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
385 f"Expected function call with name, got {tc.get('function')}"
386 tool_calls.append(dict(
387 id="",
388 type="function",
389 function=dict(
390 name="",
391 arguments="",
392 )
393 ))
394 tool_call = tool_calls[tc['index']]
395 if tc.get('id') is not None:
396 tool_call['id'] = tc['id']
397 fct = tc['function']
398 assert 'id' not in fct, f"Function call should not have id: {fct}"
399 if fct.get('name') is not None:
400 tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
401 if fct.get('arguments') is not None:
402 tool_call['function']['arguments'] += fct['arguments']
403 arguments_parts += 1
404 tool_call_parts += 1
405 else:
406 # When `include_usage` is True (the default), we expect the last chunk of the stream
407 # immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
408 # and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
409 # the last chunk)
410 assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
411 assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
412 print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
413 result = dict(
414 choices=[
415 dict(
416 index=0,
417 finish_reason=finish_reason,
418 message=dict(
419 role='assistant',
420 content=''.join(content) if content else None,
421 reasoning_content=''.join(reasoning_content) if reasoning_content else None,
422 tool_calls=tool_calls if tool_calls else None,
423 ),
424 )
425 ],
426 )
427 print("Final response from server", json.dumps(result, indent=2))
428 return result
429 else:
430 response = self.make_request(method, path, data, headers, timeout=timeout)
431 assert response.status_code == 200, f"Server returned error: {response.status_code}"
432 return response.body
433
434
435
436server_instances: Set[ServerProcess] = set()
437
438
439class ServerPreset:
440 @staticmethod
441 def load_all() -> None:
442 """ Load all server presets to ensure model files are cached. """
443 servers: List[ServerProcess] = [
444 method()
445 for name, method in ServerPreset.__dict__.items()
446 if callable(method) and name != "load_all"
447 ]
448 for server in servers:
449 server.offline = False
450 server.start()
451 server.stop()
452
453 @staticmethod
454 def tinyllama2() -> ServerProcess:
455 server = ServerProcess()
456 server.offline = True # will be downloaded by load_all()
457 server.model_hf_repo = "ggml-org/test-model-stories260K"
458 server.model_hf_file = None
459 server.model_alias = "tinyllama-2"
460 server.n_ctx = 512
461 server.n_batch = 32
462 server.n_slots = 2
463 server.n_predict = 64
464 server.seed = 42
465 return server
466
467 @staticmethod
468 def bert_bge_small() -> ServerProcess:
469 server = ServerProcess()
470 server.offline = True # will be downloaded by load_all()
471 server.model_hf_repo = "ggml-org/models"
472 server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
473 server.model_alias = "bert-bge-small"
474 server.n_ctx = 512
475 server.n_batch = 128
476 server.n_ubatch = 128
477 server.n_slots = 2
478 server.seed = 42
479 server.server_embeddings = True
480 return server
481
482 @staticmethod
483 def bert_bge_small_with_fa() -> ServerProcess:
484 server = ServerProcess()
485 server.offline = True # will be downloaded by load_all()
486 server.model_hf_repo = "ggml-org/models"
487 server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
488 server.model_alias = "bert-bge-small"
489 server.n_ctx = 1024
490 server.n_batch = 300
491 server.n_ubatch = 300
492 server.n_slots = 2
493 server.fa = "on"
494 server.seed = 42
495 server.server_embeddings = True
496 return server
497
498 @staticmethod
499 def tinyllama_infill() -> ServerProcess:
500 server = ServerProcess()
501 server.offline = True # will be downloaded by load_all()
502 server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
503 server.model_hf_file = None
504 server.model_alias = "tinyllama-infill"
505 server.n_ctx = 2048
506 server.n_batch = 1024
507 server.n_slots = 1
508 server.n_predict = 64
509 server.temperature = 0.0
510 server.seed = 42
511 return server
512
513 @staticmethod
514 def stories15m_moe() -> ServerProcess:
515 server = ServerProcess()
516 server.offline = True # will be downloaded by load_all()
517 server.model_hf_repo = "ggml-org/stories15M_MOE"
518 server.model_hf_file = "stories15M_MOE-F16.gguf"
519 server.model_alias = "stories15m-moe"
520 server.n_ctx = 2048
521 server.n_batch = 1024
522 server.n_slots = 1
523 server.n_predict = 64
524 server.temperature = 0.0
525 server.seed = 42
526 return server
527
528 @staticmethod
529 def jina_reranker_tiny() -> ServerProcess:
530 server = ServerProcess()
531 server.offline = True # will be downloaded by load_all()
532 server.model_hf_repo = "ggml-org/models"
533 server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
534 server.model_alias = "jina-reranker"
535 server.n_ctx = 512
536 server.n_batch = 512
537 server.n_slots = 1
538 server.seed = 42
539 server.server_reranking = True
540 return server
541
542 @staticmethod
543 def tinygemma3() -> ServerProcess:
544 server = ServerProcess()
545 server.offline = True # will be downloaded by load_all()
546 # mmproj is already provided by HF registry API
547 server.model_hf_file = None
548 server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0"
549 server.model_alias = "tinygemma3"
550 server.n_ctx = 1024
551 server.n_batch = 32
552 server.n_slots = 2
553 server.n_predict = 4
554 server.seed = 42
555 return server
556
557 @staticmethod
558 def router() -> ServerProcess:
559 server = ServerProcess()
560 server.offline = True # will be downloaded by load_all()
561 # router server has no models
562 server.model_file = None
563 server.model_alias = None
564 server.model_hf_repo = None
565 server.model_hf_file = None
566 server.n_ctx = 1024
567 server.n_batch = 16
568 server.n_slots = 1
569 server.n_predict = 16
570 server.seed = 42
571 return server
572
573
574def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
575 """
576 Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
577
578 Example usage:
579
580 results = parallel_function_calls([
581 (func1, (arg1, arg2)),
582 (func2, (arg3, arg4)),
583 ])
584 """
585 results = [None] * len(function_list)
586 exceptions = []
587
588 def worker(index, func, args):
589 try:
590 result = func(*args)
591 results[index] = result
592 except Exception as e:
593 exceptions.append((index, str(e)))
594
595 with ThreadPoolExecutor() as executor:
596 futures = []
597 for i, (func, args) in enumerate(function_list):
598 future = executor.submit(worker, i, func, args)
599 futures.append(future)
600
601 # Wait for all futures to complete
602 for future in as_completed(futures):
603 pass
604
605 # Check if there were any exceptions
606 if exceptions:
607 print("Exceptions occurred:")
608 for index, error in exceptions:
609 print(f"Function at index {index}: {error}")
610
611 return results
612
613
614def match_regex(regex: str, text: str) -> bool:
615 return (
616 re.compile(
617 regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
618 ).search(text)
619 is not None
620 )
621
622
623def download_file(url: str, output_file_path: str | None = None) -> str:
624 """
625 Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
626
627 output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
628
629 Returns the local path of the downloaded file.
630 """
631 file_name = url.split('/').pop()
632 output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
633 if not os.path.exists(output_file):
634 print(f"Downloading {url} to {output_file}")
635 wget.download(url, out=output_file)
636 print(f"Done downloading to {output_file}")
637 else:
638 print(f"File already exists at {output_file}")
639 return output_file
640
641
642def is_slow_test_allowed():
643 return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"