1import pytest
2from utils import *
3import base64
4import requests
5
6server: ServerProcess
7
8def get_img_url(id: str) -> str:
9 IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
10 IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
11 if id == "IMG_URL_0":
12 return IMG_URL_0
13 elif id == "IMG_URL_1":
14 return IMG_URL_1
15 elif id == "IMG_BASE64_URI_0":
16 response = requests.get(IMG_URL_0)
17 response.raise_for_status() # Raise an exception for bad status codes
18 return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
19 elif id == "IMG_BASE64_0":
20 response = requests.get(IMG_URL_0)
21 response.raise_for_status() # Raise an exception for bad status codes
22 return base64.b64encode(response.content).decode("utf-8")
23 elif id == "IMG_BASE64_URI_1":
24 response = requests.get(IMG_URL_1)
25 response.raise_for_status() # Raise an exception for bad status codes
26 return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
27 elif id == "IMG_BASE64_1":
28 response = requests.get(IMG_URL_1)
29 response.raise_for_status() # Raise an exception for bad status codes
30 return base64.b64encode(response.content).decode("utf-8")
31 else:
32 return id
33
34JSON_MULTIMODAL_KEY = "multimodal_data"
35JSON_PROMPT_STRING_KEY = "prompt_string"
36
37@pytest.fixture(autouse=True)
38def create_server():
39 global server
40 server = ServerPreset.tinygemma3()
41
42def test_models_supports_multimodal_capability():
43 global server
44 server.start()
45 res = server.make_request("GET", "/models", data={})
46 assert res.status_code == 200
47 model_info = res.body["models"][0]
48 print(model_info)
49 assert "completion" in model_info["capabilities"]
50 assert "multimodal" in model_info["capabilities"]
51
52def test_v1_models_supports_multimodal_capability():
53 global server
54 server.start()
55 res = server.make_request("GET", "/v1/models", data={})
56 assert res.status_code == 200
57 model_info = res.body["models"][0]
58 print(model_info)
59 assert "completion" in model_info["capabilities"]
60 assert "multimodal" in model_info["capabilities"]
61
62@pytest.mark.parametrize(
63 "prompt, image_url, success, re_content",
64 [
65 # test model is trained on CIFAR-10, but it's quite dumb due to small size
66 ("What is this:\n", "IMG_URL_0", True, "(cat)+"),
67 ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
68 ("What is this:\n", "IMG_URL_1", True, "(frog)+"),
69 ("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
70 ("What is this:\n", "malformed", False, None),
71 ("What is this:\n", "https://google.com/404", False, None), # non-existent image
72 ("What is this:\n", "https://ggml.ai", False, None), # non-image data
73 # TODO @ngxson : test with multiple images, no images and with audio
74 ]
75)
76def test_vision_chat_completion(prompt, image_url, success, re_content):
77 global server
78 server.start()
79 res = server.make_request("POST", "/chat/completions", data={
80 "temperature": 0.0,
81 "top_k": 1,
82 "messages": [
83 {"role": "user", "content": [
84 {"type": "text", "text": prompt},
85 {"type": "image_url", "image_url": {
86 "url": get_img_url(image_url),
87 }},
88 ]},
89 ],
90 })
91 if success:
92 assert res.status_code == 200
93 choice = res.body["choices"][0]
94 assert "assistant" == choice["message"]["role"]
95 assert match_regex(re_content, choice["message"]["content"])
96 else:
97 assert res.status_code != 200
98
99
100@pytest.mark.parametrize(
101 "prompt, image_data, success, re_content",
102 [
103 # test model is trained on CIFAR-10, but it's quite dumb due to small size
104 ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
105 ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
106 ("What is this: <__media__>\n", "malformed", False, None), # non-image data
107 ("What is this:\n", "", False, None), # empty string
108 ]
109)
110def test_vision_completion(prompt, image_data, success, re_content):
111 global server
112 server.start()
113 res = server.make_request("POST", "/completions", data={
114 "temperature": 0.0,
115 "top_k": 1,
116 "prompt": {
117 JSON_PROMPT_STRING_KEY: prompt,
118 JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
119 },
120 })
121 if success:
122 assert res.status_code == 200
123 content = res.body["content"]
124 assert match_regex(re_content, content)
125 else:
126 assert res.status_code != 200
127
128
129@pytest.mark.parametrize(
130 "prompt, image_data, success",
131 [
132 # test model is trained on CIFAR-10, but it's quite dumb due to small size
133 ("What is this: <__media__>\n", "IMG_BASE64_0", True),
134 ("What is this: <__media__>\n", "IMG_BASE64_1", True),
135 ("What is this: <__media__>\n", "malformed", False), # non-image data
136 ("What is this:\n", "base64", False), # non-image data
137 ]
138)
139def test_vision_embeddings(prompt, image_data, success):
140 global server
141 server.server_embeddings = True
142 server.n_batch = 512
143 server.start()
144 image_data = get_img_url(image_data)
145 res = server.make_request("POST", "/embeddings", data={
146 "content": [
147 { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
148 { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
149 { JSON_PROMPT_STRING_KEY: prompt, },
150 ],
151 })
152 if success:
153 assert res.status_code == 200
154 content = res.body
155 # Ensure embeddings are stable when multimodal.
156 assert content[0]['embedding'] == content[1]['embedding']
157 # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
158 assert content[0]['embedding'] != content[2]['embedding']
159 else:
160 assert res.status_code != 200