diff --git a/internal/querynode/global_sealed_segment_manager.go b/internal/querynode/global_sealed_segment_manager.go index 1ab71da799..d5810eba23 100644 --- a/internal/querynode/global_sealed_segment_manager.go +++ b/internal/querynode/global_sealed_segment_manager.go @@ -68,6 +68,13 @@ func (g *globalSealedSegmentManager) getGlobalSegmentIDsByPartitionIds(partition return resIDs } +func (g *globalSealedSegmentManager) hasGlobalSegment(segmentID UniqueID) bool { + g.mu.Lock() + defer g.mu.Unlock() + _, ok := g.globalSealedSegments[segmentID] + return ok +} + func (g *globalSealedSegmentManager) removeGlobalSegmentInfo(segmentID UniqueID) { g.mu.Lock() defer g.mu.Unlock() diff --git a/internal/querynode/global_sealed_segment_manager_test.go b/internal/querynode/global_sealed_segment_manager_test.go index 05e56cddc1..bb44895464 100644 --- a/internal/querynode/global_sealed_segment_manager_test.go +++ b/internal/querynode/global_sealed_segment_manager_test.go @@ -56,6 +56,9 @@ func TestGlobalSealedSegmentManager(t *testing.T) { ids = manager.getGlobalSegmentIDs() assert.Len(t, ids, 0) + has := manager.hasGlobalSegment(defaultSegmentID) + assert.False(t, has) + segmentInfo.CollectionID = defaultCollectionID err = manager.addGlobalSegmentInfo(segmentInfo) assert.NoError(t, err) diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index aacd090cf6..2e637cdf65 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" ) @@ -1248,6 +1249,33 @@ func consumeSimpleRetrieveResult(stream msgstream.MsgStream) (*msgstream.Retriev return res.Msgs[0].(*msgstream.RetrieveResultMsg), nil } +func genSimpleChangeInfo() *querypb.SealedSegmentsChangeInfo { + return &querypb.SealedSegmentsChangeInfo{ + Base: genCommonMsgBase(commonpb.MsgType_LoadBalanceSegments), + OnlineNodeID: Params.QueryNodeID, + OnlineSegments: []*querypb.SegmentInfo{ + genSimpleSegmentInfo(), + }, + OfflineNodeID: Params.QueryNodeID + 1, + OfflineSegments: []*querypb.SegmentInfo{ + genSimpleSegmentInfo(), + }, + } +} + +func saveChangeInfo(key string, value string) error { + log.Debug(".. [query node unittest] Saving change info") + + kv, err := genEtcdKV() + if err != nil { + return err + } + + key = changeInfoMetaPrefix + "/" + key + + return kv.Save(key, value) +} + // node func genSimpleQueryNode(ctx context.Context) (*QueryNode, error) { fac, err := genFactory() @@ -1256,6 +1284,13 @@ func genSimpleQueryNode(ctx context.Context) (*QueryNode, error) { } node := NewQueryNode(ctx, fac) + etcdKV, err := genEtcdKV() + if err != nil { + return nil, err + } + + node.etcdKV = etcdKV + streaming, err := genSimpleStreaming(ctx) if err != nil { return nil, err diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index fd1d415cbf..0002a3e5cf 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -27,25 +27,33 @@ import "C" import ( "context" "errors" + "fmt" + "path/filepath" "strconv" "sync" "sync/atomic" "time" "unsafe" + "github.com/golang/protobuf/proto" + "go.etcd.io/etcd/api/v3/mvccpb" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/msgstream" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/typeutil" ) +const changeInfoMetaPrefix = "query-changeInfo" + // make sure QueryNode implements types.QueryNode var _ types.QueryNode = (*QueryNode)(nil) @@ -212,6 +220,7 @@ func (node *QueryNode) Start() error { // start services go node.historical.start() + go node.watchChangeInfo() Params.CreatedTime = time.Now() Params.UpdatedTime = time.Now() @@ -256,3 +265,137 @@ func (node *QueryNode) SetIndexCoord(index types.IndexCoord) error { node.indexCoord = index return nil } + +func (node *QueryNode) watchChangeInfo() { + log.Debug("query node watchChangeInfo start") + watchChan := node.etcdKV.WatchWithPrefix(changeInfoMetaPrefix) + + for { + select { + case <-node.queryNodeLoopCtx.Done(): + log.Debug("query node watchChangeInfo close") + return + case resp := <-watchChan: + for _, event := range resp.Events { + switch event.Type { + case mvccpb.PUT: + infoID, err := strconv.ParseInt(filepath.Base(string(event.Kv.Key)), 10, 64) + if err != nil { + log.Warn("Parse SealedSegmentsChangeInfo id failed", zap.Any("error", err.Error())) + continue + } + log.Debug("get SealedSegmentsChangeInfo from etcd", + zap.Any("infoID", infoID), + ) + info := &querypb.SealedSegmentsChangeInfo{} + err = proto.Unmarshal(event.Kv.Value, info) + if err != nil { + log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error())) + continue + } + go func() { + err = node.adjustByChangeInfo(info) + if err != nil { + log.Warn("adjustByChangeInfo failed", zap.Any("error", err.Error())) + } + }() + default: + // do nothing + } + } + } + } +} + +func (node *QueryNode) waitChangeInfo(info *querypb.SealedSegmentsChangeInfo) error { + fn := func() error { + canDoLoadBalance := true + // Check online segments: + for _, segmentInfo := range info.OnlineSegments { + if node.queryService.hasQueryCollection(segmentInfo.CollectionID) { + qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID) + if err != nil { + canDoLoadBalance = false + break + } + if info.OnlineNodeID == Params.QueryNodeID && !qc.globalSegmentManager.hasGlobalSegment(segmentInfo.SegmentID) { + canDoLoadBalance = false + break + } + } + } + // Check offline segments: + for _, segmentInfo := range info.OfflineSegments { + if node.queryService.hasQueryCollection(segmentInfo.CollectionID) { + qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID) + if err != nil { + canDoLoadBalance = false + break + } + if info.OfflineNodeID == Params.QueryNodeID && qc.globalSegmentManager.hasGlobalSegment(segmentInfo.SegmentID) { + canDoLoadBalance = false + break + } + } + } + if canDoLoadBalance { + return nil + } + return errors.New(fmt.Sprintln("waitChangeInfo failed, infoID = ", info.Base.GetMsgID())) + } + + return retry.Do(context.TODO(), fn, retry.Attempts(10)) +} + +func (node *QueryNode) adjustByChangeInfo(info *querypb.SealedSegmentsChangeInfo) error { + err := node.waitChangeInfo(info) + if err != nil { + log.Error("waitChangeInfo failed", zap.Any("error", err.Error())) + return err + } + + // For online segments: + for _, segmentInfo := range info.OnlineSegments { + // 1. update excluded segment, cluster have been loaded sealed segments, + // so we need to avoid getting growing segment from flow graph. + node.streaming.replica.addExcludedSegments(segmentInfo.CollectionID, []*datapb.SegmentInfo{ + { + ID: segmentInfo.SegmentID, + CollectionID: segmentInfo.CollectionID, + PartitionID: segmentInfo.PartitionID, + InsertChannel: segmentInfo.ChannelID, + NumOfRows: segmentInfo.NumRows, + // TODO: add status, remove query pb segment status, use common pb segment status? + DmlPosition: &internalpb.MsgPosition{ + // use max timestamp to filter out dm messages + Timestamp: typeutil.MaxTimestamp, + }, + }, + }) + // 2. delete growing segment because these segments are loaded in historical. + hasGrowingSegment := node.streaming.replica.hasSegment(segmentInfo.SegmentID) + if hasGrowingSegment { + err := node.streaming.replica.removeSegment(segmentInfo.SegmentID) + if err != nil { + return err + } + log.Debug("remove growing segment in adjustByChangeInfo", + zap.Any("collectionID", segmentInfo.CollectionID), + zap.Any("segmentID", segmentInfo.SegmentID), + zap.Any("infoID", info.Base.GetMsgID()), + ) + } + } + + // For offline segments: + for _, segment := range info.OfflineSegments { + // 1. load balance or compaction, remove old sealed segments. + if info.OfflineNodeID == Params.QueryNodeID { + err := node.historical.replica.removeSegment(segment.SegmentID) + if err != nil { + return err + } + } + } + return nil +} diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 6b23dae4e2..0199ef9ffd 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" @@ -272,3 +273,138 @@ func TestQueryNode_init(t *testing.T) { err = node.Init() assert.NoError(t, err) } + +func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, error) { + node, err := genSimpleQueryNode(ctx) + if err != nil { + return nil, err + } + + err = node.queryService.addQueryCollection(defaultCollectionID) + if err != nil { + return nil, err + } + + qc, err := node.queryService.getQueryCollection(defaultCollectionID) + if err != nil { + return nil, err + } + err = qc.globalSegmentManager.addGlobalSegmentInfo(genSimpleSegmentInfo()) + if err != nil { + return nil, err + } + + return node, nil +} + +func TestQueryNode_waitChangeInfo(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) + assert.NoError(t, err) + + err = node.waitChangeInfo(genSimpleChangeInfo()) + assert.NoError(t, err) +} + +func TestQueryNode_adjustByChangeInfo(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + t.Run("test adjustByChangeInfo", func(t *testing.T) { + node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) + assert.NoError(t, err) + + err = node.adjustByChangeInfo(genSimpleChangeInfo()) + assert.NoError(t, err) + }) + + t.Run("test adjustByChangeInfo no segment", func(t *testing.T) { + node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) + assert.NoError(t, err) + + err = node.historical.replica.removeSegment(defaultSegmentID) + assert.NoError(t, err) + + info := genSimpleChangeInfo() + info.OnlineSegments = nil + info.OfflineNodeID = Params.QueryNodeID + + qc, err := node.queryService.getQueryCollection(defaultCollectionID) + assert.NoError(t, err) + qc.globalSegmentManager.removeGlobalSegmentInfo(defaultSegmentID) + + err = node.adjustByChangeInfo(info) + assert.Error(t, err) + }) +} + +func TestQueryNode_watchChangeInfo(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + t.Run("test watchChangeInfo", func(t *testing.T) { + node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) + assert.NoError(t, err) + + go node.watchChangeInfo() + + info := genSimpleSegmentInfo() + value, err := proto.Marshal(info) + assert.NoError(t, err) + err = saveChangeInfo("0", string(value)) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + }) + + t.Run("test watchChangeInfo key error", func(t *testing.T) { + node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) + assert.NoError(t, err) + + go node.watchChangeInfo() + + err = saveChangeInfo("*$&#%^^", "%EUY%&#^$%&@") + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + }) + + t.Run("test watchChangeInfo unmarshal error", func(t *testing.T) { + node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) + assert.NoError(t, err) + + go node.watchChangeInfo() + + err = saveChangeInfo("0", "$%^$*&%^#$&*") + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + }) + + t.Run("test watchChangeInfo adjustByChangeInfo error", func(t *testing.T) { + node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) + assert.NoError(t, err) + + err = node.historical.replica.removeSegment(defaultSegmentID) + assert.NoError(t, err) + + info := genSimpleChangeInfo() + info.OnlineSegments = nil + info.OfflineNodeID = Params.QueryNodeID + + qc, err := node.queryService.getQueryCollection(defaultCollectionID) + assert.NoError(t, err) + qc.globalSegmentManager.removeGlobalSegmentInfo(defaultSegmentID) + + go node.watchChangeInfo() + + value, err := proto.Marshal(info) + assert.NoError(t, err) + err = saveChangeInfo("0", string(value)) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + }) +} diff --git a/internal/util/typeutil/time.go b/internal/util/typeutil/time.go index ae31207672..0305642722 100644 --- a/internal/util/typeutil/time.go +++ b/internal/util/typeutil/time.go @@ -11,7 +11,12 @@ package typeutil -import "time" +import ( + "math" + "time" +) + +const MaxTimestamp = math.MaxUint64 // ZeroTime is a zero time. var ZeroTime = time.Time{}