1import base64
  2import struct
  3import pytest
  4from openai import OpenAI
  5from utils import *
  6
  7server = ServerPreset.bert_bge_small()
  8
  9EPSILON = 1e-3
 10
 11@pytest.fixture(autouse=True)
 12def create_server():
 13    global server
 14    server = ServerPreset.bert_bge_small()
 15
 16
 17def test_embedding_single():
 18    global server
 19    server.pooling = 'last'
 20    server.start()
 21    res = server.make_request("POST", "/v1/embeddings", data={
 22        "input": "I believe the meaning of life is",
 23    })
 24    assert res.status_code == 200
 25    assert len(res.body['data']) == 1
 26    assert 'embedding' in res.body['data'][0]
 27    assert len(res.body['data'][0]['embedding']) > 1
 28
 29    # make sure embedding vector is normalized
 30    assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
 31
 32
 33def test_embedding_multiple():
 34    global server
 35    server.pooling = 'last'
 36    server.start()
 37    res = server.make_request("POST", "/v1/embeddings", data={
 38        "input": [
 39            "I believe the meaning of life is",
 40            "Write a joke about AI from a very long prompt which will not be truncated",
 41            "This is a test",
 42            "This is another test",
 43        ],
 44    })
 45    assert res.status_code == 200
 46    assert len(res.body['data']) == 4
 47    for d in res.body['data']:
 48        assert 'embedding' in d
 49        assert len(d['embedding']) > 1
 50
 51
 52def test_embedding_multiple_with_fa():
 53    server = ServerPreset.bert_bge_small_with_fa()
 54    server.pooling = 'last'
 55    server.start()
 56    # one of these should trigger the FA branch (i.e. context size % 256 == 0)
 57    res = server.make_request("POST", "/v1/embeddings", data={
 58        "input": [
 59            "a "*253,
 60            "b "*254,
 61            "c "*255,
 62            "d "*256,
 63        ],
 64    })
 65    assert res.status_code == 200
 66    assert len(res.body['data']) == 4
 67    for d in res.body['data']:
 68        assert 'embedding' in d
 69        assert len(d['embedding']) > 1
 70
 71
 72@pytest.mark.parametrize(
 73    "input,is_multi_prompt",
 74    [
 75        # do not crash on empty input
 76        ("", False),
 77        # single prompt
 78        ("string", False),
 79        ([12, 34, 56], False),
 80        ([12, 34, "string", 56, 78], False),
 81        # multiple prompts
 82        (["string1", "string2"], True),
 83        (["string1", [12, 34, 56]], True),
 84        ([[12, 34, 56], [12, 34, 56]], True),
 85        ([[12, 34, 56], [12, "string", 34, 56]], True),
 86    ]
 87)
 88def test_embedding_mixed_input(input, is_multi_prompt: bool):
 89    global server
 90    server.start()
 91    res = server.make_request("POST", "/v1/embeddings", data={"input": input})
 92    assert res.status_code == 200
 93    data = res.body['data']
 94    if is_multi_prompt:
 95        assert len(data) == len(input)
 96        for d in data:
 97            assert 'embedding' in d
 98            assert len(d['embedding']) > 1
 99    else:
100        assert 'embedding' in data[0]
101        assert len(data[0]['embedding']) > 1
102
103
104def test_embedding_pooling_none():
105    global server
106    server.pooling = 'none'
107    server.start()
108    res = server.make_request("POST", "/embeddings", data={
109        "input": "hello hello hello",
110    })
111    assert res.status_code == 200
112    assert 'embedding' in res.body[0]
113    assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
114
115    # make sure embedding vector is not normalized
116    for x in res.body[0]['embedding']:
117        assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
118
119
120def test_embedding_pooling_none_oai():
121    global server
122    server.pooling = 'none'
123    server.start()
124    res = server.make_request("POST", "/v1/embeddings", data={
125        "input": "hello hello hello",
126    })
127
128    # /v1/embeddings does not support pooling type 'none'
129    assert res.status_code == 400
130    assert "error" in res.body
131
132
133def test_embedding_openai_library_single():
134    global server
135    server.pooling = 'last'
136    server.start()
137    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
138    res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
139    assert len(res.data) == 1
140    assert len(res.data[0].embedding) > 1
141
142
143def test_embedding_openai_library_multiple():
144    global server
145    server.pooling = 'last'
146    server.start()
147    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
148    res = client.embeddings.create(model="text-embedding-3-small", input=[
149        "I believe the meaning of life is",
150        "Write a joke about AI from a very long prompt which will not be truncated",
151        "This is a test",
152        "This is another test",
153    ])
154    assert len(res.data) == 4
155    for d in res.data:
156        assert len(d.embedding) > 1
157
158
159def test_embedding_error_prompt_too_long():
160    global server
161    server.pooling = 'last'
162    server.start()
163    res = server.make_request("POST", "/v1/embeddings", data={
164        "input": "This is a test " * 512,
165    })
166    assert res.status_code != 200
167    assert "too large" in res.body["error"]["message"]
168
169
170def test_same_prompt_give_same_result():
171    server.pooling = 'last'
172    server.start()
173    res = server.make_request("POST", "/v1/embeddings", data={
174        "input": [
175            "I believe the meaning of life is",
176            "I believe the meaning of life is",
177            "I believe the meaning of life is",
178            "I believe the meaning of life is",
179            "I believe the meaning of life is",
180        ],
181    })
182    assert res.status_code == 200
183    assert len(res.body['data']) == 5
184    for i in range(1, len(res.body['data'])):
185        v0 = res.body['data'][0]['embedding']
186        vi = res.body['data'][i]['embedding']
187        for x, y in zip(v0, vi):
188            assert abs(x - y) < EPSILON
189
190
191@pytest.mark.parametrize(
192    "content,n_tokens",
193    [
194        ("I believe the meaning of life is", 9),
195        ("This is a test", 6),
196    ]
197)
198def test_embedding_usage_single(content, n_tokens):
199    global server
200    server.start()
201    res = server.make_request("POST", "/v1/embeddings", data={"input": content})
202    assert res.status_code == 200
203    assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
204    assert res.body['usage']['prompt_tokens'] == n_tokens
205
206
207def test_embedding_usage_multiple():
208    global server
209    server.start()
210    res = server.make_request("POST", "/v1/embeddings", data={
211        "input": [
212            "I believe the meaning of life is",
213            "I believe the meaning of life is",
214        ],
215    })
216    assert res.status_code == 200
217    assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
218    assert res.body['usage']['prompt_tokens'] == 2 * 9
219
220
221def test_embedding_openai_library_base64():
222    server.start()
223    test_input = "Test base64 embedding output"
224
225    # get embedding in default format
226    res = server.make_request("POST", "/v1/embeddings", data={
227        "input": test_input
228    })
229    assert res.status_code == 200
230    vec0 = res.body["data"][0]["embedding"]
231
232    # get embedding in base64 format
233    res = server.make_request("POST", "/v1/embeddings", data={
234        "input": test_input,
235        "encoding_format": "base64"
236    })
237
238    assert res.status_code == 200
239    assert "data" in res.body
240    assert len(res.body["data"]) == 1
241
242    embedding_data = res.body["data"][0]
243    assert "embedding" in embedding_data
244    assert isinstance(embedding_data["embedding"], str)
245
246    # Verify embedding is valid base64
247    decoded = base64.b64decode(embedding_data["embedding"])
248    # Verify decoded data can be converted back to float array
249    float_count = len(decoded) // 4  # 4 bytes per float
250    floats = struct.unpack(f'{float_count}f', decoded)
251    assert len(floats) > 0
252    assert all(isinstance(x, float) for x in floats)
253    assert len(floats) == len(vec0)
254
255    # make sure the decoded data is the same as the original
256    for x, y in zip(floats, vec0):
257        assert abs(x - y) < EPSILON