milvus/tests/python_client/testcases/test_issues.py
yanliang567 7ac339ac64
test: Update init collection method (#35596)
Related issue: #32653

Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
2024-08-21 09:22:56 +08:00

78 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from utils.util_pymilvus import *
from common.common_type import CaseLabel, CheckTasks
from common import common_type as ct
from common import common_func as cf
from utils.util_log import test_log as log
from base.client_base import TestcaseBase
import random
import pytest
class TestIssues(TestcaseBase):
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.parametrize("par_key_field", [ct.default_int64_field_name])
@pytest.mark.parametrize("use_upsert", [True, False])
def test_issue_30607(self, par_key_field, use_upsert):
"""
Method
1. create a collection with partition key on collection schema with customized num_partitions
2. randomly check 200 entities
2. verify partition key values are hashed into correct partitions
"""
self._connect()
pk_field = cf.gen_string_field(name='pk', is_primary=True)
int64_field = cf.gen_int64_field()
string_field = cf.gen_string_field()
vector_field = cf.gen_float_vec_field()
schema = cf.gen_collection_schema(fields=[pk_field, int64_field, string_field, vector_field],
auto_id=False, partition_key_field=par_key_field)
c_name = cf.gen_unique_str("par_key")
collection_w = self.init_collection_wrap(name=c_name, schema=schema, num_partitions=9)
# insert
nb = 500
string_prefix = cf.gen_str_by_length(length=6)
entities_per_parkey = 20
for n in range(entities_per_parkey):
pk_values = [str(i) for i in range(n * nb, (n+1)*nb)]
int64_values = [i for i in range(0, nb)]
string_values = [string_prefix + str(i) for i in range(0, nb)]
float_vec_values = gen_vectors(nb, ct.default_dim)
data = [pk_values, int64_values, string_values, float_vec_values]
if use_upsert:
collection_w.upsert(data)
else:
collection_w.insert(data)
# flush
collection_w.flush()
num_entities = collection_w.num_entities
# build index
collection_w.create_index(field_name=vector_field.name, index_params=ct.default_index)
for index_on_par_key_field in [False, True]:
collection_w.release()
if index_on_par_key_field:
collection_w.create_index(field_name=par_key_field, index_params={})
# load
collection_w.load()
# verify the partition key values are bashed correctly
seeds = 200
rand_ids = random.sample(range(0, num_entities), seeds)
rand_ids = [str(rand_ids[i]) for i in range(len(rand_ids))]
res, _ = collection_w.query(expr=f"pk in {rand_ids}", output_fields=["pk", par_key_field])
# verify every the random id exists
assert len(res) == len(rand_ids)
dirty_count = 0
for i in range(len(res)):
pk = res[i].get("pk")
parkey_value = res[i].get(par_key_field)
res_parkey, _ = collection_w.query(expr=f"{par_key_field}=={parkey_value} and pk=='{pk}'",
output_fields=["pk", par_key_field])
if len(res_parkey) != 1:
log.info(f"dirty data found: pk {pk} with parkey {parkey_value}")
dirty_count += 1
assert dirty_count == 0
log.info(f"check randomly {seeds}/{num_entities}, dirty count={dirty_count}")