1import base64
2import struct
3import pytest
4from openai import OpenAI
5from utils import *
6
7server = ServerPreset.bert_bge_small()
8
9EPSILON = 1e-3
10
11@pytest.fixture(autouse=True)
12def create_server():
13 global server
14 server = ServerPreset.bert_bge_small()
15
16
17def test_embedding_single():
18 global server
19 server.pooling = 'last'
20 server.start()
21 res = server.make_request("POST", "/v1/embeddings", data={
22 "input": "I believe the meaning of life is",
23 })
24 assert res.status_code == 200
25 assert len(res.body['data']) == 1
26 assert 'embedding' in res.body['data'][0]
27 assert len(res.body['data'][0]['embedding']) > 1
28
29 # make sure embedding vector is normalized
30 assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
31
32
33def test_embedding_multiple():
34 global server
35 server.pooling = 'last'
36 server.start()
37 res = server.make_request("POST", "/v1/embeddings", data={
38 "input": [
39 "I believe the meaning of life is",
40 "Write a joke about AI from a very long prompt which will not be truncated",
41 "This is a test",
42 "This is another test",
43 ],
44 })
45 assert res.status_code == 200
46 assert len(res.body['data']) == 4
47 for d in res.body['data']:
48 assert 'embedding' in d
49 assert len(d['embedding']) > 1
50
51
52def test_embedding_multiple_with_fa():
53 server = ServerPreset.bert_bge_small_with_fa()
54 server.pooling = 'last'
55 server.start()
56 # one of these should trigger the FA branch (i.e. context size % 256 == 0)
57 res = server.make_request("POST", "/v1/embeddings", data={
58 "input": [
59 "a "*253,
60 "b "*254,
61 "c "*255,
62 "d "*256,
63 ],
64 })
65 assert res.status_code == 200
66 assert len(res.body['data']) == 4
67 for d in res.body['data']:
68 assert 'embedding' in d
69 assert len(d['embedding']) > 1
70
71
72@pytest.mark.parametrize(
73 "input,is_multi_prompt",
74 [
75 # do not crash on empty input
76 ("", False),
77 # single prompt
78 ("string", False),
79 ([12, 34, 56], False),
80 ([12, 34, "string", 56, 78], False),
81 # multiple prompts
82 (["string1", "string2"], True),
83 (["string1", [12, 34, 56]], True),
84 ([[12, 34, 56], [12, 34, 56]], True),
85 ([[12, 34, 56], [12, "string", 34, 56]], True),
86 ]
87)
88def test_embedding_mixed_input(input, is_multi_prompt: bool):
89 global server
90 server.start()
91 res = server.make_request("POST", "/v1/embeddings", data={"input": input})
92 assert res.status_code == 200
93 data = res.body['data']
94 if is_multi_prompt:
95 assert len(data) == len(input)
96 for d in data:
97 assert 'embedding' in d
98 assert len(d['embedding']) > 1
99 else:
100 assert 'embedding' in data[0]
101 assert len(data[0]['embedding']) > 1
102
103
104def test_embedding_pooling_none():
105 global server
106 server.pooling = 'none'
107 server.start()
108 res = server.make_request("POST", "/embeddings", data={
109 "input": "hello hello hello",
110 })
111 assert res.status_code == 200
112 assert 'embedding' in res.body[0]
113 assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
114
115 # make sure embedding vector is not normalized
116 for x in res.body[0]['embedding']:
117 assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
118
119
120def test_embedding_pooling_none_oai():
121 global server
122 server.pooling = 'none'
123 server.start()
124 res = server.make_request("POST", "/v1/embeddings", data={
125 "input": "hello hello hello",
126 })
127
128 # /v1/embeddings does not support pooling type 'none'
129 assert res.status_code == 400
130 assert "error" in res.body
131
132
133def test_embedding_openai_library_single():
134 global server
135 server.pooling = 'last'
136 server.start()
137 client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
138 res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
139 assert len(res.data) == 1
140 assert len(res.data[0].embedding) > 1
141
142
143def test_embedding_openai_library_multiple():
144 global server
145 server.pooling = 'last'
146 server.start()
147 client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
148 res = client.embeddings.create(model="text-embedding-3-small", input=[
149 "I believe the meaning of life is",
150 "Write a joke about AI from a very long prompt which will not be truncated",
151 "This is a test",
152 "This is another test",
153 ])
154 assert len(res.data) == 4
155 for d in res.data:
156 assert len(d.embedding) > 1
157
158
159def test_embedding_error_prompt_too_long():
160 global server
161 server.pooling = 'last'
162 server.start()
163 res = server.make_request("POST", "/v1/embeddings", data={
164 "input": "This is a test " * 512,
165 })
166 assert res.status_code != 200
167 assert "too large" in res.body["error"]["message"]
168
169
170def test_same_prompt_give_same_result():
171 server.pooling = 'last'
172 server.start()
173 res = server.make_request("POST", "/v1/embeddings", data={
174 "input": [
175 "I believe the meaning of life is",
176 "I believe the meaning of life is",
177 "I believe the meaning of life is",
178 "I believe the meaning of life is",
179 "I believe the meaning of life is",
180 ],
181 })
182 assert res.status_code == 200
183 assert len(res.body['data']) == 5
184 for i in range(1, len(res.body['data'])):
185 v0 = res.body['data'][0]['embedding']
186 vi = res.body['data'][i]['embedding']
187 for x, y in zip(v0, vi):
188 assert abs(x - y) < EPSILON
189
190
191@pytest.mark.parametrize(
192 "content,n_tokens",
193 [
194 ("I believe the meaning of life is", 9),
195 ("This is a test", 6),
196 ]
197)
198def test_embedding_usage_single(content, n_tokens):
199 global server
200 server.start()
201 res = server.make_request("POST", "/v1/embeddings", data={"input": content})
202 assert res.status_code == 200
203 assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
204 assert res.body['usage']['prompt_tokens'] == n_tokens
205
206
207def test_embedding_usage_multiple():
208 global server
209 server.start()
210 res = server.make_request("POST", "/v1/embeddings", data={
211 "input": [
212 "I believe the meaning of life is",
213 "I believe the meaning of life is",
214 ],
215 })
216 assert res.status_code == 200
217 assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
218 assert res.body['usage']['prompt_tokens'] == 2 * 9
219
220
221def test_embedding_openai_library_base64():
222 server.start()
223 test_input = "Test base64 embedding output"
224
225 # get embedding in default format
226 res = server.make_request("POST", "/v1/embeddings", data={
227 "input": test_input
228 })
229 assert res.status_code == 200
230 vec0 = res.body["data"][0]["embedding"]
231
232 # get embedding in base64 format
233 res = server.make_request("POST", "/v1/embeddings", data={
234 "input": test_input,
235 "encoding_format": "base64"
236 })
237
238 assert res.status_code == 200
239 assert "data" in res.body
240 assert len(res.body["data"]) == 1
241
242 embedding_data = res.body["data"][0]
243 assert "embedding" in embedding_data
244 assert isinstance(embedding_data["embedding"], str)
245
246 # Verify embedding is valid base64
247 decoded = base64.b64decode(embedding_data["embedding"])
248 # Verify decoded data can be converted back to float array
249 float_count = len(decoded) // 4 # 4 bytes per float
250 floats = struct.unpack(f'{float_count}f', decoded)
251 assert len(floats) > 0
252 assert all(isinstance(x, float) for x in floats)
253 assert len(floats) == len(vec0)
254
255 # make sure the decoded data is the same as the original
256 for x, y in zip(floats, vec0):
257 assert abs(x - y) < EPSILON