diff --git a/tests/python_client/base/client_base.py b/tests/python_client/base/client_base.py index 512bdba73b..27a28b00e6 100644 --- a/tests/python_client/base/client_base.py +++ b/tests/python_client/base/client_base.py @@ -15,21 +15,6 @@ from common import common_func as cf from common import common_type as ct -class ParamInfo: - def __init__(self): - self.param_host = "" - self.param_port = "" - self.param_handler = "" - - def prepare_param_info(self, host, port, handler): - self.param_host = host - self.param_port = port - self.param_handler = handler - - -param_info = ParamInfo() - - class Base: """ Initialize class object """ connection_wrap = None @@ -65,8 +50,8 @@ class Base: try: """ Drop collection before disconnect """ if not self.connection_wrap.has_connection(alias=DefaultConfig.DEFAULT_USING)[0]: - self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, host=param_info.param_host, - port=param_info.param_port) + self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, host=cf.param_info.param_host, + port=cf.param_info.param_port) if self.collection_wrap.collection is not None: self.collection_wrap.drop(check_task=ct.CheckTasks.check_nothing) @@ -100,8 +85,8 @@ class TestcaseBase(Base): def _connect(self): """ Add a connection and create the connect """ - res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, host=param_info.param_host, - port=param_info.param_port) + res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, host=cf.param_info.param_host, + port=cf.param_info.param_port) return res def init_collection_wrap(self, name=None, schema=None, shards_num=2, check_task=None, check_items=None, **kwargs): diff --git a/tests/python_client/base/collection_wrapper.py b/tests/python_client/base/collection_wrapper.py index b92fd43e03..e47697c3d0 100644 --- a/tests/python_client/base/collection_wrapper.py +++ b/tests/python_client/base/collection_wrapper.py @@ -1,6 +1,7 @@ import sys import time import timeout_decorator +from numpy import NaN from pymilvus import Collection @@ -10,6 +11,7 @@ from utils.api_request import api_request from utils.wrapper import trace from utils.util_log import test_log as log from pymilvus.orm.types import CONSISTENCY_STRONG +from common.common_func import param_info TIMEOUT = 20 @@ -85,9 +87,10 @@ class ApiCollectionWrapper: return res, check_result @trace() - def load(self, partition_names=None, replica_number=1, timeout=None, check_task=None, check_items=None, **kwargs): + def load(self, partition_names=None, replica_number=NaN, timeout=None, check_task=None, check_items=None, **kwargs): timeout = TIMEOUT if timeout is None else timeout - + replica_number = param_info.param_replica_num if replica_number is NaN else replica_number + func_name = sys._getframe().f_code.co_name res, check = api_request([self.collection.load, partition_names, replica_number, timeout], **kwargs) check_result = ResponseChecker(res, func_name, check_task, check_items, check, diff --git a/tests/python_client/base/partition_wrapper.py b/tests/python_client/base/partition_wrapper.py index 7f3d4ab6ff..5f58d1415d 100644 --- a/tests/python_client/base/partition_wrapper.py +++ b/tests/python_client/base/partition_wrapper.py @@ -1,10 +1,12 @@ import sys +from numpy import NaN from pymilvus import Partition sys.path.append("..") from check.func_check import ResponseChecker from utils.api_request import api_request +from common.common_func import param_info TIMEOUT = 20 @@ -49,8 +51,9 @@ class ApiPartitionWrapper: check_task, check_items, succ, **kwargs).run() return res, check_result - def load(self, replica_number=1, timeout=None, check_task=None, check_items=None, **kwargs): + def load(self, replica_number=NaN, timeout=None, check_task=None, check_items=None, **kwargs): timeout = TIMEOUT if timeout is None else timeout + replica_number = param_info.param_replica_num if replica_number is NaN else replica_number func_name = sys._getframe().f_code.co_name res, succ = api_request([self.partition.load, replica_number, timeout], **kwargs) diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index c42d24b3bd..f8609fae4f 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -13,6 +13,22 @@ from utils.util_log import test_log as log """" Methods of processing data """ +class ParamInfo: + def __init__(self): + self.param_host = "" + self.param_port = "" + self.param_handler = "" + self.param_replica_num = ct.default_replica_num + + def prepare_param_info(self, host, port, handler, replica_num): + self.param_host = host + self.param_port = port + self.param_handler = handler + self.param_replica_num = replica_num + + +param_info = ParamInfo() + def gen_unique_str(str_value=None): prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)) diff --git a/tests/python_client/common/common_type.py b/tests/python_client/common/common_type.py index cacde89e24..6f9a874702 100644 --- a/tests/python_client/common/common_type.py +++ b/tests/python_client/common/common_type.py @@ -50,6 +50,7 @@ compact_delta_ratio_reciprocal = 5 # compact_delta_binlog_ratio is 0.2 compact_retention_duration = 40 # compaction travel time retention range 20s max_compaction_interval = 60 # the max time interval (s) from the last compaction max_field_num = 256 # Maximum number of fields in a collection +default_replica_num = 1 # default memory replica number Not_Exist = "Not_Exist" Connect_Object_Name = True diff --git a/tests/python_client/conftest.py b/tests/python_client/conftest.py index ece41d343f..f9de799178 100644 --- a/tests/python_client/conftest.py +++ b/tests/python_client/conftest.py @@ -7,7 +7,7 @@ import socket import common.common_type as ct import common.common_func as cf from utils.util_log import test_log as log -from base.client_base import param_info +from common.common_func import param_info from check.param_check import ip_check, number_check from config.log_config import log_config from utils.util_pymilvus import get_milvus, gen_unique_str, gen_default_fields, gen_binary_default_fields @@ -39,6 +39,7 @@ def pytest_addoption(parser): parser.addoption('--term_expr', action='store', default="term_expr", help="expr of query quest") parser.addoption('--check_content', action='store', default="check_content", help="content of check") parser.addoption('--field_name', action='store', default="field_name", help="field_name of index") + parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number") @pytest.fixture @@ -153,6 +154,7 @@ def initialize_env(request): port = request.config.getoption("--port") handler = request.config.getoption("--handler") clean_log = request.config.getoption("--clean_log") + replica_num = request.config.getoption("--replica_num") """ params check """ assert ip_check(host) and number_check(port) @@ -165,7 +167,7 @@ def initialize_env(request): log.info("#" * 80) log.info("[initialize_milvus] Log cleaned up, start testing...") - param_info.prepare_param_info(host, port, handler) + param_info.prepare_param_info(host, port, handler, replica_num) @pytest.fixture(params=ct.get_invalid_strs) diff --git a/tests/python_client/testcases/test_collection.py b/tests/python_client/testcases/test_collection.py index 56550c8b15..97338a2880 100644 --- a/tests/python_client/testcases/test_collection.py +++ b/tests/python_client/testcases/test_collection.py @@ -2195,7 +2195,7 @@ class TestLoadCollection(TestcaseBase): """ target: test load partition with invalid replica number method: load with invalid replica number - expected: raise exception + expected: load successfully as replica = 1 """ # create, insert collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) diff --git a/tests/python_client/testcases/test_compaction.py b/tests/python_client/testcases/test_compaction.py index e8fe3887f6..459bfa44b4 100644 --- a/tests/python_client/testcases/test_compaction.py +++ b/tests/python_client/testcases/test_compaction.py @@ -321,8 +321,10 @@ class TestCompactionParams(TestcaseBase): # verify queryNode load the compacted segments collection_w.load() + replicas = collection_w.get_replicas()[0] + replica_num = len(replicas.groups) segment_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] - assert len(segment_info) == 1 + assert len(segment_info) == 1*replica_num @pytest.mark.skip(reason="TODO") @pytest.mark.tags(CaseLabel.L2) @@ -763,8 +765,10 @@ class TestCompactionOperation(TestcaseBase): collection_w.get_compaction_plans(check_task=CheckTasks.check_merge_compact, check_items={"segment_num": 2}) collection_w.load() + replicas = collection_w.get_replicas()[0] + replica_num = len(replicas.groups) segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] - assert len(segments_info) == 1 + assert len(segments_info) == 1*replica_num @pytest.mark.tags(CaseLabel.L1) def test_compact_merge_multi_segments(self): @@ -790,8 +794,10 @@ class TestCompactionOperation(TestcaseBase): target = c_plans.plans[0].target collection_w.load() + replicas = collection_w.get_replicas()[0] + replica_num = len(replicas.groups) segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] - assert len(segments_info) == 1 + assert len(segments_info) == 1*replica_num assert segments_info[0].segmentID == target @pytest.mark.tags(CaseLabel.L2) @@ -845,13 +851,15 @@ class TestCompactionOperation(TestcaseBase): # Estimated auto-merging takes 30s cost = 60 collection_w.load() + replicas = collection_w.get_replicas()[0] + replica_num = len(replicas.groups) start = time() while True: sleep(5) segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] # verify segments reaches threshold, auto-merge ten segments into one - if len(segments_info) == 1: + if len(segments_info) == 1*replica_num: break end = time() if end - start > cost: @@ -874,8 +882,10 @@ class TestCompactionOperation(TestcaseBase): # load and verify no auto-merge collection_w.load() + replicas = collection_w.get_replicas()[0] + replica_num = len(replicas.groups) segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] - assert len(segments_info) == less_threshold + assert len(segments_info) == less_threshold*replica_num @pytest.mark.skip(reason="Todo") @pytest.mark.tags(CaseLabel.L2) @@ -1042,14 +1052,16 @@ class TestCompactionOperation(TestcaseBase): t.join() collection_w.load() + replicas = collection_w.get_replicas()[0] + replica_num = len(replicas.groups) seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] - assert len(seg_info) == 2 + assert len(seg_info) == 2*replica_num @pytest.mark.tags(CaseLabel.L2) def test_compact_during_index(self): """ target: test compact during index - method: while compact collection start a thread to creat index + method: while compact collection start a thread to create index expected: No exception """ collection_w = self.collection_insert_multi_segments_one_shard(prefix, nb_of_segment=ct.default_nb, @@ -1068,8 +1080,10 @@ class TestCompactionOperation(TestcaseBase): t.join() collection_w.load() + replicas = collection_w.get_replicas()[0] + replica_num = len(replicas.groups) seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] - assert len(seg_info) == 1 + assert len(seg_info) == 1*replica_num def test_compact_during_search(self): """