1import pytest
  2from utils import *
  3import base64
  4import requests
  5
  6server: ServerProcess
  7
  8def get_img_url(id: str) -> str:
  9    IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
 10    IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
 11    if id == "IMG_URL_0":
 12        return IMG_URL_0
 13    elif id == "IMG_URL_1":
 14        return IMG_URL_1
 15    elif id == "IMG_BASE64_URI_0":
 16        response = requests.get(IMG_URL_0)
 17        response.raise_for_status() # Raise an exception for bad status codes
 18        return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
 19    elif id == "IMG_BASE64_0":
 20        response = requests.get(IMG_URL_0)
 21        response.raise_for_status() # Raise an exception for bad status codes
 22        return base64.b64encode(response.content).decode("utf-8")
 23    elif id == "IMG_BASE64_URI_1":
 24        response = requests.get(IMG_URL_1)
 25        response.raise_for_status() # Raise an exception for bad status codes
 26        return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
 27    elif id == "IMG_BASE64_1":
 28        response = requests.get(IMG_URL_1)
 29        response.raise_for_status() # Raise an exception for bad status codes
 30        return base64.b64encode(response.content).decode("utf-8")
 31    else:
 32        return id
 33
 34JSON_MULTIMODAL_KEY = "multimodal_data"
 35JSON_PROMPT_STRING_KEY = "prompt_string"
 36
 37@pytest.fixture(autouse=True)
 38def create_server():
 39    global server
 40    server = ServerPreset.tinygemma3()
 41
 42def test_models_supports_multimodal_capability():
 43    global server
 44    server.start()
 45    res = server.make_request("GET", "/models", data={})
 46    assert res.status_code == 200
 47    model_info = res.body["models"][0]
 48    print(model_info)
 49    assert "completion" in model_info["capabilities"]
 50    assert "multimodal" in model_info["capabilities"]
 51
 52def test_v1_models_supports_multimodal_capability():
 53    global server
 54    server.start()
 55    res = server.make_request("GET", "/v1/models", data={})
 56    assert res.status_code == 200
 57    model_info = res.body["models"][0]
 58    print(model_info)
 59    assert "completion" in model_info["capabilities"]
 60    assert "multimodal" in model_info["capabilities"]
 61
 62@pytest.mark.parametrize(
 63    "prompt, image_url, success, re_content",
 64    [
 65        # test model is trained on CIFAR-10, but it's quite dumb due to small size
 66        ("What is this:\n", "IMG_URL_0",              True, "(cat)+"),
 67        ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"),
 68        ("What is this:\n", "IMG_URL_1",              True, "(frog)+"),
 69        ("Test test\n",     "IMG_URL_1",              True, "(frog)+"), # test invalidate cache
 70        ("What is this:\n", "malformed",              False, None),
 71        ("What is this:\n", "https://google.com/404", False, None), # non-existent image
 72        ("What is this:\n", "https://ggml.ai",        False, None), # non-image data
 73        # TODO @ngxson : test with multiple images, no images and with audio
 74    ]
 75)
 76def test_vision_chat_completion(prompt, image_url, success, re_content):
 77    global server
 78    server.start()
 79    res = server.make_request("POST", "/chat/completions", data={
 80        "temperature": 0.0,
 81        "top_k": 1,
 82        "messages": [
 83            {"role": "user", "content": [
 84                {"type": "text", "text": prompt},
 85                {"type": "image_url", "image_url": {
 86                    "url": get_img_url(image_url),
 87                }},
 88            ]},
 89        ],
 90    })
 91    if success:
 92        assert res.status_code == 200
 93        choice = res.body["choices"][0]
 94        assert "assistant" == choice["message"]["role"]
 95        assert match_regex(re_content, choice["message"]["content"])
 96    else:
 97        assert res.status_code != 200
 98
 99
100@pytest.mark.parametrize(
101    "prompt, image_data, success, re_content",
102    [
103        # test model is trained on CIFAR-10, but it's quite dumb due to small size
104        ("What is this: <__media__>\n", "IMG_BASE64_0",         True, "(cat)+"),
105        ("What is this: <__media__>\n", "IMG_BASE64_1",         True, "(frog)+"),
106        ("What is this: <__media__>\n", "malformed",            False, None), # non-image data
107        ("What is this:\n",             "",                     False, None), # empty string
108    ]
109)
110def test_vision_completion(prompt, image_data, success, re_content):
111    global server
112    server.start()
113    res = server.make_request("POST", "/completions", data={
114        "temperature": 0.0,
115        "top_k": 1,
116        "prompt": {
117            JSON_PROMPT_STRING_KEY: prompt,
118            JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
119        },
120    })
121    if success:
122        assert res.status_code == 200
123        content = res.body["content"]
124        assert match_regex(re_content, content)
125    else:
126        assert res.status_code != 200
127
128
129@pytest.mark.parametrize(
130    "prompt, image_data, success",
131    [
132        # test model is trained on CIFAR-10, but it's quite dumb due to small size
133        ("What is this: <__media__>\n", "IMG_BASE64_0",         True),
134        ("What is this: <__media__>\n", "IMG_BASE64_1",         True),
135        ("What is this: <__media__>\n", "malformed",            False), # non-image data
136        ("What is this:\n",             "base64",               False), # non-image data
137    ]
138)
139def test_vision_embeddings(prompt, image_data, success):
140    global server
141    server.server_embeddings = True
142    server.n_batch = 512
143    server.start()
144    image_data = get_img_url(image_data)
145    res = server.make_request("POST", "/embeddings", data={
146        "content": [
147            { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
148            { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
149            { JSON_PROMPT_STRING_KEY: prompt, },
150        ],
151    })
152    if success:
153        assert res.status_code == 200
154        content = res.body
155        # Ensure embeddings are stable when multimodal.
156        assert content[0]['embedding'] == content[1]['embedding']
157        # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
158        assert content[0]['embedding'] != content[2]['embedding']
159    else:
160        assert res.status_code != 200