From af5c01082bbd9c41d8515e252e5ba0231bae420d Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 7 Sep 2023 10:11:15 +0800 Subject: [PATCH] Refine delegator lifetime control (#26881) - Add SafeChan interface in lifetime package - Embed SafeChan into interface - Replace private lifetime struct in delegator package with - Refine delegator on-going task lifetime control and wait all accepted task done - Fix potential goroutine leakage from if delegator closed concurrently /kind improvement Signed-off-by: Congqi Xia --- internal/querynodev2/delegator/delegator.go | 69 +++++++++---------- .../querynodev2/delegator/delegator_test.go | 21 +++--- pkg/util/lifetime/lifetime.go | 5 +- pkg/util/lifetime/safe_chan.go | 45 ++++++++++++ pkg/util/lifetime/safe_chan_test.go | 37 ++++++++++ pkg/util/typeutil/chan.go | 12 ++++ 6 files changed, 141 insertions(+), 48 deletions(-) create mode 100644 pkg/util/lifetime/safe_chan.go create mode 100644 pkg/util/lifetime/safe_chan_test.go create mode 100644 pkg/util/typeutil/chan.go diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 3c4300d656..c0c532efa3 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -41,38 +41,13 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" ) -type lifetime struct { - state atomic.Int32 - closeCh chan struct{} - closeOnce sync.Once -} - -func (lt *lifetime) SetState(state int32) { - lt.state.Store(state) -} - -func (lt *lifetime) GetState() int32 { - return lt.state.Load() -} - -func (lt *lifetime) Close() { - lt.closeOnce.Do(func() { - close(lt.closeCh) - }) -} - -func newLifetime() *lifetime { - return &lifetime{ - closeCh: make(chan struct{}), - } -} - // ShardDelegator is the interface definition. type ShardDelegator interface { Collection() int64 @@ -106,6 +81,14 @@ const ( stopped ) +func notStopped(state int32) bool { + return state != stopped +} + +func isWorking(state int32) bool { + return state == working +} + // shardDelegator maintains the shard distribution and streaming part of the data. type shardDelegator struct { // shard information attributes @@ -118,7 +101,7 @@ type shardDelegator struct { workerManager cluster.Manager - lifetime *lifetime + lifetime lifetime.Lifetime[int32] distribution *distribution segmentManager segments.SegmentManager @@ -131,7 +114,6 @@ type shardDelegator struct { factory msgstream.Factory loader segments.Loader - wg sync.WaitGroup tsCond *sync.Cond latestTsafe *atomic.Uint64 } @@ -206,9 +188,10 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu // Search preforms search operation on shard. func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) { log := sd.getLogger(ctx) - if !sd.Serviceable() { + if !sd.lifetime.Add(isWorking) { return nil, errors.New("delegator is not serviceable") } + defer sd.lifetime.Done() if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { log.Warn("deletgator received search request not belongs to it", @@ -271,9 +254,10 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest // Query performs query operation on shard. func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) { log := sd.getLogger(ctx) - if !sd.Serviceable() { + if !sd.lifetime.Add(isWorking) { return nil, errors.New("delegator is not serviceable") } + defer sd.lifetime.Done() if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { log.Warn("delegator received query request not belongs to it", @@ -335,9 +319,10 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) // GetStatistics returns statistics aggregated by delegator. func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) { log := sd.getLogger(ctx) - if !sd.Serviceable() { + if !sd.lifetime.Add(isWorking) { return nil, errors.New("delegator is not serviceable") } + defer sd.lifetime.Done() if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { log.Warn("deletgator received query request not belongs to it", @@ -510,7 +495,9 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { sd.tsCond.L.Lock() defer sd.tsCond.L.Unlock() - for sd.latestTsafe.Load() < ts && ctx.Err() == nil { + for sd.latestTsafe.Load() < ts && + ctx.Err() == nil && + sd.Serviceable() { sd.tsCond.Wait() } close(ch) @@ -524,6 +511,9 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { sd.tsCond.Broadcast() return ctx.Err() case <-ch: + if !sd.Serviceable() { + return merr.WrapErrChannelNotAvailable(sd.vchannelName, "delegator closed during wait tsafe") + } return nil } } @@ -531,7 +521,7 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { // watchTSafe is the worker function to update serviceable timestamp. func (sd *shardDelegator) watchTSafe() { - defer sd.wg.Done() + defer sd.lifetime.Done() listener := sd.tsafeManager.WatchChannel(sd.vchannelName) sd.updateTSafe() log := sd.getLogger(context.Background()) @@ -544,7 +534,7 @@ func (sd *shardDelegator) watchTSafe() { return } sd.updateTSafe() - case <-sd.lifetime.closeCh: + case <-sd.lifetime.CloseCh(): log.Info("updateTSafe quit") // shard delegator closed return @@ -570,7 +560,9 @@ func (sd *shardDelegator) updateTSafe() { func (sd *shardDelegator) Close() { sd.lifetime.SetState(stopped) sd.lifetime.Close() - sd.wg.Wait() + // broadcast to all waitTsafe goroutine to quit + sd.tsCond.Broadcast() + sd.lifetime.Wait() } // NewShardDelegator creates a new ShardDelegator instance with all fields initialized. @@ -600,7 +592,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string collection: collection, segmentManager: manager.Segment, workerManager: workerManager, - lifetime: newLifetime(), + lifetime: lifetime.NewLifetime(initializing), distribution: NewDistribution(), deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer), pkOracle: pkoracle.NewPkOracle(), @@ -611,8 +603,9 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string } m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) - sd.wg.Add(1) - go sd.watchTSafe() + if sd.lifetime.Add(notStopped) { + go sd.watchTSafe() + } log.Info("finish build new shardDelegator") return sd, nil } diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index d3de8c5468..33ef077e61 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -41,6 +41,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -877,15 +878,16 @@ func TestDelegatorWatchTsafe(t *testing.T) { sd := &shardDelegator{ tsafeManager: tsafeManager, vchannelName: channelName, - lifetime: newLifetime(), + lifetime: lifetime.NewLifetime(initializing), latestTsafe: atomic.NewUint64(0), } defer sd.Close() m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) - sd.wg.Add(1) - go sd.watchTSafe() + if sd.lifetime.Add(notStopped) { + go sd.watchTSafe() + } err := tsafeManager.Set(channelName, 200) require.NoError(t, err) @@ -903,19 +905,20 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) { sd := &shardDelegator{ tsafeManager: tsafeManager, vchannelName: channelName, - lifetime: newLifetime(), + lifetime: lifetime.NewLifetime(initializing), latestTsafe: atomic.NewUint64(0), } defer sd.Close() m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) - sd.wg.Add(1) signal := make(chan struct{}) - go func() { - sd.watchTSafe() - close(signal) - }() + if sd.lifetime.Add(notStopped) { + go func() { + sd.watchTSafe() + close(signal) + }() + } select { case <-signal: diff --git a/pkg/util/lifetime/lifetime.go b/pkg/util/lifetime/lifetime.go index 80f8db8e34..1154888201 100644 --- a/pkg/util/lifetime/lifetime.go +++ b/pkg/util/lifetime/lifetime.go @@ -23,6 +23,7 @@ import ( // Lifetime interface for lifetime control. type Lifetime[T any] interface { + SafeChan // SetState is the method to change lifetime state. SetState(state T) // GetState returns current state. @@ -43,13 +44,15 @@ var _ Lifetime[any] = (*lifetime[any])(nil) // NewLifetime returns a new instance of Lifetime with init state and isHealthy logic. func NewLifetime[T any](initState T) Lifetime[T] { return &lifetime[T]{ - state: initState, + safeChan: newSafeChan(), + state: initState, } } // lifetime implementation of Lifetime. // users shall not care about the internal fields of this struct. type lifetime[T any] struct { + *safeChan // wg is used for keeping record each running task. wg sync.WaitGroup // state is the "atomic" value to store component state. diff --git a/pkg/util/lifetime/safe_chan.go b/pkg/util/lifetime/safe_chan.go new file mode 100644 index 0000000000..ac877c215f --- /dev/null +++ b/pkg/util/lifetime/safe_chan.go @@ -0,0 +1,45 @@ +package lifetime + +import ( + "sync" + + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// SafeChan is the utility type combining chan struct{} & sync.Once. +// It provides double close protection internally. +type SafeChan interface { + IsClosed() bool + CloseCh() <-chan struct{} + Close() +} + +type safeChan struct { + closed chan struct{} + once sync.Once +} + +// NewSafeChan returns a SafeChan with internal channel initialized +func NewSafeChan() SafeChan { + return newSafeChan() +} + +func newSafeChan() *safeChan { + return &safeChan{ + closed: make(chan struct{}), + } +} + +func (sc *safeChan) CloseCh() <-chan struct{} { + return sc.closed +} + +func (sc *safeChan) IsClosed() bool { + return typeutil.IsChanClosed(sc.closed) +} + +func (sc *safeChan) Close() { + sc.once.Do(func() { + close(sc.closed) + }) +} diff --git a/pkg/util/lifetime/safe_chan_test.go b/pkg/util/lifetime/safe_chan_test.go new file mode 100644 index 0000000000..98ddce20b0 --- /dev/null +++ b/pkg/util/lifetime/safe_chan_test.go @@ -0,0 +1,37 @@ +package lifetime + +import ( + "testing" + + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/stretchr/testify/suite" +) + +type SafeChanSuite struct { + suite.Suite +} + +func (s *SafeChanSuite) TestClose() { + sc := NewSafeChan() + + s.False(sc.IsClosed(), "IsClosed() shall return false before Close()") + s.False(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall not be closed before Close()") + + s.NotPanics(func() { + sc.Close() + }, "SafeChan shall not panic during first close") + + s.True(sc.IsClosed(), "IsClosed() shall return true after Close()") + s.True(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall be closed after Close()") + + s.NotPanics(func() { + sc.Close() + }, "SafeChan shall not panic during second close") + + s.True(sc.IsClosed(), "IsClosed() shall return true after double Close()") + s.True(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall be still closed after double Close()") +} + +func TestSafeChan(t *testing.T) { + suite.Run(t, new(SafeChanSuite)) +} diff --git a/pkg/util/typeutil/chan.go b/pkg/util/typeutil/chan.go new file mode 100644 index 0000000000..1b33626d36 --- /dev/null +++ b/pkg/util/typeutil/chan.go @@ -0,0 +1,12 @@ +package typeutil + +// IsChanClosed returns whether input signal channel is closed or not. +// this method accept `chan struct{}` type only in case of passing msg channels by mistake. +func IsChanClosed(ch <-chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +}