mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
[test]Fix testcase assertion (#15595)
Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
de9ee84bbb
commit
bebbc841e1
@ -2071,7 +2071,7 @@ class TestCollectionSearch(TestcaseBase):
|
||||
"""
|
||||
|
||||
|
||||
def init_data(connect, collection, nb=3000, partition_names=None, auto_id=True):
|
||||
def init_data(connect, collection, start=0, nb=3000, partition_names=None, auto_id=True):
|
||||
"""
|
||||
Generate entities and add it in collection
|
||||
"""
|
||||
@ -2079,7 +2079,7 @@ def init_data(connect, collection, nb=3000, partition_names=None, auto_id=True):
|
||||
if nb == 3000:
|
||||
insert_entities = entities
|
||||
else:
|
||||
insert_entities = gen_entities(nb, is_normal=True)
|
||||
insert_entities = gen_entities(nb, start=start, is_normal=True)
|
||||
if partition_names is None:
|
||||
res = connect.insert(collection, insert_entities)
|
||||
else:
|
||||
@ -2336,20 +2336,21 @@ class TestSearchBase:
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
entities, ids = init_data(connect, collection, partition_names=default_tag)
|
||||
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_names=new_tag)
|
||||
start = max(ids) + 1
|
||||
new_entities, new_ids = init_data(connect, collection, start=start, nb=6001, partition_names=new_tag)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, search_params=search_param)
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, **query)
|
||||
res = connect.search(collection, partition_names=[default_tag], **query)
|
||||
else:
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, **query)
|
||||
res = connect.search(collection, partition_names=[default_tag], **query)
|
||||
assert check_id_result(res[0], ids[0])
|
||||
assert res[0]._distances[0] < epsilon
|
||||
assert res[1]._distances[0] < epsilon
|
||||
res = connect.search(collection, **query, partition_names=[new_tag])
|
||||
res = connect.search(collection, partition_names=[new_tag], **query)
|
||||
assert res[0]._distances[0] > epsilon
|
||||
assert res[1]._distances[0] > epsilon
|
||||
connect.release_collection(collection)
|
||||
@ -2447,18 +2448,24 @@ class TestSearchBase:
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
entities, ids = init_data(connect, collection, partition_names=default_tag)
|
||||
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_names=new_tag)
|
||||
start = max(ids) + 1
|
||||
new_entities, new_ids = init_data(connect, collection, start=start, nb=6001, partition_names=new_tag)
|
||||
get_simple_index["metric_type"] = metric_type
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
# query vectors are selected from default partition
|
||||
query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, metric_type="IP",
|
||||
search_params=search_param)
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, **query)
|
||||
# do search in default partition, so the results's id should be in default partition
|
||||
res = connect.search(collection, partition_names=[default_tag], **query)
|
||||
assert check_id_result(res[0], ids[0])
|
||||
assert not check_id_result(res[1], new_ids[0])
|
||||
# the top_1 of res[0] and res[1] are themselfs, so the distance is 1 (when metric_type is IP, the distance more closer to 1 means more similar)
|
||||
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
|
||||
assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
|
||||
res = connect.search(collection, **query, partition_names=["new_tag"])
|
||||
# the query vector is selected from default partition, so the top 1 can't be itself when searching in new_tag partition, which means the distance less than 1
|
||||
assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0])
|
||||
# TODO:
|
||||
# assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
|
||||
|
@ -274,10 +274,10 @@ def gen_binary_default_fields(auto_id=True):
|
||||
return default_fields
|
||||
|
||||
|
||||
def gen_entities(nb, is_normal=False):
|
||||
def gen_entities(nb, start=0, is_normal=False):
|
||||
vectors = gen_vectors(nb, default_dim, is_normal)
|
||||
entities = [
|
||||
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
|
||||
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(start, nb+start)]},
|
||||
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
|
||||
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user