enhance: Decouple shard client manager from shard cache (#37371)

issue: #37115
the old implementation update shard cache and shard client manager at
same time, which causes lots of conor case due to concurrent issue
without lock.

This PR decouple shard client manager from shard cache, so only shard
cache will be updated if delegator changes. and make sure shard client
manager will always return the right client, and create a new client if
not exist. in case of client leak, shard client manager will purge
client in async for every 10 minutes.

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2024-11-12 10:30:28 +08:00 committed by GitHub
parent f5b06a3c9f
commit 2a4c00de9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 462 additions and 322 deletions

View File

@ -23,6 +23,7 @@ import (
)
type LBBalancer interface {
RegisterNodeInfo(nodeInfos []nodeInfo)
SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error)
CancelWorkload(node int64, nq int64)
UpdateCostMetrics(node int64, cost *internalpb.CostAggregation)

View File

@ -39,7 +39,7 @@ type ChannelWorkload struct {
collectionName string
collectionID int64
channel string
shardLeaders []int64
shardLeaders []nodeInfo
nq int64
exec executeFunc
retryTimes uint
@ -116,9 +116,20 @@ func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, coll
}
// try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
availableNodes := lo.FilterMap(workload.shardLeaders, func(node int64, _ int) (int64, bool) { return node, !excludeNodes.Contain(node) })
targetNode, err := balancer.SelectNode(ctx, availableNodes, workload.nq)
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) {
filterDelegator := func(nodes []nodeInfo) map[int64]nodeInfo {
ret := make(map[int64]nodeInfo)
for _, node := range nodes {
if !excludeNodes.Contain(node.nodeID) {
ret[node.nodeID] = node
}
}
return ret
}
availableNodes := filterDelegator(workload.shardLeaders)
balancer.RegisterNodeInfo(lo.Values(availableNodes))
targetNode, err := balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq)
if err != nil {
log := log.Ctx(ctx)
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
@ -128,32 +139,33 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Error(err))
return -1, err
return nodeInfo{}, err
}
availableNodes := lo.FilterMap(shardLeaders[workload.channel], func(node nodeInfo, _ int) (int64, bool) { return node.nodeID, !excludeNodes.Contain(node.nodeID) })
availableNodes = filterDelegator(shardLeaders[workload.channel])
if len(availableNodes) == 0 {
nodes := lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID })
log.Warn("no available shard delegator found",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("nodes", nodes),
zap.Int64s("availableNodes", lo.Keys(availableNodes)),
zap.Int64s("excluded", excludeNodes.Collect()))
return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found")
return nodeInfo{}, merr.WrapErrChannelNotAvailable("no available shard delegator found")
}
targetNode, err = balancer.SelectNode(ctx, availableNodes, workload.nq)
balancer.RegisterNodeInfo(lo.Values(availableNodes))
targetNode, err = balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq)
if err != nil {
log.Warn("failed to select shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("availableNodes", availableNodes),
zap.Int64s("availableNodes", lo.Keys(availableNodes)),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.Error(err))
return -1, err
return nodeInfo{}, err
}
}
return targetNode, nil
return availableNodes[targetNode], nil
}
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
@ -168,7 +180,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
log.Warn("failed to select node for shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Int64("nodeID", targetNode.nodeID),
zap.Error(err),
)
if lastErr != nil {
@ -177,29 +189,30 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
return err
}
// cancel work load which assign to the target node
defer balancer.CancelWorkload(targetNode, workload.nq)
defer balancer.CancelWorkload(targetNode.nodeID, workload.nq)
client, err := lb.clientMgr.GetClient(ctx, targetNode)
if err != nil {
log.Warn("search/query channel failed, node not available",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Int64("nodeID", targetNode.nodeID),
zap.Error(err))
excludeNodes.Insert(targetNode)
excludeNodes.Insert(targetNode.nodeID)
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel)
return lastErr
}
defer lb.clientMgr.ReleaseClientRef(targetNode.nodeID)
err = workload.exec(ctx, targetNode, client, workload.channel)
err = workload.exec(ctx, targetNode.nodeID, client, workload.channel)
if err != nil {
log.Warn("search/query channel failed",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Int64("nodeID", targetNode.nodeID),
zap.Error(err))
excludeNodes.Insert(targetNode)
excludeNodes.Insert(targetNode.nodeID)
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode, workload.channel)
return lastErr
}
@ -221,9 +234,9 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
// let every request could retry at least twice, which could retry after update shard leader cache
retryTimes := Params.ProxyCfg.RetryTimesOnReplica.GetAsInt()
wg, ctx := errgroup.WithContext(ctx)
for channel, nodes := range dml2leaders {
channel := channel
nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID })
for k, v := range dml2leaders {
channel := k
nodes := v
channelRetryTimes := retryTimes
if len(nodes) > 0 {
channelRetryTimes *= len(nodes)

View File

@ -49,7 +49,8 @@ type LBPolicySuite struct {
lbBalancer *MockLBBalancer
lbPolicy *LBPolicyImpl
nodes []int64
nodeIDs []int64
nodes []nodeInfo
channels []string
qnList []*mocks.MockQueryNode
@ -62,7 +63,14 @@ func (s *LBPolicySuite) SetupSuite() {
}
func (s *LBPolicySuite) SetupTest() {
s.nodes = []int64{1, 2, 3, 4, 5}
s.nodeIDs = make([]int64, 0)
for i := 1; i <= 5; i++ {
s.nodeIDs = append(s.nodeIDs, int64(i))
s.nodes = append(s.nodes, nodeInfo{
nodeID: int64(i),
address: "localhost",
})
}
s.channels = []string{"channel1", "channel2"}
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc := mocks.NewMockQueryCoordClient(s.T())
@ -74,12 +82,12 @@ func (s *LBPolicySuite) SetupTest() {
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: s.nodes,
NodeIds: s.nodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
},
{
ChannelName: s.channels[1],
NodeIds: s.nodes,
NodeIds: s.nodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
},
},
@ -96,7 +104,6 @@ func (s *LBPolicySuite) SetupTest() {
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
s.mgr = NewMockShardClientManager(s.T())
s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
s.lbBalancer = NewMockLBBalancer(s.T())
s.lbBalancer.EXPECT().Start(context.Background()).Maybe()
s.lbPolicy = NewLBPolicyImpl(s.mgr)
@ -164,6 +171,7 @@ func (s *LBPolicySuite) loadCollection() {
func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
@ -174,10 +182,11 @@ func (s *LBPolicySuite) TestSelectNode() {
nq: 1,
}, typeutil.NewUniqueSet())
s.NoError(err)
s.Equal(int64(5), targetNode)
s.Equal(int64(5), targetNode.nodeID)
// test select node failed, then update shard leader cache and retry, expect success
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
@ -185,28 +194,29 @@ func (s *LBPolicySuite) TestSelectNode() {
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: []int64{},
shardLeaders: []nodeInfo{},
nq: 1,
}, typeutil.NewUniqueSet())
s.NoError(err)
s.Equal(int64(3), targetNode)
s.Equal(int64(3), targetNode.nodeID)
// test select node always fails, expected failure
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: []int64{},
shardLeaders: []nodeInfo{},
nq: 1,
}, typeutil.NewUniqueSet())
s.ErrorIs(err, merr.ErrNodeNotAvailable)
s.Equal(int64(-1), targetNode)
// test all nodes has been excluded, expected failure
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
@ -215,12 +225,12 @@ func (s *LBPolicySuite) TestSelectNode() {
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet(s.nodes...))
}, typeutil.NewUniqueSet(s.nodeIDs...))
s.ErrorIs(err, merr.ErrChannelNotAvailable)
s.Equal(int64(-1), targetNode)
// test get shard leaders failed, retry to select node failed
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
s.qc.ExpectedCalls = nil
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable)
@ -233,7 +243,6 @@ func (s *LBPolicySuite) TestSelectNode() {
nq: 1,
}, typeutil.NewUniqueSet())
s.ErrorIs(err, merr.ErrServiceUnavailable)
s.Equal(int64(-1), targetNode)
}
func (s *LBPolicySuite) TestExecuteWithRetry() {
@ -241,7 +250,9 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
// test execute success
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
@ -260,6 +271,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
// test select node failed, expected error
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
@ -277,8 +289,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
// test get client failed, and retry failed, expected success
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
@ -296,8 +310,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.Error(err)
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
@ -315,8 +331,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
// test exec failed, then retry success
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0
@ -341,6 +359,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
// test exec timeout
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Times(1)
@ -365,7 +384,9 @@ func (s *LBPolicySuite) TestExecute() {
ctx := context.Background()
mockErr := errors.New("mock error")
// test all channel success
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{

View File

@ -44,7 +44,8 @@ type CostMetrics struct {
type LookAsideBalancer struct {
clientMgr shardClientMgr
metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics]
knownNodeInfos *typeutil.ConcurrentMap[int64, nodeInfo]
metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics]
// query node id -> number of consecutive heartbeat failures
failedHeartBeatCounter *typeutil.ConcurrentMap[int64, *atomic.Int64]
@ -64,6 +65,7 @@ type LookAsideBalancer struct {
func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer {
balancer := &LookAsideBalancer{
clientMgr: clientMgr,
knownNodeInfos: typeutil.NewConcurrentMap[int64, nodeInfo](),
metricsMap: typeutil.NewConcurrentMap[int64, *CostMetrics](),
failedHeartBeatCounter: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
closeCh: make(chan struct{}),
@ -88,6 +90,12 @@ func (b *LookAsideBalancer) Close() {
})
}
func (b *LookAsideBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {
for _, node := range nodeInfos {
b.knownNodeInfos.Insert(node.nodeID, node)
}
}
func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
targetNode := int64(-1)
defer func() {
@ -233,9 +241,10 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
case <-ticker.C:
var futures []*conc.Future[any]
now := time.Now()
b.metricsMap.Range(func(node int64, metrics *CostMetrics) bool {
b.knownNodeInfos.Range(func(node int64, info nodeInfo) bool {
futures = append(futures, pool.Submit(func() (any, error) {
if now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() {
metrics, ok := b.metricsMap.Get(node)
if !ok || now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() {
checkTimeout := Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), checkTimeout)
defer cancel()
@ -244,13 +253,14 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
panic("let it panic")
}
qn, err := b.clientMgr.GetClient(ctx, node)
qn, err := b.clientMgr.GetClient(ctx, info)
if err != nil {
// get client from clientMgr failed, which means this qn isn't a shard leader anymore, skip it's health check
b.trySetQueryNodeUnReachable(node, err)
log.RatedInfo(10, "get client failed", zap.Int64("node", node), zap.Error(err))
return struct{}{}, nil
}
defer b.clientMgr.ReleaseClientRef(node)
resp, err := qn.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{})
if err != nil {
@ -304,6 +314,7 @@ func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) {
zap.Int64("nodeID", node))
// stop the heartbeat
b.metricsMap.Remove(node)
b.knownNodeInfos.Remove(node)
return
}

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -42,11 +43,12 @@ type LookAsideBalancerSuite struct {
func (suite *LookAsideBalancerSuite) SetupTest() {
suite.clientMgr = NewMockShardClientManager(suite.T())
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe()
suite.balancer = NewLookAsideBalancer(suite.clientMgr)
suite.balancer.Start(context.Background())
qn := mocks.NewMockQueryNodeClient(suite.T())
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(1)).Return(qn, nil).Maybe()
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Maybe()
}
@ -298,22 +300,46 @@ func (suite *LookAsideBalancerSuite) TestCancelWorkload() {
}
func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
qn := mocks.NewMockQueryNodeClient(suite.T())
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Maybe()
qn2 := mocks.NewMockQueryNodeClient(suite.T())
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil).Maybe()
qn2.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
}, nil).Maybe()
suite.clientMgr.ExpectedCalls = nil
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe()
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni nodeInfo) (types.QueryNodeClient, error) {
if ni.nodeID == 1 {
return qn, nil
}
if ni.nodeID == 2 {
return qn2, nil
}
return nil, errors.New("unexpected node")
}).Maybe()
metrics1 := &CostMetrics{}
metrics1.ts.Store(time.Now().UnixMilli())
metrics1.unavailable.Store(true)
suite.balancer.metricsMap.Insert(1, metrics1)
suite.balancer.RegisterNodeInfo([]nodeInfo{
{
nodeID: 1,
},
})
metrics2 := &CostMetrics{}
metrics2.ts.Store(time.Now().UnixMilli())
metrics2.unavailable.Store(true)
suite.balancer.metricsMap.Insert(2, metrics2)
suite.balancer.knownNodeInfos.Insert(2, nodeInfo{})
suite.balancer.RegisterNodeInfo([]nodeInfo{
{
nodeID: 2,
},
})
suite.Eventually(func() bool {
metrics, ok := suite.balancer.metricsMap.Get(1)
return ok && metrics.unavailable.Load()
@ -339,10 +365,16 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() {
metrics1.ts.Store(time.Now().UnixMilli())
metrics1.unavailable.Store(true)
suite.balancer.metricsMap.Insert(2, metrics1)
suite.balancer.RegisterNodeInfo([]nodeInfo{
{
nodeID: 2,
},
})
// test get shard client from client mgr return nil
suite.clientMgr.ExpectedCalls = nil
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(nil, errors.New("shard client not found"))
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe()
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("shard client not found"))
// expected stopping the health check after failure times reaching the limit
suite.Eventually(func() bool {
return !suite.balancer.metricsMap.Contain(2)
@ -352,7 +384,9 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() {
func (suite *LookAsideBalancerSuite) TestNodeRecover() {
// mock qn down for a while and then recover
qn3 := mocks.NewMockQueryNodeClient(suite.T())
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil)
suite.clientMgr.ExpectedCalls = nil
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything)
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil)
qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Abnormal,
@ -368,6 +402,11 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() {
metrics1 := &CostMetrics{}
metrics1.ts.Store(time.Now().UnixMilli())
suite.balancer.metricsMap.Insert(3, metrics1)
suite.balancer.RegisterNodeInfo([]nodeInfo{
{
nodeID: 3,
},
})
suite.Eventually(func() bool {
metrics, ok := suite.balancer.metricsMap.Get(3)
return ok && metrics.unavailable.Load()
@ -384,7 +423,9 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() {
Params.Save(Params.ProxyCfg.HealthCheckTimeout.Key, "1000")
// mock qn down for a while and then recover
qn3 := mocks.NewMockQueryNodeClient(suite.T())
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil)
suite.clientMgr.ExpectedCalls = nil
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything)
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil)
qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Abnormal,
@ -394,6 +435,11 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() {
metrics1 := &CostMetrics{}
metrics1.ts.Store(time.Now().UnixMilli())
suite.balancer.metricsMap.Insert(3, metrics1)
suite.balancer.RegisterNodeInfo([]nodeInfo{
{
nodeID: 3,
},
})
suite.Eventually(func() bool {
metrics, ok := suite.balancer.metricsMap.Get(3)
return ok && metrics.unavailable.Load()

View File

@ -72,6 +72,7 @@ type Cache interface {
GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
DeprecateShardCache(database, collectionName string)
InvalidateShardLeaderCache(collections []int64)
ListShardLocation() map[int64]nodeInfo
RemoveCollection(ctx context.Context, database, collectionName string)
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string
RemovePartition(ctx context.Context, database, collectionName string, partitionName string)
@ -288,9 +289,7 @@ func (info *collectionInfo) isCollectionCached() bool {
// shardLeaders wraps shard leader mapping for iteration.
type shardLeaders struct {
idx *atomic.Int64
deprecated *atomic.Bool
idx *atomic.Int64
collectionID int64
shardLeaders map[string][]nodeInfo
}
@ -419,19 +418,19 @@ func (m *MetaCache) getCollection(database, collectionName string, collectionID
return nil, false
}
func (m *MetaCache) getCollectionShardLeader(database, collectionName string) (*shardLeaders, bool) {
func (m *MetaCache) getCollectionShardLeader(database, collectionName string) *shardLeaders {
m.leaderMut.RLock()
defer m.leaderMut.RUnlock()
db, ok := m.collLeader[database]
if !ok {
return nil, false
return nil
}
if leaders, ok := db[collectionName]; ok {
return leaders, !leaders.deprecated.Load()
return leaders
}
return nil, false
return nil
}
func (m *MetaCache) update(ctx context.Context, database, collectionName string, collectionID UniqueID) (*collectionInfo, error) {
@ -957,9 +956,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
cacheShardLeaders, ok := m.getCollectionShardLeader(database, collectionName)
cacheShardLeaders := m.getCollectionShardLeader(database, collectionName)
if withCache {
if ok {
if cacheShardLeaders != nil {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
iterator := cacheShardLeaders.GetReader()
return iterator.Shuffle(), nil
@ -995,11 +994,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
newShardLeaders := &shardLeaders{
collectionID: info.collID,
shardLeaders: shards,
deprecated: atomic.NewBool(false),
idx: atomic.NewInt64(0),
}
// lock leader
m.leaderMut.Lock()
if _, ok := m.collLeader[database]; !ok {
m.collLeader[database] = make(map[string]*shardLeaders)
@ -1008,15 +1005,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
iterator := newShardLeaders.GetReader()
ret := iterator.Shuffle()
oldLeaders := make(map[string][]nodeInfo)
if cacheShardLeaders != nil {
oldLeaders = cacheShardLeaders.shardLeaders
}
// update refcnt in shardClientMgr
// update shard leader's just create a empty client pool
// and init new client will be execute in getClient
_ = m.shardMgr.UpdateShardLeaders(oldLeaders, ret)
m.leaderMut.Unlock()
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
@ -1042,23 +1030,50 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m
// DeprecateShardCache clear the shard leader cache of a collection
func (m *MetaCache) DeprecateShardCache(database, collectionName string) {
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
if shards, ok := m.getCollectionShardLeader(database, collectionName); ok {
shards.deprecated.Store(true)
m.leaderMut.Lock()
defer m.leaderMut.Unlock()
dbInfo, ok := m.collLeader[database]
if ok {
delete(dbInfo, collectionName)
if len(dbInfo) == 0 {
delete(m.collLeader, database)
}
}
}
// used for Garbage collection shard client
func (m *MetaCache) ListShardLocation() map[int64]nodeInfo {
m.leaderMut.RLock()
defer m.leaderMut.RUnlock()
shardLeaderInfo := make(map[int64]nodeInfo)
for _, dbInfo := range m.collLeader {
for _, shardLeaders := range dbInfo {
for _, nodeInfos := range shardLeaders.shardLeaders {
for _, node := range nodeInfos {
shardLeaderInfo[node.nodeID] = node
}
}
}
}
return shardLeaderInfo
}
func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) {
log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections))
m.leaderMut.Lock()
defer m.leaderMut.Unlock()
collectionSet := typeutil.NewUniqueSet(collections...)
for _, db := range m.collLeader {
for _, shardLeaders := range db {
for dbName, dbInfo := range m.collLeader {
for collectionName, shardLeaders := range dbInfo {
if collectionSet.Contain(shardLeaders.collectionID) {
shardLeaders.deprecated.Store(true)
delete(dbInfo, collectionName)
}
}
if len(dbInfo) == 0 {
delete(m.collLeader, dbName)
}
}
}

View File

@ -805,7 +805,6 @@ func TestGlobalMetaCache_ShuffleShardLeaders(t *testing.T) {
},
}
sl := &shardLeaders{
deprecated: uatomic.NewBool(false),
idx: uatomic.NewInt64(5),
shardLeaders: shards,
}

View File

@ -981,6 +981,53 @@ func (_c *MockCache_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int
return _c
}
// ListShardLocation provides a mock function with given fields:
func (_m *MockCache) ListShardLocation() map[int64]nodeInfo {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ListShardLocation")
}
var r0 map[int64]nodeInfo
if rf, ok := ret.Get(0).(func() map[int64]nodeInfo); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]nodeInfo)
}
}
return r0
}
// MockCache_ListShardLocation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListShardLocation'
type MockCache_ListShardLocation_Call struct {
*mock.Call
}
// ListShardLocation is a helper method to define mock.On call
func (_e *MockCache_Expecter) ListShardLocation() *MockCache_ListShardLocation_Call {
return &MockCache_ListShardLocation_Call{Call: _e.mock.On("ListShardLocation")}
}
func (_c *MockCache_ListShardLocation_Call) Run(run func()) *MockCache_ListShardLocation_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCache_ListShardLocation_Call) Return(_a0 map[int64]nodeInfo) *MockCache_ListShardLocation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCache_ListShardLocation_Call) RunAndReturn(run func() map[int64]nodeInfo) *MockCache_ListShardLocation_Call {
_c.Call.Return(run)
return _c
}
// RefreshPolicyInfo provides a mock function with given fields: op
func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
ret := _m.Called(op)

View File

@ -88,6 +88,39 @@ func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Cl
return _c
}
// RegisterNodeInfo provides a mock function with given fields: nodeInfos
func (_m *MockLBBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {
_m.Called(nodeInfos)
}
// MockLBBalancer_RegisterNodeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RegisterNodeInfo'
type MockLBBalancer_RegisterNodeInfo_Call struct {
*mock.Call
}
// RegisterNodeInfo is a helper method to define mock.On call
// - nodeInfos []nodeInfo
func (_e *MockLBBalancer_Expecter) RegisterNodeInfo(nodeInfos interface{}) *MockLBBalancer_RegisterNodeInfo_Call {
return &MockLBBalancer_RegisterNodeInfo_Call{Call: _e.mock.On("RegisterNodeInfo", nodeInfos)}
}
func (_c *MockLBBalancer_RegisterNodeInfo_Call) Run(run func(nodeInfos []nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]nodeInfo))
})
return _c
}
func (_c *MockLBBalancer_RegisterNodeInfo_Call) Return() *MockLBBalancer_RegisterNodeInfo_Call {
_c.Call.Return()
return _c
}
func (_c *MockLBBalancer_RegisterNodeInfo_Call) RunAndReturn(run func([]nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
_c.Call.Return(run)
return _c
}
// SelectNode provides a mock function with given fields: ctx, availableNodes, nq
func (_m *MockLBBalancer) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
ret := _m.Called(ctx, availableNodes, nq)

View File

@ -54,9 +54,9 @@ func (_c *MockShardClientManager_Close_Call) RunAndReturn(run func()) *MockShard
return _c
}
// GetClient provides a mock function with given fields: ctx, nodeID
func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeID int64) (types.QueryNodeClient, error) {
ret := _m.Called(ctx, nodeID)
// GetClient provides a mock function with given fields: ctx, nodeInfo1
func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeInfo1 nodeInfo) (types.QueryNodeClient, error) {
ret := _m.Called(ctx, nodeInfo1)
if len(ret) == 0 {
panic("no return value specified for GetClient")
@ -64,19 +64,19 @@ func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeID int64) (
var r0 types.QueryNodeClient
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64) (types.QueryNodeClient, error)); ok {
return rf(ctx, nodeID)
if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) (types.QueryNodeClient, error)); ok {
return rf(ctx, nodeInfo1)
}
if rf, ok := ret.Get(0).(func(context.Context, int64) types.QueryNodeClient); ok {
r0 = rf(ctx, nodeID)
if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) types.QueryNodeClient); ok {
r0 = rf(ctx, nodeInfo1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(types.QueryNodeClient)
}
}
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, nodeID)
if rf, ok := ret.Get(1).(func(context.Context, nodeInfo) error); ok {
r1 = rf(ctx, nodeInfo1)
} else {
r1 = ret.Error(1)
}
@ -91,14 +91,14 @@ type MockShardClientManager_GetClient_Call struct {
// GetClient is a helper method to define mock.On call
// - ctx context.Context
// - nodeID int64
func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeID interface{}) *MockShardClientManager_GetClient_Call {
return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeID)}
// - nodeInfo1 nodeInfo
func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeInfo1 interface{}) *MockShardClientManager_GetClient_Call {
return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeInfo1)}
}
func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeID int64)) *MockShardClientManager_GetClient_Call {
func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeInfo1 nodeInfo)) *MockShardClientManager_GetClient_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
run(args[0].(context.Context), args[1].(nodeInfo))
})
return _c
}
@ -108,7 +108,40 @@ func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNodeClien
return _c
}
func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, int64) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call {
func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, nodeInfo) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call {
_c.Call.Return(run)
return _c
}
// ReleaseClientRef provides a mock function with given fields: nodeID
func (_m *MockShardClientManager) ReleaseClientRef(nodeID int64) {
_m.Called(nodeID)
}
// MockShardClientManager_ReleaseClientRef_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseClientRef'
type MockShardClientManager_ReleaseClientRef_Call struct {
*mock.Call
}
// ReleaseClientRef is a helper method to define mock.On call
// - nodeID int64
func (_e *MockShardClientManager_Expecter) ReleaseClientRef(nodeID interface{}) *MockShardClientManager_ReleaseClientRef_Call {
return &MockShardClientManager_ReleaseClientRef_Call{Call: _e.mock.On("ReleaseClientRef", nodeID)}
}
func (_c *MockShardClientManager_ReleaseClientRef_Call) Run(run func(nodeID int64)) *MockShardClientManager_ReleaseClientRef_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockShardClientManager_ReleaseClientRef_Call) Return() *MockShardClientManager_ReleaseClientRef_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_ReleaseClientRef_Call) RunAndReturn(run func(int64)) *MockShardClientManager_ReleaseClientRef_Call {
_c.Call.Return(run)
return _c
}
@ -146,53 +179,6 @@ func (_c *MockShardClientManager_SetClientCreatorFunc_Call) RunAndReturn(run fun
return _c
}
// UpdateShardLeaders provides a mock function with given fields: oldLeaders, newLeaders
func (_m *MockShardClientManager) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error {
ret := _m.Called(oldLeaders, newLeaders)
if len(ret) == 0 {
panic("no return value specified for UpdateShardLeaders")
}
var r0 error
if rf, ok := ret.Get(0).(func(map[string][]nodeInfo, map[string][]nodeInfo) error); ok {
r0 = rf(oldLeaders, newLeaders)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockShardClientManager_UpdateShardLeaders_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateShardLeaders'
type MockShardClientManager_UpdateShardLeaders_Call struct {
*mock.Call
}
// UpdateShardLeaders is a helper method to define mock.On call
// - oldLeaders map[string][]nodeInfo
// - newLeaders map[string][]nodeInfo
func (_e *MockShardClientManager_Expecter) UpdateShardLeaders(oldLeaders interface{}, newLeaders interface{}) *MockShardClientManager_UpdateShardLeaders_Call {
return &MockShardClientManager_UpdateShardLeaders_Call{Call: _e.mock.On("UpdateShardLeaders", oldLeaders, newLeaders)}
}
func (_c *MockShardClientManager_UpdateShardLeaders_Call) Run(run func(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo)) *MockShardClientManager_UpdateShardLeaders_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(map[string][]nodeInfo), args[1].(map[string][]nodeInfo))
})
return _c
}
func (_c *MockShardClientManager_UpdateShardLeaders_Call) Return(_a0 error) *MockShardClientManager_UpdateShardLeaders_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockShardClientManager_UpdateShardLeaders_Call) RunAndReturn(run func(map[string][]nodeInfo, map[string][]nodeInfo) error) *MockShardClientManager_UpdateShardLeaders_Call {
_c.Call.Return(run)
return _c
}
// NewMockShardClientManager creates a new instance of MockShardClientManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockShardClientManager(t interface {

View File

@ -32,6 +32,8 @@ func NewRoundRobinBalancer() *RoundRobinBalancer {
return &RoundRobinBalancer{}
}
func (b *RoundRobinBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {}
func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) {
if len(availableNodes) == 0 {
return -1, merr.ErrNodeNotAvailable

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"sync"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/atomic"
@ -32,7 +33,6 @@ type shardClient struct {
sync.RWMutex
info nodeInfo
isClosed bool
refCnt int
clients []types.QueryNodeClient
idx atomic.Int64
poolSize int
@ -40,13 +40,15 @@ type shardClient struct {
initialized atomic.Bool
creator queryNodeCreatorFunc
refCnt *atomic.Int64
}
func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) {
if !n.initialized.Load() {
n.Lock()
if !n.initialized.Load() {
if err := n.initClients(); err != nil {
if err := n.initClients(ctx); err != nil {
n.Unlock()
return nil, err
}
@ -55,28 +57,34 @@ func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, err
n.Unlock()
}
n.RLock()
defer n.RUnlock()
if n.isClosed {
return nil, errClosed
// Attempt to get a connection from the idle connection pool, supporting context cancellation
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
client, err := n.roundRobinSelectClient()
if err != nil {
return nil, err
}
n.IncRef()
return client, nil
}
idx := n.idx.Inc()
return n.clients[int(idx)%n.poolSize], nil
}
func (n *shardClient) inc() {
n.Lock()
defer n.Unlock()
if n.isClosed {
return
func (n *shardClient) DecRef() bool {
if n.refCnt.Dec() == 0 {
n.Close()
return true
}
n.refCnt++
return false
}
func (n *shardClient) IncRef() {
n.refCnt.Inc()
}
func (n *shardClient) close() {
n.isClosed = true
n.refCnt = 0
for _, client := range n.clients {
if err := client.Close(); err != nil {
@ -86,50 +94,36 @@ func (n *shardClient) close() {
n.clients = nil
}
func (n *shardClient) dec() bool {
n.Lock()
defer n.Unlock()
if n.isClosed {
return true
}
if n.refCnt > 0 {
n.refCnt--
}
if n.refCnt == 0 {
n.close()
}
return n.refCnt == 0
}
func (n *shardClient) Close() {
n.Lock()
defer n.Unlock()
n.close()
}
func newPoolingShardClient(info *nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) {
func newShardClient(info nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) {
num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt()
if num <= 0 {
num = 1
}
return &shardClient{
info: nodeInfo{
nodeID: info.nodeID,
address: info.address,
},
refCnt: 1,
pooling: true,
creator: creator,
poolSize: num,
creator: creator,
refCnt: atomic.NewInt64(1),
}, nil
}
func (n *shardClient) initClients() error {
num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt()
if num <= 0 {
num = 1
}
clients := make([]types.QueryNodeClient, 0, num)
for i := 0; i < num; i++ {
client, err := n.creator(context.Background(), n.info.address, n.info.nodeID)
func (n *shardClient) initClients(ctx context.Context) error {
clients := make([]types.QueryNodeClient, 0, n.poolSize)
for i := 0; i < n.poolSize; i++ {
client, err := n.creator(ctx, n.info.address, n.info.nodeID)
if err != nil {
// roll back already created clients
for _, c := range clients[:i] {
// Roll back already created clients
for _, c := range clients {
c.Close()
}
return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID))
@ -138,13 +132,29 @@ func (n *shardClient) initClients() error {
}
n.clients = clients
n.poolSize = num
return nil
}
// roundRobinSelectClient selects a client in a round-robin manner
func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) {
n.Lock()
defer n.Unlock()
if n.isClosed {
return nil, errClosed
}
if len(n.clients) == 0 {
return nil, errors.New("no available clients")
}
nextClientIndex := n.idx.Inc() % int64(len(n.clients))
nextClient := n.clients[nextClientIndex]
return nextClient, nil
}
type shardClientMgr interface {
GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error)
UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error
GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error)
ReleaseClientRef(nodeID int64)
Close()
SetClientCreatorFunc(creator queryNodeCreatorFunc)
}
@ -155,6 +165,8 @@ type shardClientMgrImpl struct {
data map[UniqueID]*shardClient
}
clientCreator queryNodeCreatorFunc
closeCh chan struct{}
}
// SessionOpt provides a way to set params in SessionManager
@ -176,10 +188,13 @@ func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgrImpl {
data map[UniqueID]*shardClient
}{data: make(map[UniqueID]*shardClient)},
clientCreator: defaultQueryNodeClientCreator,
closeCh: make(chan struct{}),
}
for _, opt := range options {
opt(s)
}
go s.PurgeClient()
return s
}
@ -187,79 +202,65 @@ func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc)
c.clientCreator = creator
}
// Warning this method may modify parameter `oldLeaders`
func (c *shardClientMgrImpl) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error {
oldLocalMap := make(map[UniqueID]*nodeInfo)
for _, nodes := range oldLeaders {
for i := range nodes {
n := &nodes[i]
_, ok := oldLocalMap[n.nodeID]
if !ok {
oldLocalMap[n.nodeID] = n
}
}
}
newLocalMap := make(map[UniqueID]*nodeInfo)
for _, nodes := range newLeaders {
for i := range nodes {
n := &nodes[i]
_, ok := oldLocalMap[n.nodeID]
if !ok {
_, ok2 := newLocalMap[n.nodeID]
if !ok2 {
newLocalMap[n.nodeID] = n
}
}
delete(oldLocalMap, n.nodeID)
}
}
c.clients.Lock()
defer c.clients.Unlock()
for _, node := range newLocalMap {
client, ok := c.clients.data[node.nodeID]
if ok {
client.inc()
} else {
// context.Background() is useless
// TODO QueryNode NewClient remove ctx parameter
// TODO Remove Init && Start interface in QueryNode client
if c.clientCreator == nil {
return fmt.Errorf("clientCreator function is nil")
}
client, err := newPoolingShardClient(node, c.clientCreator)
if err != nil {
return err
}
c.clients.data[node.nodeID] = client
}
}
for _, node := range oldLocalMap {
client, ok := c.clients.data[node.nodeID]
if ok && client.dec() {
delete(c.clients.data, node.nodeID)
}
}
return nil
}
func (c *shardClientMgrImpl) GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) {
func (c *shardClientMgrImpl) GetClient(ctx context.Context, info nodeInfo) (types.QueryNodeClient, error) {
c.clients.RLock()
client, ok := c.clients.data[nodeID]
client, ok := c.clients.data[info.nodeID]
c.clients.RUnlock()
if !ok {
return nil, fmt.Errorf("can not find client of node %d", nodeID)
c.clients.Lock()
// Check again after acquiring the lock
client, ok = c.clients.data[info.nodeID]
if !ok {
// Create a new client if it doesn't exist
newClient, err := newShardClient(info, c.clientCreator)
if err != nil {
c.clients.Unlock()
return nil, err
}
c.clients.data[info.nodeID] = newClient
client = newClient
}
c.clients.Unlock()
}
return client.getClient(ctx)
}
func (c *shardClientMgrImpl) PurgeClient() {
ticker := time.NewTicker(600 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.closeCh:
return
case <-ticker.C:
shardLocations := globalMetaCache.ListShardLocation()
c.clients.Lock()
for nodeID, client := range c.clients.data {
if _, ok := shardLocations[nodeID]; !ok {
client.DecRef()
delete(c.clients.data, nodeID)
}
}
c.clients.Unlock()
}
}
}
func (c *shardClientMgrImpl) ReleaseClientRef(nodeID int64) {
c.clients.RLock()
defer c.clients.RUnlock()
if client, ok := c.clients.data[nodeID]; ok {
client.DecRef()
}
}
// Close release clients
func (c *shardClientMgrImpl) Close() {
c.clients.Lock()
defer c.clients.Unlock()
close(c.closeCh)
for _, s := range c.clients.data {
s.Close()
}

View File

@ -8,94 +8,59 @@ import (
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func genShardLeaderInfo(channel string, leaderIDs []UniqueID) map[string][]nodeInfo {
leaders := make(map[string][]nodeInfo)
nodeInfos := make([]nodeInfo, len(leaderIDs))
for i, id := range leaderIDs {
nodeInfos[i] = nodeInfo{
nodeID: id,
address: "fake",
}
func TestShardClientMgr(t *testing.T) {
ctx := context.Background()
nodeInfo := nodeInfo{
nodeID: 1,
}
leaders[channel] = nodeInfos
return leaders
}
func TestShardClientMgr_UpdateShardLeaders_CreatorNil(t *testing.T) {
mgr := newShardClientMgr(withShardClientCreator(nil))
mgr.clientCreator = nil
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
err := mgr.UpdateShardLeaders(nil, leaders)
assert.Error(t, err)
}
func TestShardClientMgr_UpdateShardLeaders_Empty(t *testing.T) {
mockCreator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return &mocks.MockQueryNodeClient{}, nil
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil)
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
_, err := mgr.GetClient(context.Background(), UniqueID(1))
assert.Error(t, err)
err = mgr.UpdateShardLeaders(nil, nil)
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.Error(t, err)
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
err = mgr.UpdateShardLeaders(leaders, nil)
assert.NoError(t, err)
}
func TestShardClientMgr_UpdateShardLeaders_NonEmpty(t *testing.T) {
mgr := newShardClientMgr()
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
err := mgr.UpdateShardLeaders(nil, leaders)
assert.NoError(t, err)
mgr.SetClientCreatorFunc(creator)
_, err := mgr.GetClient(ctx, nodeInfo)
assert.Nil(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.NoError(t, err)
newLeaders := genShardLeaderInfo("c1", []UniqueID{2, 3})
err = mgr.UpdateShardLeaders(leaders, newLeaders)
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.Error(t, err)
mgr.ReleaseClientRef(1)
assert.Equal(t, len(mgr.clients.data), 1)
mgr.Close()
assert.Equal(t, len(mgr.clients.data), 0)
}
func TestShardClientMgr_UpdateShardLeaders_Ref(t *testing.T) {
mgr := newShardClientMgr()
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
for i := 0; i < 2; i++ {
err := mgr.UpdateShardLeaders(nil, leaders)
assert.NoError(t, err)
func TestShardClient(t *testing.T) {
nodeInfo := nodeInfo{
nodeID: 1,
}
partLeaders := genShardLeaderInfo("c1", []UniqueID{1})
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil)
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
shardClient, err := newShardClient(nodeInfo, creator)
assert.Nil(t, err)
assert.Equal(t, len(shardClient.clients), 0)
assert.Equal(t, int64(1), shardClient.refCnt.Load())
assert.Equal(t, false, shardClient.initialized.Load())
_, err := mgr.GetClient(context.Background(), UniqueID(1))
assert.NoError(t, err)
ctx := context.Background()
_, err = shardClient.getClient(ctx)
assert.Nil(t, err)
assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt())
assert.Equal(t, int64(2), shardClient.refCnt.Load())
assert.Equal(t, true, shardClient.initialized.Load())
err = mgr.UpdateShardLeaders(partLeaders, nil)
assert.NoError(t, err)
shardClient.DecRef()
assert.Equal(t, int64(1), shardClient.refCnt.Load())
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.NoError(t, err)
err = mgr.UpdateShardLeaders(partLeaders, nil)
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.Error(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(2))
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(3))
assert.NoError(t, err)
shardClient.DecRef()
assert.Equal(t, int64(0), shardClient.refCnt.Load())
assert.Equal(t, true, shardClient.isClosed)
}

View File

@ -33,12 +33,13 @@ func RoundRobinPolicy(
leaders := dml2leaders[channel]
for _, target := range leaders {
qn, err := mgr.GetClient(ctx, target.nodeID)
qn, err := mgr.GetClient(ctx, target)
if err != nil {
log.Warn("query channel failed, node not available", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err))
combineErr = merr.Combine(combineErr, err)
continue
}
defer mgr.ReleaseClientRef(target.nodeID)
err = query(ctx, target.nodeID, qn, channel)
if err != nil {
log.Warn("query channel failed", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err))

View File

@ -26,7 +26,6 @@ func TestRoundRobinPolicy(t *testing.T) {
"c2": {{nodeID: 0, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}},
"c3": {{nodeID: 1, address: "fake"}, {nodeID: 3, address: "fake"}, {nodeID: 4, address: "fake"}},
}
mgr.UpdateShardLeaders(nil, shard2leaders)
querier := &mockQuery{}
querier.init()

View File

@ -79,8 +79,8 @@ func TestQueryTask_all(t *testing.T) {
}, nil).Maybe()
mgr := NewMockShardClientManager(t)
mgr.EXPECT().ReleaseClientRef(mock.Anything)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
lb := NewLBPolicyImpl(mgr)
defer rc.Close()

View File

@ -2111,8 +2111,8 @@ func TestSearchTask_ErrExecute(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
mgr := NewMockShardClientManager(t)
mgr.EXPECT().ReleaseClientRef(mock.Anything)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
lb := NewLBPolicyImpl(mgr)
defer qc.Close()

View File

@ -80,8 +80,8 @@ func (s *StatisticTaskSuite) SetupTest() {
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
mgr := NewMockShardClientManager(s.T())
mgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe()
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
s.lb = NewLBPolicyImpl(mgr)
err := InitMetaCache(context.Background(), s.rc, s.qc, mgr)