mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-03 12:29:36 +08:00
cdbc6d2c94
Signed-off-by: xige-16 <xi.ge@zilliz.com> Signed-off-by: bigsheeper <yihao.dai@zilliz.com> Co-authored-by: xige-16 <xi.ge@zilliz.com> Co-authored-by: yudong.cai <yudong.cai@zilliz.com>
184 lines
5.3 KiB
Go
184 lines
5.3 KiB
Go
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
|
|
|
package querynode
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strconv"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"github.com/milvus-io/milvus/internal/msgstream"
|
|
"github.com/milvus-io/milvus/internal/util/trace"
|
|
)
|
|
|
|
type retrieveService struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
historical *historical
|
|
streaming *streaming
|
|
|
|
retrieveMsgStream msgstream.MsgStream
|
|
retrieveResultMsgStream msgstream.MsgStream
|
|
|
|
queryNodeID UniqueID
|
|
retrieveCollections map[UniqueID]*retrieveCollection
|
|
}
|
|
|
|
func newRetrieveService(ctx context.Context,
|
|
historical *historical,
|
|
streaming *streaming,
|
|
factory msgstream.Factory) *retrieveService {
|
|
|
|
retrieveStream, _ := factory.NewQueryMsgStream(ctx)
|
|
retrieveResultStream, _ := factory.NewQueryMsgStream(ctx)
|
|
|
|
if len(Params.SearchChannelNames) > 0 && len(Params.SearchResultChannelNames) > 0 {
|
|
consumeChannels := Params.SearchChannelNames
|
|
consumeSubName := "RetrieveSubName"
|
|
retrieveStream.AsConsumer(consumeChannels, consumeSubName)
|
|
log.Debug("query node AsConsumer", zap.Any("retrieveChannels", consumeChannels), zap.Any("consumeSubName", consumeSubName))
|
|
producerChannels := Params.SearchResultChannelNames
|
|
retrieveResultStream.AsProducer(producerChannels)
|
|
log.Debug("query node AsProducer", zap.Any("retrieveResultChannels", producerChannels))
|
|
}
|
|
|
|
retrieveServiceCtx, retrieveServiceCancel := context.WithCancel(ctx)
|
|
return &retrieveService{
|
|
ctx: retrieveServiceCtx,
|
|
cancel: retrieveServiceCancel,
|
|
|
|
historical: historical,
|
|
streaming: streaming,
|
|
|
|
retrieveMsgStream: retrieveStream,
|
|
retrieveResultMsgStream: retrieveResultStream,
|
|
|
|
queryNodeID: Params.QueryNodeID,
|
|
retrieveCollections: make(map[UniqueID]*retrieveCollection),
|
|
}
|
|
}
|
|
|
|
func (rs *retrieveService) start() {
|
|
rs.retrieveMsgStream.Start()
|
|
rs.retrieveResultMsgStream.Start()
|
|
rs.consumeRetrieve()
|
|
}
|
|
|
|
func (rs *retrieveService) collectionCheck(collectionID UniqueID) error {
|
|
if ok := rs.historical.replica.hasCollection(collectionID); !ok {
|
|
err := errors.New("no collection found, collectionID = " + strconv.FormatInt(collectionID, 10))
|
|
log.Error(err.Error())
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rs *retrieveService) consumeRetrieve() {
|
|
for {
|
|
select {
|
|
case <-rs.ctx.Done():
|
|
return
|
|
default:
|
|
msgPack := rs.retrieveMsgStream.Consume()
|
|
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
|
continue
|
|
}
|
|
for _, msg := range msgPack.Msgs {
|
|
rm, ok := msg.(*msgstream.RetrieveMsg)
|
|
if !ok {
|
|
// Not a retrieve request, discard
|
|
continue
|
|
}
|
|
log.Info("RetrieveService consume retrieve message",
|
|
zap.Int64("collectionID", rm.CollectionID),
|
|
zap.Int64("requestID", msg.ID()),
|
|
zap.Any("requestType", "retrieve"),
|
|
)
|
|
|
|
sp, ctx := trace.StartSpanFromContext(rm.TraceCtx())
|
|
rm.SetTraceCtx(ctx)
|
|
err := rs.collectionCheck(rm.CollectionID)
|
|
if err != nil {
|
|
log.Debug("Failed to check collection exist, discard.",
|
|
zap.Int64("collectionID", rm.CollectionID),
|
|
zap.Int64("requestID", msg.ID()),
|
|
zap.Any("requestType", "retrieve"),
|
|
)
|
|
continue
|
|
}
|
|
|
|
_, ok = rs.retrieveCollections[rm.CollectionID]
|
|
if !ok {
|
|
rs.startRetrieveCollection(rm.CollectionID)
|
|
log.Debug("Receive retrieve request on new collection, start an new retrieve collection service",
|
|
zap.Int64("collectionID", rm.CollectionID),
|
|
zap.Int64("requestID", msg.ID()),
|
|
zap.Any("requestType", "retrieve"),
|
|
)
|
|
}
|
|
|
|
rs.retrieveCollections[rm.CollectionID].msgBuffer <- rm
|
|
log.Info("Put retrieve msg into msgBuffer",
|
|
zap.Any("requestID", msg.ID),
|
|
zap.Any("requestType", "retrieve"),
|
|
)
|
|
sp.Finish()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (rs *retrieveService) close() {
|
|
if rs.retrieveMsgStream != nil {
|
|
rs.retrieveMsgStream.Close()
|
|
}
|
|
if rs.retrieveResultMsgStream != nil {
|
|
rs.retrieveResultMsgStream.Close()
|
|
}
|
|
for collectionID := range rs.retrieveCollections {
|
|
rs.stopRetrieveCollection(collectionID)
|
|
}
|
|
rs.retrieveCollections = make(map[UniqueID]*retrieveCollection)
|
|
rs.cancel()
|
|
}
|
|
|
|
func (rs *retrieveService) startRetrieveCollection(collectionID UniqueID) {
|
|
ctx1, cancel := context.WithCancel(rs.ctx)
|
|
rc := newRetrieveCollection(ctx1,
|
|
cancel,
|
|
collectionID,
|
|
rs.historical,
|
|
rs.streaming,
|
|
rs.retrieveResultMsgStream)
|
|
rs.retrieveCollections[collectionID] = rc
|
|
rc.start()
|
|
}
|
|
|
|
func (rs *retrieveService) hasRetrieveCollection(collectionID UniqueID) bool {
|
|
_, ok := rs.retrieveCollections[collectionID]
|
|
return ok
|
|
}
|
|
|
|
func (rs *retrieveService) stopRetrieveCollection(collectionID UniqueID) {
|
|
rc, ok := rs.retrieveCollections[collectionID]
|
|
if !ok {
|
|
log.Error("stopRetrieveCollection failed, collection doesn't exist", zap.Int64("collectionID", collectionID))
|
|
}
|
|
rc.cancel()
|
|
delete(rs.retrieveCollections, collectionID)
|
|
}
|