1import pytest
  2from openai import OpenAI
  3from utils import *
  4
  5server = ServerPreset.tinyllama2()
  6
  7TEST_API_KEY = "sk-this-is-the-secret-key"
  8
  9@pytest.fixture(autouse=True)
 10def create_server():
 11    global server
 12    server = ServerPreset.tinyllama2()
 13    server.api_key = TEST_API_KEY
 14
 15
 16@pytest.mark.parametrize("endpoint", ["/health", "/models"])
 17def test_access_public_endpoint(endpoint: str):
 18    global server
 19    server.start()
 20    res = server.make_request("GET", endpoint)
 21    assert res.status_code == 200
 22    assert "error" not in res.body
 23
 24
 25@pytest.mark.parametrize("api_key", [None, "invalid-key"])
 26def test_incorrect_api_key(api_key: str):
 27    global server
 28    server.start()
 29    res = server.make_request("POST", "/completions", data={
 30        "prompt": "I believe the meaning of life is",
 31    }, headers={
 32        "Authorization": f"Bearer {api_key}" if api_key else None,
 33    })
 34    assert res.status_code == 401
 35    assert "error" in res.body
 36    assert res.body["error"]["type"] == "authentication_error"
 37
 38
 39def test_correct_api_key():
 40    global server
 41    server.start()
 42    res = server.make_request("POST", "/completions", data={
 43        "prompt": "I believe the meaning of life is",
 44    }, headers={
 45        "Authorization": f"Bearer {TEST_API_KEY}",
 46    })
 47    assert res.status_code == 200
 48    assert "error" not in res.body
 49    assert "content" in res.body
 50
 51
 52def test_correct_api_key_anthropic_header():
 53    global server
 54    server.start()
 55    res = server.make_request("POST", "/completions", data={
 56        "prompt": "I believe the meaning of life is",
 57    }, headers={
 58        "X-Api-Key": TEST_API_KEY,
 59    })
 60    assert res.status_code == 200
 61    assert "error" not in res.body
 62    assert "content" in res.body
 63
 64
 65def test_openai_library_correct_api_key():
 66    global server
 67    server.start()
 68    client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
 69    res = client.chat.completions.create(
 70        model="gpt-3.5-turbo",
 71        messages=[
 72            {"role": "system", "content": "You are a chatbot."},
 73            {"role": "user", "content": "What is the meaning of life?"},
 74        ],
 75    )
 76    assert len(res.choices) == 1
 77
 78
 79@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
 80    ("localhost", "Access-Control-Allow-Origin", "localhost"),
 81    ("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
 82    ("origin", "Access-Control-Allow-Credentials", "true"),
 83    ("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
 84    ("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
 85])
 86def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
 87    global server
 88    server.start()
 89    res = server.make_request("OPTIONS", "/completions", headers={
 90        "Origin": origin,
 91        "Access-Control-Request-Method": "POST",
 92        "Access-Control-Request-Headers": "Authorization",
 93    })
 94    assert res.status_code == 200
 95    assert cors_header in res.headers
 96    assert res.headers[cors_header] == cors_header_value
 97
 98
 99@pytest.mark.parametrize(
100    "media_path, image_url, success",
101    [
102        (None,             "file://mtmd/test-1.jpeg",    False), # disabled media path, should fail
103        ("../../../tools", "file://mtmd/test-1.jpeg",    True),
104        ("../../../tools", "file:////mtmd//test-1.jpeg", True),  # should be the same file as above
105        ("../../../tools", "file://mtmd/notfound.jpeg",  False), # non-existent file
106        ("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal
107    ]
108)
109def test_local_media_file(media_path, image_url, success,):
110    server = ServerPreset.tinygemma3()
111    server.media_path = media_path
112    server.start()
113    res = server.make_request("POST", "/chat/completions", data={
114        "max_tokens": 1,
115        "messages": [
116            {"role": "user", "content": [
117                {"type": "text", "text": "test"},
118                {"type": "image_url", "image_url": {
119                    "url": image_url,
120                }},
121            ]},
122        ],
123    })
124    if success:
125        assert res.status_code == 200
126    else:
127        assert res.status_code == 400