milvus/internal/querynode/retrieve_service.go
bigsheeper cdbc6d2c94
Refactor query node and query service (#5751)
Signed-off-by: xige-16 <xi.ge@zilliz.com>
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>

Co-authored-by: xige-16 <xi.ge@zilliz.com>
Co-authored-by: yudong.cai <yudong.cai@zilliz.com>
2021-06-15 12:41:40 +08:00

184 lines
5.3 KiB
Go

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querynode
import (
"context"
"errors"
"strconv"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/util/trace"
)
type retrieveService struct {
ctx context.Context
cancel context.CancelFunc
historical *historical
streaming *streaming
retrieveMsgStream msgstream.MsgStream
retrieveResultMsgStream msgstream.MsgStream
queryNodeID UniqueID
retrieveCollections map[UniqueID]*retrieveCollection
}
func newRetrieveService(ctx context.Context,
historical *historical,
streaming *streaming,
factory msgstream.Factory) *retrieveService {
retrieveStream, _ := factory.NewQueryMsgStream(ctx)
retrieveResultStream, _ := factory.NewQueryMsgStream(ctx)
if len(Params.SearchChannelNames) > 0 && len(Params.SearchResultChannelNames) > 0 {
consumeChannels := Params.SearchChannelNames
consumeSubName := "RetrieveSubName"
retrieveStream.AsConsumer(consumeChannels, consumeSubName)
log.Debug("query node AsConsumer", zap.Any("retrieveChannels", consumeChannels), zap.Any("consumeSubName", consumeSubName))
producerChannels := Params.SearchResultChannelNames
retrieveResultStream.AsProducer(producerChannels)
log.Debug("query node AsProducer", zap.Any("retrieveResultChannels", producerChannels))
}
retrieveServiceCtx, retrieveServiceCancel := context.WithCancel(ctx)
return &retrieveService{
ctx: retrieveServiceCtx,
cancel: retrieveServiceCancel,
historical: historical,
streaming: streaming,
retrieveMsgStream: retrieveStream,
retrieveResultMsgStream: retrieveResultStream,
queryNodeID: Params.QueryNodeID,
retrieveCollections: make(map[UniqueID]*retrieveCollection),
}
}
func (rs *retrieveService) start() {
rs.retrieveMsgStream.Start()
rs.retrieveResultMsgStream.Start()
rs.consumeRetrieve()
}
func (rs *retrieveService) collectionCheck(collectionID UniqueID) error {
if ok := rs.historical.replica.hasCollection(collectionID); !ok {
err := errors.New("no collection found, collectionID = " + strconv.FormatInt(collectionID, 10))
log.Error(err.Error())
return err
}
return nil
}
func (rs *retrieveService) consumeRetrieve() {
for {
select {
case <-rs.ctx.Done():
return
default:
msgPack := rs.retrieveMsgStream.Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
continue
}
for _, msg := range msgPack.Msgs {
rm, ok := msg.(*msgstream.RetrieveMsg)
if !ok {
// Not a retrieve request, discard
continue
}
log.Info("RetrieveService consume retrieve message",
zap.Int64("collectionID", rm.CollectionID),
zap.Int64("requestID", msg.ID()),
zap.Any("requestType", "retrieve"),
)
sp, ctx := trace.StartSpanFromContext(rm.TraceCtx())
rm.SetTraceCtx(ctx)
err := rs.collectionCheck(rm.CollectionID)
if err != nil {
log.Debug("Failed to check collection exist, discard.",
zap.Int64("collectionID", rm.CollectionID),
zap.Int64("requestID", msg.ID()),
zap.Any("requestType", "retrieve"),
)
continue
}
_, ok = rs.retrieveCollections[rm.CollectionID]
if !ok {
rs.startRetrieveCollection(rm.CollectionID)
log.Debug("Receive retrieve request on new collection, start an new retrieve collection service",
zap.Int64("collectionID", rm.CollectionID),
zap.Int64("requestID", msg.ID()),
zap.Any("requestType", "retrieve"),
)
}
rs.retrieveCollections[rm.CollectionID].msgBuffer <- rm
log.Info("Put retrieve msg into msgBuffer",
zap.Any("requestID", msg.ID),
zap.Any("requestType", "retrieve"),
)
sp.Finish()
}
}
}
}
func (rs *retrieveService) close() {
if rs.retrieveMsgStream != nil {
rs.retrieveMsgStream.Close()
}
if rs.retrieveResultMsgStream != nil {
rs.retrieveResultMsgStream.Close()
}
for collectionID := range rs.retrieveCollections {
rs.stopRetrieveCollection(collectionID)
}
rs.retrieveCollections = make(map[UniqueID]*retrieveCollection)
rs.cancel()
}
func (rs *retrieveService) startRetrieveCollection(collectionID UniqueID) {
ctx1, cancel := context.WithCancel(rs.ctx)
rc := newRetrieveCollection(ctx1,
cancel,
collectionID,
rs.historical,
rs.streaming,
rs.retrieveResultMsgStream)
rs.retrieveCollections[collectionID] = rc
rc.start()
}
func (rs *retrieveService) hasRetrieveCollection(collectionID UniqueID) bool {
_, ok := rs.retrieveCollections[collectionID]
return ok
}
func (rs *retrieveService) stopRetrieveCollection(collectionID UniqueID) {
rc, ok := rs.retrieveCollections[collectionID]
if !ok {
log.Error("stopRetrieveCollection failed, collection doesn't exist", zap.Int64("collectionID", collectionID))
}
rc.cancel()
delete(rs.retrieveCollections, collectionID)
}