mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-04 04:49:08 +08:00
Dedup output fields for task query (#18673)
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com> Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
8387e3e0e2
commit
d4c54d96b0
@ -7,6 +7,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
@ -53,35 +54,47 @@ type queryTask struct {
|
||||
|
||||
// translateOutputFields translates output fields name to output fields id.
|
||||
func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) {
|
||||
outputFieldIDs := make([]UniqueID, 0, len(outputFields))
|
||||
outputFieldIDs := make([]UniqueID, 0, len(outputFields)+1)
|
||||
if len(outputFields) == 0 {
|
||||
for _, field := range schema.Fields {
|
||||
if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
|
||||
if field.FieldID >= common.StartOfUserFieldID && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
|
||||
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
addPrimaryKey := false
|
||||
var pkFieldID UniqueID
|
||||
for _, field := range schema.Fields {
|
||||
if field.IsPrimaryKey {
|
||||
pkFieldID = field.FieldID
|
||||
}
|
||||
}
|
||||
for _, reqField := range outputFields {
|
||||
findField := false
|
||||
var fieldFound bool
|
||||
for _, field := range schema.Fields {
|
||||
if reqField == field.Name {
|
||||
if field.IsPrimaryKey {
|
||||
addPrimaryKey = true
|
||||
}
|
||||
findField = true
|
||||
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
||||
} else {
|
||||
if field.IsPrimaryKey && !addPrimaryKey {
|
||||
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
||||
addPrimaryKey = true
|
||||
fieldFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !findField {
|
||||
if !fieldFound {
|
||||
return nil, fmt.Errorf("field %s not exist", reqField)
|
||||
}
|
||||
}
|
||||
|
||||
// pk field needs to be in output field list
|
||||
var pkFound bool
|
||||
for _, outputField := range outputFieldIDs {
|
||||
if outputField == pkFieldID {
|
||||
pkFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !pkFound {
|
||||
outputFieldIDs = append(outputFieldIDs, pkFieldID)
|
||||
}
|
||||
|
||||
}
|
||||
return outputFieldIDs, nil
|
||||
}
|
||||
|
@ -196,3 +196,145 @@ func TestQueryTask_all(t *testing.T) {
|
||||
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
}
|
||||
|
||||
func Test_translateToOutputFieldIDs(t *testing.T) {
|
||||
type testCases struct {
|
||||
name string
|
||||
outputFields []string
|
||||
schema *schemapb.CollectionSchema
|
||||
expectedError bool
|
||||
expectedIDs []int64
|
||||
}
|
||||
|
||||
cases := []testCases{
|
||||
{
|
||||
name: "empty output fields",
|
||||
outputFields: []string{},
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: common.RowIDField,
|
||||
Name: common.RowIDFieldName,
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "ID",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "Vector",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
expectedIDs: []int64{100, 101},
|
||||
},
|
||||
{
|
||||
name: "nil output fields",
|
||||
outputFields: nil,
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: common.RowIDField,
|
||||
Name: common.RowIDFieldName,
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "ID",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "Vector",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
expectedIDs: []int64{100, 101},
|
||||
},
|
||||
{
|
||||
name: "full list",
|
||||
outputFields: []string{"ID", "Vector"},
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: common.RowIDField,
|
||||
Name: common.RowIDFieldName,
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "ID",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "Vector",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
expectedIDs: []int64{100, 101},
|
||||
},
|
||||
{
|
||||
name: "vector only",
|
||||
outputFields: []string{"Vector"},
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: common.RowIDField,
|
||||
Name: common.RowIDFieldName,
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "ID",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "Vector",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
expectedIDs: []int64{101, 100},
|
||||
},
|
||||
{
|
||||
name: "with field not exist",
|
||||
outputFields: []string{"ID", "Vector", "Extra"},
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: common.RowIDField,
|
||||
Name: common.RowIDFieldName,
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "ID",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "Vector",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ids, err := translateToOutputFieldIDs(tc.outputFields, tc.schema)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
require.Equal(t, len(tc.expectedIDs), len(ids))
|
||||
for idx, expectedID := range tc.expectedIDs {
|
||||
assert.Equal(t, expectedID, ids[idx])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -525,7 +525,7 @@ class TestQueryParams(TestcaseBase):
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
for fields in [None, []]:
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=fields)
|
||||
assert list(res[0].keys()) == [ct.default_int64_field_name]
|
||||
assert res[0].keys() == {ct.default_int64_field_name}
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_output_one_field(self):
|
||||
@ -649,7 +649,7 @@ class TestQueryParams(TestcaseBase):
|
||||
fields = [[ct.default_binary_vec_field_name], [ct.default_int64_field_name, ct.default_binary_vec_field_name]]
|
||||
for output_fields in fields:
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=output_fields)
|
||||
assert list(res[0].keys()) == fields[-1]
|
||||
assert res[0].keys() == set(fields[-1])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_output_primary_field(self):
|
||||
@ -660,7 +660,7 @@ class TestQueryParams(TestcaseBase):
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
|
||||
assert list(res[0].keys()) == [ct.default_int64_field_name]
|
||||
assert res[0].keys() == {ct.default_int64_field_name}
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_output_not_existed_field(self):
|
||||
@ -1097,7 +1097,7 @@ class TestQueryOperation(TestcaseBase):
|
||||
collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params)
|
||||
assert collection_w.has_index()[0]
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name])
|
||||
assert list(res[0].keys()) == fields
|
||||
assert res[0].keys() == set(fields)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_partition_repeatedly(self):
|
||||
@ -1237,7 +1237,7 @@ class TestqueryString(TestcaseBase):
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2]
|
||||
res, _ = collection_w.query(expression, output_fields=[ct.default_string_field_name])
|
||||
assert list(res[0].keys()) == [ct.default_string_field_name]
|
||||
assert res[0].keys() == {ct.default_string_field_name}
|
||||
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@ -1285,7 +1285,7 @@ class TestqueryString(TestcaseBase):
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_string_expr_with_prefixes(self):
|
||||
"""
|
||||
target: test query with
|
||||
target: test query with prefix string expression
|
||||
method: specify string is primary field, use prefix string expr
|
||||
expected: verify query successfully
|
||||
"""
|
||||
@ -1299,7 +1299,7 @@ class TestqueryString(TestcaseBase):
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_string_with_invaild_prefix_expr(self):
|
||||
"""
|
||||
target: test query with
|
||||
target: test query with invalid prefix string expression
|
||||
method: specify string primary field, use invaild prefix string expr
|
||||
expected: raise error
|
||||
"""
|
||||
@ -1312,7 +1312,7 @@ class TestqueryString(TestcaseBase):
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_compare_two_fields(self):
|
||||
"""
|
||||
target: test query with
|
||||
target: test query with bool expression comparing two fields
|
||||
method: specify string primary field, compare two fields
|
||||
expected: verify query successfully
|
||||
"""
|
||||
@ -1381,4 +1381,3 @@ class TestqueryString(TestcaseBase):
|
||||
check_items={exp_res: df_dict_list,
|
||||
"primary_field": default_int_field_name,
|
||||
"with_vec": True})
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user