summaryrefslogtreecommitdiff
path: root/llama.cpp/tools/server/tests/unit/test_vision_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/tools/server/tests/unit/test_vision_api.py')
-rw-r--r--llama.cpp/tools/server/tests/unit/test_vision_api.py160
1 files changed, 160 insertions, 0 deletions
diff --git a/llama.cpp/tools/server/tests/unit/test_vision_api.py b/llama.cpp/tools/server/tests/unit/test_vision_api.py
new file mode 100644
index 0000000..9408116
--- /dev/null
+++ b/llama.cpp/tools/server/tests/unit/test_vision_api.py
@@ -0,0 +1,160 @@
+import pytest
+from utils import *
+import base64
+import requests
+
+server: ServerProcess
+
+def get_img_url(id: str) -> str:
+ IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
+ IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
+ if id == "IMG_URL_0":
+ return IMG_URL_0
+ elif id == "IMG_URL_1":
+ return IMG_URL_1
+ elif id == "IMG_BASE64_URI_0":
+ response = requests.get(IMG_URL_0)
+ response.raise_for_status() # Raise an exception for bad status codes
+ return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
+ elif id == "IMG_BASE64_0":
+ response = requests.get(IMG_URL_0)
+ response.raise_for_status() # Raise an exception for bad status codes
+ return base64.b64encode(response.content).decode("utf-8")
+ elif id == "IMG_BASE64_URI_1":
+ response = requests.get(IMG_URL_1)
+ response.raise_for_status() # Raise an exception for bad status codes
+ return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
+ elif id == "IMG_BASE64_1":
+ response = requests.get(IMG_URL_1)
+ response.raise_for_status() # Raise an exception for bad status codes
+ return base64.b64encode(response.content).decode("utf-8")
+ else:
+ return id
+
+JSON_MULTIMODAL_KEY = "multimodal_data"
+JSON_PROMPT_STRING_KEY = "prompt_string"
+
+@pytest.fixture(autouse=True)
+def create_server():
+ global server
+ server = ServerPreset.tinygemma3()
+
+def test_models_supports_multimodal_capability():
+ global server
+ server.start()
+ res = server.make_request("GET", "/models", data={})
+ assert res.status_code == 200
+ model_info = res.body["models"][0]
+ print(model_info)
+ assert "completion" in model_info["capabilities"]
+ assert "multimodal" in model_info["capabilities"]
+
+def test_v1_models_supports_multimodal_capability():
+ global server
+ server.start()
+ res = server.make_request("GET", "/v1/models", data={})
+ assert res.status_code == 200
+ model_info = res.body["models"][0]
+ print(model_info)
+ assert "completion" in model_info["capabilities"]
+ assert "multimodal" in model_info["capabilities"]
+
+@pytest.mark.parametrize(
+ "prompt, image_url, success, re_content",
+ [
+ # test model is trained on CIFAR-10, but it's quite dumb due to small size
+ ("What is this:\n", "IMG_URL_0", True, "(cat)+"),
+ ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
+ ("What is this:\n", "IMG_URL_1", True, "(frog)+"),
+ ("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
+ ("What is this:\n", "malformed", False, None),
+ ("What is this:\n", "https://google.com/404", False, None), # non-existent image
+ ("What is this:\n", "https://ggml.ai", False, None), # non-image data
+ # TODO @ngxson : test with multiple images, no images and with audio
+ ]
+)
+def test_vision_chat_completion(prompt, image_url, success, re_content):
+ global server
+ server.start()
+ res = server.make_request("POST", "/chat/completions", data={
+ "temperature": 0.0,
+ "top_k": 1,
+ "messages": [
+ {"role": "user", "content": [
+ {"type": "text", "text": prompt},
+ {"type": "image_url", "image_url": {
+ "url": get_img_url(image_url),
+ }},
+ ]},
+ ],
+ })
+ if success:
+ assert res.status_code == 200
+ choice = res.body["choices"][0]
+ assert "assistant" == choice["message"]["role"]
+ assert match_regex(re_content, choice["message"]["content"])
+ else:
+ assert res.status_code != 200
+
+
+@pytest.mark.parametrize(
+ "prompt, image_data, success, re_content",
+ [
+ # test model is trained on CIFAR-10, but it's quite dumb due to small size
+ ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
+ ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
+ ("What is this: <__media__>\n", "malformed", False, None), # non-image data
+ ("What is this:\n", "", False, None), # empty string
+ ]
+)
+def test_vision_completion(prompt, image_data, success, re_content):
+ global server
+ server.start()
+ res = server.make_request("POST", "/completions", data={
+ "temperature": 0.0,
+ "top_k": 1,
+ "prompt": {
+ JSON_PROMPT_STRING_KEY: prompt,
+ JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
+ },
+ })
+ if success:
+ assert res.status_code == 200
+ content = res.body["content"]
+ assert match_regex(re_content, content)
+ else:
+ assert res.status_code != 200
+
+
+@pytest.mark.parametrize(
+ "prompt, image_data, success",
+ [
+ # test model is trained on CIFAR-10, but it's quite dumb due to small size
+ ("What is this: <__media__>\n", "IMG_BASE64_0", True),
+ ("What is this: <__media__>\n", "IMG_BASE64_1", True),
+ ("What is this: <__media__>\n", "malformed", False), # non-image data
+ ("What is this:\n", "base64", False), # non-image data
+ ]
+)
+def test_vision_embeddings(prompt, image_data, success):
+ global server
+ server.server_embeddings = True
+ server.n_batch = 512
+ server.start()
+ image_data = get_img_url(image_data)
+ res = server.make_request("POST", "/embeddings", data={
+ "content": [
+ { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
+ { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
+ { JSON_PROMPT_STRING_KEY: prompt, },
+ ],
+ })
+ if success:
+ assert res.status_code == 200
+ content = res.body
+ # Ensure embeddings are stable when multimodal.
+ assert content[0]['embedding'] == content[1]['embedding']
+ # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
+ assert content[0]['embedding'] != content[2]['embedding']
+ else:
+ assert res.status_code != 200