Update query cases for output vector field (#6633)

* fix check query result with vec fields

Signed-off-by: ThreadDao <yufen.zong@zilliz.com>

* update query cases for output vec field

Signed-off-by: ThreadDao <yufen.zong@zilliz.com>
This commit is contained in:
ThreadDao 2021-07-20 16:06:08 +08:00 committed by GitHub
parent 234954931f
commit 860ca4b40e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 185 additions and 117 deletions

View File

@ -202,8 +202,9 @@ class ResponseChecker:
if len(check_items) == 0: if len(check_items) == 0:
raise Exception("No expect values found in the check task") raise Exception("No expect values found in the check task")
exp_res = check_items.get("exp_res", None) exp_res = check_items.get("exp_res", None)
with_vec = check_items.get("with_vec", False)
if exp_res and isinstance(query_res, list): if exp_res and isinstance(query_res, list):
assert pc.equal_entities_list(exp=exp_res, actual=query_res) assert pc.equal_entities_list(exp=exp_res, actual=query_res, with_vec=with_vec)
# assert len(exp_res) == len(query_res) # assert len(exp_res) == len(query_res)
# for i in range(len(exp_res)): # for i in range(len(exp_res)):
# assert_entity_equal(exp=exp_res[i], actual=query_res[i]) # assert_entity_equal(exp=exp_res[i], actual=query_res[i])

View File

@ -151,10 +151,11 @@ def equal_entity(exp, actual):
for field, value in exp.items(): for field, value in exp.items():
if isinstance(value, list): if isinstance(value, list):
assert len(actual[field]) == len(exp[field]) assert len(actual[field]) == len(exp[field])
for i in range(len(exp[field])): for i in range(0, len(exp[field]), 2):
assert abs(actual[field][i] - exp[field][i]) < ct.epsilon assert abs(actual[field][i] - exp[field][i]) < ct.epsilon
else: else:
assert actual[field] == exp[field] assert actual[field] == exp[field]
return True
def entity_in(entity, entities, primary_field=ct.default_int64_field_name): def entity_in(entity, entities, primary_field=ct.default_int64_field_name):
@ -173,7 +174,7 @@ def entity_in(entity, entities, primary_field=ct.default_int64_field_name):
primary_keys.append(e[primary_field]) primary_keys.append(e[primary_field])
if primary_key not in primary_keys: if primary_key not in primary_keys:
return False return False
index = primary_key.index(primary_key) index = primary_keys.index(primary_key)
return equal_entity(entities[index], entity) return equal_entity(entities[index], entity)
@ -196,9 +197,10 @@ def remove_entity(entity, entities, primary_field=ct.default_int64_field_name):
return entities return entities
def equal_entities_list(exp, actual): def equal_entities_list(exp, actual, with_vec=False):
""" """
compare two entities lists in inconsistent order compare two entities lists in inconsistent order
:param with_vec: whether entities with vec field
:param exp: exp entities list, list of dict :param exp: exp entities list, list of dict
:param actual: actual entities list, list of dict :param actual: actual entities list, list of dict
:return: True or False :return: True or False
@ -209,14 +211,21 @@ def equal_entities_list(exp, actual):
""" """
if len(exp) != len(actual): if len(exp) != len(actual):
return False return False
if with_vec:
for a in actual: for a in actual:
# if vec field returned in query res # if vec field returned in query res
# if entity_in_entities(a, exp): if entity_in(a, exp):
try:
# if vec field returned in query res
remove_entity(a, exp)
except Exception as ex:
log.error(ex)
else:
for a in actual:
if a in exp: if a in exp:
try: try:
exp.remove(a) exp.remove(a)
# if vec field returned in query res
# remove_entity(a, exp)
except Exception as ex: except Exception as ex:
print(ex) log.error(ex)
return True if len(exp) == 0 else False return True if len(exp) == 0 else False

View File

@ -13,6 +13,8 @@ from utils.util_log import test_log as log
prefix = "query" prefix = "query"
exp_res = "exp_res" exp_res = "exp_res"
default_term_expr = f'{ct.default_int64_field_name} in [0, 1]' default_term_expr = f'{ct.default_int64_field_name} in [0, 1]'
default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}}
binary_index_params = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"nlist": 64}}
class TestQueryBase(TestcaseBase): class TestQueryBase(TestcaseBase):
@ -52,20 +54,30 @@ class TestQueryBase(TestcaseBase):
@pytest.mark.tags(CaseLabel.L0) @pytest.mark.tags(CaseLabel.L0)
def test_query_auto_id_collection(self): def test_query_auto_id_collection(self):
""" """
target: test query on collection that primary field auto_id=True target: test query with auto_id=True collection
method: 1.create collection with auto_id=True 2.query on primary field method: test query with auto id
expected: verify primary field values of query result expected: query result is correct
""" """
schema = cf.gen_default_collection_schema(auto_id=True) self._connect()
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema)
df = cf.gen_default_dataframe_data(ct.default_nb) df = cf.gen_default_dataframe_data(ct.default_nb)
df.drop(ct.default_int64_field_name, axis=1, inplace=True) df[ct.default_int64_field_name] = None
mutation_res, _ = collection_w.insert(data=df) res, _, = self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df,
assert collection_w.num_entities == ct.default_nb primary_field=ct.default_int64_field_name, auto_id=True)
collection_w.load() assert self.collection_wrap.num_entities == ct.default_nb
term_expr = f'{ct.default_int64_field_name} in [{mutation_res.primary_keys[0]}]' ids = res[1].primary_keys
res, _ = collection_w.query(term_expr) res = df.iloc[:2, :2].to_dict('records')
assert res[0][ct.default_int64_field_name] == mutation_res.primary_keys[0] self.collection_wrap.load()
# query with all primary keys
term_expr_1 = f'{ct.default_int64_field_name} in {ids[:2]}'
for i in range(2):
res[i][ct.default_int64_field_name] = ids[i]
self.collection_wrap.query(term_expr_1, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
# query with part primary keys
term_expr_2 = f'{ct.default_int64_field_name} in {[ids[0], 0]}'
self.collection_wrap.query(term_expr_2, check_task=CheckTasks.check_query_results,
check_items={exp_res: res[:1]})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_auto_id_not_existed_primary_key(self): def test_query_auto_id_not_existed_primary_key(self):
@ -97,20 +109,20 @@ class TestQueryBase(TestcaseBase):
collection_w.query(None, check_task=CheckTasks.err_res, check_items=error) collection_w.query(None, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr", [1, 2., [], {}, ()]) def test_query_expr_non_string(self):
def test_query_expr_non_string(self, expr):
""" """
target: test query with non-string expr target: test query with non-string expr
method: query with non-string expr, eg 1, [] .. method: query with non-string expr, eg 1, [] ..
expected: raise exception expected: raise exception
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
exprs = [1, 2., [], {}, ()]
error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"} error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"}
for expr in exprs:
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr", ["12-s", "中文", "a", " "]) def test_query_expr_invalid_string(self):
def test_query_expr_invalid_string(self, expr):
""" """
target: test query with invalid expr target: test query with invalid expr
method: query with invalid string expr method: query with invalid string expr
@ -118,10 +130,12 @@ class TestQueryBase(TestcaseBase):
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
error = {ct.err_code: 1, ct.err_msg: "Invalid expression!"} error = {ct.err_code: 1, ct.err_msg: "Invalid expression!"}
exprs = ["12-s", "中文", "a", " "]
for expr in exprs:
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_expr_term(self): def _test_query_expr_term(self):
""" """
target: test query with TermExpr target: test query with TermExpr
method: query with TermExpr method: query with TermExpr
@ -132,33 +146,30 @@ class TestQueryBase(TestcaseBase):
collection_w.query(default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) collection_w.query(default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6259")
def test_query_expr_not_existed_field(self): def test_query_expr_not_existed_field(self):
""" """
target: test query with not existed field target: test query with not existed field
method: query by term expr with fake field method: query by term expr with fake field
expected: raise exception expected: raise exception
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix))
term_expr = 'field in [1, 2]' term_expr = 'field in [1, 2]'
error = {ct.err_code: 1, ct.err_msg: "fieldName(field) not found"} error = {ct.err_code: 1, ct.err_msg: "fieldName(field) not found"}
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6259")
def test_query_expr_unsupported_field(self): def test_query_expr_unsupported_field(self):
""" """
target: test query on unsupported field target: test query on unsupported field
method: query on float field method: query on float field
expected: raise exception expected: raise exception
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix))
term_expr = f'{ct.default_float_field_name} in [1., 2.]' term_expr = f'{ct.default_float_field_name} in [1., 2.]'
error = {ct.err_code: 1, ct.err_msg: "column is not int64"} error = {ct.err_code: 1, ct.err_msg: "column is not int64"}
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6259")
def test_query_expr_non_primary_field(self): def test_query_expr_non_primary_field(self):
""" """
target: test query on non-primary field target: test query on non-primary field
@ -177,7 +188,6 @@ class TestQueryBase(TestcaseBase):
collection_w.query(default_term_expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(default_term_expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6259")
def test_query_expr_wrong_term_keyword(self): def test_query_expr_wrong_term_keyword(self):
""" """
target: test query with wrong term expr keyword target: test query with wrong term expr keyword
@ -198,18 +208,18 @@ class TestQueryBase(TestcaseBase):
collection_w.query(expr_3, check_task=CheckTasks.err_res, check_items=error_3) collection_w.query(expr_3, check_task=CheckTasks.err_res, check_items=error_3)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6259") def test_query_expr_non_array_term(self):
@pytest.mark.parametrize("expr", [f'{ct.default_int64_field_name} in 1',
f'{ct.default_int64_field_name} in "in"',
f'{ct.default_int64_field_name} in (mn)'])
def test_query_expr_non_array_term(self, expr):
""" """
target: test query with non-array term expr target: test query with non-array term expr
method: query with non-array term expr method: query with non-array term expr
expected: raise exception expected: raise exception
""" """
exprs = [f'{ct.default_int64_field_name} in 1',
f'{ct.default_int64_field_name} in "in"',
f'{ct.default_int64_field_name} in (mn)']
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
error = {ct.err_code: 1, ct.err_msg: "right operand of the InExpr must be array"} error = {ct.err_code: 1, ct.err_msg: "right operand of the InExpr must be array"}
for expr in exprs:
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@ -225,46 +235,32 @@ class TestQueryBase(TestcaseBase):
assert len(res) == 0 assert len(res) == 0
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6259") def test_query_expr_inconsistent_mix_term_array(self):
def test_query_expr_inconstant_term_array(self):
""" """
target: test query with term expr that field and array are inconsistent target: test query with term expr that field and array are inconsistent or mix type
method: query with int field and float values method: 1.query with int field and float values
2.query with term expr that has int and float type value
expected: raise exception expected: raise exception
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix))
int_values = [1., 2.] int_values = [[1., 2.], [1, 2.]]
term_expr = f'{ct.default_int64_field_name} in {int_values}'
error = {ct.err_code: 1, ct.err_msg: "type mismatch"} error = {ct.err_code: 1, ct.err_msg: "type mismatch"}
for values in int_values:
term_expr = f'{ct.default_int64_field_name} in {values}'
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6259") def test_query_expr_non_constant_array_term(self):
def test_query_expr_mix_term_array(self):
"""
target: test query with mix type value expr
method: query with term expr that has int and float type value
expected: raise exception
"""
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
int_values = [1, 2.]
term_expr = f'{ct.default_int64_field_name} in {int_values}'
error = {ct.err_code: 1, ct.err_msg: "type mismatch"}
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6259")
@pytest.mark.parametrize("constant", [[1], (), {}])
def test_query_expr_non_constant_array_term(self, constant):
""" """
target: test query with non-constant array term expr target: test query with non-constant array term expr
method: query with non-constant array expr method: query with non-constant array expr
expected: raise exception expected: raise exception
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
term_expr = f'{ct.default_int64_field_name} in [{constant}]' constants = [[1], (), {}]
log.debug(term_expr)
error = {ct.err_code: 1, ct.err_msg: "unsupported leaf node"} error = {ct.err_code: 1, ct.err_msg: "unsupported leaf node"}
for constant in constants:
term_expr = f'{ct.default_int64_field_name} in [{constant}]'
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@ -287,8 +283,8 @@ class TestQueryBase(TestcaseBase):
expected: return one field expected: return one field
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name]) res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_float_field_name])
assert set(res[0].keys()) == set([ct.default_int64_field_name]) assert set(res[0].keys()) == set([ct.default_int64_field_name, ct.default_float_field_name])
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_output_all_fields(self): def test_query_output_all_fields(self):
@ -297,27 +293,52 @@ class TestQueryBase(TestcaseBase):
method: query with output field=None method: query with output field=None
expected: return all fields expected: return all fields
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
fields = [ct.default_int64_field_name, ct.default_float_field_name] df = cf.gen_default_dataframe_data()
res, _ = collection_w.query(default_term_expr, output_fields=fields) collection_w.insert(df)
assert set(res[0].keys()) == set(fields) assert collection_w.num_entities == ct.default_nb
res_1, _ = collection_w.query(default_term_expr, output_fields=[ct.default_float_field_name]) all_fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_float_vec_field_name]
assert set(res_1[0].keys()) == set(fields) res = df.iloc[:2].to_dict('records')
collection_w.load()
actual_res, _ = collection_w.query(default_term_expr, output_fields=all_fields,
check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True})
assert set(actual_res[0].keys()) == set(all_fields)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6299")
def test_query_output_vec_field(self): def test_query_output_vec_field(self):
""" """
target: test query with vec output field target: test query with vec output field
method: specify vec field as output field method: specify vec field as output field
expected: raise exception expected: return primary field and vec field
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
df = cf.gen_default_dataframe_data()
collection_w.insert(df)
assert collection_w.num_entities == ct.default_nb
fields = [[ct.default_float_vec_field_name], [ct.default_int64_field_name, ct.default_float_vec_field_name]] fields = [[ct.default_float_vec_field_name], [ct.default_int64_field_name, ct.default_float_vec_field_name]]
error = {ct.err_code: 1, ct.err_msg: "Query does not support vector field currently"} res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records')
collection_w.load()
for output_fields in fields: for output_fields in fields:
collection_w.query(default_term_expr, output_fields=output_fields, collection_w.query(default_term_expr, output_fields=output_fields,
check_task=CheckTasks.err_res, check_items=error) check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6594")
# todo
def test_query_output_binary_vec_field(self):
"""
target: test query with binary vec output field
method: specify binary vec field as output field
expected: return primary field and binary vec field
"""
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True)[0:2]
log.debug(collection_w.schema)
fields = [[ct.default_binary_vec_field_name], [ct.default_int64_field_name, ct.default_binary_vec_field_name]]
for output_fields in fields:
res, _ = collection_w.query(default_term_expr, output_fields=output_fields)
assert list(res[0].keys()) == fields[-1]
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_output_primary_field(self): def test_query_output_primary_field(self):
@ -331,9 +352,7 @@ class TestQueryBase(TestcaseBase):
assert list(res[0].keys()) == [ct.default_int64_field_name] assert list(res[0].keys()) == [ct.default_int64_field_name]
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("output_fields", [["int"], def test_query_output_not_existed_field(self):
[ct.default_int64_field_name, "int"]])
def test_query_output_not_existed_field(self, output_fields):
""" """
target: test query output not existed field target: test query output not existed field
method: query with not existed output field method: query with not existed output field
@ -341,7 +360,10 @@ class TestQueryBase(TestcaseBase):
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
error = {ct.err_code: 1, ct.err_msg: 'Field int not exist'} error = {ct.err_code: 1, ct.err_msg: 'Field int not exist'}
collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.err_res, check_items=error) output_fields = [["int"], [ct.default_int64_field_name, "int"]]
for fields in output_fields:
collection_w.query(default_term_expr, output_fields=fields, check_task=CheckTasks.err_res,
check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_empty_output_fields(self): def test_query_empty_output_fields(self):
@ -355,18 +377,20 @@ class TestQueryBase(TestcaseBase):
fields = [ct.default_int64_field_name, ct.default_float_field_name] fields = [ct.default_int64_field_name, ct.default_float_field_name]
assert list(query_res[0].keys()) == fields assert list(query_res[0].keys()) == fields
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L2)
@pytest.mark.xfail(reason="exception not MilvusException") @pytest.mark.xfail(reason="exception not MilvusException")
@pytest.mark.parametrize("output_fields", ["12-s", 1, [1, "2", 3], (1,), {1: 1}]) def test_query_invalid_output_fields(self):
def test_query_invalid_output_fields(self, output_fields):
""" """
target: test query with invalid output fields target: test query with invalid output fields
method: query with invalid field fields method: query with invalid field fields
expected: raise exception expected: raise exception
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
output_fields = ["12-s", 1, [1, "2", 3], (1,), {1: 1}]
error = {ct.err_code: 0, ct.err_msg: f'Invalid query format. \'output_fields\' must be a list'} error = {ct.err_code: 0, ct.err_msg: f'Invalid query format. \'output_fields\' must be a list'}
collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.err_res, check_items=error) for fields in output_fields:
collection_w.query(default_term_expr, output_fields=fields, check_task=CheckTasks.err_res,
check_items=error)
@pytest.mark.tags(CaseLabel.L0) @pytest.mark.tags(CaseLabel.L0)
def test_query_partition(self): def test_query_partition(self):
@ -568,7 +592,7 @@ class TestQueryOperation(TestcaseBase):
""" """
target: test query with repeated term array on primary field with unique value target: test query with repeated term array on primary field with unique value
method: query with repeated array value method: query with repeated array value
expected: verify query result expected: todo
""" """
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3] collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3]
int_values = [0, 0, 0, 0] int_values = [0, 0, 0, 0]
@ -577,6 +601,25 @@ class TestQueryOperation(TestcaseBase):
assert len(res) == 1 assert len(res) == 1
assert res[0][ct.default_int64_field_name] == int_values[0] assert res[0][ct.default_int64_field_name] == int_values[0]
@pytest.mark.tags(ct.CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6624")
def test_query_dup_ids_dup_term_array(self):
"""
target: test query on duplicate primary keys with dup term array
method: 1.create collection and insert dup primary keys
2.query with dup term array
expected: todo
"""
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
df = cf.gen_default_dataframe_data(nb=ct.default_nb)
df[ct.default_int64_field_name] = 0
mutation_res, _ = collection_w.insert(df)
assert mutation_res.primary_keys == df[ct.default_int64_field_name].tolist()
collection_w.load()
term_expr = f'{ct.default_int64_field_name} in {[0, 0, 0]}'
res, _ = collection_w.query(term_expr)
log.debug(res)
@pytest.mark.tags(ct.CaseLabel.L0) @pytest.mark.tags(ct.CaseLabel.L0)
def test_query_after_index(self): def test_query_after_index(self):
""" """
@ -587,9 +630,7 @@ class TestQueryOperation(TestcaseBase):
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3] collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3]
default_field_name = ct.default_float_vec_field_name default_field_name = ct.default_float_vec_field_name
default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}} collection_w.create_index(default_field_name, default_index_params)
index_name = ct.default_index_name
collection_w.create_index(default_field_name, default_index_params, index_name=index_name)
collection_w.load() collection_w.load()
@ -625,6 +666,42 @@ class TestQueryOperation(TestcaseBase):
check_vec = vectors[0].iloc[:, [0, 1]][0:2].to_dict('records') check_vec = vectors[0].iloc[:, [0, 1]][0:2].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L1)
def test_query_output_vec_field_after_index(self):
"""
target: test query output vec field after index
method: create index and specify vec field as output field
expected: return primary field and vec field
"""
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
df = cf.gen_default_dataframe_data(nb=5000)
collection_w.insert(df)
assert collection_w.num_entities == 5000
fields = [ct.default_int64_field_name, ct.default_float_vec_field_name]
collection_w.create_index(ct.default_float_vec_field_name, default_index_params)
assert collection_w.has_index()[0]
res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records')
collection_w.load()
collection_w.query(default_term_expr, output_fields=fields,
check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True})
@pytest.mark.tags(ct.CaseLabel.L2)
@pytest.mark.xfail(reason="issue #6594")
# todo
def test_query_output_binary_vec_field_after_index(self):
"""
target: test query output vec field after index
method: create index and specify vec field as output field
expected: return primary field and vec field
"""
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True)[0:2]
fields = [ct.default_int64_field_name, ct.default_binary_vec_field_name]
collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params)
assert collection_w.has_index()[0]
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name])
assert list(res[0].keys()) == fields
@pytest.mark.tags(ct.CaseLabel.L2) @pytest.mark.tags(ct.CaseLabel.L2)
def test_query_partition_repeatedly(self): def test_query_partition_repeatedly(self):
""" """
@ -705,22 +782,3 @@ class TestQueryOperation(TestcaseBase):
res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name]) res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name])
assert len(res) == 1 assert len(res) == 1
assert res[0][ct.default_int64_field_name] == half assert res[0][ct.default_int64_field_name] == half
# def insert_entities_into_two_partitions_in_half(self, half):
# """
# insert default entities into two partitions(partition_w and _default) in half(int64 and float fields values)
# :param half: half of nb
# :return: collection wrap and partition wrap
# """
# conn = self._connect()
# collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
# partition_w = self.init_partition_wrap(collection_wrap=collection_w)
# # insert [0, half) into partition_w
# df_partition = cf.gen_default_dataframe_data(nb=half, start=0)
# partition_w.insert(df_partition)
# # insert [half, nb) into _default
# df_default = cf.gen_default_dataframe_data(nb=half, start=half)
# collection_w.insert(df_default)
# conn.flush([collection_w.name])
# collection_w.load()
# return collection_w, partition_w, df_partition, df_default