mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
Refine delegator lifetime control (#26881)
- Add SafeChan interface in lifetime package - Embed SafeChan into interface - Replace private lifetime struct in delegator package with - Refine delegator on-going task lifetime control and wait all accepted task done - Fix potential goroutine leakage from if delegator closed concurrently /kind improvement Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
69bac68f8c
commit
af5c01082b
@ -41,38 +41,13 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/lifetime"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
)
|
||||
|
||||
type lifetime struct {
|
||||
state atomic.Int32
|
||||
closeCh chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (lt *lifetime) SetState(state int32) {
|
||||
lt.state.Store(state)
|
||||
}
|
||||
|
||||
func (lt *lifetime) GetState() int32 {
|
||||
return lt.state.Load()
|
||||
}
|
||||
|
||||
func (lt *lifetime) Close() {
|
||||
lt.closeOnce.Do(func() {
|
||||
close(lt.closeCh)
|
||||
})
|
||||
}
|
||||
|
||||
func newLifetime() *lifetime {
|
||||
return &lifetime{
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ShardDelegator is the interface definition.
|
||||
type ShardDelegator interface {
|
||||
Collection() int64
|
||||
@ -106,6 +81,14 @@ const (
|
||||
stopped
|
||||
)
|
||||
|
||||
func notStopped(state int32) bool {
|
||||
return state != stopped
|
||||
}
|
||||
|
||||
func isWorking(state int32) bool {
|
||||
return state == working
|
||||
}
|
||||
|
||||
// shardDelegator maintains the shard distribution and streaming part of the data.
|
||||
type shardDelegator struct {
|
||||
// shard information attributes
|
||||
@ -118,7 +101,7 @@ type shardDelegator struct {
|
||||
|
||||
workerManager cluster.Manager
|
||||
|
||||
lifetime *lifetime
|
||||
lifetime lifetime.Lifetime[int32]
|
||||
|
||||
distribution *distribution
|
||||
segmentManager segments.SegmentManager
|
||||
@ -131,7 +114,6 @@ type shardDelegator struct {
|
||||
factory msgstream.Factory
|
||||
|
||||
loader segments.Loader
|
||||
wg sync.WaitGroup
|
||||
tsCond *sync.Cond
|
||||
latestTsafe *atomic.Uint64
|
||||
}
|
||||
@ -206,9 +188,10 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
|
||||
// Search preforms search operation on shard.
|
||||
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if !sd.Serviceable() {
|
||||
if !sd.lifetime.Add(isWorking) {
|
||||
return nil, errors.New("delegator is not serviceable")
|
||||
}
|
||||
defer sd.lifetime.Done()
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("deletgator received search request not belongs to it",
|
||||
@ -271,9 +254,10 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
||||
// Query performs query operation on shard.
|
||||
func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if !sd.Serviceable() {
|
||||
if !sd.lifetime.Add(isWorking) {
|
||||
return nil, errors.New("delegator is not serviceable")
|
||||
}
|
||||
defer sd.lifetime.Done()
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("delegator received query request not belongs to it",
|
||||
@ -335,9 +319,10 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
|
||||
// GetStatistics returns statistics aggregated by delegator.
|
||||
func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if !sd.Serviceable() {
|
||||
if !sd.lifetime.Add(isWorking) {
|
||||
return nil, errors.New("delegator is not serviceable")
|
||||
}
|
||||
defer sd.lifetime.Done()
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("deletgator received query request not belongs to it",
|
||||
@ -510,7 +495,9 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error {
|
||||
sd.tsCond.L.Lock()
|
||||
defer sd.tsCond.L.Unlock()
|
||||
|
||||
for sd.latestTsafe.Load() < ts && ctx.Err() == nil {
|
||||
for sd.latestTsafe.Load() < ts &&
|
||||
ctx.Err() == nil &&
|
||||
sd.Serviceable() {
|
||||
sd.tsCond.Wait()
|
||||
}
|
||||
close(ch)
|
||||
@ -524,6 +511,9 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error {
|
||||
sd.tsCond.Broadcast()
|
||||
return ctx.Err()
|
||||
case <-ch:
|
||||
if !sd.Serviceable() {
|
||||
return merr.WrapErrChannelNotAvailable(sd.vchannelName, "delegator closed during wait tsafe")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -531,7 +521,7 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error {
|
||||
|
||||
// watchTSafe is the worker function to update serviceable timestamp.
|
||||
func (sd *shardDelegator) watchTSafe() {
|
||||
defer sd.wg.Done()
|
||||
defer sd.lifetime.Done()
|
||||
listener := sd.tsafeManager.WatchChannel(sd.vchannelName)
|
||||
sd.updateTSafe()
|
||||
log := sd.getLogger(context.Background())
|
||||
@ -544,7 +534,7 @@ func (sd *shardDelegator) watchTSafe() {
|
||||
return
|
||||
}
|
||||
sd.updateTSafe()
|
||||
case <-sd.lifetime.closeCh:
|
||||
case <-sd.lifetime.CloseCh():
|
||||
log.Info("updateTSafe quit")
|
||||
// shard delegator closed
|
||||
return
|
||||
@ -570,7 +560,9 @@ func (sd *shardDelegator) updateTSafe() {
|
||||
func (sd *shardDelegator) Close() {
|
||||
sd.lifetime.SetState(stopped)
|
||||
sd.lifetime.Close()
|
||||
sd.wg.Wait()
|
||||
// broadcast to all waitTsafe goroutine to quit
|
||||
sd.tsCond.Broadcast()
|
||||
sd.lifetime.Wait()
|
||||
}
|
||||
|
||||
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
|
||||
@ -600,7 +592,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
|
||||
collection: collection,
|
||||
segmentManager: manager.Segment,
|
||||
workerManager: workerManager,
|
||||
lifetime: newLifetime(),
|
||||
lifetime: lifetime.NewLifetime(initializing),
|
||||
distribution: NewDistribution(),
|
||||
deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer),
|
||||
pkOracle: pkoracle.NewPkOracle(),
|
||||
@ -611,8 +603,9 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
|
||||
}
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
sd.wg.Add(1)
|
||||
go sd.watchTSafe()
|
||||
if sd.lifetime.Add(notStopped) {
|
||||
go sd.watchTSafe()
|
||||
}
|
||||
log.Info("finish build new shardDelegator")
|
||||
return sd, nil
|
||||
}
|
||||
|
@ -41,6 +41,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/lifetime"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
@ -877,15 +878,16 @@ func TestDelegatorWatchTsafe(t *testing.T) {
|
||||
sd := &shardDelegator{
|
||||
tsafeManager: tsafeManager,
|
||||
vchannelName: channelName,
|
||||
lifetime: newLifetime(),
|
||||
lifetime: lifetime.NewLifetime(initializing),
|
||||
latestTsafe: atomic.NewUint64(0),
|
||||
}
|
||||
defer sd.Close()
|
||||
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
sd.wg.Add(1)
|
||||
go sd.watchTSafe()
|
||||
if sd.lifetime.Add(notStopped) {
|
||||
go sd.watchTSafe()
|
||||
}
|
||||
|
||||
err := tsafeManager.Set(channelName, 200)
|
||||
require.NoError(t, err)
|
||||
@ -903,19 +905,20 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) {
|
||||
sd := &shardDelegator{
|
||||
tsafeManager: tsafeManager,
|
||||
vchannelName: channelName,
|
||||
lifetime: newLifetime(),
|
||||
lifetime: lifetime.NewLifetime(initializing),
|
||||
latestTsafe: atomic.NewUint64(0),
|
||||
}
|
||||
defer sd.Close()
|
||||
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
sd.wg.Add(1)
|
||||
signal := make(chan struct{})
|
||||
go func() {
|
||||
sd.watchTSafe()
|
||||
close(signal)
|
||||
}()
|
||||
if sd.lifetime.Add(notStopped) {
|
||||
go func() {
|
||||
sd.watchTSafe()
|
||||
close(signal)
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-signal:
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
|
||||
// Lifetime interface for lifetime control.
|
||||
type Lifetime[T any] interface {
|
||||
SafeChan
|
||||
// SetState is the method to change lifetime state.
|
||||
SetState(state T)
|
||||
// GetState returns current state.
|
||||
@ -43,13 +44,15 @@ var _ Lifetime[any] = (*lifetime[any])(nil)
|
||||
// NewLifetime returns a new instance of Lifetime with init state and isHealthy logic.
|
||||
func NewLifetime[T any](initState T) Lifetime[T] {
|
||||
return &lifetime[T]{
|
||||
state: initState,
|
||||
safeChan: newSafeChan(),
|
||||
state: initState,
|
||||
}
|
||||
}
|
||||
|
||||
// lifetime implementation of Lifetime.
|
||||
// users shall not care about the internal fields of this struct.
|
||||
type lifetime[T any] struct {
|
||||
*safeChan
|
||||
// wg is used for keeping record each running task.
|
||||
wg sync.WaitGroup
|
||||
// state is the "atomic" value to store component state.
|
||||
|
45
pkg/util/lifetime/safe_chan.go
Normal file
45
pkg/util/lifetime/safe_chan.go
Normal file
@ -0,0 +1,45 @@
|
||||
package lifetime
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// SafeChan is the utility type combining chan struct{} & sync.Once.
|
||||
// It provides double close protection internally.
|
||||
type SafeChan interface {
|
||||
IsClosed() bool
|
||||
CloseCh() <-chan struct{}
|
||||
Close()
|
||||
}
|
||||
|
||||
type safeChan struct {
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// NewSafeChan returns a SafeChan with internal channel initialized
|
||||
func NewSafeChan() SafeChan {
|
||||
return newSafeChan()
|
||||
}
|
||||
|
||||
func newSafeChan() *safeChan {
|
||||
return &safeChan{
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *safeChan) CloseCh() <-chan struct{} {
|
||||
return sc.closed
|
||||
}
|
||||
|
||||
func (sc *safeChan) IsClosed() bool {
|
||||
return typeutil.IsChanClosed(sc.closed)
|
||||
}
|
||||
|
||||
func (sc *safeChan) Close() {
|
||||
sc.once.Do(func() {
|
||||
close(sc.closed)
|
||||
})
|
||||
}
|
37
pkg/util/lifetime/safe_chan_test.go
Normal file
37
pkg/util/lifetime/safe_chan_test.go
Normal file
@ -0,0 +1,37 @@
|
||||
package lifetime
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SafeChanSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *SafeChanSuite) TestClose() {
|
||||
sc := NewSafeChan()
|
||||
|
||||
s.False(sc.IsClosed(), "IsClosed() shall return false before Close()")
|
||||
s.False(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall not be closed before Close()")
|
||||
|
||||
s.NotPanics(func() {
|
||||
sc.Close()
|
||||
}, "SafeChan shall not panic during first close")
|
||||
|
||||
s.True(sc.IsClosed(), "IsClosed() shall return true after Close()")
|
||||
s.True(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall be closed after Close()")
|
||||
|
||||
s.NotPanics(func() {
|
||||
sc.Close()
|
||||
}, "SafeChan shall not panic during second close")
|
||||
|
||||
s.True(sc.IsClosed(), "IsClosed() shall return true after double Close()")
|
||||
s.True(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall be still closed after double Close()")
|
||||
}
|
||||
|
||||
func TestSafeChan(t *testing.T) {
|
||||
suite.Run(t, new(SafeChanSuite))
|
||||
}
|
12
pkg/util/typeutil/chan.go
Normal file
12
pkg/util/typeutil/chan.go
Normal file
@ -0,0 +1,12 @@
|
||||
package typeutil
|
||||
|
||||
// IsChanClosed returns whether input signal channel is closed or not.
|
||||
// this method accept `chan struct{}` type only in case of passing msg channels by mistake.
|
||||
func IsChanClosed(ch <-chan struct{}) bool {
|
||||
select {
|
||||
case <-ch:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user