Fix test cases for search pagination (#21179)

Signed-off-by: nico <cheng.yuan@zilliz.com>

Signed-off-by: nico <cheng.yuan@zilliz.com>
This commit is contained in:
NicoYuan1986 2022-12-14 14:57:27 +08:00 committed by GitHub
parent 0e34decf23
commit 98d6d0feba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3758,14 +3758,10 @@ class TestsearchString(TestcaseBase):
class TestsearchPagination(TestcaseBase):
""" Test case of search pagination """
@pytest.fixture(scope="function", params=[0, 10])
@pytest.fixture(scope="function", params=[0, 10, 100])
def offset(self, request):
yield request.param
@pytest.fixture(scope="function", params=[32, 128])
def dim(self, request):
yield request.param
@pytest.fixture(scope="function", params=[False, True])
def auto_id(self, request):
yield request.param
@ -3782,7 +3778,7 @@ class TestsearchPagination(TestcaseBase):
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("limit", [10, 20])
def test_search_with_pagination(self, offset, auto_id, dim, limit, _async):
def test_search_with_pagination(self, offset, auto_id, limit, _async):
"""
target: test search with pagination
method: 1. connect and create a collection
@ -3792,10 +3788,10 @@ class TestsearchPagination(TestcaseBase):
expected: search successfully and ids is correct
"""
# 1. create a collection
collection_w = self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim)[0]
collection_w = self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim)[0]
# 2. search pagination with offset
search_param = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": offset}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
search_res = collection_w.search(vectors[:default_nq], default_search_field,
search_param, limit,
default_search_exp, _async=_async,
@ -3811,9 +3807,8 @@ class TestsearchPagination(TestcaseBase):
search_res = search_res.result()
res.done()
res = res.result()
assert res[0].distances == sorted(res[0].distances)
assert search_res[0].distances == sorted(search_res[0].distances)
assert search_res[0].distances == res[0].distances[offset:]
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L1)
@ -3851,9 +3846,8 @@ class TestsearchPagination(TestcaseBase):
search_res = search_res.result()
res.done()
res = res.result()
assert res[0].distances == sorted(res[0].distances)
assert search_res[0].distances == sorted(search_res[0].distances)
assert search_res[0].distances == res[0].distances[offset:]
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L1)
@ -3890,7 +3884,7 @@ class TestsearchPagination(TestcaseBase):
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("limit", [100, 3000, 10000])
def test_search_with_pagination_topK(self, auto_id, dim, limit, _async):
def test_search_with_pagination_topK(self, auto_id, limit, _async):
"""
target: test search with pagination limit + offset = topK
method: 1. connect and create a collection
@ -3902,10 +3896,10 @@ class TestsearchPagination(TestcaseBase):
# 1. create a collection
topK = 16384
offset = topK - limit
collection_w = self.init_collection_general(prefix, True, nb=20000, auto_id=auto_id, dim=dim)[0]
collection_w = self.init_collection_general(prefix, True, nb=20000, auto_id=auto_id, dim=default_dim)[0]
# 2. search
search_param = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": offset}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
search_res = collection_w.search(vectors[:default_nq], default_search_field,
search_param, limit,
default_search_exp, _async=_async,
@ -3921,9 +3915,8 @@ class TestsearchPagination(TestcaseBase):
search_res = search_res.result()
res.done()
res = res.result()
assert res[0].distances == sorted(res[0].distances)
assert search_res[0].distances == sorted(search_res[0].distances)
assert search_res[0].distances == res[0].distances[offset:]
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L2)
@ -3977,9 +3970,8 @@ class TestsearchPagination(TestcaseBase):
for hits in search_res:
ids = hits.ids
assert set(ids).issubset(filter_ids_set)
assert res[0].distances == sorted(res[0].distances)
assert search_res[0].distances == sorted(search_res[0].distances)
assert search_res[0].distances == res[0].distances[offset:]
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L2)
@ -4020,9 +4012,8 @@ class TestsearchPagination(TestcaseBase):
search_res = search_res.result()
res.done()
res = res.result()
assert res[0].distances == sorted(res[0].distances)
assert search_res[0].distances == sorted(search_res[0].distances)
assert search_res[0].distances == res[0].distances[offset:]
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L2)
@ -4066,7 +4057,7 @@ class TestsearchPagination(TestcaseBase):
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L2)
def test_search_pagination_with_inserted_data(self, offset, dim, _async):
def test_search_pagination_with_inserted_data(self, offset, _async):
"""
target: test search pagination with inserted data
method: create connection, collection, insert data and search
@ -4074,14 +4065,14 @@ class TestsearchPagination(TestcaseBase):
expected: searched successfully
"""
# 1. create collection
collection_w = self.init_collection_general(prefix, False, dim=dim)[0]
collection_w = self.init_collection_general(prefix, False, dim=default_dim)[0]
# 2. insert data
data = cf.gen_default_dataframe_data(dim=dim)
data = cf.gen_default_dataframe_data(dim=default_dim)
collection_w.insert(data)
collection_w.load()
# 3. search
search_param = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": offset}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
search_res = collection_w.search(vectors[:default_nq], default_search_field,
search_param, default_limit,
default_search_exp, _async=_async,
@ -4097,9 +4088,8 @@ class TestsearchPagination(TestcaseBase):
search_res = search_res.result()
res.done()
res = res.result()
assert res[0].distances == sorted(res[0].distances)
assert search_res[0].distances == sorted(search_res[0].distances)
assert search_res[0].distances == res[0].distances[offset:]
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L2)
@ -4172,9 +4162,8 @@ class TestsearchPagination(TestcaseBase):
search_res = search_res.result()
res.done()
res = res.result()
assert res[0].distances == sorted(res[0].distances)
assert search_res[0].distances == sorted(search_res[0].distances)
assert search_res[0].distances == res[0].distances[offset:]
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert set(search_res[0].ids) == set(res[0].ids[offset:])