mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 11:29:48 +08:00
Fix retrieve_service (#5531)
* Fix retrieve_service Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix retrieve_collection Signed-off-by: fishpenguin <kun.yu@zilliz.com>
This commit is contained in:
parent
56da071cce
commit
dcb4161c9f
@ -13,6 +13,7 @@ package msgstream
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
@ -127,7 +128,7 @@ func DefaultRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack,
|
||||
for i, request := range tsMsgs {
|
||||
keys := hashKeys[i]
|
||||
if len(keys) != 1 {
|
||||
return nil, errors.New("len(msg.hashValue) must equal 1")
|
||||
return nil, errors.New("len(msg.hashValue) must equal 1, but it is: " + strconv.Itoa(len(keys)))
|
||||
}
|
||||
key := keys[0]
|
||||
_, ok := result[key]
|
||||
|
@ -401,8 +401,6 @@ func (sched *TaskScheduler) queryResultLoop() {
|
||||
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
|
||||
queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName)
|
||||
log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames))
|
||||
queryResultMsgStream.AsConsumer(Params.RetrieveResultChannelNames, Params.ProxySubName)
|
||||
log.Debug("proxynode", zap.Strings("Retrieve result channel names", Params.RetrieveResultChannelNames))
|
||||
log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName))
|
||||
|
||||
queryNodeNum := Params.QueryNodeNum
|
||||
|
@ -85,11 +85,13 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
|
||||
consumeChannels := []string{in.RequestChannelID}
|
||||
consumeSubName := Params.MsgChannelSubName
|
||||
node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
node.retrieveService.retrieveMsgStream.AsConsumer(consumeChannels, "RetrieveSubName")
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
|
||||
// add result channel
|
||||
producerChannels := []string{in.ResultChannelID}
|
||||
node.searchService.searchResultMsgStream.AsProducer(producerChannels)
|
||||
node.retrieveService.retrieveResultMsgStream.AsProducer(producerChannels)
|
||||
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
||||
|
||||
status := &commonpb.Status{
|
||||
|
@ -69,6 +69,8 @@ func newRetrieveCollection(releaseCtx context.Context,
|
||||
|
||||
msgBuffer: msgBuffer,
|
||||
unsolvedMsg: unsolvedMsg,
|
||||
|
||||
retrieveResultMsgStream: retrieveResultStream,
|
||||
}
|
||||
|
||||
rc.register(collectionID)
|
||||
@ -103,6 +105,11 @@ func (rc *retrieveCollection) waitNewTSafe() Timestamp {
|
||||
return ts
|
||||
}
|
||||
|
||||
func (rc *retrieveCollection) start() {
|
||||
go rc.receiveRetrieveMsg()
|
||||
go rc.doUnsolvedMsgRetrieve()
|
||||
}
|
||||
|
||||
func (rc *retrieveCollection) register(collectionID UniqueID) {
|
||||
vChannel := collectionIDToChannel(collectionID)
|
||||
rc.tSafeReplica.addTSafe(vChannel)
|
||||
@ -237,8 +244,9 @@ func (rc *retrieveCollection) doUnsolvedMsgRetrieve() {
|
||||
|
||||
func (rc *retrieveCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
|
||||
// TODO(yukun)
|
||||
resultChannelInt := 0
|
||||
retrieveResultMsg := &msgstream.RetrieveResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{Ctx: retrieveMsg.Ctx},
|
||||
BaseMsg: msgstream.BaseMsg{Ctx: retrieveMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}},
|
||||
RetrieveResults: internalpb.RetrieveResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_RetrieveResult,
|
||||
|
@ -42,15 +42,16 @@ func newRetrieveService(ctx context.Context,
|
||||
streamingReplica ReplicaInterface,
|
||||
tSafeReplica TSafeReplicaInterface,
|
||||
factory msgstream.Factory) *retrieveService {
|
||||
|
||||
retrieveStream, _ := factory.NewQueryMsgStream(ctx)
|
||||
retrieveResultStream, _ := factory.NewQueryMsgStream(ctx)
|
||||
|
||||
if len(Params.RetrieveChannelNames) > 0 && len(Params.RetrieveResultChannelNames) > 0 {
|
||||
consumeChannels := Params.RetrieveChannelNames
|
||||
consumeSubName := Params.MsgChannelSubName
|
||||
if len(Params.SearchChannelNames) > 0 && len(Params.SearchResultChannelNames) > 0 {
|
||||
consumeChannels := Params.SearchChannelNames
|
||||
consumeSubName := "RetrieveSubName"
|
||||
retrieveStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
log.Debug("query node AdConsumer", zap.Any("retrieveChannels", consumeChannels), zap.Any("consumeSubName", consumeSubName))
|
||||
producerChannels := Params.RetrieveChannelNames
|
||||
log.Debug("query node AsConsumer", zap.Any("retrieveChannels", consumeChannels), zap.Any("consumeSubName", consumeSubName))
|
||||
producerChannels := Params.SearchResultChannelNames
|
||||
retrieveResultStream.AsProducer(producerChannels)
|
||||
log.Debug("query node AsProducer", zap.Any("retrieveResultChannels", producerChannels))
|
||||
}
|
||||
@ -147,7 +148,7 @@ func (rs *retrieveService) startRetrieveCollection(collectionID UniqueID) {
|
||||
rs.tSafeReplica,
|
||||
rs.retrieveResultMsgStream)
|
||||
rs.retrieveCollections[collectionID] = rc
|
||||
rs.start()
|
||||
rc.start()
|
||||
}
|
||||
|
||||
func (rs *retrieveService) hasRetrieveCollection(collectionID UniqueID) bool {
|
||||
|
Loading…
Reference in New Issue
Block a user