diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index d8d1730b5d..b7e890a6b1 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -406,6 +406,7 @@ func (s *Server) initServiceDiscovery() error { s.cluster.Startup(datanodes) + // TODO implement rewatch logic s.eventCh = s.session.WatchServices(typeutil.DataNodeRole, rev+1, nil) return nil } @@ -607,7 +608,13 @@ func (s *Server) watchService(ctx context.Context) { return case event, ok := <-s.eventCh: if !ok { - //TODO add retry logic + // ErrCompacted in handled inside SessionWatcher + // So there is some other error occurred, closing DataCoord server + logutil.Logger(s.ctx).Error("watch service channel closed", zap.Int64("serverID", s.session.ServerID)) + go s.Stop() + if s.session.TriggerKill { + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + } return } if err := s.handleSessionEvent(ctx, event); err != nil { @@ -620,7 +627,6 @@ func (s *Server) watchService(ctx context.Context) { } } } - } // handles session events - DataNodes Add/Del diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 563150bc94..1e942389cc 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -22,9 +22,11 @@ import ( "fmt" "math/rand" "os" + "os/signal" "path" "strconv" "sync/atomic" + "syscall" "testing" "time" @@ -611,25 +613,40 @@ func TestGetFlushedSegments(t *testing.T) { } func TestService_WatchServices(t *testing.T) { + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT) + defer signal.Reset(syscall.SIGINT) factory := msgstream.NewPmsFactory() svr := CreateServer(context.TODO(), factory) + svr.session = &sessionutil.Session{ + TriggerKill: true, + } svr.serverLoopWg.Add(1) ech := make(chan *sessionutil.SessionEvent) svr.eventCh = ech flag := false - signal := make(chan struct{}, 1) + closed := false + sigDone := make(chan struct{}, 1) + sigQuit := make(chan struct{}, 1) go func() { svr.watchService(context.Background()) flag = true - signal <- struct{}{} + sigDone <- struct{}{} + }() + go func() { + <-sc + closed = true + sigQuit <- struct{}{} }() close(ech) - <-signal + <-sigDone + <-sigQuit assert.True(t, flag) + assert.True(t, closed) ech = make(chan *sessionutil.SessionEvent) @@ -641,12 +658,12 @@ func TestService_WatchServices(t *testing.T) { go func() { svr.watchService(ctx) flag = true - signal <- struct{}{} + sigDone <- struct{}{} }() ech <- nil cancel() - <-signal + <-sigDone assert.True(t, flag) } diff --git a/internal/indexcoord/index_coord.go b/internal/indexcoord/index_coord.go index 1b64af275d..25d81659af 100644 --- a/internal/indexcoord/index_coord.go +++ b/internal/indexcoord/index_coord.go @@ -200,6 +200,7 @@ func (i *IndexCoord) Init() error { } log.Debug("IndexCoord", zap.Int("IndexNode number", len(i.nodeManager.nodeClients))) + // TODO silverxia add Rewatch logic i.eventChan = i.session.WatchServices(typeutil.IndexNodeRole, revision+1, nil) nodeTasks := i.metaTable.GetNodeTaskStats() for nodeID, taskNum := range nodeTasks { @@ -758,7 +759,12 @@ func (i *IndexCoord) watchNodeLoop() { return case event, ok := <-i.eventChan: if !ok { - //TODO silverxia add retry + // ErrCompacted is handled inside SessionWatcher + log.Error("Session Watcher channel closed", zap.Int64("server id", i.session.ServerID)) + go i.Stop() + if i.session.TriggerKill { + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + } return } log.Debug("IndexCoord watchNodeLoop event updated") diff --git a/internal/indexcoord/index_coord_test.go b/internal/indexcoord/index_coord_test.go index eea740d32d..51331552e2 100644 --- a/internal/indexcoord/index_coord_test.go +++ b/internal/indexcoord/index_coord_test.go @@ -19,7 +19,10 @@ package indexcoord import ( "context" "math/rand" + "os" + "os/signal" "sync" + "syscall" "testing" "time" @@ -227,21 +230,37 @@ func TestIndexCoord_watchNodeLoop(t *testing.T) { loopWg: sync.WaitGroup{}, loopCtx: context.Background(), eventChan: ech, + session: &sessionutil.Session{ + TriggerKill: true, + ServerID: 0, + }, } in.loopWg.Add(1) flag := false - signal := make(chan struct{}, 1) + closed := false + sigDone := make(chan struct{}, 1) + sigQuit := make(chan struct{}, 1) + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT) + defer signal.Reset(syscall.SIGINT) + go func() { in.watchNodeLoop() flag = true - signal <- struct{}{} + sigDone <- struct{}{} + }() + go func() { + <-sc + closed = true + sigQuit <- struct{}{} }() close(ech) - <-signal + <-sigDone + <-sigQuit assert.True(t, flag) - + assert.True(t, closed) } func TestIndexCoord_GetComponentStates(t *testing.T) { diff --git a/internal/querycoord/query_coord.go b/internal/querycoord/query_coord.go index 5dc414489a..ebdad00315 100644 --- a/internal/querycoord/query_coord.go +++ b/internal/querycoord/query_coord.go @@ -367,13 +367,24 @@ func (qc *QueryCoord) watchNodeLoop() { log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask)) } + // TODO silverxia add Rewatch logic qc.eventChan = qc.session.WatchServices(typeutil.QueryNodeRole, qc.cluster.getSessionVersion()+1, nil) + qc.handleNodeEvent(ctx) +} + +func (qc *QueryCoord) handleNodeEvent(ctx context.Context) { for { select { case <-ctx.Done(): return case event, ok := <-qc.eventChan: if !ok { + // ErrCompacted is handled inside SessionWatcher + log.Error("Session Watcher channel closed", zap.Int64("server id", qc.session.ServerID)) + go qc.Stop() + if qc.session.TriggerKill { + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + } return } switch event.EventType { diff --git a/internal/querycoord/query_coord_test.go b/internal/querycoord/query_coord_test.go index 78f451281c..57c5a0901c 100644 --- a/internal/querycoord/query_coord_test.go +++ b/internal/querycoord/query_coord_test.go @@ -22,7 +22,9 @@ import ( "fmt" "math/rand" "os" + "os/signal" "strconv" + "syscall" "testing" "time" @@ -243,6 +245,44 @@ func TestWatchNodeLoop(t *testing.T) { }) } +func TestHandleNodeEventClosed(t *testing.T) { + ech := make(chan *sessionutil.SessionEvent) + qc := &QueryCoord{ + eventChan: ech, + session: &sessionutil.Session{ + TriggerKill: true, + ServerID: 0, + }, + } + flag := false + closed := false + + sigDone := make(chan struct{}, 1) + sigQuit := make(chan struct{}, 1) + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT) + + defer signal.Reset(syscall.SIGINT) + + go func() { + qc.handleNodeEvent(context.Background()) + flag = true + sigDone <- struct{}{} + }() + + go func() { + <-sc + closed = true + sigQuit <- struct{}{} + }() + + close(ech) + <-sigDone + <-sigQuit + assert.True(t, flag) + assert.True(t, closed) +} + func TestHandoffSegmentLoop(t *testing.T) { refreshParams() baseCtx := context.Background() diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 35277199e7..a0ce55dc86 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -217,7 +217,13 @@ func (node *QueryNode) watchService(ctx context.Context) { return case event, ok := <-node.eventCh: if !ok { - //TODO add retry logic + // ErrCompacted is handled inside SessionWatcher + log.Error("Session Watcher channel closed", zap.Int64("server id", node.session.ServerID)) + // need to call stop in separate goroutine + go node.Stop() + if node.session.TriggerKill { + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + } return } if err := node.handleSessionEvent(ctx, event); err != nil { diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 97fff55bdd..bf1767ac9f 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -20,8 +20,10 @@ import ( "context" "math/rand" "os" + "os/signal" "strconv" "sync" + "syscall" "testing" "time" @@ -36,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/etcd" + "github.com/milvus-io/milvus/internal/util/sessionutil" ) // mock of query coordinator client @@ -425,3 +428,80 @@ func TestQueryNode_watchChangeInfo(t *testing.T) { }) wg.Wait() } + +func TestQueryNode_watchService(t *testing.T) { + t.Run("watch channel closed", func(t *testing.T) { + ech := make(chan *sessionutil.SessionEvent) + qn := &QueryNode{ + session: &sessionutil.Session{ + TriggerKill: true, + ServerID: 0, + }, + wg: sync.WaitGroup{}, + eventCh: ech, + queryNodeLoopCancel: func() {}, + } + flag := false + closed := false + + sigDone := make(chan struct{}, 1) + sigQuit := make(chan struct{}, 1) + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT) + + defer signal.Reset(syscall.SIGINT) + + qn.wg.Add(1) + + go func() { + qn.watchService(context.Background()) + flag = true + sigDone <- struct{}{} + }() + go func() { + <-sc + closed = true + sigQuit <- struct{}{} + }() + + close(ech) + <-sigDone + <-sigQuit + assert.True(t, flag) + assert.True(t, closed) + }) + + t.Run("context done", func(t *testing.T) { + ech := make(chan *sessionutil.SessionEvent) + qn := &QueryNode{ + session: &sessionutil.Session{ + TriggerKill: true, + ServerID: 0, + }, + wg: sync.WaitGroup{}, + eventCh: ech, + } + flag := false + + sigDone := make(chan struct{}, 1) + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT) + + defer signal.Reset(syscall.SIGINT) + + qn.wg.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + qn.watchService(ctx) + flag = true + sigDone <- struct{}{} + }() + + assert.False(t, flag) + cancel() + <-sigDone + assert.True(t, flag) + }) +} diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 20759c5231..699894e9b5 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -403,21 +403,18 @@ func (w *sessionWatcher) handleWatchErr(err error) error { return err } - // rewatch is nil, no logic to handle - if w.rewatch == nil { - log.Warn("Watch service with ErrCompacted but no rewatch logic provided") - close(w.eventCh) - return err - } - sessions, revision, err := w.s.GetSessions(w.prefix) if err != nil { log.Warn("GetSession before rewatch failed", zap.String("prefix", w.prefix), zap.Error(err)) close(w.eventCh) return err } - - err = w.rewatch(sessions) + // rewatch is nil, no logic to handle + if w.rewatch == nil { + log.Warn("Watch service with ErrCompacted but no rewatch logic provided") + } else { + err = w.rewatch(sessions) + } if err != nil { log.Warn("WatchServices rewatch failed", zap.String("prefix", w.prefix), zap.Error(err)) close(w.eventCh) diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index a3fe708d33..f545171791 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -18,7 +18,6 @@ import ( "github.com/stretchr/testify/require" "go.etcd.io/etcd/api/v3/mvccpb" - v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -303,8 +302,7 @@ func TestWatcherHandleWatchResp(t *testing.T) { CompactRevision: 1, } err := w.handleWatchResponse(wresp) - assert.Error(t, err) - assert.Equal(t, v3rpc.ErrCompacted, err) + assert.NoError(t, err) }) t.Run("err compacted resp, valid Rewatch", func(t *testing.T) { @@ -327,6 +325,19 @@ func TestWatcherHandleWatchResp(t *testing.T) { assert.Error(t, err) }) + t.Run("err handled but rewatch failed", func(t *testing.T) { + w := getWatcher(s, func(sessions map[string]*Session) error { + return errors.New("mocked") + }) + wresp := clientv3.WatchResponse{ + CompactRevision: 1, + } + err := w.handleWatchResponse(wresp) + t.Log(err.Error()) + + assert.Error(t, err) + }) + t.Run("err handled but list failed", func(t *testing.T) { s := NewSession(ctx, "/by-dev/session-ut", etcdCli) s.etcdCli.Close() @@ -341,17 +352,6 @@ func TestWatcherHandleWatchResp(t *testing.T) { assert.Error(t, err) }) - t.Run("err handled but rewatch failed", func(t *testing.T) { - w := getWatcher(s, func(sessions map[string]*Session) error { - return errors.New("mocked") - }) - wresp := clientv3.WatchResponse{ - CompactRevision: 1, - } - err := w.handleWatchResponse(wresp) - - assert.Error(t, err) - }) } func TestSessionRevoke(t *testing.T) {