mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 11:29:48 +08:00
Fix retrieve_service (#5531)
* Fix retrieve_service Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix retrieve_collection Signed-off-by: fishpenguin <kun.yu@zilliz.com>
This commit is contained in:
parent
56da071cce
commit
dcb4161c9f
@ -13,6 +13,7 @@ package msgstream
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
@ -127,7 +128,7 @@ func DefaultRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack,
|
|||||||
for i, request := range tsMsgs {
|
for i, request := range tsMsgs {
|
||||||
keys := hashKeys[i]
|
keys := hashKeys[i]
|
||||||
if len(keys) != 1 {
|
if len(keys) != 1 {
|
||||||
return nil, errors.New("len(msg.hashValue) must equal 1")
|
return nil, errors.New("len(msg.hashValue) must equal 1, but it is: " + strconv.Itoa(len(keys)))
|
||||||
}
|
}
|
||||||
key := keys[0]
|
key := keys[0]
|
||||||
_, ok := result[key]
|
_, ok := result[key]
|
||||||
|
@ -401,8 +401,6 @@ func (sched *TaskScheduler) queryResultLoop() {
|
|||||||
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
|
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
|
||||||
queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName)
|
queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName)
|
||||||
log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames))
|
log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames))
|
||||||
queryResultMsgStream.AsConsumer(Params.RetrieveResultChannelNames, Params.ProxySubName)
|
|
||||||
log.Debug("proxynode", zap.Strings("Retrieve result channel names", Params.RetrieveResultChannelNames))
|
|
||||||
log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName))
|
log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName))
|
||||||
|
|
||||||
queryNodeNum := Params.QueryNodeNum
|
queryNodeNum := Params.QueryNodeNum
|
||||||
|
@ -85,11 +85,13 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
|
|||||||
consumeChannels := []string{in.RequestChannelID}
|
consumeChannels := []string{in.RequestChannelID}
|
||||||
consumeSubName := Params.MsgChannelSubName
|
consumeSubName := Params.MsgChannelSubName
|
||||||
node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||||
|
node.retrieveService.retrieveMsgStream.AsConsumer(consumeChannels, "RetrieveSubName")
|
||||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||||
|
|
||||||
// add result channel
|
// add result channel
|
||||||
producerChannels := []string{in.ResultChannelID}
|
producerChannels := []string{in.ResultChannelID}
|
||||||
node.searchService.searchResultMsgStream.AsProducer(producerChannels)
|
node.searchService.searchResultMsgStream.AsProducer(producerChannels)
|
||||||
|
node.retrieveService.retrieveResultMsgStream.AsProducer(producerChannels)
|
||||||
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
||||||
|
|
||||||
status := &commonpb.Status{
|
status := &commonpb.Status{
|
||||||
|
@ -69,6 +69,8 @@ func newRetrieveCollection(releaseCtx context.Context,
|
|||||||
|
|
||||||
msgBuffer: msgBuffer,
|
msgBuffer: msgBuffer,
|
||||||
unsolvedMsg: unsolvedMsg,
|
unsolvedMsg: unsolvedMsg,
|
||||||
|
|
||||||
|
retrieveResultMsgStream: retrieveResultStream,
|
||||||
}
|
}
|
||||||
|
|
||||||
rc.register(collectionID)
|
rc.register(collectionID)
|
||||||
@ -103,6 +105,11 @@ func (rc *retrieveCollection) waitNewTSafe() Timestamp {
|
|||||||
return ts
|
return ts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rc *retrieveCollection) start() {
|
||||||
|
go rc.receiveRetrieveMsg()
|
||||||
|
go rc.doUnsolvedMsgRetrieve()
|
||||||
|
}
|
||||||
|
|
||||||
func (rc *retrieveCollection) register(collectionID UniqueID) {
|
func (rc *retrieveCollection) register(collectionID UniqueID) {
|
||||||
vChannel := collectionIDToChannel(collectionID)
|
vChannel := collectionIDToChannel(collectionID)
|
||||||
rc.tSafeReplica.addTSafe(vChannel)
|
rc.tSafeReplica.addTSafe(vChannel)
|
||||||
@ -237,8 +244,9 @@ func (rc *retrieveCollection) doUnsolvedMsgRetrieve() {
|
|||||||
|
|
||||||
func (rc *retrieveCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
|
func (rc *retrieveCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
|
||||||
// TODO(yukun)
|
// TODO(yukun)
|
||||||
|
resultChannelInt := 0
|
||||||
retrieveResultMsg := &msgstream.RetrieveResultMsg{
|
retrieveResultMsg := &msgstream.RetrieveResultMsg{
|
||||||
BaseMsg: msgstream.BaseMsg{Ctx: retrieveMsg.Ctx},
|
BaseMsg: msgstream.BaseMsg{Ctx: retrieveMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}},
|
||||||
RetrieveResults: internalpb.RetrieveResults{
|
RetrieveResults: internalpb.RetrieveResults{
|
||||||
Base: &commonpb.MsgBase{
|
Base: &commonpb.MsgBase{
|
||||||
MsgType: commonpb.MsgType_RetrieveResult,
|
MsgType: commonpb.MsgType_RetrieveResult,
|
||||||
|
@ -42,15 +42,16 @@ func newRetrieveService(ctx context.Context,
|
|||||||
streamingReplica ReplicaInterface,
|
streamingReplica ReplicaInterface,
|
||||||
tSafeReplica TSafeReplicaInterface,
|
tSafeReplica TSafeReplicaInterface,
|
||||||
factory msgstream.Factory) *retrieveService {
|
factory msgstream.Factory) *retrieveService {
|
||||||
|
|
||||||
retrieveStream, _ := factory.NewQueryMsgStream(ctx)
|
retrieveStream, _ := factory.NewQueryMsgStream(ctx)
|
||||||
retrieveResultStream, _ := factory.NewQueryMsgStream(ctx)
|
retrieveResultStream, _ := factory.NewQueryMsgStream(ctx)
|
||||||
|
|
||||||
if len(Params.RetrieveChannelNames) > 0 && len(Params.RetrieveResultChannelNames) > 0 {
|
if len(Params.SearchChannelNames) > 0 && len(Params.SearchResultChannelNames) > 0 {
|
||||||
consumeChannels := Params.RetrieveChannelNames
|
consumeChannels := Params.SearchChannelNames
|
||||||
consumeSubName := Params.MsgChannelSubName
|
consumeSubName := "RetrieveSubName"
|
||||||
retrieveStream.AsConsumer(consumeChannels, consumeSubName)
|
retrieveStream.AsConsumer(consumeChannels, consumeSubName)
|
||||||
log.Debug("query node AdConsumer", zap.Any("retrieveChannels", consumeChannels), zap.Any("consumeSubName", consumeSubName))
|
log.Debug("query node AsConsumer", zap.Any("retrieveChannels", consumeChannels), zap.Any("consumeSubName", consumeSubName))
|
||||||
producerChannels := Params.RetrieveChannelNames
|
producerChannels := Params.SearchResultChannelNames
|
||||||
retrieveResultStream.AsProducer(producerChannels)
|
retrieveResultStream.AsProducer(producerChannels)
|
||||||
log.Debug("query node AsProducer", zap.Any("retrieveResultChannels", producerChannels))
|
log.Debug("query node AsProducer", zap.Any("retrieveResultChannels", producerChannels))
|
||||||
}
|
}
|
||||||
@ -147,7 +148,7 @@ func (rs *retrieveService) startRetrieveCollection(collectionID UniqueID) {
|
|||||||
rs.tSafeReplica,
|
rs.tSafeReplica,
|
||||||
rs.retrieveResultMsgStream)
|
rs.retrieveResultMsgStream)
|
||||||
rs.retrieveCollections[collectionID] = rc
|
rs.retrieveCollections[collectionID] = rc
|
||||||
rs.start()
|
rc.start()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *retrieveService) hasRetrieveCollection(collectionID UniqueID) bool {
|
func (rs *retrieveService) hasRetrieveCollection(collectionID UniqueID) bool {
|
||||||
|
Loading…
Reference in New Issue
Block a user