Fix retrieve_service (#5531)

* Fix retrieve_service

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix retrieve_collection

Signed-off-by: fishpenguin <kun.yu@zilliz.com>
This commit is contained in:
yukun 2021-06-02 19:18:33 +08:00 committed by GitHub
parent 56da071cce
commit dcb4161c9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 10 deletions

View File

@ -13,6 +13,7 @@ package msgstream
import ( import (
"errors" "errors"
"strconv"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
@ -127,7 +128,7 @@ func DefaultRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack,
for i, request := range tsMsgs { for i, request := range tsMsgs {
keys := hashKeys[i] keys := hashKeys[i]
if len(keys) != 1 { if len(keys) != 1 {
return nil, errors.New("len(msg.hashValue) must equal 1") return nil, errors.New("len(msg.hashValue) must equal 1, but it is: " + strconv.Itoa(len(keys)))
} }
key := keys[0] key := keys[0]
_, ok := result[key] _, ok := result[key]

View File

@ -401,8 +401,6 @@ func (sched *TaskScheduler) queryResultLoop() {
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx) queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName) queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName)
log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames)) log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames))
queryResultMsgStream.AsConsumer(Params.RetrieveResultChannelNames, Params.ProxySubName)
log.Debug("proxynode", zap.Strings("Retrieve result channel names", Params.RetrieveResultChannelNames))
log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName)) log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName))
queryNodeNum := Params.QueryNodeNum queryNodeNum := Params.QueryNodeNum

View File

@ -85,11 +85,13 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
consumeChannels := []string{in.RequestChannelID} consumeChannels := []string{in.RequestChannelID}
consumeSubName := Params.MsgChannelSubName consumeSubName := Params.MsgChannelSubName
node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName) node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName)
node.retrieveService.retrieveMsgStream.AsConsumer(consumeChannels, "RetrieveSubName")
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName) log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
// add result channel // add result channel
producerChannels := []string{in.ResultChannelID} producerChannels := []string{in.ResultChannelID}
node.searchService.searchResultMsgStream.AsProducer(producerChannels) node.searchService.searchResultMsgStream.AsProducer(producerChannels)
node.retrieveService.retrieveResultMsgStream.AsProducer(producerChannels)
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", ")) log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
status := &commonpb.Status{ status := &commonpb.Status{

View File

@ -69,6 +69,8 @@ func newRetrieveCollection(releaseCtx context.Context,
msgBuffer: msgBuffer, msgBuffer: msgBuffer,
unsolvedMsg: unsolvedMsg, unsolvedMsg: unsolvedMsg,
retrieveResultMsgStream: retrieveResultStream,
} }
rc.register(collectionID) rc.register(collectionID)
@ -103,6 +105,11 @@ func (rc *retrieveCollection) waitNewTSafe() Timestamp {
return ts return ts
} }
func (rc *retrieveCollection) start() {
go rc.receiveRetrieveMsg()
go rc.doUnsolvedMsgRetrieve()
}
func (rc *retrieveCollection) register(collectionID UniqueID) { func (rc *retrieveCollection) register(collectionID UniqueID) {
vChannel := collectionIDToChannel(collectionID) vChannel := collectionIDToChannel(collectionID)
rc.tSafeReplica.addTSafe(vChannel) rc.tSafeReplica.addTSafe(vChannel)
@ -237,8 +244,9 @@ func (rc *retrieveCollection) doUnsolvedMsgRetrieve() {
func (rc *retrieveCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error { func (rc *retrieveCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
// TODO(yukun) // TODO(yukun)
resultChannelInt := 0
retrieveResultMsg := &msgstream.RetrieveResultMsg{ retrieveResultMsg := &msgstream.RetrieveResultMsg{
BaseMsg: msgstream.BaseMsg{Ctx: retrieveMsg.Ctx}, BaseMsg: msgstream.BaseMsg{Ctx: retrieveMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}},
RetrieveResults: internalpb.RetrieveResults{ RetrieveResults: internalpb.RetrieveResults{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RetrieveResult, MsgType: commonpb.MsgType_RetrieveResult,

View File

@ -42,15 +42,16 @@ func newRetrieveService(ctx context.Context,
streamingReplica ReplicaInterface, streamingReplica ReplicaInterface,
tSafeReplica TSafeReplicaInterface, tSafeReplica TSafeReplicaInterface,
factory msgstream.Factory) *retrieveService { factory msgstream.Factory) *retrieveService {
retrieveStream, _ := factory.NewQueryMsgStream(ctx) retrieveStream, _ := factory.NewQueryMsgStream(ctx)
retrieveResultStream, _ := factory.NewQueryMsgStream(ctx) retrieveResultStream, _ := factory.NewQueryMsgStream(ctx)
if len(Params.RetrieveChannelNames) > 0 && len(Params.RetrieveResultChannelNames) > 0 { if len(Params.SearchChannelNames) > 0 && len(Params.SearchResultChannelNames) > 0 {
consumeChannels := Params.RetrieveChannelNames consumeChannels := Params.SearchChannelNames
consumeSubName := Params.MsgChannelSubName consumeSubName := "RetrieveSubName"
retrieveStream.AsConsumer(consumeChannels, consumeSubName) retrieveStream.AsConsumer(consumeChannels, consumeSubName)
log.Debug("query node AdConsumer", zap.Any("retrieveChannels", consumeChannels), zap.Any("consumeSubName", consumeSubName)) log.Debug("query node AsConsumer", zap.Any("retrieveChannels", consumeChannels), zap.Any("consumeSubName", consumeSubName))
producerChannels := Params.RetrieveChannelNames producerChannels := Params.SearchResultChannelNames
retrieveResultStream.AsProducer(producerChannels) retrieveResultStream.AsProducer(producerChannels)
log.Debug("query node AsProducer", zap.Any("retrieveResultChannels", producerChannels)) log.Debug("query node AsProducer", zap.Any("retrieveResultChannels", producerChannels))
} }
@ -147,7 +148,7 @@ func (rs *retrieveService) startRetrieveCollection(collectionID UniqueID) {
rs.tSafeReplica, rs.tSafeReplica,
rs.retrieveResultMsgStream) rs.retrieveResultMsgStream)
rs.retrieveCollections[collectionID] = rc rs.retrieveCollections[collectionID] = rc
rs.start() rc.start()
} }
func (rs *retrieveService) hasRetrieveCollection(collectionID UniqueID) bool { func (rs *retrieveService) hasRetrieveCollection(collectionID UniqueID) bool {