Refine delegator lifetime control (#26881)

- Add SafeChan interface in lifetime package
- Embed SafeChan into  interface
- Replace private lifetime struct in delegator package with
- Refine delegator on-going task lifetime control and wait all accepted task done
- Fix potential goroutine leakage from  if delegator closed concurrently

/kind improvement

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2023-09-07 10:11:15 +08:00 committed by GitHub
parent 69bac68f8c
commit af5c01082b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 141 additions and 48 deletions

View File

@ -41,38 +41,13 @@ import (
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
)
type lifetime struct {
state atomic.Int32
closeCh chan struct{}
closeOnce sync.Once
}
func (lt *lifetime) SetState(state int32) {
lt.state.Store(state)
}
func (lt *lifetime) GetState() int32 {
return lt.state.Load()
}
func (lt *lifetime) Close() {
lt.closeOnce.Do(func() {
close(lt.closeCh)
})
}
func newLifetime() *lifetime {
return &lifetime{
closeCh: make(chan struct{}),
}
}
// ShardDelegator is the interface definition.
type ShardDelegator interface {
Collection() int64
@ -106,6 +81,14 @@ const (
stopped
)
func notStopped(state int32) bool {
return state != stopped
}
func isWorking(state int32) bool {
return state == working
}
// shardDelegator maintains the shard distribution and streaming part of the data.
type shardDelegator struct {
// shard information attributes
@ -118,7 +101,7 @@ type shardDelegator struct {
workerManager cluster.Manager
lifetime *lifetime
lifetime lifetime.Lifetime[int32]
distribution *distribution
segmentManager segments.SegmentManager
@ -131,7 +114,6 @@ type shardDelegator struct {
factory msgstream.Factory
loader segments.Loader
wg sync.WaitGroup
tsCond *sync.Cond
latestTsafe *atomic.Uint64
}
@ -206,9 +188,10 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
if !sd.Serviceable() {
if !sd.lifetime.Add(isWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
log.Warn("deletgator received search request not belongs to it",
@ -271,9 +254,10 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
// Query performs query operation on shard.
func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
log := sd.getLogger(ctx)
if !sd.Serviceable() {
if !sd.lifetime.Add(isWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
log.Warn("delegator received query request not belongs to it",
@ -335,9 +319,10 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
// GetStatistics returns statistics aggregated by delegator.
func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
log := sd.getLogger(ctx)
if !sd.Serviceable() {
if !sd.lifetime.Add(isWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
log.Warn("deletgator received query request not belongs to it",
@ -510,7 +495,9 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error {
sd.tsCond.L.Lock()
defer sd.tsCond.L.Unlock()
for sd.latestTsafe.Load() < ts && ctx.Err() == nil {
for sd.latestTsafe.Load() < ts &&
ctx.Err() == nil &&
sd.Serviceable() {
sd.tsCond.Wait()
}
close(ch)
@ -524,6 +511,9 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error {
sd.tsCond.Broadcast()
return ctx.Err()
case <-ch:
if !sd.Serviceable() {
return merr.WrapErrChannelNotAvailable(sd.vchannelName, "delegator closed during wait tsafe")
}
return nil
}
}
@ -531,7 +521,7 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error {
// watchTSafe is the worker function to update serviceable timestamp.
func (sd *shardDelegator) watchTSafe() {
defer sd.wg.Done()
defer sd.lifetime.Done()
listener := sd.tsafeManager.WatchChannel(sd.vchannelName)
sd.updateTSafe()
log := sd.getLogger(context.Background())
@ -544,7 +534,7 @@ func (sd *shardDelegator) watchTSafe() {
return
}
sd.updateTSafe()
case <-sd.lifetime.closeCh:
case <-sd.lifetime.CloseCh():
log.Info("updateTSafe quit")
// shard delegator closed
return
@ -570,7 +560,9 @@ func (sd *shardDelegator) updateTSafe() {
func (sd *shardDelegator) Close() {
sd.lifetime.SetState(stopped)
sd.lifetime.Close()
sd.wg.Wait()
// broadcast to all waitTsafe goroutine to quit
sd.tsCond.Broadcast()
sd.lifetime.Wait()
}
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
@ -600,7 +592,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
collection: collection,
segmentManager: manager.Segment,
workerManager: workerManager,
lifetime: newLifetime(),
lifetime: lifetime.NewLifetime(initializing),
distribution: NewDistribution(),
deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer),
pkOracle: pkoracle.NewPkOracle(),
@ -611,8 +603,9 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
}
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
sd.wg.Add(1)
go sd.watchTSafe()
if sd.lifetime.Add(notStopped) {
go sd.watchTSafe()
}
log.Info("finish build new shardDelegator")
return sd, nil
}

View File

@ -41,6 +41,7 @@ import (
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -877,15 +878,16 @@ func TestDelegatorWatchTsafe(t *testing.T) {
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: newLifetime(),
lifetime: lifetime.NewLifetime(initializing),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
sd.wg.Add(1)
go sd.watchTSafe()
if sd.lifetime.Add(notStopped) {
go sd.watchTSafe()
}
err := tsafeManager.Set(channelName, 200)
require.NoError(t, err)
@ -903,19 +905,20 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) {
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: newLifetime(),
lifetime: lifetime.NewLifetime(initializing),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
sd.wg.Add(1)
signal := make(chan struct{})
go func() {
sd.watchTSafe()
close(signal)
}()
if sd.lifetime.Add(notStopped) {
go func() {
sd.watchTSafe()
close(signal)
}()
}
select {
case <-signal:

View File

@ -23,6 +23,7 @@ import (
// Lifetime interface for lifetime control.
type Lifetime[T any] interface {
SafeChan
// SetState is the method to change lifetime state.
SetState(state T)
// GetState returns current state.
@ -43,13 +44,15 @@ var _ Lifetime[any] = (*lifetime[any])(nil)
// NewLifetime returns a new instance of Lifetime with init state and isHealthy logic.
func NewLifetime[T any](initState T) Lifetime[T] {
return &lifetime[T]{
state: initState,
safeChan: newSafeChan(),
state: initState,
}
}
// lifetime implementation of Lifetime.
// users shall not care about the internal fields of this struct.
type lifetime[T any] struct {
*safeChan
// wg is used for keeping record each running task.
wg sync.WaitGroup
// state is the "atomic" value to store component state.

View File

@ -0,0 +1,45 @@
package lifetime
import (
"sync"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// SafeChan is the utility type combining chan struct{} & sync.Once.
// It provides double close protection internally.
type SafeChan interface {
IsClosed() bool
CloseCh() <-chan struct{}
Close()
}
type safeChan struct {
closed chan struct{}
once sync.Once
}
// NewSafeChan returns a SafeChan with internal channel initialized
func NewSafeChan() SafeChan {
return newSafeChan()
}
func newSafeChan() *safeChan {
return &safeChan{
closed: make(chan struct{}),
}
}
func (sc *safeChan) CloseCh() <-chan struct{} {
return sc.closed
}
func (sc *safeChan) IsClosed() bool {
return typeutil.IsChanClosed(sc.closed)
}
func (sc *safeChan) Close() {
sc.once.Do(func() {
close(sc.closed)
})
}

View File

@ -0,0 +1,37 @@
package lifetime
import (
"testing"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/stretchr/testify/suite"
)
type SafeChanSuite struct {
suite.Suite
}
func (s *SafeChanSuite) TestClose() {
sc := NewSafeChan()
s.False(sc.IsClosed(), "IsClosed() shall return false before Close()")
s.False(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall not be closed before Close()")
s.NotPanics(func() {
sc.Close()
}, "SafeChan shall not panic during first close")
s.True(sc.IsClosed(), "IsClosed() shall return true after Close()")
s.True(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall be closed after Close()")
s.NotPanics(func() {
sc.Close()
}, "SafeChan shall not panic during second close")
s.True(sc.IsClosed(), "IsClosed() shall return true after double Close()")
s.True(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall be still closed after double Close()")
}
func TestSafeChan(t *testing.T) {
suite.Run(t, new(SafeChanSuite))
}

12
pkg/util/typeutil/chan.go Normal file
View File

@ -0,0 +1,12 @@
package typeutil
// IsChanClosed returns whether input signal channel is closed or not.
// this method accept `chan struct{}` type only in case of passing msg channels by mistake.
func IsChanClosed(ch <-chan struct{}) bool {
select {
case <-ch:
return true
default:
return false
}
}