Dedup output fields for task query (#18673)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2022-08-16 18:16:48 +08:00 committed by GitHub
parent 8387e3e0e2
commit d4c54d96b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 189 additions and 35 deletions

View File

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -53,35 +54,47 @@ type queryTask struct {
// translateOutputFields translates output fields name to output fields id. // translateOutputFields translates output fields name to output fields id.
func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) { func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) {
outputFieldIDs := make([]UniqueID, 0, len(outputFields)) outputFieldIDs := make([]UniqueID, 0, len(outputFields)+1)
if len(outputFields) == 0 { if len(outputFields) == 0 {
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector { if field.FieldID >= common.StartOfUserFieldID && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
outputFieldIDs = append(outputFieldIDs, field.FieldID) outputFieldIDs = append(outputFieldIDs, field.FieldID)
} }
} }
} else { } else {
addPrimaryKey := false var pkFieldID UniqueID
for _, field := range schema.Fields {
if field.IsPrimaryKey {
pkFieldID = field.FieldID
}
}
for _, reqField := range outputFields { for _, reqField := range outputFields {
findField := false var fieldFound bool
for _, field := range schema.Fields { for _, field := range schema.Fields {
if reqField == field.Name { if reqField == field.Name {
if field.IsPrimaryKey {
addPrimaryKey = true
}
findField = true
outputFieldIDs = append(outputFieldIDs, field.FieldID) outputFieldIDs = append(outputFieldIDs, field.FieldID)
} else { fieldFound = true
if field.IsPrimaryKey && !addPrimaryKey { break
outputFieldIDs = append(outputFieldIDs, field.FieldID)
addPrimaryKey = true
}
} }
} }
if !findField { if !fieldFound {
return nil, fmt.Errorf("field %s not exist", reqField) return nil, fmt.Errorf("field %s not exist", reqField)
} }
} }
// pk field needs to be in output field list
var pkFound bool
for _, outputField := range outputFieldIDs {
if outputField == pkFieldID {
pkFound = true
break
}
}
if !pkFound {
outputFieldIDs = append(outputFieldIDs, pkFieldID)
}
} }
return outputFieldIDs, nil return outputFieldIDs, nil
} }

View File

@ -196,3 +196,145 @@ func TestQueryTask_all(t *testing.T) {
assert.NoError(t, task.PostExecute(ctx)) assert.NoError(t, task.PostExecute(ctx))
} }
func Test_translateToOutputFieldIDs(t *testing.T) {
type testCases struct {
name string
outputFields []string
schema *schemapb.CollectionSchema
expectedError bool
expectedIDs []int64
}
cases := []testCases{
{
name: "empty output fields",
outputFields: []string{},
schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: common.RowIDFieldName,
},
{
FieldID: 100,
Name: "ID",
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: "Vector",
},
},
},
expectedError: false,
expectedIDs: []int64{100, 101},
},
{
name: "nil output fields",
outputFields: nil,
schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: common.RowIDFieldName,
},
{
FieldID: 100,
Name: "ID",
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: "Vector",
},
},
},
expectedError: false,
expectedIDs: []int64{100, 101},
},
{
name: "full list",
outputFields: []string{"ID", "Vector"},
schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: common.RowIDFieldName,
},
{
FieldID: 100,
Name: "ID",
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: "Vector",
},
},
},
expectedError: false,
expectedIDs: []int64{100, 101},
},
{
name: "vector only",
outputFields: []string{"Vector"},
schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: common.RowIDFieldName,
},
{
FieldID: 100,
Name: "ID",
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: "Vector",
},
},
},
expectedError: false,
expectedIDs: []int64{101, 100},
},
{
name: "with field not exist",
outputFields: []string{"ID", "Vector", "Extra"},
schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: common.RowIDFieldName,
},
{
FieldID: 100,
Name: "ID",
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: "Vector",
},
},
},
expectedError: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ids, err := translateToOutputFieldIDs(tc.outputFields, tc.schema)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
require.Equal(t, len(tc.expectedIDs), len(ids))
for idx, expectedID := range tc.expectedIDs {
assert.Equal(t, expectedID, ids[idx])
}
}
})
}
}

View File

@ -35,7 +35,7 @@ class TestQueryParams(TestcaseBase):
test Query interface test Query interface
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None) query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
""" """
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_query_invalid(self): def test_query_invalid(self):
""" """
@ -525,7 +525,7 @@ class TestQueryParams(TestcaseBase):
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
for fields in [None, []]: for fields in [None, []]:
res, _ = collection_w.query(default_term_expr, output_fields=fields) res, _ = collection_w.query(default_term_expr, output_fields=fields)
assert list(res[0].keys()) == [ct.default_int64_field_name] assert res[0].keys() == {ct.default_int64_field_name}
@pytest.mark.tags(CaseLabel.L0) @pytest.mark.tags(CaseLabel.L0)
def test_query_output_one_field(self): def test_query_output_one_field(self):
@ -649,7 +649,7 @@ class TestQueryParams(TestcaseBase):
fields = [[ct.default_binary_vec_field_name], [ct.default_int64_field_name, ct.default_binary_vec_field_name]] fields = [[ct.default_binary_vec_field_name], [ct.default_int64_field_name, ct.default_binary_vec_field_name]]
for output_fields in fields: for output_fields in fields:
res, _ = collection_w.query(default_term_expr, output_fields=output_fields) res, _ = collection_w.query(default_term_expr, output_fields=output_fields)
assert list(res[0].keys()) == fields[-1] assert res[0].keys() == set(fields[-1])
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_output_primary_field(self): def test_query_output_primary_field(self):
@ -660,7 +660,7 @@ class TestQueryParams(TestcaseBase):
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name]) res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
assert list(res[0].keys()) == [ct.default_int64_field_name] assert res[0].keys() == {ct.default_int64_field_name}
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_query_output_not_existed_field(self): def test_query_output_not_existed_field(self):
@ -1097,7 +1097,7 @@ class TestQueryOperation(TestcaseBase):
collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params) collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params)
assert collection_w.has_index()[0] assert collection_w.has_index()[0]
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name]) res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name])
assert list(res[0].keys()) == fields assert res[0].keys() == set(fields)
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_query_partition_repeatedly(self): def test_query_partition_repeatedly(self):
@ -1210,7 +1210,7 @@ class TestqueryString(TestcaseBase):
The following cases are used to test query with string The following cases are used to test query with string
****************************************************************** ******************************************************************
""" """
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_string_is_not_primary(self): def test_query_string_is_not_primary(self):
""" """
@ -1220,13 +1220,13 @@ class TestqueryString(TestcaseBase):
query with string expr in string field is not primary query with string expr in string field is not primary
expected: query successfully expected: query successfully
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
res = vectors[0].iloc[:2, :3].to_dict('records') res = vectors[0].iloc[:2, :3].to_dict('records')
output_fields = [default_float_field_name, default_string_field_name] output_fields = [default_float_field_name, default_string_field_name]
collection_w.query(default_string_term_expr, output_fields=output_fields, collection_w.query(default_string_term_expr, output_fields=output_fields,
check_task=CheckTasks.check_query_results, check_items={exp_res: res}) check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(default_string_field_name)) @pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(default_string_field_name))
def test_query_string_is_primary(self, expression): def test_query_string_is_primary(self, expression):
@ -1237,7 +1237,7 @@ class TestqueryString(TestcaseBase):
""" """
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2]
res, _ = collection_w.query(expression, output_fields=[ct.default_string_field_name]) res, _ = collection_w.query(expression, output_fields=[ct.default_string_field_name])
assert list(res[0].keys()) == [ct.default_string_field_name] assert res[0].keys() == {ct.default_string_field_name}
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@ -1251,7 +1251,7 @@ class TestqueryString(TestcaseBase):
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2]
res = vectors[0].iloc[:, 1:3].to_dict('records') res = vectors[0].iloc[:, 1:3].to_dict('records')
output_fields = [default_float_field_name, default_string_field_name] output_fields = [default_float_field_name, default_string_field_name]
collection_w.query(default_mix_expr, output_fields=output_fields, collection_w.query(default_mix_expr, output_fields=output_fields,
check_task=CheckTasks.check_query_results, check_items={exp_res: res}) check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@ -1267,12 +1267,12 @@ class TestqueryString(TestcaseBase):
collection_w = self.init_collection_general(prefix, insert_data=True)[0] collection_w = self.init_collection_general(prefix, insert_data=True)[0]
collection_w.query(expression, check_task=CheckTasks.err_res, collection_w.query(expression, check_task=CheckTasks.err_res,
check_items={ct.err_code: 1, ct.err_msg: "type mismatch"}) check_items={ct.err_code: 1, ct.err_msg: "type mismatch"})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_string_expr_with_binary(self): def test_query_string_expr_with_binary(self):
""" """
target: test query string expr with binary target: test query string expr with binary
method: query string expr with binary method: query string expr with binary
expected: verify query successfully expected: verify query successfully
""" """
collection_w, vectors= self.init_collection_general(prefix, insert_data=True, is_binary=True, is_index=True)[0:2] collection_w, vectors= self.init_collection_general(prefix, insert_data=True, is_binary=True, is_index=True)[0:2]
@ -1285,7 +1285,7 @@ class TestqueryString(TestcaseBase):
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_string_expr_with_prefixes(self): def test_query_string_expr_with_prefixes(self):
""" """
target: test query with target: test query with prefix string expression
method: specify string is primary field, use prefix string expr method: specify string is primary field, use prefix string expr
expected: verify query successfully expected: verify query successfully
""" """
@ -1293,13 +1293,13 @@ class TestqueryString(TestcaseBase):
res = vectors[0].iloc[:1, :3].to_dict('records') res = vectors[0].iloc[:1, :3].to_dict('records')
expression = 'varchar like "0%"' expression = 'varchar like "0%"'
output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] output_fields = [default_int_field_name, default_float_field_name, default_string_field_name]
collection_w.query(expression, output_fields=output_fields, collection_w.query(expression, output_fields=output_fields,
check_task=CheckTasks.check_query_results, check_items={exp_res: res}) check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_string_with_invaild_prefix_expr(self): def test_query_string_with_invaild_prefix_expr(self):
""" """
target: test query with target: test query with invalid prefix string expression
method: specify string primary field, use invaild prefix string expr method: specify string primary field, use invaild prefix string expr
expected: raise error expected: raise error
""" """
@ -1312,7 +1312,7 @@ class TestqueryString(TestcaseBase):
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_compare_two_fields(self): def test_query_compare_two_fields(self):
""" """
target: test query with target: test query with bool expression comparing two fields
method: specify string primary field, compare two fields method: specify string primary field, compare two fields
expected: verify query successfully expected: verify query successfully
""" """
@ -1320,13 +1320,13 @@ class TestqueryString(TestcaseBase):
res = [] res = []
expression = 'float > int64' expression = 'float > int64'
output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] output_fields = [default_int_field_name, default_float_field_name, default_string_field_name]
collection_w.query(expression, output_fields=output_fields, collection_w.query(expression, output_fields=output_fields,
check_task=CheckTasks.check_query_results, check_items={exp_res: res}) check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_compare_invalid_fields(self): def test_query_compare_invalid_fields(self):
""" """
target: test query with target: test query with
method: specify string primary field, compare string and int field method: specify string primary field, compare string and int field
expected: raise error expected: raise error
""" """
@ -1381,4 +1381,3 @@ class TestqueryString(TestcaseBase):
check_items={exp_res: df_dict_list, check_items={exp_res: df_dict_list,
"primary_field": default_int_field_name, "primary_field": default_int_field_name,
"with_vec": True}) "with_vec": True})