diff options
Diffstat (limited to 'llama.cpp/tools/server/tests/unit/test_vision_api.py')
| -rw-r--r-- | llama.cpp/tools/server/tests/unit/test_vision_api.py | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/llama.cpp/tools/server/tests/unit/test_vision_api.py b/llama.cpp/tools/server/tests/unit/test_vision_api.py new file mode 100644 index 0000000..9408116 --- /dev/null +++ b/llama.cpp/tools/server/tests/unit/test_vision_api.py @@ -0,0 +1,160 @@ +import pytest +from utils import * +import base64 +import requests + +server: ServerProcess + +def get_img_url(id: str) -> str: + IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" + IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" + if id == "IMG_URL_0": + return IMG_URL_0 + elif id == "IMG_URL_1": + return IMG_URL_1 + elif id == "IMG_BASE64_URI_0": + response = requests.get(IMG_URL_0) + response.raise_for_status() # Raise an exception for bad status codes + return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_0": + response = requests.get(IMG_URL_0) + response.raise_for_status() # Raise an exception for bad status codes + return base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_URI_1": + response = requests.get(IMG_URL_1) + response.raise_for_status() # Raise an exception for bad status codes + return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_1": + response = requests.get(IMG_URL_1) + response.raise_for_status() # Raise an exception for bad status codes + return base64.b64encode(response.content).decode("utf-8") + else: + return id + +JSON_MULTIMODAL_KEY = "multimodal_data" +JSON_PROMPT_STRING_KEY = "prompt_string" + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinygemma3() + +def test_models_supports_multimodal_capability(): + global server + server.start() + res = server.make_request("GET", "/models", data={}) + assert res.status_code == 200 + model_info = res.body["models"][0] + print(model_info) + assert "completion" in model_info["capabilities"] + assert "multimodal" in model_info["capabilities"] + +def test_v1_models_supports_multimodal_capability(): + global server + server.start() + res = server.make_request("GET", "/v1/models", data={}) + assert res.status_code == 200 + model_info = res.body["models"][0] + print(model_info) + assert "completion" in model_info["capabilities"] + assert "multimodal" in model_info["capabilities"] + +@pytest.mark.parametrize( + "prompt, image_url, success, re_content", + [ + # test model is trained on CIFAR-10, but it's quite dumb due to small size + ("What is this:\n", "IMG_URL_0", True, "(cat)+"), + ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), + ("What is this:\n", "IMG_URL_1", True, "(frog)+"), + ("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache + ("What is this:\n", "malformed", False, None), + ("What is this:\n", "https://google.com/404", False, None), # non-existent image + ("What is this:\n", "https://ggml.ai", False, None), # non-image data + # TODO @ngxson : test with multiple images, no images and with audio + ] +) +def test_vision_chat_completion(prompt, image_url, success, re_content): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "temperature": 0.0, + "top_k": 1, + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": { + "url": get_img_url(image_url), + }}, + ]}, + ], + }) + if success: + assert res.status_code == 200 + choice = res.body["choices"][0] + assert "assistant" == choice["message"]["role"] + assert match_regex(re_content, choice["message"]["content"]) + else: + assert res.status_code != 200 + + +@pytest.mark.parametrize( + "prompt, image_data, success, re_content", + [ + # test model is trained on CIFAR-10, but it's quite dumb due to small size + ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"), + ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"), + ("What is this: <__media__>\n", "malformed", False, None), # non-image data + ("What is this:\n", "", False, None), # empty string + ] +) +def test_vision_completion(prompt, image_data, success, re_content): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "temperature": 0.0, + "top_k": 1, + "prompt": { + JSON_PROMPT_STRING_KEY: prompt, + JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ], + }, + }) + if success: + assert res.status_code == 200 + content = res.body["content"] + assert match_regex(re_content, content) + else: + assert res.status_code != 200 + + +@pytest.mark.parametrize( + "prompt, image_data, success", + [ + # test model is trained on CIFAR-10, but it's quite dumb due to small size + ("What is this: <__media__>\n", "IMG_BASE64_0", True), + ("What is this: <__media__>\n", "IMG_BASE64_1", True), + ("What is this: <__media__>\n", "malformed", False), # non-image data + ("What is this:\n", "base64", False), # non-image data + ] +) +def test_vision_embeddings(prompt, image_data, success): + global server + server.server_embeddings = True + server.n_batch = 512 + server.start() + image_data = get_img_url(image_data) + res = server.make_request("POST", "/embeddings", data={ + "content": [ + { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, + { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, + { JSON_PROMPT_STRING_KEY: prompt, }, + ], + }) + if success: + assert res.status_code == 200 + content = res.body + # Ensure embeddings are stable when multimodal. + assert content[0]['embedding'] == content[1]['embedding'] + # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings. + assert content[0]['embedding'] != content[2]['embedding'] + else: + assert res.status_code != 200 |
