[test]Fix testcase assertion (#15595)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2022-02-16 19:35:49 +08:00 committed by GitHub
parent de9ee84bbb
commit bebbc841e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 10 deletions

View File

@ -2071,7 +2071,7 @@ class TestCollectionSearch(TestcaseBase):
"""
def init_data(connect, collection, nb=3000, partition_names=None, auto_id=True):
def init_data(connect, collection, start=0, nb=3000, partition_names=None, auto_id=True):
"""
Generate entities and add it in collection
"""
@ -2079,7 +2079,7 @@ def init_data(connect, collection, nb=3000, partition_names=None, auto_id=True):
if nb == 3000:
insert_entities = entities
else:
insert_entities = gen_entities(nb, is_normal=True)
insert_entities = gen_entities(nb, start=start, is_normal=True)
if partition_names is None:
res = connect.insert(collection, insert_entities)
else:
@ -2336,20 +2336,21 @@ class TestSearchBase:
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
entities, ids = init_data(connect, collection, partition_names=default_tag)
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_names=new_tag)
start = max(ids) + 1
new_entities, new_ids = init_data(connect, collection, start=start, nb=6001, partition_names=new_tag)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, search_params=search_param)
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, **query)
res = connect.search(collection, partition_names=[default_tag], **query)
else:
connect.load_collection(collection)
res = connect.search(collection, **query)
res = connect.search(collection, partition_names=[default_tag], **query)
assert check_id_result(res[0], ids[0])
assert res[0]._distances[0] < epsilon
assert res[1]._distances[0] < epsilon
res = connect.search(collection, **query, partition_names=[new_tag])
res = connect.search(collection, partition_names=[new_tag], **query)
assert res[0]._distances[0] > epsilon
assert res[1]._distances[0] > epsilon
connect.release_collection(collection)
@ -2447,18 +2448,24 @@ class TestSearchBase:
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
entities, ids = init_data(connect, collection, partition_names=default_tag)
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_names=new_tag)
start = max(ids) + 1
new_entities, new_ids = init_data(connect, collection, start=start, nb=6001, partition_names=new_tag)
get_simple_index["metric_type"] = metric_type
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
# query vectors are selected from default partition
query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, metric_type="IP",
search_params=search_param)
connect.load_collection(collection)
res = connect.search(collection, **query)
# do search in default partition, so the results's id should be in default partition
res = connect.search(collection, partition_names=[default_tag], **query)
assert check_id_result(res[0], ids[0])
assert not check_id_result(res[1], new_ids[0])
# the top_1 of res[0] and res[1] are themselfs, so the distance is 1 (when metric_type is IP, the distance more closer to 1 means more similar)
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
res = connect.search(collection, **query, partition_names=["new_tag"])
# the query vector is selected from default partition, so the top 1 can't be itself when searching in new_tag partition, which means the distance less than 1
assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0])
# TODO:
# assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])

View File

@ -274,10 +274,10 @@ def gen_binary_default_fields(auto_id=True):
return default_fields
def gen_entities(nb, is_normal=False):
def gen_entities(nb, start=0, is_normal=False):
vectors = gen_vectors(nb, default_dim, is_normal)
entities = [
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(start, nb+start)]},
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
]