Refactor the workflow of receiving search result from query node (#5527)

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
dragondriver 2021-06-02 10:17:32 +08:00 committed by zhenshan.cao
parent 54ab03e28f
commit 31b400a9f7
3 changed files with 102 additions and 15 deletions

View File

@ -1130,6 +1130,7 @@ func (node *ProxyNode) Search(ctx context.Context, request *milvuspb.SearchReque
queryMsgStream: node.queryMsgStream,
resultBuf: make(chan []*internalpb.SearchResults),
query: request,
chMgr: node.chMgr,
}
err := node.sched.DqQueue.Enqueue(qt)

View File

@ -104,6 +104,11 @@ type dmlTask interface {
getStatistics(pchan pChan) (pChanStatistics, error)
}
type dqlTask interface {
task
getVChannels() ([]vChan, error)
}
type BaseInsertTask = msgstream.InsertMsg
type InsertTask struct {
@ -978,6 +983,7 @@ type SearchTask struct {
resultBuf chan []*internalpb.SearchResults
result *milvuspb.SearchResults
query *milvuspb.SearchRequest
chMgr channelsMgr
}
func (st *SearchTask) TraceCtx() context.Context {
@ -1017,6 +1023,23 @@ func (st *SearchTask) OnEnqueue() error {
return nil
}
func (st *SearchTask) getVChannels() ([]vChan, error) {
collID, err := globalMetaCache.GetCollectionID(st.ctx, st.query.CollectionName)
if err != nil {
return nil, err
}
_, err = st.chMgr.getChannels(collID)
if err != nil {
err := st.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
}
return st.chMgr.getVChannels(collID)
}
func (st *SearchTask) PreExecute(ctx context.Context) error {
st.Base.MsgType = commonpb.MsgType_Search
st.Base.SourceID = Params.ProxyID

View File

@ -423,6 +423,63 @@ func (sched *TaskScheduler) queryLoop() {
}
}
type searchResultBuf struct {
usedVChans map[interface{}]struct{} // set of vChan
receivedVChansSet map[interface{}]struct{} // set of vChan
receivedSealedSegmentIDsSet map[interface{}]struct{} // set of UniqueID
receivedGlobalSegmentIDsSet map[interface{}]struct{} // set of UniqueID
resultBuf []*internalpb.SearchResults
}
func newSearchResultBuf() *searchResultBuf {
return &searchResultBuf{
usedVChans: make(map[interface{}]struct{}),
receivedVChansSet: make(map[interface{}]struct{}),
receivedSealedSegmentIDsSet: make(map[interface{}]struct{}),
receivedGlobalSegmentIDsSet: make(map[interface{}]struct{}),
resultBuf: make([]*internalpb.SearchResults, 0),
}
}
func setContain(m1, m2 map[interface{}]struct{}) bool {
if len(m1) < len(m2) {
return false
}
for k2 := range m2 {
_, ok := m1[k2]
if !ok {
return false
}
}
return true
}
func (sr *searchResultBuf) readyToReduce() bool {
if !setContain(sr.receivedVChansSet, sr.usedVChans) {
return false
}
return setContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet)
}
func (sr *searchResultBuf) addPartialResult(result *internalpb.SearchResults) {
sr.resultBuf = append(sr.resultBuf, result)
for _, vchan := range result.ChannelIDsSearched {
sr.receivedVChansSet[vchan] = struct{}{}
}
for _, sealedSegment := range result.SealedSegmentIDsSearched {
sr.receivedSealedSegmentIDsSet[sealedSegment] = struct{}{}
}
for _, globalSegment := range result.GlobalSealedSegmentIDs {
sr.receivedGlobalSegmentIDsSet[globalSegment] = struct{}{}
}
}
func (sched *TaskScheduler) queryResultLoop() {
defer sched.wg.Done()
@ -436,7 +493,7 @@ func (sched *TaskScheduler) queryResultLoop() {
queryResultMsgStream.Start()
defer queryResultMsgStream.Close()
queryResultBuf := make(map[UniqueID][]*internalpb.SearchResults)
queryResultBuf := make(map[UniqueID]*searchResultBuf)
retrieveResultBuf := make(map[UniqueID][]*internalpb.RetrieveResults)
for {
@ -462,30 +519,36 @@ func (sched *TaskScheduler) queryResultLoop() {
continue
}
st, ok := t.(*SearchTask)
if !ok {
delete(queryResultBuf, reqID)
continue
}
_, ok = queryResultBuf[reqID]
if !ok {
queryResultBuf[reqID] = make([]*internalpb.SearchResults, 0)
queryResultBuf[reqID] = newSearchResultBuf()
vchans, err := st.getVChannels()
if err != nil {
delete(queryResultBuf, reqID)
continue
}
for _, vchan := range vchans {
queryResultBuf[reqID].usedVChans[vchan] = struct{}{}
}
}
queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResults)
queryResultBuf[reqID].addPartialResult(&searchResultMsg.SearchResults)
//t := sched.getTaskByReqID(reqID)
{
colName := t.(*SearchTask).query.CollectionName
log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(queryResultBuf[reqID])))
log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(queryResultBuf[reqID].resultBuf)))
}
if len(queryResultBuf[reqID]) == queryNodeNum {
t := sched.getTaskByReqID(reqID)
if t != nil {
qt, ok := t.(*SearchTask)
if ok {
qt.resultBuf <- queryResultBuf[reqID]
delete(queryResultBuf, reqID)
}
} else {
// log.Printf("task with reqID %v is nil", reqID)
}
if queryResultBuf[reqID].readyToReduce() {
st.resultBuf <- queryResultBuf[reqID].resultBuf
}
sp.Finish()
}
if retrieveResultMsg, rtOk := tsMsg.(*msgstream.RetrieveResultMsg); rtOk {