diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 1e130baa03..e0bb9794d2 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -61,77 +61,75 @@ type LBPolicy interface { Close() } +const ( + RoundRobin = "round_robin" + LookAside = "look_aside" +) + type LBPolicyImpl struct { - balancer LBBalancer - clientMgr shardClientMgr + getBalancer func() LBBalancer + clientMgr shardClientMgr + balancerMap map[string]LBBalancer } func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl { - balancePolicy := params.Params.ProxyCfg.ReplicaSelectionPolicy.GetValue() + balancerMap := make(map[string]LBBalancer) + balancerMap[LookAside] = NewLookAsideBalancer(clientMgr) + balancerMap[RoundRobin] = NewRoundRobinBalancer() - var balancer LBBalancer - switch balancePolicy { - case "round_robin": - log.Info("use round_robin policy on replica selection") - balancer = NewRoundRobinBalancer() - default: - log.Info("use look_aside policy on replica selection") - balancer = NewLookAsideBalancer(clientMgr) + getBalancer := func() LBBalancer { + balancePolicy := params.Params.ProxyCfg.ReplicaSelectionPolicy.GetValue() + if _, ok := balancerMap[balancePolicy]; !ok { + return balancerMap[LookAside] + } + return balancerMap[balancePolicy] } return &LBPolicyImpl{ - balancer: balancer, - clientMgr: clientMgr, + getBalancer: getBalancer, + clientMgr: clientMgr, + balancerMap: balancerMap, } } func (lb *LBPolicyImpl) Start(ctx context.Context) { - lb.balancer.Start(ctx) + for _, lb := range lb.balancerMap { + lb.Start(ctx) + } } // try to select the best node from the available nodes -func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) { - log := log.Ctx(ctx).With( - zap.Int64("collectionID", workload.collectionID), - zap.String("collectionName", workload.collectionName), - zap.String("channelName", workload.channel), - ) - - filterAvailableNodes := func(node int64, _ int) bool { - return !excludeNodes.Contain(node) - } - - getShardLeaders := func() ([]int64, error) { +func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) { + availableNodes := lo.FilterMap(workload.shardLeaders, func(node int64, _ int) (int64, bool) { return node, !excludeNodes.Contain(node) }) + targetNode, err := balancer.SelectNode(ctx, availableNodes, workload.nq) + if err != nil { + log := log.Ctx(ctx) + globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName) shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID) if err != nil { - return nil, err - } - - return lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID }), nil - } - - availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes) - targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq) - if err != nil { - globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName) - nodes, err := getShardLeaders() - if err != nil || len(nodes) == 0 { log.Warn("failed to get shard delegator", + zap.Int64("collectionID", workload.collectionID), + zap.String("channelName", workload.channel), zap.Error(err)) return -1, err } - availableNodes := lo.Filter(nodes, filterAvailableNodes) + availableNodes := lo.FilterMap(shardLeaders[workload.channel], func(node nodeInfo, _ int) (int64, bool) { return node.nodeID, !excludeNodes.Contain(node.nodeID) }) if len(availableNodes) == 0 { + nodes := lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID }) log.Warn("no available shard delegator found", + zap.Int64("collectionID", workload.collectionID), + zap.String("channelName", workload.channel), zap.Int64s("nodes", nodes), zap.Int64s("excluded", excludeNodes.Collect())) return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found") } - targetNode, err = lb.balancer.SelectNode(ctx, availableNodes, workload.nq) + targetNode, err = balancer.SelectNode(ctx, availableNodes, workload.nq) if err != nil { log.Warn("failed to select shard", + zap.Int64("collectionID", workload.collectionID), + zap.String("channelName", workload.channel), zap.Int64s("availableNodes", availableNodes), zap.Error(err)) return -1, err @@ -144,17 +142,15 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload // ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes. func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error { excludeNodes := typeutil.NewUniqueSet() - log := log.Ctx(ctx).With( - zap.Int64("collectionID", workload.collectionID), - zap.String("collectionName", workload.collectionName), - zap.String("channelName", workload.channel), - ) var lastErr error err := retry.Do(ctx, func() error { - targetNode, err := lb.selectNode(ctx, workload, excludeNodes) + balancer := lb.getBalancer() + targetNode, err := lb.selectNode(ctx, balancer, workload, excludeNodes) if err != nil { log.Warn("failed to select node for shard", + zap.Int64("collectionID", workload.collectionID), + zap.String("channelName", workload.channel), zap.Int64("nodeID", targetNode), zap.Error(err), ) @@ -163,16 +159,18 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo } return err } + // cancel work load which assign to the target node + defer balancer.CancelWorkload(targetNode, workload.nq) client, err := lb.clientMgr.GetClient(ctx, targetNode) if err != nil { log.Warn("search/query channel failed, node not available", + zap.Int64("collectionID", workload.collectionID), + zap.String("channelName", workload.channel), zap.Int64("nodeID", targetNode), zap.Error(err)) excludeNodes.Insert(targetNode) - // cancel work load which assign to the target node - lb.balancer.CancelWorkload(targetNode, workload.nq) lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel) return lastErr } @@ -180,16 +178,15 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo err = workload.exec(ctx, targetNode, client, workload.channel) if err != nil { log.Warn("search/query channel failed", + zap.Int64("collectionID", workload.collectionID), + zap.String("channelName", workload.channel), zap.Int64("nodeID", targetNode), zap.Error(err)) excludeNodes.Insert(targetNode) - lb.balancer.CancelWorkload(targetNode, workload.nq) - lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode, workload.channel) return lastErr } - lb.balancer.CancelWorkload(targetNode, workload.nq) return nil }, retry.Attempts(workload.retryTimes)) @@ -232,9 +229,11 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad } func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { - lb.balancer.UpdateCostMetrics(node, cost) + lb.getBalancer().UpdateCostMetrics(node, cost) } func (lb *LBPolicyImpl) Close() { - lb.balancer.Close() + for _, lb := range lb.balancerMap { + lb.Close() + } } diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index a83e1a5c06..0f0f8e4688 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -101,7 +101,9 @@ func (s *LBPolicySuite) SetupTest() { s.lbBalancer.EXPECT().Start(context.Background()).Maybe() s.lbPolicy = NewLBPolicyImpl(s.mgr) s.lbPolicy.Start(context.Background()) - s.lbPolicy.balancer = s.lbBalancer + s.lbPolicy.getBalancer = func() LBBalancer { + return s.lbBalancer + } err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr) s.NoError(err) @@ -163,7 +165,7 @@ func (s *LBPolicySuite) loadCollection() { func (s *LBPolicySuite) TestSelectNode() { ctx := context.Background() s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil) - targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{ + targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, @@ -178,7 +180,7 @@ func (s *LBPolicySuite) TestSelectNode() { s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil) - targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, @@ -192,7 +194,7 @@ func (s *LBPolicySuite) TestSelectNode() { // test select node always fails, expected failure s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) - targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, @@ -206,7 +208,7 @@ func (s *LBPolicySuite) TestSelectNode() { // test all nodes has been excluded, expected failure s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) - targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, @@ -222,7 +224,7 @@ func (s *LBPolicySuite) TestSelectNode() { s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) s.qc.ExpectedCalls = nil s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable) - targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, @@ -419,17 +421,17 @@ func (s *LBPolicySuite) TestUpdateCostMetrics() { func (s *LBPolicySuite) TestNewLBPolicy() { policy := NewLBPolicyImpl(s.mgr) - s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.LookAsideBalancer") + s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer") policy.Close() Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "round_robin") policy = NewLBPolicyImpl(s.mgr) - s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.RoundRobinBalancer") + s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.RoundRobinBalancer") policy.Close() Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "look_aside") policy = NewLBPolicyImpl(s.mgr) - s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.LookAsideBalancer") + s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer") policy.Close() } diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 6f6e7e61fb..46461efa97 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -949,11 +949,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col zap.String("collectionName", collectionName), zap.Int64("collectionID", collectionID)) - info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID) - if err != nil { - return nil, err - } - cacheShardLeaders, ok := m.getCollectionShardLeader(database, collectionName) if withCache { if ok { @@ -965,6 +960,12 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() log.Info("no shard cache for collection, try to get shard leaders from QueryCoord") } + + info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID) + if err != nil { + return nil, err + } + req := &querypb.GetShardLeadersRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders), diff --git a/internal/proxy/roundrobin_balancer.go b/internal/proxy/roundrobin_balancer.go index bd54f0f82a..983514d6ac 100644 --- a/internal/proxy/roundrobin_balancer.go +++ b/internal/proxy/roundrobin_balancer.go @@ -22,18 +22,14 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) type RoundRobinBalancer struct { - // request num send to each node - nodeWorkload *typeutil.ConcurrentMap[int64, *atomic.Int64] + idx atomic.Int64 } func NewRoundRobinBalancer() *RoundRobinBalancer { - return &RoundRobinBalancer{ - nodeWorkload: typeutil.NewConcurrentMap[int64, *atomic.Int64](), - } + return &RoundRobinBalancer{} } func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) { @@ -41,32 +37,11 @@ func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []in return -1, merr.ErrNodeNotAvailable } - targetNode := int64(-1) - var targetNodeWorkload *atomic.Int64 - for _, node := range availableNodes { - workload, ok := b.nodeWorkload.Get(node) - - if !ok { - workload = atomic.NewInt64(0) - b.nodeWorkload.Insert(node, workload) - } - - if targetNodeWorkload == nil || workload.Load() < targetNodeWorkload.Load() { - targetNode = node - targetNodeWorkload = workload - } - } - - targetNodeWorkload.Add(cost) - return targetNode, nil + idx := b.idx.Inc() + return availableNodes[int(idx)%len(availableNodes)], nil } func (b *RoundRobinBalancer) CancelWorkload(node int64, nq int64) { - load, ok := b.nodeWorkload.Get(node) - - if ok { - load.Sub(nq) - } } func (b *RoundRobinBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {} diff --git a/internal/proxy/roundrobin_balancer_test.go b/internal/proxy/roundrobin_balancer_test.go index 8840099bc2..f30cbf3e36 100644 --- a/internal/proxy/roundrobin_balancer_test.go +++ b/internal/proxy/roundrobin_balancer_test.go @@ -20,6 +20,8 @@ import ( "testing" "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/pkg/util/merr" ) type RoundRobinBalancerSuite struct { @@ -33,48 +35,34 @@ func (s *RoundRobinBalancerSuite) SetupTest() { s.balancer.Start(context.Background()) } -func (s *RoundRobinBalancerSuite) TestRoundRobin() { - availableNodes := []int64{1, 2} - s.balancer.SelectNode(context.TODO(), availableNodes, 1) - s.balancer.SelectNode(context.TODO(), availableNodes, 1) - s.balancer.SelectNode(context.TODO(), availableNodes, 1) - s.balancer.SelectNode(context.TODO(), availableNodes, 1) +func TestSelectNode(t *testing.T) { + balancer := NewRoundRobinBalancer() - workload, ok := s.balancer.nodeWorkload.Get(1) - s.True(ok) - s.Equal(int64(2), workload.Load()) - workload, ok = s.balancer.nodeWorkload.Get(1) - s.True(ok) - s.Equal(int64(2), workload.Load()) + // Test case 1: Empty availableNodes + _, err1 := balancer.SelectNode(context.Background(), []int64{}, 0) + if err1 != merr.ErrNodeNotAvailable { + t.Errorf("Expected ErrNodeNotAvailable, got %v", err1) + } - s.balancer.SelectNode(context.TODO(), availableNodes, 3) - s.balancer.SelectNode(context.TODO(), availableNodes, 1) - s.balancer.SelectNode(context.TODO(), availableNodes, 1) - s.balancer.SelectNode(context.TODO(), availableNodes, 1) + // Test case 2: Non-empty availableNodes + availableNodes := []int64{1, 2, 3} + selectedNode2, err2 := balancer.SelectNode(context.Background(), availableNodes, 0) + if err2 != nil { + t.Errorf("Expected no error, got %v", err2) + } + if selectedNode2 < 1 || selectedNode2 > 3 { + t.Errorf("Expected a node in the range [1, 3], got %d", selectedNode2) + } - workload, ok = s.balancer.nodeWorkload.Get(1) - s.True(ok) - s.Equal(int64(5), workload.Load()) - workload, ok = s.balancer.nodeWorkload.Get(1) - s.True(ok) - s.Equal(int64(5), workload.Load()) -} - -func (s *RoundRobinBalancerSuite) TestNoAvailableNode() { - availableNodes := []int64{} - _, err := s.balancer.SelectNode(context.TODO(), availableNodes, 1) - s.Error(err) -} - -func (s *RoundRobinBalancerSuite) TestCancelWorkload() { - availableNodes := []int64{101} - _, err := s.balancer.SelectNode(context.TODO(), availableNodes, 5) - s.NoError(err) - workload, ok := s.balancer.nodeWorkload.Get(101) - s.True(ok) - s.Equal(int64(5), workload.Load()) - s.balancer.CancelWorkload(101, 5) - s.Equal(int64(0), workload.Load()) + // Test case 3: Boundary case + availableNodes = []int64{1} + selectedNode3, err3 := balancer.SelectNode(context.Background(), availableNodes, 0) + if err3 != nil { + t.Errorf("Expected no error, got %v", err3) + } + if selectedNode3 != 1 { + t.Errorf("Expected 1, got %d", selectedNode3) + } } func TestRoundRobinBalancerSuite(t *testing.T) {