mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-03 12:29:36 +08:00
related issue: #30607 and update some test for groupby Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
parent
0c7474d7e8
commit
d930666b3e
76
tests/python_client/testcases/test_issues.py
Normal file
76
tests/python_client/testcases/test_issues.py
Normal file
@ -0,0 +1,76 @@
|
||||
from utils.util_pymilvus import *
|
||||
from common.common_type import CaseLabel, CheckTasks
|
||||
from common import common_type as ct
|
||||
from common import common_func as cf
|
||||
from utils.util_log import test_log as log
|
||||
from base.client_base import TestcaseBase
|
||||
import random
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIssues(TestcaseBase):
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("par_key_field", [ct.default_int64_field_name])
|
||||
@pytest.mark.parametrize("index_on_par_key_field", [True])
|
||||
@pytest.mark.parametrize("use_upsert", [True, False])
|
||||
def test_issue_30607(self, par_key_field, index_on_par_key_field, use_upsert):
|
||||
"""
|
||||
Method:
|
||||
1. create a collection with partition key on collection schema with customized num_partitions
|
||||
2. randomly check 200 entities
|
||||
2. verify partition key values are hashed into correct partitions
|
||||
"""
|
||||
self._connect()
|
||||
pk_field = cf.gen_string_field(name='pk', is_primary=True)
|
||||
int64_field = cf.gen_int64_field()
|
||||
string_field = cf.gen_string_field()
|
||||
vector_field = cf.gen_float_vec_field()
|
||||
schema = cf.gen_collection_schema(fields=[pk_field, int64_field, string_field, vector_field],
|
||||
auto_id=False, partition_key_field=par_key_field)
|
||||
c_name = cf.gen_unique_str("par_key")
|
||||
collection_w, _ = self.collection_wrap.init_collection(name=c_name, schema=schema, num_partitions=9)
|
||||
|
||||
# insert
|
||||
nb = 500
|
||||
string_prefix = cf.gen_str_by_length(length=6)
|
||||
entities_per_parkey = 20
|
||||
for n in range(entities_per_parkey):
|
||||
pk_values = [str(i) for i in range(n * nb, (n+1)*nb)]
|
||||
int64_values = [i for i in range(0, nb)]
|
||||
string_values = [string_prefix + str(i) for i in range(0, nb)]
|
||||
float_vec_values = gen_vectors(nb, ct.default_dim)
|
||||
data = [pk_values, int64_values, string_values, float_vec_values]
|
||||
if use_upsert:
|
||||
collection_w.upsert(data)
|
||||
else:
|
||||
collection_w.insert(data)
|
||||
|
||||
# flush
|
||||
collection_w.flush()
|
||||
num_entities = collection_w.num_entities
|
||||
# build index
|
||||
collection_w.create_index(field_name=vector_field.name, index_params=ct.default_index)
|
||||
if index_on_par_key_field:
|
||||
collection_w.create_index(field_name=par_key_field, index_params={})
|
||||
# load
|
||||
collection_w.load()
|
||||
|
||||
# verify the partition key values are bashed correctly
|
||||
seeds = 200
|
||||
rand_ids = random.sample(range(0, num_entities), seeds)
|
||||
rand_ids = [str(rand_ids[i]) for i in range(len(rand_ids))]
|
||||
res = collection_w.query(expr=f"pk in {rand_ids}", output_fields=["pk", par_key_field])
|
||||
# verify every the random id exists
|
||||
assert len(res) == len(rand_ids)
|
||||
|
||||
dirty_count = 0
|
||||
for i in range(len(res)):
|
||||
pk = res[i].get("pk")
|
||||
parkey_value = res[i].get(par_key_field)
|
||||
res_parkey = collection_w.query(expr=f"{par_key_field}=={parkey_value} and pk=='{pk}'",
|
||||
output_fields=["pk", par_key_field])
|
||||
if len(res_parkey) != 1:
|
||||
log.info(f"dirty data found: pk {pk} with parkey {parkey_value}")
|
||||
dirty_count += 1
|
||||
assert dirty_count == 0
|
||||
log.info(f"check randomly {seeds}/{num_entities}, dirty count={dirty_count}")
|
@ -9798,9 +9798,9 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
""" Test case of search group by """
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("metric", ct.float_metrics)
|
||||
@pytest.mark.parametrize("index_type, metric", zip(["FLAT", "IVF_FLAT", "HNSW"], ct.float_metrics))
|
||||
@pytest.mark.skip(reason="issue #29883")
|
||||
def test_search_group_by_default(self, metric):
|
||||
def test_search_group_by_default(self, index_type, metric):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with data
|
||||
@ -9812,28 +9812,23 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
"""
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
|
||||
# create index and load
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
for vector_name in vector_name_list:
|
||||
collection_w.create_index(vector_name, _index)
|
||||
collection_w.load()
|
||||
|
||||
_index_params = {"index_type": index_type, "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
if index_type in ["IVF_FLAT", "FLAT"]:
|
||||
_index_params = {"index_type": index_type, "metric_type": metric, "params": {"nlist": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params)
|
||||
# insert with the same values for scalar fields
|
||||
for _ in range(30):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
for _ in range(200):
|
||||
data = cf.gen_dataframe_all_data_type(nb=200, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params)
|
||||
time.sleep(30)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
nq = 2
|
||||
limit = 10
|
||||
limit = 50
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
|
||||
# verify the results are same if gourp by pk
|
||||
@ -9842,8 +9837,13 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
group_by_field=ct.default_int64_field_name)[0]
|
||||
res2 = collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit, consistency_level=CONSISTENCY_STRONG)[0]
|
||||
hits_num = 0
|
||||
for i in range(nq):
|
||||
assert res1[i].ids == res2[i].ids
|
||||
# assert res1[i].ids == res2[i].ids
|
||||
hits_num += len(set(res1[i].ids).intersection(set(res2[i].ids)))
|
||||
hit_rate = hits_num / (nq * limit)
|
||||
log.info(f"groupy primary key hits_num: {hits_num}, nq: {nq}, limit: {limit}, hit_rate: {hit_rate}")
|
||||
assert hit_rate > 0.80
|
||||
|
||||
# verify that every record in groupby results is the top1 for that value of the group_by_field
|
||||
supported_grpby_fields = [ct.default_int8_field_name, ct.default_int16_field_name,
|
||||
@ -9856,6 +9856,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
output_fields=[grpby_field])[0]
|
||||
for i in range(nq):
|
||||
grpby_values = []
|
||||
dismatch = 0
|
||||
results_num = 2 if grpby_field == ct.default_bool_field_name else limit
|
||||
for l in range(results_num):
|
||||
top1 = res1[i][l]
|
||||
@ -9870,7 +9871,12 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
expr=expr,
|
||||
output_fields=[grpby_field])[0]
|
||||
top1_expr_pk = res_tmp[0][0].id
|
||||
assert top1_grpby_pk == top1_expr_pk
|
||||
if top1_grpby_pk != top1_expr_pk:
|
||||
dismatch += 1
|
||||
log.info(f"{grpby_field} on {metric} dismatch_item, top1_grpby_dis: {top1.distance}, top1_expr_dis: {res_tmp[0][0].distance}")
|
||||
log.info(f"{grpby_field} on {metric} top1_dismatch_num: {dismatch}, results_num: {results_num}, dismatch_rate: {dismatch / results_num}")
|
||||
baseline = 1 if grpby_field == ct.default_bool_field_name else 0.2 # skip baseline check for boolean
|
||||
assert dismatch / results_num <= baseline
|
||||
# verify no dup values of the group_by_field in results
|
||||
assert len(grpby_values) == len(set(grpby_values))
|
||||
|
||||
@ -9958,10 +9964,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
for vector_name in vector_name_list:
|
||||
collection_w.create_index(vector_name, _index)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
# insert with the same values(by insert rounds) for scalar fields
|
||||
for _ in range(100):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
@ -10020,10 +10023,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
|
||||
is_all_data_type=True, with_json=True,)[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
for vector_name in vector_name_list:
|
||||
collection_w.create_index(vector_name, _index)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
@ -10044,6 +10044,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
@pytest.mark.parametrize("index, params",
|
||||
zip(ct.all_index_types[:7],
|
||||
ct.default_index_params[:7]))
|
||||
@pytest.mark.skip(reason="issue #29968")
|
||||
def test_search_group_by_unsupported_index(self, index, params):
|
||||
"""
|
||||
target: test search group by with the unsupported vector index
|
||||
@ -10052,17 +10053,14 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
3. search with group by
|
||||
verify: the error code and msg
|
||||
"""
|
||||
if index == "HNSW":
|
||||
pass # HNSW is supported
|
||||
if index in ["HNSW", "IVF_FLAT", "FLAT"]:
|
||||
pass # Only HNSW and IVF_FLAT are supported
|
||||
else:
|
||||
metric = "L2"
|
||||
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
index_params = {"index_type": index, "params": params, "metric_type": metric}
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
for vector_name in vector_name_list:
|
||||
collection_w.create_index(vector_name, index_params)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"params": {}}
|
||||
@ -10072,10 +10070,9 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
|
||||
# search with groupby
|
||||
err_code = 999
|
||||
err_msg = "Unexpected index"
|
||||
if index in ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "SCANN"]:
|
||||
err_code = 65535
|
||||
err_msg = "Returned knowhere iterator has non-ready iterators inside, terminate group_by operation"
|
||||
err_msg = "terminate group_by operation"
|
||||
if index in ["DISKANN"]:
|
||||
err_msg = "not supported for current index type"
|
||||
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=ct.default_int8_field_name,
|
||||
@ -10094,12 +10091,9 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
"""
|
||||
metric = "IP"
|
||||
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
|
||||
is_all_data_type=True, with_json=True)[0]
|
||||
is_all_data_type=True, with_json=True, )[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
for vector_name in vector_name_list:
|
||||
collection_w.create_index(vector_name, _index)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
@ -10117,7 +10111,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="issue #30033")
|
||||
@pytest.mark.skip(reason="issue #30033, #30828")
|
||||
def test_search_pagination_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
@ -10135,11 +10129,10 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
for _ in range(50):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
for vector_name in vector_name_list:
|
||||
collection_w.create_index(vector_name, _index)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
# 2. search pagination with offset
|
||||
limit = 10
|
||||
@ -10148,14 +10141,19 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
grpby_field = ct.default_string_field_name
|
||||
search_vectors = cf.gen_vectors(1, dim=ct.default_dim)
|
||||
all_pages_ids = []
|
||||
all_pages_grpby_field_values = []
|
||||
for r in range(page_rounds):
|
||||
page_res = collection_w.search(search_vectors, anns_field=default_search_field,
|
||||
param=search_param, limit=limit, offset=limit * r,
|
||||
expr=default_search_exp, group_by_field=grpby_field,
|
||||
output_fields=["*"],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": limit},
|
||||
)[0]
|
||||
for j in range(limit):
|
||||
all_pages_grpby_field_values.append(page_res[0][j].get(grpby_field))
|
||||
all_pages_ids += page_res[0].ids
|
||||
assert len(all_pages_grpby_field_values) == len(set(all_pages_grpby_field_values))
|
||||
|
||||
total_res = collection_w.search(search_vectors, anns_field=default_search_field,
|
||||
param=search_param, limit=limit * page_rounds,
|
||||
@ -10164,14 +10162,15 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": limit * page_rounds}
|
||||
)[0]
|
||||
assert total_res[0].ids == all_pages_ids
|
||||
hit_num = len(set(total_res[0].ids).intersection(set(all_pages_ids)))
|
||||
assert hit_num / (limit * page_rounds) > 0.90
|
||||
grpby_field_values = []
|
||||
for i in range(limit * page_rounds):
|
||||
grpby_field_values.append(total_res[0][i].fields.get(grpby_field))
|
||||
assert len(grpby_field_values) == len(set(grpby_field_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="issue #30033")
|
||||
@pytest.mark.skip(reason="not support iterator + group by")
|
||||
def test_search_iterator_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
@ -10189,14 +10188,13 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
for _ in range(value_num):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
for vector_name in vector_name_list:
|
||||
collection_w.create_index(vector_name, _index)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
|
||||
grpby_field = ct.default_int64_field_name
|
||||
grpby_field = ct.default_int32_field_name
|
||||
search_vectors = cf.gen_vectors(1, dim=ct.default_dim)
|
||||
search_params = {"metric_type": metric}
|
||||
batch_size = 10
|
||||
@ -10205,18 +10203,24 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
# search_params, group_by_field=grpby_field, limit=10)[0]
|
||||
|
||||
ite_res = collection_w.search_iterator(search_vectors, ct.default_float_vec_field_name,
|
||||
search_params, batch_size, group_by_field=grpby_field
|
||||
search_params, batch_size, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field]
|
||||
)[0]
|
||||
iterators = 0
|
||||
while True:
|
||||
res = ite_res.next() # turn to the next page
|
||||
if len(res) == 0:
|
||||
ite_res.close() # close the iterator
|
||||
break
|
||||
iterators += 1
|
||||
assert iterators == value_num/batch_size
|
||||
# iterators = 0
|
||||
# while True and iterators < value_num/batch_size:
|
||||
# res = ite_res.next() # turn to the next page
|
||||
# if len(res) == 0:
|
||||
# ite_res.close() # close the iterator
|
||||
# break
|
||||
# iterators += 1
|
||||
# grp_values = []
|
||||
# for j in range(len(res)):
|
||||
# grp_values.append(res.get__item(j).get(grpby_field))
|
||||
# log.info(f"iterators: {iterators}, grp_values: {grp_values}")
|
||||
# assert iterators == value_num/batch_size
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip(reason="not support range search + group by")
|
||||
def test_range_search_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
@ -10236,10 +10240,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
vector_name_list.append(ct.default_float_vec_field_name)
|
||||
for vector_name in vector_name_list:
|
||||
collection_w.create_index(vector_name, _index)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
time.sleep(10)
|
||||
collection_w.load()
|
||||
|
||||
@ -10255,14 +10256,14 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": nq, "limit": limit})[0]
|
||||
grpby_field_values = []
|
||||
for i in range(limit):
|
||||
grpby_field_values.append(res[0][i].fields.get(grpby_field))
|
||||
assert len(grpby_field_values) == len(set(grpby_field_values))
|
||||
# grpby_field_values = []
|
||||
# for i in range(limit):
|
||||
# grpby_field_values.append(res[0][i].fields.get(grpby_field))
|
||||
# assert len(grpby_field_values) == len(set(grpby_field_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip(reason="not completed")
|
||||
def test_hybrid_search_group_by(self):
|
||||
def test_advanced_search_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with multiple vector fields
|
||||
|
Loading…
Reference in New Issue
Block a user