mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-04 04:49:08 +08:00
Dedup output fields for task query (#18673)
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com> Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
8387e3e0e2
commit
d4c54d96b0
@ -7,6 +7,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/common"
|
||||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
@ -53,35 +54,47 @@ type queryTask struct {
|
|||||||
|
|
||||||
// translateOutputFields translates output fields name to output fields id.
|
// translateOutputFields translates output fields name to output fields id.
|
||||||
func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) {
|
func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) {
|
||||||
outputFieldIDs := make([]UniqueID, 0, len(outputFields))
|
outputFieldIDs := make([]UniqueID, 0, len(outputFields)+1)
|
||||||
if len(outputFields) == 0 {
|
if len(outputFields) == 0 {
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
|
if field.FieldID >= common.StartOfUserFieldID && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
|
||||||
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
addPrimaryKey := false
|
var pkFieldID UniqueID
|
||||||
|
for _, field := range schema.Fields {
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
pkFieldID = field.FieldID
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, reqField := range outputFields {
|
for _, reqField := range outputFields {
|
||||||
findField := false
|
var fieldFound bool
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if reqField == field.Name {
|
if reqField == field.Name {
|
||||||
if field.IsPrimaryKey {
|
|
||||||
addPrimaryKey = true
|
|
||||||
}
|
|
||||||
findField = true
|
|
||||||
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
||||||
} else {
|
fieldFound = true
|
||||||
if field.IsPrimaryKey && !addPrimaryKey {
|
break
|
||||||
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
|
||||||
addPrimaryKey = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !findField {
|
if !fieldFound {
|
||||||
return nil, fmt.Errorf("field %s not exist", reqField)
|
return nil, fmt.Errorf("field %s not exist", reqField)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pk field needs to be in output field list
|
||||||
|
var pkFound bool
|
||||||
|
for _, outputField := range outputFieldIDs {
|
||||||
|
if outputField == pkFieldID {
|
||||||
|
pkFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pkFound {
|
||||||
|
outputFieldIDs = append(outputFieldIDs, pkFieldID)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return outputFieldIDs, nil
|
return outputFieldIDs, nil
|
||||||
}
|
}
|
||||||
|
@ -196,3 +196,145 @@ func TestQueryTask_all(t *testing.T) {
|
|||||||
|
|
||||||
assert.NoError(t, task.PostExecute(ctx))
|
assert.NoError(t, task.PostExecute(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_translateToOutputFieldIDs(t *testing.T) {
|
||||||
|
type testCases struct {
|
||||||
|
name string
|
||||||
|
outputFields []string
|
||||||
|
schema *schemapb.CollectionSchema
|
||||||
|
expectedError bool
|
||||||
|
expectedIDs []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []testCases{
|
||||||
|
{
|
||||||
|
name: "empty output fields",
|
||||||
|
outputFields: []string{},
|
||||||
|
schema: &schemapb.CollectionSchema{
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
FieldID: common.RowIDField,
|
||||||
|
Name: common.RowIDFieldName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 100,
|
||||||
|
Name: "ID",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: "Vector",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
expectedIDs: []int64{100, 101},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil output fields",
|
||||||
|
outputFields: nil,
|
||||||
|
schema: &schemapb.CollectionSchema{
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
FieldID: common.RowIDField,
|
||||||
|
Name: common.RowIDFieldName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 100,
|
||||||
|
Name: "ID",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: "Vector",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
expectedIDs: []int64{100, 101},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full list",
|
||||||
|
outputFields: []string{"ID", "Vector"},
|
||||||
|
schema: &schemapb.CollectionSchema{
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
FieldID: common.RowIDField,
|
||||||
|
Name: common.RowIDFieldName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 100,
|
||||||
|
Name: "ID",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: "Vector",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
expectedIDs: []int64{100, 101},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "vector only",
|
||||||
|
outputFields: []string{"Vector"},
|
||||||
|
schema: &schemapb.CollectionSchema{
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
FieldID: common.RowIDField,
|
||||||
|
Name: common.RowIDFieldName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 100,
|
||||||
|
Name: "ID",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: "Vector",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
expectedIDs: []int64{101, 100},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with field not exist",
|
||||||
|
outputFields: []string{"ID", "Vector", "Extra"},
|
||||||
|
schema: &schemapb.CollectionSchema{
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
FieldID: common.RowIDField,
|
||||||
|
Name: common.RowIDFieldName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 100,
|
||||||
|
Name: "ID",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: "Vector",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
ids, err := translateToOutputFieldIDs(tc.outputFields, tc.schema)
|
||||||
|
if tc.expectedError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
require.Equal(t, len(tc.expectedIDs), len(ids))
|
||||||
|
for idx, expectedID := range tc.expectedIDs {
|
||||||
|
assert.Equal(t, expectedID, ids[idx])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -35,7 +35,7 @@ class TestQueryParams(TestcaseBase):
|
|||||||
test Query interface
|
test Query interface
|
||||||
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
def test_query_invalid(self):
|
def test_query_invalid(self):
|
||||||
"""
|
"""
|
||||||
@ -525,7 +525,7 @@ class TestQueryParams(TestcaseBase):
|
|||||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||||
for fields in [None, []]:
|
for fields in [None, []]:
|
||||||
res, _ = collection_w.query(default_term_expr, output_fields=fields)
|
res, _ = collection_w.query(default_term_expr, output_fields=fields)
|
||||||
assert list(res[0].keys()) == [ct.default_int64_field_name]
|
assert res[0].keys() == {ct.default_int64_field_name}
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L0)
|
@pytest.mark.tags(CaseLabel.L0)
|
||||||
def test_query_output_one_field(self):
|
def test_query_output_one_field(self):
|
||||||
@ -649,7 +649,7 @@ class TestQueryParams(TestcaseBase):
|
|||||||
fields = [[ct.default_binary_vec_field_name], [ct.default_int64_field_name, ct.default_binary_vec_field_name]]
|
fields = [[ct.default_binary_vec_field_name], [ct.default_int64_field_name, ct.default_binary_vec_field_name]]
|
||||||
for output_fields in fields:
|
for output_fields in fields:
|
||||||
res, _ = collection_w.query(default_term_expr, output_fields=output_fields)
|
res, _ = collection_w.query(default_term_expr, output_fields=output_fields)
|
||||||
assert list(res[0].keys()) == fields[-1]
|
assert res[0].keys() == set(fields[-1])
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_query_output_primary_field(self):
|
def test_query_output_primary_field(self):
|
||||||
@ -660,7 +660,7 @@ class TestQueryParams(TestcaseBase):
|
|||||||
"""
|
"""
|
||||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||||
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
|
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
|
||||||
assert list(res[0].keys()) == [ct.default_int64_field_name]
|
assert res[0].keys() == {ct.default_int64_field_name}
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
def test_query_output_not_existed_field(self):
|
def test_query_output_not_existed_field(self):
|
||||||
@ -1097,7 +1097,7 @@ class TestQueryOperation(TestcaseBase):
|
|||||||
collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params)
|
collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params)
|
||||||
assert collection_w.has_index()[0]
|
assert collection_w.has_index()[0]
|
||||||
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name])
|
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name])
|
||||||
assert list(res[0].keys()) == fields
|
assert res[0].keys() == set(fields)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
def test_query_partition_repeatedly(self):
|
def test_query_partition_repeatedly(self):
|
||||||
@ -1210,7 +1210,7 @@ class TestqueryString(TestcaseBase):
|
|||||||
The following cases are used to test query with string
|
The following cases are used to test query with string
|
||||||
******************************************************************
|
******************************************************************
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_query_string_is_not_primary(self):
|
def test_query_string_is_not_primary(self):
|
||||||
"""
|
"""
|
||||||
@ -1220,13 +1220,13 @@ class TestqueryString(TestcaseBase):
|
|||||||
query with string expr in string field is not primary
|
query with string expr in string field is not primary
|
||||||
expected: query successfully
|
expected: query successfully
|
||||||
"""
|
"""
|
||||||
|
|
||||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||||
res = vectors[0].iloc[:2, :3].to_dict('records')
|
res = vectors[0].iloc[:2, :3].to_dict('records')
|
||||||
output_fields = [default_float_field_name, default_string_field_name]
|
output_fields = [default_float_field_name, default_string_field_name]
|
||||||
collection_w.query(default_string_term_expr, output_fields=output_fields,
|
collection_w.query(default_string_term_expr, output_fields=output_fields,
|
||||||
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
@pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(default_string_field_name))
|
@pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(default_string_field_name))
|
||||||
def test_query_string_is_primary(self, expression):
|
def test_query_string_is_primary(self, expression):
|
||||||
@ -1237,7 +1237,7 @@ class TestqueryString(TestcaseBase):
|
|||||||
"""
|
"""
|
||||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2]
|
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2]
|
||||||
res, _ = collection_w.query(expression, output_fields=[ct.default_string_field_name])
|
res, _ = collection_w.query(expression, output_fields=[ct.default_string_field_name])
|
||||||
assert list(res[0].keys()) == [ct.default_string_field_name]
|
assert res[0].keys() == {ct.default_string_field_name}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
@ -1251,7 +1251,7 @@ class TestqueryString(TestcaseBase):
|
|||||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2]
|
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2]
|
||||||
res = vectors[0].iloc[:, 1:3].to_dict('records')
|
res = vectors[0].iloc[:, 1:3].to_dict('records')
|
||||||
output_fields = [default_float_field_name, default_string_field_name]
|
output_fields = [default_float_field_name, default_string_field_name]
|
||||||
collection_w.query(default_mix_expr, output_fields=output_fields,
|
collection_w.query(default_mix_expr, output_fields=output_fields,
|
||||||
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
||||||
|
|
||||||
|
|
||||||
@ -1267,12 +1267,12 @@ class TestqueryString(TestcaseBase):
|
|||||||
collection_w = self.init_collection_general(prefix, insert_data=True)[0]
|
collection_w = self.init_collection_general(prefix, insert_data=True)[0]
|
||||||
collection_w.query(expression, check_task=CheckTasks.err_res,
|
collection_w.query(expression, check_task=CheckTasks.err_res,
|
||||||
check_items={ct.err_code: 1, ct.err_msg: "type mismatch"})
|
check_items={ct.err_code: 1, ct.err_msg: "type mismatch"})
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_query_string_expr_with_binary(self):
|
def test_query_string_expr_with_binary(self):
|
||||||
"""
|
"""
|
||||||
target: test query string expr with binary
|
target: test query string expr with binary
|
||||||
method: query string expr with binary
|
method: query string expr with binary
|
||||||
expected: verify query successfully
|
expected: verify query successfully
|
||||||
"""
|
"""
|
||||||
collection_w, vectors= self.init_collection_general(prefix, insert_data=True, is_binary=True, is_index=True)[0:2]
|
collection_w, vectors= self.init_collection_general(prefix, insert_data=True, is_binary=True, is_index=True)[0:2]
|
||||||
@ -1285,7 +1285,7 @@ class TestqueryString(TestcaseBase):
|
|||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_query_string_expr_with_prefixes(self):
|
def test_query_string_expr_with_prefixes(self):
|
||||||
"""
|
"""
|
||||||
target: test query with
|
target: test query with prefix string expression
|
||||||
method: specify string is primary field, use prefix string expr
|
method: specify string is primary field, use prefix string expr
|
||||||
expected: verify query successfully
|
expected: verify query successfully
|
||||||
"""
|
"""
|
||||||
@ -1293,13 +1293,13 @@ class TestqueryString(TestcaseBase):
|
|||||||
res = vectors[0].iloc[:1, :3].to_dict('records')
|
res = vectors[0].iloc[:1, :3].to_dict('records')
|
||||||
expression = 'varchar like "0%"'
|
expression = 'varchar like "0%"'
|
||||||
output_fields = [default_int_field_name, default_float_field_name, default_string_field_name]
|
output_fields = [default_int_field_name, default_float_field_name, default_string_field_name]
|
||||||
collection_w.query(expression, output_fields=output_fields,
|
collection_w.query(expression, output_fields=output_fields,
|
||||||
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_query_string_with_invaild_prefix_expr(self):
|
def test_query_string_with_invaild_prefix_expr(self):
|
||||||
"""
|
"""
|
||||||
target: test query with
|
target: test query with invalid prefix string expression
|
||||||
method: specify string primary field, use invaild prefix string expr
|
method: specify string primary field, use invaild prefix string expr
|
||||||
expected: raise error
|
expected: raise error
|
||||||
"""
|
"""
|
||||||
@ -1312,7 +1312,7 @@ class TestqueryString(TestcaseBase):
|
|||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_query_compare_two_fields(self):
|
def test_query_compare_two_fields(self):
|
||||||
"""
|
"""
|
||||||
target: test query with
|
target: test query with bool expression comparing two fields
|
||||||
method: specify string primary field, compare two fields
|
method: specify string primary field, compare two fields
|
||||||
expected: verify query successfully
|
expected: verify query successfully
|
||||||
"""
|
"""
|
||||||
@ -1320,13 +1320,13 @@ class TestqueryString(TestcaseBase):
|
|||||||
res = []
|
res = []
|
||||||
expression = 'float > int64'
|
expression = 'float > int64'
|
||||||
output_fields = [default_int_field_name, default_float_field_name, default_string_field_name]
|
output_fields = [default_int_field_name, default_float_field_name, default_string_field_name]
|
||||||
collection_w.query(expression, output_fields=output_fields,
|
collection_w.query(expression, output_fields=output_fields,
|
||||||
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_query_compare_invalid_fields(self):
|
def test_query_compare_invalid_fields(self):
|
||||||
"""
|
"""
|
||||||
target: test query with
|
target: test query with
|
||||||
method: specify string primary field, compare string and int field
|
method: specify string primary field, compare string and int field
|
||||||
expected: raise error
|
expected: raise error
|
||||||
"""
|
"""
|
||||||
@ -1381,4 +1381,3 @@ class TestqueryString(TestcaseBase):
|
|||||||
check_items={exp_res: df_dict_list,
|
check_items={exp_res: df_dict_list,
|
||||||
"primary_field": default_int_field_name,
|
"primary_field": default_int_field_name,
|
||||||
"with_vec": True})
|
"with_vec": True})
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user