1import pytest
  2from utils import *
  3
  4server = ServerPreset.jina_reranker_tiny()
  5
  6
  7@pytest.fixture(autouse=True)
  8def create_server():
  9    global server
 10    server = ServerPreset.jina_reranker_tiny()
 11
 12
 13TEST_DOCUMENTS = [
 14    "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
 15    "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
 16    "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
 17    "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
 18]
 19
 20
 21def test_rerank():
 22    global server
 23    server.start()
 24    res = server.make_request("POST", "/rerank", data={
 25        "query": "Machine learning is",
 26        "documents": TEST_DOCUMENTS,
 27    })
 28    assert res.status_code == 200
 29    assert len(res.body["results"]) == 4
 30
 31    most_relevant = res.body["results"][0]
 32    least_relevant = res.body["results"][0]
 33    for doc in res.body["results"]:
 34        if doc["relevance_score"] > most_relevant["relevance_score"]:
 35            most_relevant = doc
 36        if doc["relevance_score"] < least_relevant["relevance_score"]:
 37            least_relevant = doc
 38
 39    assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
 40    assert most_relevant["index"] == 2
 41    assert least_relevant["index"] == 3
 42
 43
 44def test_rerank_tei_format():
 45    global server
 46    server.start()
 47    res = server.make_request("POST", "/rerank", data={
 48        "query": "Machine learning is",
 49        "texts": TEST_DOCUMENTS,
 50    })
 51    assert res.status_code == 200
 52    assert len(res.body) == 4
 53
 54    most_relevant = res.body[0]
 55    least_relevant = res.body[0]
 56    for doc in res.body:
 57        if doc["score"] > most_relevant["score"]:
 58            most_relevant = doc
 59        if doc["score"] < least_relevant["score"]:
 60            least_relevant = doc
 61
 62    assert most_relevant["score"] > least_relevant["score"]
 63    assert most_relevant["index"] == 2
 64    assert least_relevant["index"] == 3
 65
 66
 67@pytest.mark.parametrize("documents", [
 68    [],
 69    None,
 70    123,
 71    [1, 2, 3],
 72])
 73def test_invalid_rerank_req(documents):
 74    global server
 75    server.start()
 76    res = server.make_request("POST", "/rerank", data={
 77        "query": "Machine learning is",
 78        "documents": documents,
 79    })
 80    assert res.status_code == 400
 81    assert "error" in res.body
 82
 83
 84@pytest.mark.parametrize(
 85    "query,doc1,doc2,n_tokens",
 86    [
 87        ("Machine learning is", "A machine", "Learning is", 19),
 88        ("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
 89    ]
 90)
 91def test_rerank_usage(query, doc1, doc2, n_tokens):
 92    global server
 93    server.start()
 94
 95    res = server.make_request("POST", "/rerank", data={
 96        "query": query,
 97        "documents": [
 98            doc1,
 99            doc2,
100        ]
101    })
102    assert res.status_code == 200
103    assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
104    assert res.body['usage']['prompt_tokens'] == n_tokens
105
106
107@pytest.mark.parametrize("top_n,expected_len", [
108    (None, len(TEST_DOCUMENTS)),  # no top_n parameter
109    (2, 2),
110    (4, 4),
111    (99, len(TEST_DOCUMENTS)),    # higher than available docs
112])
113def test_rerank_top_n(top_n, expected_len):
114    global server
115    server.start()
116    data = {
117        "query": "Machine learning is",
118        "documents": TEST_DOCUMENTS,
119    }
120    if top_n is not None:
121        data["top_n"] = top_n
122
123    res = server.make_request("POST", "/rerank", data=data)
124    assert res.status_code == 200
125    assert len(res.body["results"]) == expected_len
126
127
128@pytest.mark.parametrize("top_n,expected_len", [
129    (None, len(TEST_DOCUMENTS)),  # no top_n parameter
130    (2, 2),
131    (4, 4),
132    (99, len(TEST_DOCUMENTS)),    # higher than available docs
133])
134def test_rerank_tei_top_n(top_n, expected_len):
135    global server
136    server.start()
137    data = {
138        "query": "Machine learning is",
139        "texts": TEST_DOCUMENTS,
140    }
141    if top_n is not None:
142        data["top_n"] = top_n
143
144    res = server.make_request("POST", "/rerank", data=data)
145    assert res.status_code == 200
146    assert len(res.body) == expected_len