From c97563590c5747f0f034919d1c98bbb8ea7d09eb Mon Sep 17 00:00:00 2001 From: Xiaofan <83447078+xiaofan-luan@users.noreply.github.com> Date: Mon, 18 Jul 2022 13:06:28 +0800 Subject: [PATCH] fix missing handling channels while a node down (#18250) Signed-off-by: xiaofan-luan --- internal/querycoord/channel_unsubscribe.go | 179 ++++---- .../querycoord/channel_unsubscribe_test.go | 21 +- internal/querycoord/cluster.go | 13 +- internal/querycoord/cluster_test.go | 14 +- internal/querycoord/const.go | 24 + internal/querycoord/handoff_handler.go | 41 +- internal/querycoord/meta.go | 9 +- internal/querycoord/query_coord.go | 432 +++++++++--------- internal/querycoord/query_coord_test.go | 6 +- internal/querycoord/querynode_test.go | 1 + internal/querycoord/segment_allocator_test.go | 4 +- internal/querycoord/task.go | 42 +- internal/querycoord/task_scheduler.go | 79 +++- internal/querycoord/task_test.go | 89 +--- internal/util/paramtable/component_param.go | 17 + 15 files changed, 493 insertions(+), 478 deletions(-) create mode 100644 internal/querycoord/const.go diff --git a/internal/querycoord/channel_unsubscribe.go b/internal/querycoord/channel_unsubscribe.go index 690ec6ff74..f2733e0a34 100644 --- a/internal/querycoord/channel_unsubscribe.go +++ b/internal/querycoord/channel_unsubscribe.go @@ -17,10 +17,10 @@ package querycoord import ( - "container/list" "context" "fmt" "sync" + "time" "github.com/golang/protobuf/proto" "go.uber.org/zap" @@ -33,35 +33,32 @@ import ( "github.com/milvus-io/milvus/internal/util/funcutil" ) -const ( - unsubscribeChannelInfoPrefix = "queryCoord-unsubscribeChannelInfo" -) - -type channelUnsubscribeHandler struct { +type ChannelCleaner struct { ctx context.Context cancel context.CancelFunc kvClient *etcdkv.EtcdKV factory msgstream.Factory - mut sync.RWMutex // mutex for channelInfos, since container/list is not goroutine-safe - channelInfos *list.List - downNodeChan chan int64 + taskMutex sync.RWMutex // mutex for channelInfos, since container/list is not goroutine-safe + // nodeID, UnsubscribeChannelInfo + tasks map[int64]*querypb.UnsubscribeChannelInfo + notify chan struct{} + closed bool wg sync.WaitGroup } // newChannelUnsubscribeHandler create a new handler service to unsubscribe channels -func newChannelUnsubscribeHandler(ctx context.Context, kv *etcdkv.EtcdKV, factory dependency.Factory) (*channelUnsubscribeHandler, error) { +func NewChannelCleaner(ctx context.Context, kv *etcdkv.EtcdKV, factory dependency.Factory) (*ChannelCleaner, error) { childCtx, cancel := context.WithCancel(ctx) - handler := &channelUnsubscribeHandler{ + handler := &ChannelCleaner{ ctx: childCtx, cancel: cancel, kvClient: kv, factory: factory, - channelInfos: list.New(), - //TODO:: if the query nodes that are down exceed 1024, query coord will not be able to restart - downNodeChan: make(chan int64, 1024), + tasks: make(map[int64]*querypb.UnsubscribeChannelInfo, 1024), + notify: make(chan struct{}, 1024), } err := handler.reloadFromKV() @@ -72,103 +69,129 @@ func newChannelUnsubscribeHandler(ctx context.Context, kv *etcdkv.EtcdKV, factor return handler, nil } -// appendUnsubInfo pushes unsub info safely -func (csh *channelUnsubscribeHandler) appendUnsubInfo(info *querypb.UnsubscribeChannelInfo) { - csh.mut.Lock() - defer csh.mut.Unlock() - csh.channelInfos.PushBack(info) -} - // reloadFromKV reload unsolved channels to unsubscribe -func (csh *channelUnsubscribeHandler) reloadFromKV() error { +func (cleaner *ChannelCleaner) reloadFromKV() error { log.Info("start reload unsubscribe channelInfo from kv") - _, channelInfoValues, err := csh.kvClient.LoadWithPrefix(unsubscribeChannelInfoPrefix) + cleaner.taskMutex.Lock() + defer cleaner.taskMutex.Unlock() + _, channelInfoValues, err := cleaner.kvClient.LoadWithPrefix(unsubscribeChannelInfoPrefix) if err != nil { return err } for _, value := range channelInfoValues { - channelInfo := &querypb.UnsubscribeChannelInfo{} - err = proto.Unmarshal([]byte(value), channelInfo) + info := &querypb.UnsubscribeChannelInfo{} + err = proto.Unmarshal([]byte(value), info) if err != nil { return err } - csh.appendUnsubInfo(channelInfo) - csh.downNodeChan <- channelInfo.NodeID + cleaner.tasks[info.NodeID] = info } - + cleaner.notify <- struct{}{} + log.Info("successufully reload unsubscribe channelInfo from kv", zap.Int("unhandled", len(channelInfoValues))) return nil } // addUnsubscribeChannelInfo add channel info to handler service, and persistent to etcd -func (csh *channelUnsubscribeHandler) addUnsubscribeChannelInfo(info *querypb.UnsubscribeChannelInfo) { +func (cleaner *ChannelCleaner) addUnsubscribeChannelInfo(info *querypb.UnsubscribeChannelInfo) { + if len(info.CollectionChannels) == 0 { + return + } nodeID := info.NodeID + cleaner.taskMutex.Lock() + defer cleaner.taskMutex.Unlock() + if cleaner.closed { + return + } + + _, ok := cleaner.tasks[nodeID] + if ok { + log.Info("duplicate add unsubscribe channel, ignore..", zap.Int64("nodeID", nodeID)) + return + } + channelInfoValue, err := proto.Marshal(info) if err != nil { panic(err) } - // when queryCoord is restarted multiple times, the nodeID of added channelInfo may be the same - hasEnqueue := false - // reduce the lock range to iteration here, since `addUnsubscribeChannelInfo` is called one by one - csh.mut.RLock() - for e := csh.channelInfos.Back(); e != nil; e = e.Prev() { - if e.Value.(*querypb.UnsubscribeChannelInfo).NodeID == nodeID { - hasEnqueue = true - } - } - csh.mut.RUnlock() - if !hasEnqueue { - channelInfoKey := fmt.Sprintf("%s/%d", unsubscribeChannelInfoPrefix, nodeID) - err = csh.kvClient.Save(channelInfoKey, string(channelInfoValue)) - if err != nil { - panic(err) - } - csh.appendUnsubInfo(info) - csh.downNodeChan <- info.NodeID - log.Info("add unsubscribeChannelInfo to handler", zap.Int64("nodeID", info.NodeID)) + //TODO, we don't even need unsubscribeChannelInfoPrefix, each time we just call addUnsubscribeChannelInfo when querycoord restard + channelInfoKey := fmt.Sprintf("%s/%d", unsubscribeChannelInfoPrefix, nodeID) + err = cleaner.kvClient.Save(channelInfoKey, string(channelInfoValue)) + if err != nil { + panic(err) } + cleaner.tasks[info.NodeID] = info + cleaner.notify <- struct{}{} + log.Info("successfully add unsubscribeChannelInfo to handler", zap.Int64("nodeID", info.NodeID), zap.Any("channels", info.CollectionChannels)) } // handleChannelUnsubscribeLoop handle the unsubscription of channels which query node has watched -func (csh *channelUnsubscribeHandler) handleChannelUnsubscribeLoop() { - defer csh.wg.Done() +func (cleaner *ChannelCleaner) handleChannelCleanLoop() { + defer cleaner.wg.Done() + + ticker := time.NewTicker(time.Second * 1) + defer ticker.Stop() for { select { - case <-csh.ctx.Done(): - log.Info("channelUnsubscribeHandler ctx done, handleChannelUnsubscribeLoop end") + case <-cleaner.ctx.Done(): + log.Info("channelUnsubscribeHandler ctx done, handleChannelCleanLoop end") return - case <-csh.downNodeChan: - csh.mut.RLock() - e := csh.channelInfos.Front() - channelInfo := csh.channelInfos.Front().Value.(*querypb.UnsubscribeChannelInfo) - csh.mut.RUnlock() - nodeID := channelInfo.NodeID - for _, collectionChannels := range channelInfo.CollectionChannels { - collectionID := collectionChannels.CollectionID - subName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, nodeID) - msgstream.UnsubscribeChannels(csh.ctx, csh.factory, subName, collectionChannels.Channels) + case _, ok := <-cleaner.notify: + if ok { + cleaner.taskMutex.Lock() + for segmentID := range cleaner.tasks { + cleaner.process(segmentID) + } + cleaner.taskMutex.Unlock() } - channelInfoKey := fmt.Sprintf("%s/%d", unsubscribeChannelInfoPrefix, nodeID) - err := csh.kvClient.Remove(channelInfoKey) - if err != nil { - log.Error("remove unsubscribe channelInfo from etcd failed", zap.Int64("nodeID", nodeID)) - panic(err) + case <-ticker.C: + cleaner.taskMutex.Lock() + for segmentID := range cleaner.tasks { + cleaner.process(segmentID) } - - csh.mut.Lock() - csh.channelInfos.Remove(e) - csh.mut.Unlock() - log.Info("unsubscribe channels success", zap.Int64("nodeID", nodeID)) + cleaner.taskMutex.Unlock() } } } -func (csh *channelUnsubscribeHandler) start() { - csh.wg.Add(1) - go csh.handleChannelUnsubscribeLoop() +func (cleaner *ChannelCleaner) process(nodeID int64) error { + log.Info("start to handle channel clean", zap.Int64("nodeID", nodeID)) + channelInfo := cleaner.tasks[nodeID] + for _, collectionChannels := range channelInfo.CollectionChannels { + collectionID := collectionChannels.CollectionID + subName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, nodeID) + // should be ok if we call unsubscribe multiple times + msgstream.UnsubscribeChannels(cleaner.ctx, cleaner.factory, subName, collectionChannels.Channels) + } + channelInfoKey := fmt.Sprintf("%s/%d", unsubscribeChannelInfoPrefix, nodeID) + err := cleaner.kvClient.Remove(channelInfoKey) + if err != nil { + log.Warn("remove unsubscribe channelInfo from etcd failed", zap.Int64("nodeID", nodeID)) + return err + } + delete(cleaner.tasks, nodeID) + log.Info("unsubscribe channels success", zap.Int64("nodeID", nodeID)) + return nil } -func (csh *channelUnsubscribeHandler) close() { - csh.cancel() - csh.wg.Wait() +// check if there exists any unsubscribe task for specified channel +func (cleaner *ChannelCleaner) isNodeChannelCleanHandled(nodeID UniqueID) bool { + cleaner.taskMutex.RLock() + defer cleaner.taskMutex.RUnlock() + _, ok := cleaner.tasks[nodeID] + return !ok +} + +func (cleaner *ChannelCleaner) start() { + cleaner.wg.Add(1) + go cleaner.handleChannelCleanLoop() +} + +func (cleaner *ChannelCleaner) close() { + cleaner.taskMutex.Lock() + cleaner.closed = true + close(cleaner.notify) + cleaner.taskMutex.Unlock() + cleaner.cancel() + cleaner.wg.Wait() } diff --git a/internal/querycoord/channel_unsubscribe_test.go b/internal/querycoord/channel_unsubscribe_test.go index f43a443c2c..0f848e5838 100644 --- a/internal/querycoord/channel_unsubscribe_test.go +++ b/internal/querycoord/channel_unsubscribe_test.go @@ -49,9 +49,10 @@ func Test_HandlerReloadFromKV(t *testing.T) { assert.Nil(t, err) factory := dependency.NewDefaultFactory(true) - handler, err := newChannelUnsubscribeHandler(baseCtx, kv, factory) + cleaner, err := NewChannelCleaner(baseCtx, kv, factory) assert.Nil(t, err) - assert.Equal(t, 1, len(handler.downNodeChan)) + + assert.False(t, cleaner.isNodeChannelCleanHandled(defaultQueryNodeID)) cancel() } @@ -64,7 +65,7 @@ func Test_AddUnsubscribeChannelInfo(t *testing.T) { defer etcdCli.Close() kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath) factory := dependency.NewDefaultFactory(true) - handler, err := newChannelUnsubscribeHandler(baseCtx, kv, factory) + cleaner, err := NewChannelCleaner(baseCtx, kv, factory) assert.Nil(t, err) collectionChannels := &querypb.UnsubscribeChannels{ @@ -76,14 +77,12 @@ func Test_AddUnsubscribeChannelInfo(t *testing.T) { CollectionChannels: []*querypb.UnsubscribeChannels{collectionChannels}, } - handler.addUnsubscribeChannelInfo(unsubscribeChannelInfo) - frontValue := handler.channelInfos.Front() - assert.NotNil(t, frontValue) - assert.Equal(t, defaultQueryNodeID, frontValue.Value.(*querypb.UnsubscribeChannelInfo).NodeID) + cleaner.addUnsubscribeChannelInfo(unsubscribeChannelInfo) + assert.Equal(t, len(cleaner.tasks), 1) // repeat nodeID which has down - handler.addUnsubscribeChannelInfo(unsubscribeChannelInfo) - assert.Equal(t, 1, len(handler.downNodeChan)) + cleaner.addUnsubscribeChannelInfo(unsubscribeChannelInfo) + assert.Equal(t, len(cleaner.tasks), 1) cancel() } @@ -96,7 +95,7 @@ func Test_HandleChannelUnsubscribeLoop(t *testing.T) { defer etcdCli.Close() kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath) factory := dependency.NewDefaultFactory(true) - handler, err := newChannelUnsubscribeHandler(baseCtx, kv, factory) + handler, err := NewChannelCleaner(baseCtx, kv, factory) assert.Nil(t, err) collectionChannels := &querypb.UnsubscribeChannels{ @@ -116,7 +115,7 @@ func Test_HandleChannelUnsubscribeLoop(t *testing.T) { handler.start() for { - _, err = kv.Load(channelInfoKey) + _, err := kv.Load(channelInfoKey) if err != nil { break } diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index 57a10d25dd..057208875e 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -42,10 +42,6 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) -const ( - queryNodeInfoPrefix = "queryCoord-queryNodeInfo" -) - // Cluster manages all query node connections and grpc requests type Cluster interface { // Collection/Parition @@ -107,14 +103,14 @@ type queryNodeCluster struct { sync.RWMutex clusterMeta Meta - handler *channelUnsubscribeHandler + cleaner *ChannelCleaner nodes map[int64]Node newNodeFn newQueryNodeFn segmentAllocator SegmentAllocatePolicy channelAllocator ChannelAllocatePolicy } -func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn, session *sessionutil.Session, handler *channelUnsubscribeHandler) (Cluster, error) { +func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn, session *sessionutil.Session, cleaner *ChannelCleaner) (Cluster, error) { childCtx, cancel := context.WithCancel(ctx) nodes := make(map[int64]Node) c := &queryNodeCluster{ @@ -123,7 +119,7 @@ func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdK client: kv, session: session, clusterMeta: clusterMeta, - handler: handler, + cleaner: cleaner, nodes: nodes, newNodeFn: newNodeFn, segmentAllocator: defaultSegAllocatePolicy(), @@ -543,13 +539,14 @@ func (c *queryNodeCluster) setNodeState(nodeID int64, node Node, state nodeState // 2.add unsubscribed channels to handler, handler will auto unsubscribe channel if len(unsubscribeChannelInfo.CollectionChannels) != 0 { - c.handler.addUnsubscribeChannelInfo(unsubscribeChannelInfo) + c.cleaner.addUnsubscribeChannelInfo(unsubscribeChannelInfo) } } node.setState(state) } +// TODO, registerNode return error is not handled correctly func (c *queryNodeCluster) RegisterNode(ctx context.Context, session *sessionutil.Session, id UniqueID, state nodeState) error { c.Lock() defer c.Unlock() diff --git a/internal/querycoord/cluster_test.go b/internal/querycoord/cluster_test.go index 53cef1ad88..ec590f794e 100644 --- a/internal/querycoord/cluster_test.go +++ b/internal/querycoord/cluster_test.go @@ -369,7 +369,7 @@ func TestReloadClusterFromKV(t *testing.T) { clusterSession.Init(typeutil.QueryCoordRole, Params.QueryCoordCfg.Address, true, false) clusterSession.Register() factory := dependency.NewDefaultFactory(true) - handler, err := newChannelUnsubscribeHandler(ctx, kv, factory) + cleaner, err := NewChannelCleaner(ctx, kv, factory) assert.Nil(t, err) id := UniqueID(rand.Int31()) idAllocator := func() (UniqueID, error) { @@ -381,7 +381,7 @@ func TestReloadClusterFromKV(t *testing.T) { cluster := &queryNodeCluster{ client: kv, - handler: handler, + cleaner: cleaner, clusterMeta: meta, nodes: make(map[int64]Node), newNodeFn: newQueryNodeTest, @@ -439,7 +439,7 @@ func TestGrpcRequest(t *testing.T) { err = meta.setDeltaChannel(defaultCollectionID, deltaChannelInfo) assert.Nil(t, err) - handler, err := newChannelUnsubscribeHandler(baseCtx, kv, factory) + cleaner, err := NewChannelCleaner(baseCtx, kv, factory) assert.Nil(t, err) var cluster Cluster = &queryNodeCluster{ @@ -447,7 +447,7 @@ func TestGrpcRequest(t *testing.T) { cancel: cancel, client: kv, clusterMeta: meta, - handler: handler, + cleaner: cleaner, nodes: make(map[int64]Node), newNodeFn: newQueryNodeTest, session: clusterSession, @@ -609,7 +609,7 @@ func TestSetNodeState(t *testing.T) { meta, err := newMeta(baseCtx, kv, factory, idAllocator) assert.Nil(t, err) - handler, err := newChannelUnsubscribeHandler(baseCtx, kv, factory) + cleaner, err := NewChannelCleaner(baseCtx, kv, factory) assert.Nil(t, err) cluster := &queryNodeCluster{ @@ -617,7 +617,7 @@ func TestSetNodeState(t *testing.T) { cancel: cancel, client: kv, clusterMeta: meta, - handler: handler, + cleaner: cleaner, nodes: make(map[int64]Node), newNodeFn: newQueryNodeTest, session: clusterSession, @@ -647,7 +647,7 @@ func TestSetNodeState(t *testing.T) { nodeInfo, err := cluster.GetNodeInfoByID(node.queryNodeID) assert.Nil(t, err) cluster.setNodeState(node.queryNodeID, nodeInfo, offline) - assert.Equal(t, 1, len(handler.downNodeChan)) + assert.Equal(t, 1, len(cleaner.tasks)) node.stop() removeAllSession() diff --git a/internal/querycoord/const.go b/internal/querycoord/const.go new file mode 100644 index 0000000000..da1d322fb0 --- /dev/null +++ b/internal/querycoord/const.go @@ -0,0 +1,24 @@ +package querycoord + +import "time" + +const ( + collectionMetaPrefix = "queryCoord-collectionMeta" + dmChannelMetaPrefix = "queryCoord-dmChannelWatchInfo" + deltaChannelMetaPrefix = "queryCoord-deltaChannel" + ReplicaMetaPrefix = "queryCoord-ReplicaMeta" + + // TODO, we shouldn't separate querycoord tasks to 3 meta keys, there should only one with different states, otherwise there will be a high possibility to be inconsitent + triggerTaskPrefix = "queryCoord-triggerTask" + activeTaskPrefix = "queryCoord-activeTask" + taskInfoPrefix = "queryCoord-taskInfo" + + queryNodeInfoPrefix = "queryCoord-queryNodeInfo" + // TODO, remove unsubscribe + unsubscribeChannelInfoPrefix = "queryCoord-unsubscribeChannelInfo" + timeoutForRPC = 10 * time.Second + // MaxSendSizeToEtcd is the default limit size of etcd messages that can be sent and received + // MaxSendSizeToEtcd = 2097152 + // Limit size of every loadSegmentReq to 200k + MaxSendSizeToEtcd = 200000 +) diff --git a/internal/querycoord/handoff_handler.go b/internal/querycoord/handoff_handler.go index 51dfe0369c..b17ed16ecb 100644 --- a/internal/querycoord/handoff_handler.go +++ b/internal/querycoord/handoff_handler.go @@ -76,7 +76,8 @@ type HandoffHandler struct { taskMutex sync.Mutex tasks map[int64]*HandOffTask - notify chan bool + notify chan struct{} + closed bool meta Meta scheduler *TaskScheduler @@ -96,7 +97,7 @@ func newHandoffHandler(ctx context.Context, client kv.MetaKv, meta Meta, cluster client: client, tasks: make(map[int64]*HandOffTask, 1024), - notify: make(chan bool, 1024), + notify: make(chan struct{}, 1024), meta: meta, scheduler: scheduler, @@ -119,8 +120,11 @@ func (handler *HandoffHandler) Start() { } func (handler *HandoffHandler) Stop() { - handler.cancel() + handler.taskMutex.Lock() + handler.closed = true close(handler.notify) + handler.taskMutex.Unlock() + handler.cancel() handler.wg.Wait() } @@ -193,15 +197,19 @@ func (handler *HandoffHandler) verifyRequest(req *querypb.SegmentInfo) (bool, *q func (handler *HandoffHandler) enqueue(req *querypb.SegmentInfo) { handler.taskMutex.Lock() defer handler.taskMutex.Unlock() + if handler.closed { + return + } handler.tasks[req.SegmentID] = &HandOffTask{ req, handoffTaskInit, false, } - handler.notify <- false + handler.notify <- struct{}{} } func (handler *HandoffHandler) schedule() { defer handler.wg.Done() - timer := time.NewTimer(time.Second * 5) + ticker := time.NewTicker(time.Second * 5) + defer ticker.Stop() for { select { case <-handler.ctx.Done(): @@ -209,17 +217,21 @@ func (handler *HandoffHandler) schedule() { case _, ok := <-handler.notify: if ok { handler.taskMutex.Lock() + if len(handler.tasks) != 0 { + log.Info("handoff task scheduled: ", zap.Int("task number", len(handler.tasks))) + for segmentID := range handler.tasks { + handler.process(segmentID) + } + } + handler.taskMutex.Unlock() + } + case <-ticker.C: + handler.taskMutex.Lock() + if len(handler.tasks) != 0 { log.Info("handoff task scheduled: ", zap.Int("task number", len(handler.tasks))) for segmentID := range handler.tasks { handler.process(segmentID) } - handler.taskMutex.Unlock() - } - case <-timer.C: - handler.taskMutex.Lock() - log.Info("handoff task scheduled: ", zap.Int("task number", len(handler.tasks))) - for segmentID := range handler.tasks { - handler.process(segmentID) } handler.taskMutex.Unlock() } @@ -228,6 +240,9 @@ func (handler *HandoffHandler) schedule() { // must hold the lock func (handler *HandoffHandler) process(segmentID int64) error { + if handler.closed { + return nil + } task := handler.tasks[segmentID] // if task is cancel and success, clean up switch task.state { @@ -267,7 +282,7 @@ func (handler *HandoffHandler) process(segmentID int64) error { task.segmentInfo, handoffTaskReady, true, } handler.tasks[task.segmentInfo.SegmentID] = task - handler.notify <- false + handler.notify <- struct{}{} log.Info("HandoffHandler: enqueue indexed segments", zap.Int64("segmentID", task.segmentInfo.SegmentID)) } diff --git a/internal/querycoord/meta.go b/internal/querycoord/meta.go index faefafbd35..08b682f3d4 100644 --- a/internal/querycoord/meta.go +++ b/internal/querycoord/meta.go @@ -44,13 +44,6 @@ import ( "github.com/milvus-io/milvus/internal/util/funcutil" ) -const ( - collectionMetaPrefix = "queryCoord-collectionMeta" - dmChannelMetaPrefix = "queryCoord-dmChannelWatchInfo" - deltaChannelMetaPrefix = "queryCoord-deltaChannel" - ReplicaMetaPrefix = "queryCoord-ReplicaMeta" -) - type col2SegmentInfos = map[UniqueID][]*querypb.SegmentInfo type col2SealedSegmentChangeInfos = map[UniqueID]*querypb.SealedSegmentsChangeInfo @@ -537,7 +530,7 @@ func (m *MetaReplica) releaseCollection(collectionID UniqueID) error { } } m.segmentsInfo.mu.Unlock() - + log.Info("successfully release collection from meta", zap.Int64("collectionID", collectionID)) return nil } diff --git a/internal/querycoord/query_coord.go b/internal/querycoord/query_coord.go index c9ff67acf8..64a7d22d5d 100644 --- a/internal/querycoord/query_coord.go +++ b/internal/querycoord/query_coord.go @@ -41,6 +41,7 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" @@ -74,7 +75,7 @@ type QueryCoord struct { queryCoordID uint64 meta Meta cluster Cluster - handler *channelUnsubscribeHandler + channelCleaner *ChannelCleaner newNodeFn newQueryNodeFn scheduler *TaskScheduler idAllocator func() (UniqueID, error) @@ -91,6 +92,7 @@ type QueryCoord struct { session *sessionutil.Session eventChan <-chan *sessionutil.SessionEvent offlineNodesChan chan UniqueID + offlineNodes map[UniqueID]struct{} stateCode atomic.Value @@ -165,14 +167,14 @@ func (qc *QueryCoord) Init() error { } // init channelUnsubscribeHandler - qc.handler, initError = newChannelUnsubscribeHandler(qc.loopCtx, qc.kvClient, qc.factory) + qc.channelCleaner, initError = NewChannelCleaner(qc.loopCtx, qc.kvClient, qc.factory) if initError != nil { log.Error("query coordinator init channelUnsubscribeHandler failed", zap.Error(initError)) return } // init cluster - qc.cluster, initError = newQueryNodeCluster(qc.loopCtx, qc.meta, qc.kvClient, qc.newNodeFn, qc.session, qc.handler) + qc.cluster, initError = newQueryNodeCluster(qc.loopCtx, qc.meta, qc.kvClient, qc.newNodeFn, qc.session, qc.channelCleaner) if initError != nil { log.Error("query coordinator init cluster failed", zap.Error(initError)) return @@ -226,12 +228,15 @@ func (qc *QueryCoord) Start() error { qc.handoffHandler.Start() log.Info("start index checker ...") - qc.handler.start() - log.Info("start channel unsubscribe loop ...") + qc.channelCleaner.start() + log.Info("start channel cleaner loop ...") Params.QueryCoordCfg.CreatedTime = time.Now() Params.QueryCoordCfg.UpdatedTime = time.Now() + qc.loopWg.Add(1) + go qc.offlineNodeLoop() + qc.loopWg.Add(1) go qc.watchNodeLoop() @@ -262,9 +267,9 @@ func (qc *QueryCoord) Stop() error { qc.handoffHandler.Stop() } - if qc.handler != nil { - log.Info("close channel unsubscribe loop...") - qc.handler.close() + if qc.channelCleaner != nil { + log.Info("close channel cleaner loop...") + qc.channelCleaner.close() } if qc.loopCancel != nil { @@ -292,7 +297,8 @@ func NewQueryCoord(ctx context.Context, factory dependency.Factory) (*QueryCoord loopCancel: cancel, factory: factory, newNodeFn: newQueryNode, - offlineNodesChan: make(chan UniqueID, 100), + offlineNodesChan: make(chan UniqueID, 256), + offlineNodes: make(map[UniqueID]struct{}, 256), } service.UpdateStateCode(internalpb.StateCode_Abnormal) @@ -340,39 +346,15 @@ func (qc *QueryCoord) watchNodeLoop() { defer qc.loopWg.Done() log.Info("QueryCoord start watch node loop") - onlineNodes := qc.cluster.OnlineNodeIDs() - for _, node := range onlineNodes { - if err := qc.allocateNode(node); err != nil { - log.Warn("unable to allcoate node", zap.Int64("nodeID", node), zap.Error(err)) + // the only judgement of processing a offline node is 1) etcd queryNodeInfoPrefix exist 2) the querynode session not exist + offlineNodes := qc.cluster.OfflineNodeIDs() + if len(offlineNodes) != 0 { + log.Warn("find querynode down while coord not alive", zap.Any("nodeIDs", offlineNodes)) + for node := range offlineNodes { + qc.offlineNodesChan <- UniqueID(node) } } - go qc.loadBalanceNodeLoop(ctx) - offlineNodes := make(typeutil.UniqueSet) - collections := qc.meta.showCollections() - for _, collection := range collections { - for _, replicaID := range collection.ReplicaIds { - replica, err := qc.meta.getReplicaByID(replicaID) - if err != nil { - log.Warn("failed to get replica", - zap.Int64("replicaID", replicaID), - zap.Error(err)) - continue - } - - for _, node := range replica.NodeIds { - ok, err := qc.cluster.IsOnline(node) - if err != nil || !ok { - offlineNodes.Insert(node) - } - } - } - } - - for node := range offlineNodes { - qc.offlineNodesChan <- node - } - // TODO silverxia add Rewatch logic qc.eventChan = qc.session.WatchServices(typeutil.QueryNodeRole, qc.cluster.GetSessionVersion()+1, nil) qc.handleNodeEvent(ctx) @@ -442,64 +424,79 @@ func (qc *QueryCoord) handleNodeEvent(ctx context.Context) { } } -func (qc *QueryCoord) loadBalanceNodeLoop(ctx context.Context) { - const LoadBalanceRetryAfter = 100 * time.Millisecond +func (qc *QueryCoord) offlineNodeLoop() { + ctx, cancel := context.WithCancel(qc.loopCtx) + defer cancel() + defer qc.loopWg.Done() + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() for { select { case <-ctx.Done(): + log.Info("offline node loop exit") return - case node := <-qc.offlineNodesChan: - loadBalanceSegment := &querypb.LoadBalanceRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_LoadBalanceSegments, - SourceID: qc.session.ServerID, - }, - SourceNodeIDs: []int64{node}, - BalanceReason: querypb.TriggerCondition_NodeDown, - } - - baseTask := newBaseTaskWithRetry(qc.loopCtx, querypb.TriggerCondition_NodeDown, 0) - loadBalanceTask := &loadBalanceTask{ - baseTask: baseTask, - LoadBalanceRequest: loadBalanceSegment, - broker: qc.broker, - cluster: qc.cluster, - meta: qc.meta, - } - qc.metricsCacheManager.InvalidateSystemInfoMetrics() - - err := qc.scheduler.Enqueue(loadBalanceTask) - if err != nil { - log.Warn("failed to enqueue LoadBalance task into the scheduler", - zap.Int64("nodeID", node), - zap.Error(err)) - qc.offlineNodesChan <- node - time.Sleep(LoadBalanceRetryAfter) - continue - } - - log.Info("start a loadBalance task", - zap.Int64("nodeID", node), - zap.Int64("taskID", loadBalanceTask.getTaskID())) - - err = loadBalanceTask.waitToFinish() - if err != nil { - log.Warn("failed to process LoadBalance task", - zap.Int64("nodeID", node), - zap.Error(err)) - qc.offlineNodesChan <- node - time.Sleep(LoadBalanceRetryAfter) - continue - } - - log.Info("LoadBalance task done, offline node is removed", - zap.Int64("nodeID", node)) + qc.offlineNodes[node] = struct{}{} + qc.processOfflineNodes() + case <-ticker.C: + qc.processOfflineNodes() } } } +func (qc *QueryCoord) processOfflineNodes() { + for node := range qc.offlineNodes { + // check if all channel unsubscribe is handled, if not wait for next cycle + if !qc.channelCleaner.isNodeChannelCleanHandled(node) { + log.Info("node channel is not cleaned, skip offline processing", zap.Int64("node", node)) + continue + } + + loadBalanceSegment := &querypb.LoadBalanceRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadBalanceSegments, + SourceID: qc.session.ServerID, + }, + SourceNodeIDs: []int64{node}, + BalanceReason: querypb.TriggerCondition_NodeDown, + } + + baseTask := newBaseTaskWithRetry(qc.loopCtx, querypb.TriggerCondition_NodeDown, 0) + loadBalanceTask := &loadBalanceTask{ + baseTask: baseTask, + LoadBalanceRequest: loadBalanceSegment, + broker: qc.broker, + cluster: qc.cluster, + meta: qc.meta, + } + qc.metricsCacheManager.InvalidateSystemInfoMetrics() + + err := qc.scheduler.Enqueue(loadBalanceTask) + if err != nil { + log.Warn("failed to enqueue LoadBalance task into the scheduler", + zap.Int64("nodeID", node), + zap.Error(err)) + continue + } + + log.Info("start a loadBalance task", + zap.Int64("nodeID", node), + zap.Int64("taskID", loadBalanceTask.getTaskID())) + + err = loadBalanceTask.waitToFinish() + if err != nil { + log.Warn("failed to process LoadBalance task", + zap.Int64("nodeID", node), + zap.Error(err)) + continue + } + + delete(qc.offlineNodes, node) + log.Info("LoadBalance task done, offline node is removed", zap.Int64("nodeID", node)) + } +} + func (qc *QueryCoord) handoffNotificationLoop() { ctx, cancel := context.WithCancel(qc.loopCtx) @@ -566,164 +563,175 @@ func (qc *QueryCoord) loadBalanceSegmentLoop() { timer := time.NewTicker(time.Duration(Params.QueryCoordCfg.BalanceIntervalSeconds) * time.Second) - var collectionInfos []*querypb.CollectionInfo - pos := 0 - for { select { case <-ctx.Done(): return case <-timer.C: - if pos == len(collectionInfos) { - pos = 0 - collectionInfos = qc.meta.showCollections() + startTs := time.Now() + // do not trigger load balance if task queue is not empty + if !qc.scheduler.taskEmpty() { + continue } + + collectionInfos := qc.meta.showCollections() + // shuffle to avoid always balance the same collections + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(collectionInfos), func(i, j int) { + collectionInfos[i], collectionInfos[j] = collectionInfos[j], collectionInfos[i] + }) + // get mem info of online nodes from cluster nodeID2MemUsageRate := make(map[int64]float64) nodeID2MemUsage := make(map[int64]uint64) nodeID2TotalMem := make(map[int64]uint64) loadBalanceTasks := make([]*loadBalanceTask, 0) // balance at most 20 collections in a round - for i := 0; pos < len(collectionInfos) && i < 20; i, pos = i+1, pos+1 { - info := collectionInfos[pos] + for i := 0; i < len(collectionInfos) && i < 20; i++ { + info := collectionInfos[i] replicas, err := qc.meta.getReplicasByCollectionID(info.GetCollectionID()) if err != nil { log.Warn("unable to get replicas of collection", zap.Int64("collectionID", info.GetCollectionID())) continue } for _, replica := range replicas { - // auto balance is executed on replica level - onlineNodeIDs := replica.GetNodeIds() - if len(onlineNodeIDs) == 0 { - log.Error("loadBalanceSegmentLoop: there are no online QueryNode to balance", zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID)) - continue - } - var availableNodeIDs []int64 - nodeID2SegmentInfos := make(map[int64]map[UniqueID]*querypb.SegmentInfo) - for _, nodeID := range onlineNodeIDs { - if _, ok := nodeID2MemUsage[nodeID]; !ok { - nodeInfo, err := qc.cluster.GetNodeInfoByID(nodeID) - if err != nil { - log.Warn("loadBalanceSegmentLoop: get node info from QueryNode failed", - zap.Int64("nodeID", nodeID), zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID), - zap.Error(err)) - continue - } - nodeID2MemUsageRate[nodeID] = nodeInfo.(*queryNode).memUsageRate - nodeID2MemUsage[nodeID] = nodeInfo.(*queryNode).memUsage - nodeID2TotalMem[nodeID] = nodeInfo.(*queryNode).totalMem - } - - updateSegmentInfoDone := true - leastSegmentInfos := make(map[UniqueID]*querypb.SegmentInfo) - segmentInfos := qc.meta.getSegmentInfosByNodeAndCollection(nodeID, replica.GetCollectionID()) - for _, segmentInfo := range segmentInfos { - leastInfo, err := qc.cluster.GetSegmentInfoByID(ctx, segmentInfo.SegmentID) - if err != nil { - log.Warn("loadBalanceSegmentLoop: get segment info from QueryNode failed", zap.Int64("nodeID", nodeID), - zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID), - zap.Error(err)) - updateSegmentInfoDone = false - break - } - leastSegmentInfos[segmentInfo.SegmentID] = leastInfo - } - if updateSegmentInfoDone { - availableNodeIDs = append(availableNodeIDs, nodeID) - nodeID2SegmentInfos[nodeID] = leastSegmentInfos - } - } - log.Info("loadBalanceSegmentLoop: memory usage rate of all online QueryNode", zap.Int64("collection", replica.CollectionID), - zap.Int64("replica", replica.ReplicaID), zap.Any("mem rate", nodeID2MemUsageRate)) - if len(availableNodeIDs) <= 1 { - log.Info("loadBalanceSegmentLoop: there are too few available query nodes to balance", - zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID), - zap.Int64s("onlineNodeIDs", onlineNodeIDs), zap.Int64s("availableNodeIDs", availableNodeIDs)) - continue - } - - // check which nodes need balance and determine which segments on these nodes need to be migrated to other nodes - memoryInsufficient := false - for { - sort.Slice(availableNodeIDs, func(i, j int) bool { - return nodeID2MemUsageRate[availableNodeIDs[i]] > nodeID2MemUsageRate[availableNodeIDs[j]] - }) - - // the memoryUsageRate of the sourceNode is higher than other query node - sourceNodeID := availableNodeIDs[0] - dstNodeID := availableNodeIDs[len(availableNodeIDs)-1] - - memUsageRateDiff := nodeID2MemUsageRate[sourceNodeID] - nodeID2MemUsageRate[dstNodeID] - if nodeID2MemUsageRate[sourceNodeID] <= Params.QueryCoordCfg.OverloadedMemoryThresholdPercentage && - memUsageRateDiff <= Params.QueryCoordCfg.MemoryUsageMaxDifferencePercentage { - break - } - // if memoryUsageRate of source node is greater than 90%, and the max memUsageDiff is greater than 30% - // then migrate the segments on source node to other query nodes - segmentInfos := nodeID2SegmentInfos[sourceNodeID] - // select the segment that needs balance on the source node - selectedSegmentInfo, err := chooseSegmentToBalance(sourceNodeID, dstNodeID, segmentInfos, nodeID2MemUsage, nodeID2TotalMem, nodeID2MemUsageRate) - if err != nil { - // no enough memory on query nodes to balance, then notify proxy to stop insert - memoryInsufficient = true - break - } - if selectedSegmentInfo == nil { - break - } - // select a segment to balance successfully, then recursive traversal whether there are other segments that can balance - req := &querypb.LoadBalanceRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_LoadBalanceSegments, - }, - BalanceReason: querypb.TriggerCondition_LoadBalance, - SourceNodeIDs: []UniqueID{sourceNodeID}, - DstNodeIDs: []UniqueID{dstNodeID}, - SealedSegmentIDs: []UniqueID{selectedSegmentInfo.SegmentID}, - } - baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_LoadBalance) - balanceTask := &loadBalanceTask{ - baseTask: baseTask, - LoadBalanceRequest: req, - broker: qc.broker, - cluster: qc.cluster, - meta: qc.meta, - } - log.Info("loadBalanceSegmentLoop: generate a loadBalance task", - zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID), - zap.Any("task", balanceTask)) - loadBalanceTasks = append(loadBalanceTasks, balanceTask) - nodeID2MemUsage[sourceNodeID] -= uint64(selectedSegmentInfo.MemSize) - nodeID2MemUsage[dstNodeID] += uint64(selectedSegmentInfo.MemSize) - nodeID2MemUsageRate[sourceNodeID] = float64(nodeID2MemUsage[sourceNodeID]) / float64(nodeID2TotalMem[sourceNodeID]) - nodeID2MemUsageRate[dstNodeID] = float64(nodeID2MemUsage[dstNodeID]) / float64(nodeID2TotalMem[dstNodeID]) - delete(nodeID2SegmentInfos[sourceNodeID], selectedSegmentInfo.SegmentID) - nodeID2SegmentInfos[dstNodeID][selectedSegmentInfo.SegmentID] = selectedSegmentInfo - continue - } - if memoryInsufficient { - // no enough memory on query nodes to balance, then notify proxy to stop insert - //TODO:: xige-16 - log.Warn("loadBalanceSegmentLoop: QueryNode has insufficient memory, stop inserting data", zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID)) - } + loadBalanceTasks = append(loadBalanceTasks, qc.balanceReplica(ctx, replica, nodeID2MemUsageRate, nodeID2MemUsage, nodeID2TotalMem)...) } } for _, t := range loadBalanceTasks { - qc.scheduler.Enqueue(t) - err := t.waitToFinish() + err := qc.scheduler.Enqueue(t) + if err != nil { + log.Error("loadBalanceSegmentLoop: balance task enqueue failed", zap.Any("task", t), zap.Error(err)) + continue + } + err = t.waitToFinish() if err != nil { // if failed, wait for next balance loop // it may be that the collection/partition of the balanced segment has been released // it also may be other abnormal errors - log.Error("loadBalanceSegmentLoop: balance task execute failed", zap.Any("task", t)) + log.Error("loadBalanceSegmentLoop: balance task execute failed", zap.Any("task", t), zap.Error(err)) } else { log.Info("loadBalanceSegmentLoop: balance task execute success", zap.Any("task", t)) } } + log.Info("finish balance loop successfully", zap.Duration("time spent", time.Since(startTs))) } } } +// TODO balance replica need to be optimized, we can not get segment info in evert balance round +func (qc *QueryCoord) balanceReplica(ctx context.Context, replica *milvuspb.ReplicaInfo, nodeID2MemUsageRate map[int64]float64, + nodeID2MemUsage map[int64]uint64, nodeID2TotalMem map[int64]uint64) []*loadBalanceTask { + loadBalanceTasks := make([]*loadBalanceTask, 0) + // auto balance is executed on replica level + onlineNodeIDs := replica.GetNodeIds() + if len(onlineNodeIDs) == 0 { + log.Error("loadBalanceSegmentLoop: there are no online QueryNode to balance", zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID)) + return loadBalanceTasks + } + var availableNodeIDs []int64 + nodeID2SegmentInfos := make(map[int64]map[UniqueID]*querypb.SegmentInfo) + for _, nodeID := range onlineNodeIDs { + if _, ok := nodeID2MemUsage[nodeID]; !ok { + nodeInfo, err := qc.cluster.GetNodeInfoByID(nodeID) + if err != nil { + log.Warn("loadBalanceSegmentLoop: get node info from QueryNode failed", + zap.Int64("nodeID", nodeID), zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID), + zap.Error(err)) + continue + } + nodeID2MemUsageRate[nodeID] = nodeInfo.(*queryNode).memUsageRate + nodeID2MemUsage[nodeID] = nodeInfo.(*queryNode).memUsage + nodeID2TotalMem[nodeID] = nodeInfo.(*queryNode).totalMem + } + + updateSegmentInfoDone := true + leastSegmentInfos := make(map[UniqueID]*querypb.SegmentInfo) + segmentInfos := qc.meta.getSegmentInfosByNodeAndCollection(nodeID, replica.GetCollectionID()) + for _, segmentInfo := range segmentInfos { + leastInfo, err := qc.cluster.GetSegmentInfoByID(ctx, segmentInfo.SegmentID) + if err != nil { + log.Warn("loadBalanceSegmentLoop: get segment info from QueryNode failed", zap.Int64("nodeID", nodeID), + zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID), + zap.Error(err)) + updateSegmentInfoDone = false + break + } + leastSegmentInfos[segmentInfo.SegmentID] = leastInfo + } + if updateSegmentInfoDone { + availableNodeIDs = append(availableNodeIDs, nodeID) + nodeID2SegmentInfos[nodeID] = leastSegmentInfos + } + } + log.Info("loadBalanceSegmentLoop: memory usage rate of all online QueryNode", zap.Int64("collection", replica.CollectionID), + zap.Int64("replica", replica.ReplicaID), zap.Any("mem rate", nodeID2MemUsageRate)) + if len(availableNodeIDs) <= 1 { + log.Info("loadBalanceSegmentLoop: there are too few available query nodes to balance", + zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID), + zap.Int64s("onlineNodeIDs", onlineNodeIDs), zap.Int64s("availableNodeIDs", availableNodeIDs)) + return loadBalanceTasks + } + + // check which nodes need balance and determine which segments on these nodes need to be migrated to other nodes + for { + sort.Slice(availableNodeIDs, func(i, j int) bool { + return nodeID2MemUsageRate[availableNodeIDs[i]] > nodeID2MemUsageRate[availableNodeIDs[j]] + }) + + // the memoryUsageRate of the sourceNode is higher than other query node + sourceNodeID := availableNodeIDs[0] + dstNodeID := availableNodeIDs[len(availableNodeIDs)-1] + + memUsageRateDiff := nodeID2MemUsageRate[sourceNodeID] - nodeID2MemUsageRate[dstNodeID] + if nodeID2MemUsageRate[sourceNodeID] <= Params.QueryCoordCfg.OverloadedMemoryThresholdPercentage && + memUsageRateDiff <= Params.QueryCoordCfg.MemoryUsageMaxDifferencePercentage { + break + } + // if memoryUsageRate of source node is greater than 90%, and the max memUsageDiff is greater than 30% + // then migrate the segments on source node to other query nodes + segmentInfos := nodeID2SegmentInfos[sourceNodeID] + // select the segment that needs balance on the source node + selectedSegmentInfo, err := chooseSegmentToBalance(sourceNodeID, dstNodeID, segmentInfos, nodeID2MemUsage, nodeID2TotalMem, nodeID2MemUsageRate) + if err != nil { + break + } + if selectedSegmentInfo == nil { + break + } + // select a segment to balance successfully, then recursive traversal whether there are other segments that can balance + req := &querypb.LoadBalanceRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadBalanceSegments, + }, + BalanceReason: querypb.TriggerCondition_LoadBalance, + SourceNodeIDs: []UniqueID{sourceNodeID}, + DstNodeIDs: []UniqueID{dstNodeID}, + SealedSegmentIDs: []UniqueID{selectedSegmentInfo.SegmentID}, + } + baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_LoadBalance) + balanceTask := &loadBalanceTask{ + baseTask: baseTask, + LoadBalanceRequest: req, + broker: qc.broker, + cluster: qc.cluster, + meta: qc.meta, + } + log.Info("loadBalanceSegmentLoop: generate a loadBalance task", + zap.Int64("collection", replica.CollectionID), zap.Int64("replica", replica.ReplicaID), + zap.Any("task", balanceTask)) + loadBalanceTasks = append(loadBalanceTasks, balanceTask) + nodeID2MemUsage[sourceNodeID] -= uint64(selectedSegmentInfo.MemSize) + nodeID2MemUsage[dstNodeID] += uint64(selectedSegmentInfo.MemSize) + nodeID2MemUsageRate[sourceNodeID] = float64(nodeID2MemUsage[sourceNodeID]) / float64(nodeID2TotalMem[sourceNodeID]) + nodeID2MemUsageRate[dstNodeID] = float64(nodeID2MemUsage[dstNodeID]) / float64(nodeID2TotalMem[dstNodeID]) + delete(nodeID2SegmentInfos[sourceNodeID], selectedSegmentInfo.SegmentID) + nodeID2SegmentInfos[dstNodeID][selectedSegmentInfo.SegmentID] = selectedSegmentInfo + continue + } + return loadBalanceTasks +} + func chooseSegmentToBalance(sourceNodeID int64, dstNodeID int64, segmentInfos map[UniqueID]*querypb.SegmentInfo, nodeID2MemUsage map[int64]uint64, diff --git a/internal/querycoord/query_coord_test.go b/internal/querycoord/query_coord_test.go index 514a6a7e26..76c20e4506 100644 --- a/internal/querycoord/query_coord_test.go +++ b/internal/querycoord/query_coord_test.go @@ -59,6 +59,7 @@ func refreshParams() { Params.CommonCfg.QueryCoordTimeTick = Params.CommonCfg.QueryCoordTimeTick + suffix Params.EtcdCfg.MetaRootPath = Params.EtcdCfg.MetaRootPath + suffix GlobalSegmentInfos = make(map[UniqueID]*querypb.SegmentInfo) + Params.QueryCoordCfg.RetryInterval = int64(1 * time.Millisecond) } func TestMain(m *testing.M) { @@ -565,10 +566,8 @@ func TestHandoffSegmentLoop(t *testing.T) { func TestLoadBalanceSegmentLoop(t *testing.T) { refreshParams() defer removeAllSession() - Params.QueryCoordCfg.BalanceIntervalSeconds = 10 baseCtx := context.Background() - queryCoord, err := startQueryCoord(baseCtx) assert.Nil(t, err) queryCoord.cluster.(*queryNodeCluster).segmentAllocator = shuffleSegmentsToQueryNode @@ -623,8 +622,7 @@ func TestLoadBalanceSegmentLoop(t *testing.T) { queryNode1.getMetrics = returnSuccessGetMetricsResult break } - - time.Sleep(time.Second) + time.Sleep(100 * time.Millisecond) } } diff --git a/internal/querycoord/querynode_test.go b/internal/querycoord/querynode_test.go index be4225cc28..269e40c1c7 100644 --- a/internal/querycoord/querynode_test.go +++ b/internal/querycoord/querynode_test.go @@ -286,6 +286,7 @@ func TestSealedSegmentChangeAfterQueryNodeStop(t *testing.T) { if recoverDone { break } + time.Sleep(100 * time.Millisecond) } queryCoord.Stop() diff --git a/internal/querycoord/segment_allocator_test.go b/internal/querycoord/segment_allocator_test.go index 587e35217a..9f37089cec 100644 --- a/internal/querycoord/segment_allocator_test.go +++ b/internal/querycoord/segment_allocator_test.go @@ -52,14 +52,14 @@ func TestShuffleSegmentsToQueryNode(t *testing.T) { } meta, err := newMeta(baseCtx, kv, factory, idAllocator) assert.Nil(t, err) - handler, err := newChannelUnsubscribeHandler(baseCtx, kv, factory) + cleaner, err := NewChannelCleaner(baseCtx, kv, factory) assert.Nil(t, err) cluster := &queryNodeCluster{ ctx: baseCtx, cancel: cancel, client: kv, clusterMeta: meta, - handler: handler, + cleaner: cleaner, nodes: make(map[int64]Node), newNodeFn: newQueryNodeTest, session: clusterSession, diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index d7746ed2ee..694a6e540b 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -22,6 +22,7 @@ import ( "fmt" "sort" "sync" + "sync/atomic" "time" "github.com/golang/protobuf/proto" @@ -39,24 +40,6 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) -const timeoutForRPC = 10 * time.Second - -const ( - triggerTaskPrefix = "queryCoord-triggerTask" - activeTaskPrefix = "queryCoord-activeTask" - taskInfoPrefix = "queryCoord-taskInfo" - loadBalanceInfoPrefix = "queryCoord-loadBalanceInfo" -) - -const ( - // MaxRetryNum is the maximum number of times that each task can be retried - MaxRetryNum = 5 - // MaxSendSizeToEtcd is the default limit size of etcd messages that can be sent and received - // MaxSendSizeToEtcd = 2097152 - // Limit size of every loadSegmentReq to 200k - MaxSendSizeToEtcd = 200000 -) - var ( ErrCollectionLoaded = errors.New("CollectionLoaded") ErrLoadParametersMismatch = errors.New("LoadParametersMismatch") @@ -115,9 +98,7 @@ type baseTask struct { resultMu sync.RWMutex state taskState stateMu sync.RWMutex - retryCount int - retryMu sync.RWMutex - //sync.RWMutex + retryCount int32 taskID UniqueID triggerCondition querypb.TriggerCondition @@ -138,7 +119,7 @@ func newBaseTask(ctx context.Context, triggerType querypb.TriggerCondition) *bas cancel: cancel, condition: condition, state: taskUndo, - retryCount: MaxRetryNum, + retryCount: Params.QueryCoordCfg.RetryNum, triggerCondition: triggerType, childTasks: []task{}, timeRecorder: timerecord.NewTimeRecorder("QueryCoordBaseTask"), @@ -147,7 +128,7 @@ func newBaseTask(ctx context.Context, triggerType querypb.TriggerCondition) *bas return baseTask } -func newBaseTaskWithRetry(ctx context.Context, triggerType querypb.TriggerCondition, retryCount int) *baseTask { +func newBaseTaskWithRetry(ctx context.Context, triggerType querypb.TriggerCondition, retryCount int32) *baseTask { baseTask := newBaseTask(ctx, triggerType) baseTask.retryCount = retryCount return baseTask @@ -252,16 +233,11 @@ func (bt *baseTask) setState(state taskState) { } func (bt *baseTask) isRetryable() bool { - bt.retryMu.RLock() - defer bt.retryMu.RUnlock() - return bt.retryCount > 0 + return atomic.LoadInt32(&bt.retryCount) > 0 } func (bt *baseTask) reduceRetryCount() { - bt.retryMu.Lock() - defer bt.retryMu.Unlock() - - bt.retryCount-- + atomic.AddInt32(&bt.retryCount, -1) } func (bt *baseTask) setResultInfo(err error) { @@ -733,7 +709,7 @@ func (rct *releaseCollectionTask) updateTaskProcess() { func (rct *releaseCollectionTask) preExecute(context.Context) error { collectionID := rct.CollectionID rct.setResultInfo(nil) - log.Info("start do releaseCollectionTask", + log.Info("pre execute releaseCollectionTask", zap.Int64("taskID", rct.getTaskID()), zap.Int64("msgID", rct.GetBase().GetMsgID()), zap.Int64("collectionID", collectionID)) @@ -767,14 +743,14 @@ func (rct *releaseCollectionTask) execute(ctx context.Context) error { } rct.addChildTask(releaseCollectionTask) - log.Info("releaseCollectionTask: add a releaseCollectionTask to releaseCollectionTask's childTask", zap.Any("task", releaseCollectionTask)) + log.Info("releaseCollectionTask: add a releaseCollectionTask to releaseCollectionTask's childTask", zap.Any("task", releaseCollectionTask), zap.Int64("NodeID", nodeID)) } } else { // If the node crashed or be offline, the loaded segments are lost defer rct.reduceRetryCount() err := rct.cluster.ReleaseCollection(ctx, rct.NodeID, rct.ReleaseCollectionRequest) if err != nil { - log.Warn("releaseCollectionTask: release collection end, node occur error", zap.Int64("collectionID", collectionID), zap.Int64("nodeID", rct.NodeID)) + log.Warn("releaseCollectionTask: release collection end, node occur error", zap.Int64("collectionID", collectionID), zap.Int64("nodeID", rct.NodeID), zap.Error(err)) // after release failed, the task will always redo // if the query node happens to be down, the node release was judged to have succeeded return err diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index ad982ec200..4b4ca0d7cb 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -22,9 +22,9 @@ import ( "errors" "fmt" "path/filepath" - "reflect" "strconv" "sync" + "time" "github.com/golang/protobuf/proto" "github.com/opentracing/opentracing-go" @@ -170,7 +170,11 @@ type TaskScheduler struct { broker *globalMetaBroker - wg sync.WaitGroup + wg sync.WaitGroup + + closed bool + closeMutex sync.Mutex + ctx context.Context cancel context.CancelFunc } @@ -231,6 +235,7 @@ func (scheduler *TaskScheduler) reloadFromKV() error { if err != nil { return err } + log.Info("find one trigger task key", zap.Int64("id", taskID), zap.Any("task", t)) triggerTasks[taskID] = t } @@ -244,6 +249,7 @@ func (scheduler *TaskScheduler) reloadFromKV() error { if err != nil { return err } + log.Info("find one active task key", zap.Int64("id", taskID), zap.Any("task", t)) activeTasks[taskID] = t } @@ -461,6 +467,11 @@ func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, // Enqueue pushs a trigger task to triggerTaskQueue and assigns task id func (scheduler *TaskScheduler) Enqueue(t task) error { + scheduler.closeMutex.Lock() + defer scheduler.closeMutex.Unlock() + if scheduler.closed { + return fmt.Errorf("querycoord task scheduler is already closed") + } // TODO, loadbalance, handoff and other task may not want to be persisted id, err := scheduler.taskIDAllocator() if err != nil { @@ -486,13 +497,13 @@ func (scheduler *TaskScheduler) Enqueue(t task) error { } t.setState(taskUndo) scheduler.triggerTaskQueue.addTask(t) - log.Debug("EnQueue a triggerTask and save to etcd", zap.Int64("taskID", t.getTaskID())) + log.Info("EnQueue a triggerTask and save to etcd", zap.Int64("taskID", t.getTaskID()), zap.Any("task", t)) return nil } func (scheduler *TaskScheduler) processTask(t task) error { - log.Info("begin to process task", zap.Int64("taskID", t.getTaskID()), zap.String("task", reflect.TypeOf(t).String())) + log.Info("begin to process task", zap.Int64("taskID", t.getTaskID()), zap.Any("task", t)) var taskInfoKey string // assign taskID for childTask and update triggerTask's childTask to etcd updateKVFn := func(parentTask task) error { @@ -566,8 +577,7 @@ func (scheduler *TaskScheduler) processTask(t task) error { span.LogFields(oplog.Int64("processTask: scheduler process PreExecute", t.getTaskID())) err = t.preExecute(ctx) if err != nil { - log.Error("failed to preExecute task", - zap.Error(err)) + log.Error("failed to preExecute task", zap.Int64("taskID", t.getTaskID()), zap.Error(err)) t.setResultInfo(err) return err } @@ -575,6 +585,7 @@ func (scheduler *TaskScheduler) processTask(t task) error { err = scheduler.client.Save(taskInfoKey, strconv.Itoa(int(taskDoing))) if err != nil { trace.LogError(span, err) + log.Warn("failed to save task info ", zap.Int64("taskID", t.getTaskID()), zap.Error(err)) t.setResultInfo(err) return err } @@ -584,13 +595,13 @@ func (scheduler *TaskScheduler) processTask(t task) error { span.LogFields(oplog.Int64("processTask: scheduler process Execute", t.getTaskID())) err = t.execute(ctx) if err != nil { - log.Warn("failed to execute task", zap.Error(err)) + log.Warn("failed to execute task", zap.Int64("taskID", t.getTaskID()), zap.Error(err)) trace.LogError(span, err) return err } err = updateKVFn(t) if err != nil { - log.Warn("failed to execute task", zap.Error(err)) + log.Warn("failed to update kv", zap.Int64("taskID", t.getTaskID()), zap.Error(err)) trace.LogError(span, err) t.setResultInfo(err) return err @@ -618,7 +629,6 @@ func (scheduler *TaskScheduler) scheduleLoop() { ) for _, childTask := range activateTasks { if childTask != nil { - log.Debug("scheduleLoop: add an activate task to activateChan", zap.Int64("taskID", childTask.getTaskID())) scheduler.activateTaskChan <- childTask activeTaskWg.Add(1) go scheduler.waitActivateTaskDone(activeTaskWg, childTask, triggerTask) @@ -658,6 +668,15 @@ func (scheduler *TaskScheduler) scheduleLoop() { select { case <-scheduler.ctx.Done(): scheduler.stopActivateTaskLoopChan <- 1 + // drain all trigger task queue + triggerTask = scheduler.triggerTaskQueue.popTask() + for triggerTask != nil { + log.Warn("scheduler exit, set all trigger task queue to error and notify", zap.Int64("taskID", triggerTask.getTaskID())) + err := fmt.Errorf("scheduler exiting error") + triggerTask.setResultInfo(err) + triggerTask.notify(err) + triggerTask = scheduler.triggerTaskQueue.popTask() + } return case <-scheduler.triggerTaskQueue.Chan(): triggerTask = scheduler.triggerTaskQueue.popTask() @@ -679,6 +698,7 @@ func (scheduler *TaskScheduler) scheduleLoop() { triggerTask.setState(taskExpired) if errors.Is(err, ErrLoadParametersMismatch) { + log.Warn("hit param error when load ", zap.Int64("taskId", triggerTask.getTaskID()), zap.Any("task", triggerTask)) triggerTask.setState(taskFailed) } @@ -725,6 +745,7 @@ func (scheduler *TaskScheduler) scheduleLoop() { if triggerTask.getResultInfo().ErrorCode == commonpb.ErrorCode_Success || triggerTask.getTriggerCondition() == querypb.TriggerCondition_NodeDown { err = updateSegmentInfoFromTask(scheduler.ctx, triggerTask, scheduler.meta) if err != nil { + log.Warn("failed to update segment info", zap.Int64("taskID", triggerTask.getTaskID()), zap.Error(err)) triggerTask.setResultInfo(err) } } @@ -766,6 +787,7 @@ func (scheduler *TaskScheduler) scheduleLoop() { resultStatus := triggerTask.getResultInfo() if resultStatus.ErrorCode != commonpb.ErrorCode_Success { + log.Warn("task states not succeed", zap.Int64("taskId", triggerTask.getTaskID()), zap.Any("task", triggerTask), zap.Any("status", resultStatus)) triggerTask.setState(taskFailed) if !alreadyNotify { triggerTask.notify(errors.New(resultStatus.Reason)) @@ -836,7 +858,7 @@ func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task, } err = scheduler.client.MultiSaveAndRemove(saves, removes) if err != nil { - log.Error("waitActivateTaskDone: error when save and remove task from etcd", zap.Int64("triggerTaskID", triggerTask.getTaskID())) + log.Warn("waitActivateTaskDone: error when save and remove task from etcd", zap.Int64("triggerTaskID", triggerTask.getTaskID()), zap.Error(err)) triggerTask.setResultInfo(err) return } @@ -846,13 +868,16 @@ func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task, zap.Int64("failed taskID", t.getTaskID()), zap.Any("reScheduled tasks", reScheduledTasks)) - for _, rt := range reScheduledTasks { - if rt != nil { - triggerTask.addChildTask(rt) - log.Info("waitActivateTaskDone: add a reScheduled active task to activateChan", zap.Int64("taskID", rt.getTaskID())) - scheduler.activateTaskChan <- rt + for _, t := range reScheduledTasks { + if t != nil { + triggerTask.addChildTask(t) + log.Info("waitActivateTaskDone: add a reScheduled active task to activateChan", zap.Int64("taskID", t.getTaskID())) + go func() { + time.Sleep(time.Duration(Params.QueryCoordCfg.RetryInterval)) + scheduler.activateTaskChan <- t + }() wg.Add(1) - go scheduler.waitActivateTaskDone(wg, rt, triggerTask) + go scheduler.waitActivateTaskDone(wg, t, triggerTask) } } //delete task from etcd @@ -860,7 +885,10 @@ func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task, log.Info("waitActivateTaskDone: retry the active task", zap.Int64("taskID", t.getTaskID()), zap.Int64("triggerTaskID", triggerTask.getTaskID())) - scheduler.activateTaskChan <- t + go func() { + time.Sleep(time.Duration(Params.QueryCoordCfg.RetryInterval)) + scheduler.activateTaskChan <- t + }() wg.Add(1) go scheduler.waitActivateTaskDone(wg, t, triggerTask) } @@ -871,14 +899,20 @@ func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task, if !t.isRetryable() { log.Error("waitActivateTaskDone: activate task failed after retry", zap.Int64("taskID", t.getTaskID()), - zap.Int64("triggerTaskID", triggerTask.getTaskID())) + zap.Int64("triggerTaskID", triggerTask.getTaskID()), + zap.Error(err), + ) triggerTask.setResultInfo(err) return } log.Info("waitActivateTaskDone: retry the active task", zap.Int64("taskID", t.getTaskID()), zap.Int64("triggerTaskID", triggerTask.getTaskID())) - scheduler.activateTaskChan <- t + + go func() { + time.Sleep(time.Duration(Params.QueryCoordCfg.RetryInterval)) + scheduler.activateTaskChan <- t + }() wg.Add(1) go scheduler.waitActivateTaskDone(wg, t, triggerTask) } @@ -950,12 +984,19 @@ func (scheduler *TaskScheduler) Start() error { // Close function stops the scheduleLoop and the processActivateTaskLoop func (scheduler *TaskScheduler) Close() { + scheduler.closeMutex.Lock() + defer scheduler.closeMutex.Unlock() + scheduler.closed = true if scheduler.cancel != nil { scheduler.cancel() } scheduler.wg.Wait() } +func (scheduler *TaskScheduler) taskEmpty() bool { + return scheduler.triggerTaskQueue.taskEmpty() +} + // BindContext binds input context with shceduler context. // the result context will be canceled when either context is done. func (scheduler *TaskScheduler) BindContext(ctx context.Context) (context.Context, context.CancelFunc) { diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 0489f8f71c..b842a6eb94 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -23,9 +23,6 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/internal/util/etcd" - "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -281,9 +278,12 @@ func waitTaskFinalState(t task, state taskState) { break } - log.Debug("task state not match es", + log.Debug("task state not matches", + zap.Int64("task ID", t.getTaskID()), zap.Int("actual", int(currentState)), zap.Int("expected", int(state))) + + time.Sleep(100 * time.Millisecond) } } @@ -661,6 +661,7 @@ func Test_LoadPartitionExecuteFailAfterLoadCollection(t *testing.T) { func Test_ReleaseCollectionExecuteFail(t *testing.T) { refreshParams() + Params.QueryCoordCfg.RetryInterval = int64(100 * time.Millisecond) ctx := context.Background() queryCoord, err := startQueryCoord(ctx) assert.Nil(t, err) @@ -677,6 +678,7 @@ func Test_ReleaseCollectionExecuteFail(t *testing.T) { releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord) notify := make(chan struct{}) go func() { + time.Sleep(100 * time.Millisecond) waitTaskFinalState(releaseCollectionTask, taskDone) node.setRPCInterface(&node.releaseCollection, returnSuccessResult) waitTaskFinalState(releaseCollectionTask, taskExpired) @@ -1446,82 +1448,3 @@ func TestUpdateTaskProcessWhenWatchDmChannel(t *testing.T) { err = removeAllSession() assert.Nil(t, err) } - -func startMockCoord(ctx context.Context) (*QueryCoord, error) { - factory := dependency.NewDefaultFactory(true) - - coord, err := NewQueryCoordTest(ctx, factory) - if err != nil { - return nil, err - } - - rootCoord := newRootCoordMock(ctx) - rootCoord.createCollection(defaultCollectionID) - rootCoord.createPartition(defaultCollectionID, defaultPartitionID) - - dataCoord := &dataCoordMock{ - collections: make([]UniqueID, 0), - col2DmChannels: make(map[UniqueID][]*datapb.VchannelInfo), - partitionID2Segment: make(map[UniqueID][]UniqueID), - Segment2Binlog: make(map[UniqueID]*datapb.SegmentBinlogs), - baseSegmentID: defaultSegmentID, - channelNumPerCol: defaultChannelNum, - segmentState: commonpb.SegmentState_Flushed, - errLevel: 1, - segmentRefCount: make(map[int64]int), - } - indexCoord, err := newIndexCoordMock(queryCoordTestDir) - if err != nil { - return nil, err - } - - coord.SetRootCoord(rootCoord) - coord.SetDataCoord(dataCoord) - coord.SetIndexCoord(indexCoord) - etcd, err := etcd.GetEtcdClient(&Params.EtcdCfg) - if err != nil { - return nil, err - } - coord.SetEtcdClient(etcd) - err = coord.Init() - if err != nil { - return nil, err - } - err = coord.Start() - if err != nil { - return nil, err - } - err = coord.Register() - if err != nil { - return nil, err - } - return coord, nil -} - -func Test_LoadSegment(t *testing.T) { - refreshParams() - ctx := context.Background() - queryCoord, err := startMockCoord(ctx) - assert.Nil(t, err) - - node1, err := startQueryNodeServer(ctx) - assert.Nil(t, err) - - waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) - - loadSegmentTask := genLoadSegmentTask(ctx, queryCoord, node1.queryNodeID) - - loadCollectionTask := loadSegmentTask.parentTask - queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask) - - // 1. Acquire segment reference lock failed, and reschedule task. - // 2. Acquire segment reference lock successfully, but release reference lock failed, and retry release the lock. - // 3. Release segment reference lock successfully, and task done. - waitTaskFinalState(loadSegmentTask, taskDone) - - err = queryCoord.Stop() - assert.Nil(t, err) - - err = removeAllSession() - assert.Nil(t, err) -} diff --git a/internal/util/paramtable/component_param.go b/internal/util/paramtable/component_param.go index 52506b7288..94f42b6134 100644 --- a/internal/util/paramtable/component_param.go +++ b/internal/util/paramtable/component_param.go @@ -578,6 +578,10 @@ type queryCoordConfig struct { CreatedTime time.Time UpdatedTime time.Time + //---- Task --- + RetryNum int32 + RetryInterval int64 + //---- Handoff --- AutoHandoff bool @@ -591,6 +595,11 @@ type queryCoordConfig struct { func (p *queryCoordConfig) init(base *BaseTable) { p.Base = base p.NodeID.Store(UniqueID(0)) + + //---- Task --- + p.initTaskRetryNum() + p.initTaskRetryInterval() + //---- Handoff --- p.initAutoHandoff() @@ -601,6 +610,14 @@ func (p *queryCoordConfig) init(base *BaseTable) { p.initMemoryUsageMaxDifferencePercentage() } +func (p *queryCoordConfig) initTaskRetryNum() { + p.RetryNum = p.Base.ParseInt32WithDefault("queryCoord.task.retrynum", 5) +} + +func (p *queryCoordConfig) initTaskRetryInterval() { + p.RetryInterval = p.Base.ParseInt64WithDefault("queryCoord.task.retryinterval", int64(10*time.Second)) +} + func (p *queryCoordConfig) initAutoHandoff() { handoff, err := p.Base.Load("queryCoord.autoHandoff") if err != nil {