Rearrange search/queryTask for readability (#16325)

`searchTask` has 683 lines of code, `queryTask` has 485 lines of code.
`task.go` contains 4k+ codes including `searchTask` and `queryTask`.

It was so much pain navigating codes of searchTask and queryTask though task.go,
task.go and task_test.go are literaly unreadable.

This PR moves
1. 650+ lines of code of `searchTask` from `task.go` to `task_search.go`.
2. 1.4k+ lines of test code of `searchTask` from `task_test.go` to
   `task_search_test.go`.
3. 450+ lines of code of `queryTask` from `task.go` to `task_query.go`.
4. 200+ lines of test code of `queryTask from `task_test.go to
   `task_query_test.go`.

This PR also rearrange methods positions of `searchTask` and
`queryTask`:
-  Putting the most important methods `PreExecute`, `Execute`, and
   `PosExecute` at the beginning of the file.
-  Moves interface methods `ID`, `SetID`, `Type`, `BeginTs`, etc.
   that nobody cares about to the bottom of the file.

See also: #16298

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
This commit is contained in:
XuanYang-cn 2022-04-01 18:59:29 +08:00 committed by GitHub
parent 7a44fff8cd
commit e9090a62ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 2848 additions and 2757 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,485 @@
package proxy
import (
"context"
"errors"
"fmt"
"regexp"
"strings"
"time"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type queryTask struct {
Condition
*internalpb.RetrieveRequest
ctx context.Context
resultBuf chan []*internalpb.RetrieveResults
result *milvuspb.QueryResults
query *milvuspb.QueryRequest
chMgr channelsMgr
qc types.QueryCoord
ids *schemapb.IDs
collectionName string
collectionID UniqueID
}
func (qt *queryTask) PreExecute(ctx context.Context) error {
qt.Base.MsgType = commonpb.MsgType_Retrieve
qt.Base.SourceID = Params.ProxyCfg.ProxyID
collectionName := qt.query.CollectionName
if err := validateCollectionName(qt.query.CollectionName); err != nil {
log.Warn("Invalid collection name.", zap.String("collectionName", collectionName),
zap.Int64("requestID", qt.Base.MsgID), zap.String("requestType", "query"))
return err
}
log.Info("Validate collection name.", zap.Any("collectionName", collectionName),
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
info, err := globalMetaCache.GetCollectionInfo(ctx, collectionName)
if err != nil {
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
return err
}
qt.collectionName = info.schema.Name
log.Info("Get collection id by name.", zap.Any("collectionName", collectionName),
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
for _, tag := range qt.query.PartitionNames {
if err := validatePartitionTag(tag, false); err != nil {
log.Debug("Invalid partition name.", zap.Any("partitionName", tag),
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
return err
}
}
log.Info("Validate partition names.",
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
// check if collection was already loaded into query node
showResp, err := qt.qc.ShowCollections(qt.ctx, &querypb.ShowCollectionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: qt.Base.MsgID,
Timestamp: qt.Base.Timestamp,
SourceID: Params.ProxyCfg.ProxyID,
},
DbID: 0, // TODO(dragondriver)
})
if err != nil {
return err
}
if showResp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(showResp.Status.Reason)
}
log.Debug("QueryCoord show collections",
zap.Any("collections", showResp.CollectionIDs),
zap.Any("collID", info.collID))
collectionLoaded := false
for _, collID := range showResp.CollectionIDs {
if info.collID == collID {
collectionLoaded = true
break
}
}
if !collectionLoaded {
return fmt.Errorf("collection %v was not loaded into memory", collectionName)
}
schema, _ := globalMetaCache.GetCollectionSchema(ctx, qt.query.CollectionName)
if qt.ids != nil {
pkField := ""
for _, field := range schema.Fields {
if field.IsPrimaryKey {
pkField = field.Name
}
}
qt.query.Expr = IDs2Expr(pkField, qt.ids.GetIntId().Data)
}
if qt.query.Expr == "" {
errMsg := "Query expression is empty"
return fmt.Errorf(errMsg)
}
plan, err := createExprPlan(schema, qt.query.Expr)
if err != nil {
return err
}
qt.query.OutputFields, err = translateOutputFields(qt.query.OutputFields, schema, true)
if err != nil {
return err
}
log.Debug("translate output fields", zap.Any("OutputFields", qt.query.OutputFields))
if len(qt.query.OutputFields) == 0 {
for _, field := range schema.Fields {
if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
qt.OutputFieldsId = append(qt.OutputFieldsId, field.FieldID)
}
}
} else {
addPrimaryKey := false
for _, reqField := range qt.query.OutputFields {
findField := false
for _, field := range schema.Fields {
if reqField == field.Name {
if field.IsPrimaryKey {
addPrimaryKey = true
}
findField = true
qt.OutputFieldsId = append(qt.OutputFieldsId, field.FieldID)
plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
} else {
if field.IsPrimaryKey && !addPrimaryKey {
qt.OutputFieldsId = append(qt.OutputFieldsId, field.FieldID)
plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
addPrimaryKey = true
}
}
}
if !findField {
errMsg := "Field " + reqField + " not exist"
return errors.New(errMsg)
}
}
}
log.Debug("translate output fields to field ids", zap.Any("OutputFieldsID", qt.OutputFieldsId))
qt.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}
travelTimestamp := qt.query.TravelTimestamp
if travelTimestamp == 0 {
travelTimestamp = qt.BeginTs()
} else {
durationSeconds := tsoutil.CalculateDuration(qt.BeginTs(), travelTimestamp) / 1000
if durationSeconds > Params.CommonCfg.RetentionDuration {
duration := time.Second * time.Duration(durationSeconds)
return fmt.Errorf("only support to travel back to %s so far", duration.String())
}
}
guaranteeTimestamp := qt.query.GuaranteeTimestamp
if guaranteeTimestamp == 0 {
guaranteeTimestamp = qt.BeginTs()
}
qt.TravelTimestamp = travelTimestamp
qt.GuaranteeTimestamp = guaranteeTimestamp
deadline, ok := qt.TraceCtx().Deadline()
if ok {
qt.RetrieveRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
qt.ResultChannelID = Params.ProxyCfg.RetrieveResultChannelNames[0]
qt.DbID = 0 // todo(yukun)
qt.CollectionID = info.collID
qt.PartitionIDs = make([]UniqueID, 0)
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
if err != nil {
log.Debug("Failed to get partitions in collection.", zap.Any("collectionName", collectionName),
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
return err
}
log.Info("Get partitions in collection.", zap.Any("collectionName", collectionName),
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
partitionsRecord := make(map[UniqueID]bool)
for _, partitionName := range qt.query.PartitionNames {
pattern := fmt.Sprintf("^%s$", partitionName)
re, err := regexp.Compile(pattern)
if err != nil {
log.Debug("Failed to compile partition name regex expression.", zap.Any("partitionName", partitionName),
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
return errors.New("invalid partition names")
}
found := false
for name, pID := range partitionsMap {
if re.MatchString(name) {
if _, exist := partitionsRecord[pID]; !exist {
qt.PartitionIDs = append(qt.PartitionIDs, pID)
partitionsRecord[pID] = true
}
found = true
}
}
if !found {
// FIXME(wxyu): undefined behavior
errMsg := fmt.Sprintf("PartitonName: %s not found", partitionName)
return errors.New(errMsg)
}
}
log.Info("Query PreExecute done.",
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
return nil
}
func (qt *queryTask) Execute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute query %d", qt.ID()))
defer tr.Elapse("done")
var tsMsg msgstream.TsMsg = &msgstream.RetrieveMsg{
RetrieveRequest: *qt.RetrieveRequest,
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
HashValues: []uint32{uint32(Params.ProxyCfg.ProxyID)},
BeginTimestamp: qt.Base.Timestamp,
EndTimestamp: qt.Base.Timestamp,
},
}
msgPack := msgstream.MsgPack{
BeginTs: qt.Base.Timestamp,
EndTs: qt.Base.Timestamp,
Msgs: make([]msgstream.TsMsg, 1),
}
msgPack.Msgs[0] = tsMsg
stream, err := qt.chMgr.getDQLStream(qt.CollectionID)
if err != nil {
err = qt.chMgr.createDQLStream(qt.CollectionID)
if err != nil {
qt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
qt.result.Status.Reason = err.Error()
return err
}
stream, err = qt.chMgr.getDQLStream(qt.CollectionID)
if err != nil {
qt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
qt.result.Status.Reason = err.Error()
return err
}
}
tr.Record("get used message stream")
err = stream.Produce(&msgPack)
if err != nil {
log.Debug("Failed to send retrieve request.",
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
}
log.Debug("proxy sent one retrieveMsg",
zap.Int64("collectionID", qt.CollectionID),
zap.Int64("msgID", tsMsg.ID()),
zap.Int("length of search msg", len(msgPack.Msgs)),
zap.Uint64("timeoutTs", qt.RetrieveRequest.TimeoutTimestamp))
tr.Record("send retrieve request to message stream")
log.Info("Query Execute done.",
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
return err
}
func (qt *queryTask) PostExecute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder("queryTask PostExecute")
defer func() {
tr.Elapse("done")
}()
select {
case <-qt.TraceCtx().Done():
log.Debug("proxy", zap.Int64("Query: wait to finish failed, timeout!, taskID:", qt.ID()))
return fmt.Errorf("queryTask:wait to finish failed, timeout : %d", qt.ID())
case retrieveResults := <-qt.resultBuf:
filterRetrieveResults := make([]*internalpb.RetrieveResults, 0)
var reason string
for _, partialRetrieveResult := range retrieveResults {
if partialRetrieveResult.Status.ErrorCode == commonpb.ErrorCode_Success {
filterRetrieveResults = append(filterRetrieveResults, partialRetrieveResult)
} else {
reason += partialRetrieveResult.Status.Reason + "\n"
}
}
if len(filterRetrieveResults) == 0 {
qt.result = &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: reason,
},
CollectionName: qt.collectionName,
}
log.Debug("Query failed on all querynodes.",
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
return errors.New(reason)
}
var err error
qt.result, err = mergeRetrieveResults(filterRetrieveResults)
if err != nil {
return err
}
qt.result.CollectionName = qt.collectionName
if len(qt.result.FieldsData) > 0 {
qt.result.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
} else {
log.Info("Query result is nil", zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
qt.result.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_EmptyCollection,
Reason: reason,
}
return nil
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, qt.query.CollectionName)
if err != nil {
return err
}
for i := 0; i < len(qt.result.FieldsData); i++ {
for _, field := range schema.Fields {
if field.FieldID == qt.OutputFieldsId[i] {
qt.result.FieldsData[i].FieldName = field.Name
qt.result.FieldsData[i].FieldId = field.FieldID
qt.result.FieldsData[i].Type = field.DataType
}
}
}
}
log.Info("Query PostExecute done", zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
return nil
}
func (qt *queryTask) getChannels() ([]pChan, error) {
collID, err := globalMetaCache.GetCollectionID(qt.ctx, qt.query.CollectionName)
if err != nil {
return nil, err
}
var channels []pChan
channels, err = qt.chMgr.getChannels(collID)
if err != nil {
err := qt.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
return qt.chMgr.getChannels(collID)
}
return channels, nil
}
func (qt *queryTask) getVChannels() ([]vChan, error) {
collID, err := globalMetaCache.GetCollectionID(qt.ctx, qt.query.CollectionName)
if err != nil {
return nil, err
}
var channels []vChan
channels, err = qt.chMgr.getVChannels(collID)
if err != nil {
err := qt.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
return qt.chMgr.getVChannels(collID)
}
return channels, nil
}
// IDs2Expr converts ids slices to bool expresion with specified field name
func IDs2Expr(fieldName string, ids []int64) string {
idsStr := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids)), ", "), "[]")
return fieldName + " in [ " + idsStr + " ]"
}
func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) {
var ret *milvuspb.QueryResults
var skipDupCnt int64
var idSet = make(map[int64]struct{})
// merge results and remove duplicates
for _, rr := range retrieveResults {
// skip empty result, it will break merge result
if rr == nil || rr.Ids == nil || rr.Ids.GetIntId() == nil || len(rr.Ids.GetIntId().Data) == 0 {
continue
}
if ret == nil {
ret = &milvuspb.QueryResults{
FieldsData: make([]*schemapb.FieldData, len(rr.FieldsData)),
}
}
if len(ret.FieldsData) != len(rr.FieldsData) {
return nil, fmt.Errorf("mismatch FieldData in proxy RetrieveResults, expect %d get %d", len(ret.FieldsData), len(rr.FieldsData))
}
for i, id := range rr.Ids.GetIntId().GetData() {
if _, ok := idSet[id]; !ok {
typeutil.AppendFieldData(ret.FieldsData, rr.FieldsData, int64(i))
idSet[id] = struct{}{}
} else {
// primary keys duplicate
skipDupCnt++
}
}
}
log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
if ret == nil {
ret = &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{},
}
}
return ret, nil
}
func (qt *queryTask) TraceCtx() context.Context {
return qt.ctx
}
func (qt *queryTask) ID() UniqueID {
return qt.Base.MsgID
}
func (qt *queryTask) SetID(uid UniqueID) {
qt.Base.MsgID = uid
}
func (qt *queryTask) Name() string {
return RetrieveTaskName
}
func (qt *queryTask) Type() commonpb.MsgType {
return qt.Base.MsgType
}
func (qt *queryTask) BeginTs() Timestamp {
return qt.Base.Timestamp
}
func (qt *queryTask) EndTs() Timestamp {
return qt.Base.Timestamp
}
func (qt *queryTask) SetTs(ts Timestamp) {
qt.Base.Timestamp = ts
}
func (qt *queryTask) OnEnqueue() error {
qt.Base.MsgType = commonpb.MsgType_Retrieve
return nil
}

View File

@ -0,0 +1,246 @@
package proxy
import (
"context"
"fmt"
"strconv"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestQueryTask_all(t *testing.T) {
var err error
Params.Init()
Params.ProxyCfg.RetrieveResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestQueryTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testDoubleField: schemapb.DataType_Double,
testFloatVecField: schemapb.DataType_FloatVector,
}
if enableMultipleVectorFields {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
expr := fmt.Sprintf("%s > 0", testInt64Field)
hitNum := 10
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
DbID: 0,
CollectionID: collectionID,
Schema: nil,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
task := &queryTask{
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
ResultChannelID: strconv.Itoa(int(Params.ProxyCfg.ProxyID)),
DbID: 0,
CollectionID: collectionID,
PartitionIDs: nil,
SerializedExprPlan: nil,
OutputFieldsId: make([]int64, len(fieldName2Types)),
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
},
ctx: ctx,
resultBuf: make(chan []*internalpb.RetrieveResults),
result: &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
FieldsData: nil,
},
query: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
DbName: dbName,
CollectionName: collectionName,
Expr: expr,
OutputFields: nil,
PartitionNames: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
},
chMgr: chMgr,
qc: qc,
ids: nil,
}
for i := 0; i < len(fieldName2Types); i++ {
task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
}
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err = chMgr.createDQLStream(collectionID)
assert.NoError(t, err)
stream, err := chMgr.getDQLStream(collectionID)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
consumeCtx, cancel := context.WithCancel(ctx)
go func() {
defer wg.Done()
for {
select {
case <-consumeCtx.Done():
return
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.RetrieveMsg)
assert.True(t, ok)
// TODO(dragondriver): construct result according to the request
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RetrieveResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: strconv.Itoa(int(Params.ProxyCfg.ProxyID)),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: generateInt64Array(hitNum),
},
},
},
SealedSegmentIDsRetrieved: nil,
ChannelIDsRetrieved: nil,
GlobalSealedSegmentIDs: nil,
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), hitNum))
fieldID++
}
// send search result
task.resultBuf <- []*internalpb.RetrieveResults{result1}
}
}
}
}()
assert.NoError(t, task.OnEnqueue())
// test query task with timeout
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
defer cancel1()
// before preExecute
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
task.ctx = ctx1
assert.NoError(t, task.PreExecute(ctx))
// after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
task.ctx = ctx
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
cancel()
wg.Wait()
}

View File

@ -0,0 +1,683 @@
package proxy
import (
"context"
"errors"
"fmt"
"regexp"
"strconv"
"time"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
)
type searchTask struct {
Condition
*internalpb.SearchRequest
ctx context.Context
resultBuf chan []*internalpb.SearchResults
result *milvuspb.SearchResults
query *milvuspb.SearchRequest
chMgr channelsMgr
qc types.QueryCoord
collectionName string
tr *timerecord.TimeRecorder
collectionID UniqueID
}
func (st *searchTask) PreExecute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(st.TraceCtx(), "Proxy-Search-PreExecute")
defer sp.Finish()
st.Base.MsgType = commonpb.MsgType_Search
st.Base.SourceID = Params.ProxyCfg.ProxyID
collectionName := st.query.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
if err != nil { // err is not nil if collection not exists
return err
}
st.collectionID = collID
if err := validateCollectionName(st.query.CollectionName); err != nil {
return err
}
for _, tag := range st.query.PartitionNames {
if err := validatePartitionTag(tag, false); err != nil {
return err
}
}
// check if collection was already loaded into query node
showResp, err := st.qc.ShowCollections(st.ctx, &querypb.ShowCollectionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: st.Base.MsgID,
Timestamp: st.Base.Timestamp,
SourceID: Params.ProxyCfg.ProxyID,
},
DbID: 0, // TODO(dragondriver)
})
if err != nil {
return err
}
if showResp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(showResp.Status.Reason)
}
log.Debug("successfully get collections from QueryCoord",
zap.String("target collection name", collectionName),
zap.Int64("target collection ID", collID),
zap.Any("collections", showResp.CollectionIDs),
)
collectionLoaded := false
for _, collectionID := range showResp.CollectionIDs {
if collectionID == collID {
collectionLoaded = true
break
}
}
if !collectionLoaded {
return fmt.Errorf("collection %v was not loaded into memory", collectionName)
}
// TODO(dragondriver): necessary to check if partition was loaded into query node?
st.Base.MsgType = commonpb.MsgType_Search
schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName)
outputFields, err := translateOutputFields(st.query.OutputFields, schema, false)
if err != nil {
return err
}
log.Debug("translate output fields", zap.Any("OutputFields", outputFields))
st.query.OutputFields = outputFields
if st.query.GetDslType() == commonpb.DslType_BoolExprV1 {
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, st.query.SearchParams)
if err != nil {
return errors.New(AnnsFieldKey + " not found in search_params")
}
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, st.query.SearchParams)
if err != nil {
return errors.New(TopKKey + " not found in search_params")
}
topK, err := strconv.Atoi(topKStr)
if err != nil {
return errors.New(TopKKey + " " + topKStr + " is not invalid")
}
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(MetricTypeKey, st.query.SearchParams)
if err != nil {
return errors.New(MetricTypeKey + " not found in search_params")
}
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, st.query.SearchParams)
if err != nil {
return errors.New(SearchParamsKey + " not found in search_params")
}
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, st.query.SearchParams)
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err := strconv.Atoi(roundDecimalStr)
if err != nil {
return errors.New(RoundDecimalKey + " " + roundDecimalStr + " is not invalid")
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return errors.New(RoundDecimalKey + " " + roundDecimalStr + " is not invalid")
}
queryInfo := &planpb.QueryInfo{
Topk: int64(topK),
MetricType: metricType,
SearchParams: searchParams,
RoundDecimal: int64(roundDecimal),
}
log.Debug("create query plan",
//zap.Any("schema", schema),
zap.String("dsl", st.query.Dsl),
zap.String("anns field", annsField),
zap.Any("query info", queryInfo))
plan, err := createQueryPlan(schema, st.query.Dsl, annsField, queryInfo)
if err != nil {
log.Debug("failed to create query plan",
zap.Error(err),
//zap.Any("schema", schema),
zap.String("dsl", st.query.Dsl),
zap.String("anns field", annsField),
zap.Any("query info", queryInfo))
return fmt.Errorf("failed to create query plan: %v", err)
}
for _, name := range st.query.OutputFields {
hitField := false
for _, field := range schema.Fields {
if field.Name == name {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
return errors.New("search doesn't support vector field as output_fields")
}
st.SearchRequest.OutputFieldsId = append(st.SearchRequest.OutputFieldsId, field.FieldID)
plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
hitField = true
break
}
}
if !hitField {
errMsg := "Field " + name + " not exist"
return errors.New(errMsg)
}
}
st.SearchRequest.DslType = commonpb.DslType_BoolExprV1
st.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}
log.Debug("Proxy::searchTask::PreExecute", zap.Any("plan.OutputFieldIds", plan.OutputFieldIds),
zap.Any("plan", plan.String()))
}
travelTimestamp := st.query.TravelTimestamp
if travelTimestamp == 0 {
travelTimestamp = st.BeginTs()
} else {
durationSeconds := tsoutil.CalculateDuration(st.BeginTs(), travelTimestamp) / 1000
if durationSeconds > Params.CommonCfg.RetentionDuration {
duration := time.Second * time.Duration(durationSeconds)
return fmt.Errorf("only support to travel back to %s so far", duration.String())
}
}
guaranteeTimestamp := st.query.GuaranteeTimestamp
if guaranteeTimestamp == 0 {
guaranteeTimestamp = st.BeginTs()
}
st.SearchRequest.TravelTimestamp = travelTimestamp
st.SearchRequest.GuaranteeTimestamp = guaranteeTimestamp
deadline, ok := st.TraceCtx().Deadline()
if ok {
st.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
st.SearchRequest.ResultChannelID = Params.ProxyCfg.SearchResultChannelNames[0]
st.SearchRequest.DbID = 0 // todo
st.SearchRequest.CollectionID = collID
st.SearchRequest.PartitionIDs = make([]UniqueID, 0)
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
if err != nil {
return err
}
partitionsRecord := make(map[UniqueID]bool)
for _, partitionName := range st.query.PartitionNames {
pattern := fmt.Sprintf("^%s$", partitionName)
re, err := regexp.Compile(pattern)
if err != nil {
return errors.New("invalid partition names")
}
found := false
for name, pID := range partitionsMap {
if re.MatchString(name) {
if _, exist := partitionsRecord[pID]; !exist {
st.PartitionIDs = append(st.PartitionIDs, pID)
partitionsRecord[pID] = true
}
found = true
}
}
if !found {
errMsg := fmt.Sprintf("PartitonName: %s not found", partitionName)
return errors.New(errMsg)
}
}
st.SearchRequest.Dsl = st.query.Dsl
st.SearchRequest.PlaceholderGroup = st.query.PlaceholderGroup
return nil
}
func (st *searchTask) Execute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(st.TraceCtx(), "Proxy-Search-Execute")
defer sp.Finish()
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", st.ID()))
defer tr.Elapse("done")
var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{
SearchRequest: *st.SearchRequest,
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
HashValues: []uint32{uint32(Params.ProxyCfg.ProxyID)},
BeginTimestamp: st.Base.Timestamp,
EndTimestamp: st.Base.Timestamp,
},
}
msgPack := msgstream.MsgPack{
BeginTs: st.Base.Timestamp,
EndTs: st.Base.Timestamp,
Msgs: make([]msgstream.TsMsg, 1),
}
msgPack.Msgs[0] = tsMsg
collectionName := st.query.CollectionName
info, err := globalMetaCache.GetCollectionInfo(ctx, collectionName)
if err != nil { // err is not nil if collection not exists
return err
}
st.collectionName = info.schema.Name
stream, err := st.chMgr.getDQLStream(info.collID)
if err != nil {
err = st.chMgr.createDQLStream(info.collID)
if err != nil {
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}
return err
}
stream, err = st.chMgr.getDQLStream(info.collID)
if err != nil {
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}
return err
}
}
tr.Record("get used message stream")
err = stream.Produce(&msgPack)
if err != nil {
log.Debug("proxy", zap.String("send search request failed", err.Error()))
}
st.tr.Record("send message done")
log.Debug("proxy sent one searchMsg",
zap.Int64("collectionID", st.CollectionID),
zap.Int64("msgID", tsMsg.ID()),
zap.Int("length of search msg", len(msgPack.Msgs)),
zap.Uint64("timeoutTs", st.SearchRequest.TimeoutTimestamp))
sendMsgDur := tr.Record("send search msg to message stream")
metrics.ProxySendMessageLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
metrics.SearchLabel).Observe(float64(sendMsgDur.Milliseconds()))
return err
}
func (st *searchTask) PostExecute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(st.TraceCtx(), "Proxy-Search-PostExecute")
defer sp.Finish()
tr := timerecord.NewTimeRecorder("searchTask PostExecute")
defer func() {
tr.Elapse("done")
}()
for {
select {
case <-st.TraceCtx().Done():
log.Debug("Proxy searchTask PostExecute Loop exit caused by ctx.Done", zap.Int64("taskID", st.ID()))
return fmt.Errorf("searchTask:wait to finish failed, timeout: %d", st.ID())
case searchResults := <-st.resultBuf:
// fmt.Println("searchResults: ", searchResults)
filterSearchResults := make([]*internalpb.SearchResults, 0)
var filterReason string
errNum := 0
for _, partialSearchResult := range searchResults {
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success {
filterSearchResults = append(filterSearchResults, partialSearchResult)
// For debugging, please don't delete.
// printSearchResult(partialSearchResult)
} else {
errNum++
filterReason += partialSearchResult.Status.Reason + "\n"
}
}
log.Debug("Proxy Search PostExecute stage1",
zap.Any("len(filterSearchResults)", len(filterSearchResults)))
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), metrics.SearchLabel).Observe(float64(st.tr.RecordSpan().Milliseconds()))
tr.Record("Proxy Search PostExecute stage1 done")
if len(filterSearchResults) <= 0 || errNum > 0 {
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: filterReason,
},
CollectionName: st.collectionName,
}
return fmt.Errorf("QueryNode search fail, reason %s: id %d", filterReason, st.ID())
}
tr.Record("decodeResultStart")
validSearchResults, err := decodeSearchResults(filterSearchResults)
if err != nil {
return err
}
metrics.ProxyDecodeSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
log.Debug("Proxy Search PostExecute stage2", zap.Any("len(validSearchResults)", len(validSearchResults)))
if len(validSearchResults) <= 0 {
filterReason += "empty search result\n"
log.Debug("Proxy Search PostExecute stage2 failed", zap.Any("filterReason", filterReason))
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: filterReason,
},
Results: &schemapb.SearchResultData{
NumQueries: searchResults[0].NumQueries,
Topks: make([]int64, searchResults[0].NumQueries),
},
CollectionName: st.collectionName,
}
return nil
}
tr.Record("reduceResultStart")
st.result, err = reduceSearchResultData(validSearchResults, searchResults[0].NumQueries, searchResults[0].TopK, searchResults[0].MetricType)
if err != nil {
return err
}
metrics.ProxyReduceSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), metrics.SuccessLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
st.result.CollectionName = st.collectionName
schema, err := globalMetaCache.GetCollectionSchema(ctx, st.query.CollectionName)
if err != nil {
return err
}
if len(st.query.OutputFields) != 0 && len(st.result.Results.FieldsData) != 0 {
for k, fieldName := range st.query.OutputFields {
for _, field := range schema.Fields {
if st.result.Results.FieldsData[k] != nil && field.Name == fieldName {
st.result.Results.FieldsData[k].FieldName = field.Name
st.result.Results.FieldsData[k].FieldId = field.FieldID
st.result.Results.FieldsData[k].Type = field.DataType
}
}
}
}
return nil
}
}
}
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
tr := timerecord.NewTimeRecorder("decodeSearchResults")
results := make([]*schemapb.SearchResultData, 0)
for _, partialSearchResult := range searchResults {
if partialSearchResult.SlicedBlob == nil {
continue
}
var partialResultData schemapb.SearchResultData
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
if err != nil {
return nil, err
}
results = append(results, &partialResultData)
}
tr.Elapse("decodeSearchResults done")
return results, nil
}
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
if data.NumQueries != nq {
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
}
if data.TopK != topk {
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
}
if len(data.Ids.GetIntId().Data) != (int)(nq*topk) {
return fmt.Errorf("search result's id length %d invalid", len(data.Ids.GetIntId().Data))
}
if len(data.Scores) != (int)(nq*topk) {
return fmt.Errorf("search result's score length %d invalid", len(data.Scores))
}
return nil
}
func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, qi int64) int {
sel := -1
maxDistance := minFloat32
for i, offset := range offsets { // query num, the number of ways to merge
if offset >= topk {
continue
}
idx := qi*topk + offset
id := dataArray[i].Ids.GetIntId().Data[idx]
if id != -1 {
distance := dataArray[i].Scores[idx]
if distance > maxDistance {
sel = i
maxDistance = distance
}
}
}
return sel
}
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.Elapse("done")
}()
log.Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: 0,
},
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
Scores: make([]float32, 0),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0),
},
},
},
Topks: make([]int64, 0),
},
}
for i, sData := range searchResultData {
log.Debug("reduceSearchResultData",
zap.Int("i", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Any("len(FieldsData)", len(sData.FieldsData)))
if err := checkSearchResultData(sData, nq, topk); err != nil {
return ret, err
}
//printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
}
var skipDupCnt int64
var realTopK int64 = -1
for i := int64(0); i < nq; i++ {
offsets := make([]int64, len(searchResultData))
var idSet = make(map[int64]struct{})
var j int64
for j = 0; j < topk; {
sel := selectSearchResultData(searchResultData, offsets, topk, i)
if sel == -1 {
break
}
idx := i*topk + offsets[sel]
id := searchResultData[sel].Ids.GetIntId().Data[idx]
score := searchResultData[sel].Scores[idx]
// ignore invalid search result
if id == -1 {
continue
}
// remove duplicates
if _, ok := idSet[id]; !ok {
typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx)
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id)
ret.Results.Scores = append(ret.Results.Scores, score)
idSet[id] = struct{}{}
j++
} else {
// skip entity with same id
skipDupCnt++
}
offsets[sel]++
}
if realTopK != -1 && realTopK != j {
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
}
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
ret.Results.TopK = realTopK
if !distance.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}
}
return ret, nil
}
//func printSearchResultData(data *schemapb.SearchResultData, header string) {
// size := len(data.Ids.GetIntId().Data)
// if size != len(data.Scores) {
// log.Error("SearchResultData length mis-match")
// }
// log.Debug("==== SearchResultData ====",
// zap.String("header", header), zap.Int64("nq", data.NumQueries), zap.Int64("topk", data.TopK))
// for i := 0; i < size; i++ {
// log.Debug("", zap.Int("i", i), zap.Int64("id", data.Ids.GetIntId().Data[i]), zap.Float32("score", data.Scores[i]))
// }
//}
// func printSearchResult(partialSearchResult *internalpb.SearchResults) {
// for i := 0; i < len(partialSearchResult.Hits); i++ {
// testHits := milvuspb.Hits{}
// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
// if err != nil {
// panic(err)
// }
// fmt.Println(testHits.IDs)
// fmt.Println(testHits.Scores)
// }
// }
func (st *searchTask) TraceCtx() context.Context {
return st.ctx
}
func (st *searchTask) ID() UniqueID {
return st.Base.MsgID
}
func (st *searchTask) SetID(uid UniqueID) {
st.Base.MsgID = uid
}
func (st *searchTask) Name() string {
return SearchTaskName
}
func (st *searchTask) Type() commonpb.MsgType {
return st.Base.MsgType
}
func (st *searchTask) BeginTs() Timestamp {
return st.Base.Timestamp
}
func (st *searchTask) EndTs() Timestamp {
return st.Base.Timestamp
}
func (st *searchTask) SetTs(ts Timestamp) {
st.Base.Timestamp = ts
}
func (st *searchTask) OnEnqueue() error {
st.Base = &commonpb.MsgBase{}
st.Base.MsgType = commonpb.MsgType_Search
st.Base.SourceID = Params.ProxyCfg.ProxyID
return nil
}
func (st *searchTask) getChannels() ([]pChan, error) {
collID, err := globalMetaCache.GetCollectionID(st.ctx, st.query.CollectionName)
if err != nil {
return nil, err
}
var channels []pChan
channels, err = st.chMgr.getChannels(collID)
if err != nil {
err := st.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
return st.chMgr.getChannels(collID)
}
return channels, nil
}
func (st *searchTask) getVChannels() ([]vChan, error) {
collID, err := globalMetaCache.GetCollectionID(st.ctx, st.query.CollectionName)
if err != nil {
return nil, err
}
var channels []vChan
channels, err = st.chMgr.getVChannels(collID)
if err != nil {
err := st.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
return st.chMgr.getVChannels(collID)
}
return channels, nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff