Add test cases of upsert (#22630)

Signed-off-by: nico <cheng.yuan@zilliz.com>
This commit is contained in:
NicoYuan1986 2023-03-13 15:35:54 +08:00 committed by GitHub
parent 30efaa3495
commit f314d887f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 175 additions and 13 deletions

View File

@ -132,6 +132,17 @@ class ApiPartitionWrapper:
**kwargs).run() **kwargs).run()
return res, check_result return res, check_result
def upsert(self, data, check_task=None, check_items=None, **kwargs):
timeout = kwargs.get("timeout", TIMEOUT)
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, succ = api_request([self.partition.upsert, data], **kwargs)
check_result = ResponseChecker(res, func_name, check_task,
check_items, is_succ=succ, data=data,
**kwargs).run()
return res, check_result
def get_replicas(self, timeout=None, check_task=None, check_items=None, **kwargs): def get_replicas(self, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout timeout = TIMEOUT if timeout is None else timeout
func_name = sys._getframe().f_code.co_name func_name = sys._getframe().f_code.co_name

View File

@ -538,6 +538,40 @@ class TestPartitionParams(TestcaseBase):
assert not partition_w.is_empty assert not partition_w.is_empty
assert partition_w.num_entities == (nums + nums) assert partition_w.num_entities == (nums + nums)
@pytest.mark.tags(CaseLabel.L1)
def test_partition_upsert(self):
"""
target: verify upsert entities multiple times
method: 1. create a collection and a partition
2. partition.upsert(data)
3. upsert data again
expected: upsert data successfully
"""
# create collection and a partition
collection_w = self.init_collection_wrap()
partition_name = cf.gen_unique_str(prefix)
partition_w = self.init_partition_wrap(collection_w, partition_name)
# insert data and load
cf.insert_data(collection_w)
collection_w.create_index(ct.default_float_vec_field_name, ct.default_index)
collection_w.load()
# upsert data
upsert_nb = 1000
data, values = cf.gen_default_data_for_upsert(nb=upsert_nb, start=2000)
partition_w.upsert(data)
res = partition_w.query("int64 >= 2000 && int64 < 3000", [ct.default_float_field_name])[0]
assert partition_w.num_entities == upsert_nb + ct.default_nb // 2
assert [res[i][ct.default_float_field_name] for i in range(upsert_nb)] == values.to_list()
# upsert data
data, values = cf.gen_default_data_for_upsert(nb=upsert_nb, start=ct.default_nb)
partition_w.upsert(data)
res = partition_w.query("int64 >= 3000 && int64 < 4000", [ct.default_float_field_name])[0]
assert partition_w.num_entities == upsert_nb * 2 + ct.default_nb // 2
assert [res[i][ct.default_float_field_name] for i in range(upsert_nb)] == values.to_list()
class TestPartitionOperations(TestcaseBase): class TestPartitionOperations(TestcaseBase):
""" Test case of partition interface in operations """ """ Test case of partition interface in operations """
@ -1012,6 +1046,128 @@ class TestPartitionOperations(TestcaseBase):
res = partition_w.delete(expr) res = partition_w.delete(expr)
assert len(res) == 2 assert len(res) == 2
@pytest.mark.tags(CaseLabel.L1)
def test_partition_upsert_empty_partition(self):
"""
target: verify upsert data in empty partition
method: 1. create a collection
2. upsert some data in empty partition
expected: upsert successfully
"""
# create collection
collection_w = self.init_collection_wrap()
# get the default partition
partition_name = ct.default_partition_name
partition_w = self.init_partition_wrap(collection_w, partition_name)
assert partition_w.num_entities == 0
# upsert data to the empty partition
data = cf.gen_default_data_for_upsert()[0]
partition_w.upsert(data)
assert partition_w.num_entities == ct.default_nb
@pytest.mark.tags(CaseLabel.L1)
def test_partition_upsert_dropped_partition(self):
"""
target: verify upsert data in a dropped partition
method: 1. create a partition and drop
2. upsert some data into the dropped partition
expected: raise exception
"""
# create partition
partition_w = self.init_partition_wrap()
# drop partition
partition_w.drop()
# insert data to partition
partition_w.upsert(cf.gen_default_dataframe_data(),
check_task=CheckTasks.err_res,
check_items={ct.err_code: 1, ct.err_msg: "Partition not exist"})
@pytest.mark.tags(CaseLabel.L2)
def test_partition_upsert_mismatched_data(self):
"""
target: test upsert mismatched data in partition
method: 1. create a partition
2. insert some data
3. upsert with mismatched data
expected: raise exception
"""
# create a partition
partition_w = self.init_partition_wrap()
# insert data
data = cf.gen_default_dataframe_data()
partition_w.insert(data)
# upsert mismatched data
upsert_data = cf.gen_default_data_for_upsert(dim=ct.default_dim-1)[0]
error = {ct.err_code: 1, ct.err_msg: "Collection field dim is 128, but entities field dim is 127"}
partition_w.upsert(upsert_data, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L2)
def test_partition_upsert_with_auto_id(self):
"""
target: test upsert data in partition when auto_id=True
method: 1. create a partition
2. insert some data
3. upsert data
expected: raise exception
"""
# create a partition
schema = cf.gen_default_collection_schema(auto_id=True)
collection_w = self.init_collection_wrap(schema=schema)
partition_w = self.init_partition_wrap(collection_w)
# insert data
data = cf.gen_default_dataframe_data()
data.drop(ct.default_int64_field_name, axis=1, inplace=True)
partition_w.insert(data)
# upsert data
upsert_data = cf.gen_default_data_for_upsert()[0]
upsert_data.drop(ct.default_int64_field_name, axis=1, inplace=True)
error = {ct.err_code: 1, ct.err_msg: "Upsert don't support autoid == true"}
partition_w.upsert(upsert_data, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L2)
def test_partition_upsert_same_pk_in_different_partitions(self):
"""
target: test upsert same pk in different partitions
method: 1. create 2 partitions
2. insert some data
3. upsert data
expected: raise exception
"""
# create 2 partitions
collection_w = self.init_collection_wrap()
partition_1 = self.init_partition_wrap(collection_w)
partition_2 = self.init_partition_wrap(collection_w)
# insert data
nb = 1000
data = cf.gen_default_dataframe_data(nb)
partition_1.insert(data)
data = cf.gen_default_dataframe_data(nb, start=nb)
partition_2.insert(data)
# upsert data in 2 partitions
upsert_data, values = cf.gen_default_data_for_upsert(1)
partition_1.upsert(upsert_data)
partition_2.upsert(upsert_data)
# load
collection_w.create_index(ct.default_float_vec_field_name, ct.default_flat_index)
collection_w.load()
# query and check the results
expr = "int64 == 0"
res1 = partition_1.query(expr, [ct.default_float_field_name])[0]
res2 = partition_2.query(expr, [ct.default_float_field_name])[0]
assert res1 == res2
@pytest.mark.tags(CaseLabel.L0) @pytest.mark.tags(CaseLabel.L0)
def test_create_partition_repeat(self): def test_create_partition_repeat(self):
""" """

View File

@ -3070,7 +3070,7 @@ class TestSearchBase(TestcaseBase):
partition_num=1, partition_num=1,
dim=dim, is_index=False)[0:5] dim=dim, is_index=False)[0:5]
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
# 2. create patition # 2. create partition
partition_name = "search_partition_empty" partition_name = "search_partition_empty"
collection_w.create_partition(partition_name=partition_name, description="search partition empty") collection_w.create_partition(partition_name=partition_name, description="search partition empty")
par = collection_w.partitions par = collection_w.partitions
@ -3236,7 +3236,7 @@ class TestSearchBase(TestcaseBase):
partition_num=1, partition_num=1,
dim=dim, is_index=False)[0:5] dim=dim, is_index=False)[0:5]
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
# 2. create patition # 2. create partition
partition_name = "search_partition_empty" partition_name = "search_partition_empty"
collection_w.create_partition(partition_name=partition_name, description="search partition empty") collection_w.create_partition(partition_name=partition_name, description="search partition empty")
par = collection_w.partitions par = collection_w.partitions
@ -3279,7 +3279,7 @@ class TestSearchBase(TestcaseBase):
partition_num=1, partition_num=1,
dim=dim, is_index=False)[0:5] dim=dim, is_index=False)[0:5]
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
# 2. create patition # 2. create partition
par_name = collection_w.partitions[0].name par_name = collection_w.partitions[0].name
# collection_w.load() # collection_w.load()
# 3. create different index # 3. create different index
@ -3399,7 +3399,7 @@ class TestSearchDSL(TestcaseBase):
"limit": ct.default_top_k}) "limit": ct.default_top_k})
class TestsearchString(TestcaseBase): class TestsearchString(TestcaseBase):
""" """
****************************************************************** ******************************************************************
The following cases are used to test search about string The following cases are used to test search about string
@ -3455,7 +3455,6 @@ class TestsearchString(TestcaseBase):
"limit": default_limit, "limit": default_limit,
"_async": _async}) "_async": _async})
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_search_string_field_is_primary_true(self, dim, _async): def test_search_string_field_is_primary_true(self, dim, _async):
""" """
@ -3505,7 +3504,7 @@ class TestsearchString(TestcaseBase):
default_search_params, default_limit, default_search_params, default_limit,
default_search_mix_exp, default_search_mix_exp,
output_fields=output_fields, output_fields=output_fields,
_async=_async, _async=_async,
travel_timestamp=0, travel_timestamp=0,
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq, check_items={"nq": default_nq,
@ -3522,7 +3521,6 @@ class TestsearchString(TestcaseBase):
collection search uses invalid string expr collection search uses invalid string expr
expected: Raise exception expected: Raise exception
""" """
# 1. initialize with data # 1. initialize with data
collection_w, _, _, insert_ids = \ collection_w, _, _, insert_ids = \
self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim)[0:4] self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim)[0:4]
@ -3537,8 +3535,6 @@ class TestsearchString(TestcaseBase):
"err_msg": "failed to create query plan: type mismatch"} "err_msg": "failed to create query plan: type mismatch"}
) )
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(ct.default_string_field_name)) @pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(ct.default_string_field_name))
def test_search_with_different_string_expr(self, dim, expression, _async): def test_search_with_different_string_expr(self, dim, expression, _async):
@ -3631,8 +3627,7 @@ class TestsearchString(TestcaseBase):
collection search uses string expr in string field, string field is not primary collection search uses string expr in string field, string field is not primary
expected: Search successfully expected: Search successfully
""" """
# 1. initialize with binary data # 1. initialize with binary data
collection_w, _, binary_raw_vector, insert_ids = self.init_collection_general(prefix, True, 2, collection_w, _, binary_raw_vector, insert_ids = self.init_collection_general(prefix, True, 2,
is_binary=True, is_binary=True,
auto_id=auto_id, auto_id=auto_id,
@ -3680,7 +3675,7 @@ class TestsearchString(TestcaseBase):
search_params, default_limit, search_params, default_limit,
default_search_mix_exp, default_search_mix_exp,
output_fields=output_fields, output_fields=output_fields,
_async=_async, _async=_async,
travel_timestamp=0, travel_timestamp=0,
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq, check_items={"nq": default_nq,
@ -3786,7 +3781,7 @@ class TestsearchString(TestcaseBase):
search_string_exp = "varchar >= \"\"" search_string_exp = "varchar >= \"\""
limit =1 limit = 1
# 2. search # 2. search
log.info("test_search_string_field_is_primary_true: searching collection %s" % collection_w.name) log.info("test_search_string_field_is_primary_true: searching collection %s" % collection_w.name)