mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
enable look aside balancer on replica selection (#24791)
Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
a413842e38
commit
46f7d903a3
2
Makefile
2
Makefile
@ -362,6 +362,6 @@ generate-mockery: getdeps
|
||||
$(PWD)/bin/mockery --dir=internal/datacoord --name=compactionPlanContext --filename=mock_compaction_plan_context.go --output=internal/datacoord --structname=MockCompactionPlanContext --with-expecter --inpackage
|
||||
$(PWD)/bin/mockery --dir=internal/datacoord --name=Handler --filename=mock_handler.go --output=internal/datacoord --structname=NMockHandler --with-expecter --inpackage
|
||||
#internal/proxy
|
||||
$(PWD)/bin/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy
|
||||
$(PWD)/bin/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy --inpackage
|
||||
$(PWD)/bin/mockery --name=LBBalancer --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_balancer.go --structname=MockLBBalancer --with-expecter --outpkg=proxy --inpackage
|
||||
$(PWD)/bin/mockery --name=shardClientMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_shardclient_manager.go --structname=MockShardClientManager --with-expecter --outpkg=proxy --inpackage
|
@ -16,6 +16,11 @@
|
||||
|
||||
package proxy
|
||||
|
||||
import "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
|
||||
type LBBalancer interface {
|
||||
SelectNode(availableNodes []int64, nq int64) (int64, error)
|
||||
CancelWorkload(node int64, nq int64)
|
||||
UpdateCostMetrics(node int64, cost *internalpb.CostAggregation)
|
||||
Close()
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
@ -48,6 +49,8 @@ type CollectionWorkLoad struct {
|
||||
type LBPolicy interface {
|
||||
Execute(ctx context.Context, workload CollectionWorkLoad) error
|
||||
ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error
|
||||
UpdateCostMetrics(node int64, cost *internalpb.CostAggregation)
|
||||
Close()
|
||||
}
|
||||
|
||||
type LBPolicyImpl struct {
|
||||
@ -55,7 +58,8 @@ type LBPolicyImpl struct {
|
||||
clientMgr shardClientMgr
|
||||
}
|
||||
|
||||
func NewLBPolicyImpl(balancer LBBalancer, clientMgr shardClientMgr) *LBPolicyImpl {
|
||||
func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
|
||||
balancer := NewLookAsideBalancer(clientMgr)
|
||||
return &LBPolicyImpl{
|
||||
balancer: balancer,
|
||||
clientMgr: clientMgr,
|
||||
@ -135,6 +139,9 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
zap.Int64("nodeID", targetNode),
|
||||
zap.Error(err))
|
||||
excludeNodes.Insert(targetNode)
|
||||
|
||||
// cancel work load which assign to the target node
|
||||
lb.balancer.CancelWorkload(targetNode, workload.nq)
|
||||
return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error())
|
||||
}
|
||||
|
||||
@ -144,8 +151,11 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
zap.Int64("nodeID", targetNode),
|
||||
zap.Error(err))
|
||||
excludeNodes.Insert(targetNode)
|
||||
lb.balancer.CancelWorkload(targetNode, workload.nq)
|
||||
return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error())
|
||||
}
|
||||
|
||||
lb.balancer.CancelWorkload(targetNode, workload.nq)
|
||||
return nil
|
||||
}, retry.Attempts(workload.retryTimes))
|
||||
|
||||
@ -179,3 +189,11 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
|
||||
err = wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
|
||||
lb.balancer.UpdateCostMetrics(node, cost)
|
||||
}
|
||||
|
||||
func (lb *LBPolicyImpl) Close() {
|
||||
lb.balancer.Close()
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
@ -91,11 +92,13 @@ func (s *LBPolicySuite) SetupTest() {
|
||||
|
||||
s.qn = types.NewMockQueryNode(s.T())
|
||||
s.qn.EXPECT().GetAddress().Return("localhost").Maybe()
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
|
||||
s.mgr = NewMockShardClientManager(s.T())
|
||||
s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
s.lbBalancer = NewMockLBBalancer(s.T())
|
||||
s.lbPolicy = NewLBPolicyImpl(s.lbBalancer, s.mgr)
|
||||
s.lbPolicy = NewLBPolicyImpl(s.mgr)
|
||||
s.lbPolicy.balancer = s.lbBalancer
|
||||
|
||||
err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr)
|
||||
s.NoError(err)
|
||||
@ -223,6 +226,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
@ -255,6 +259,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
@ -270,6 +275,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
@ -287,6 +293,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
counter := 0
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
collection: s.collection,
|
||||
@ -310,6 +317,7 @@ func (s *LBPolicySuite) TestExecute() {
|
||||
// test all channel success
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
collection: s.collection,
|
||||
nq: 1,
|
||||
@ -348,6 +356,11 @@ func (s *LBPolicySuite) TestExecute() {
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestUpdateCostMetrics() {
|
||||
s.lbBalancer.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything)
|
||||
s.lbPolicy.UpdateCostMetrics(1, &internalpb.CostAggregation{})
|
||||
}
|
||||
|
||||
func TestLBPolicySuite(t *testing.T) {
|
||||
suite.Run(t, new(LBPolicySuite))
|
||||
}
|
||||
|
185
internal/proxy/look_aside_balancer.go
Normal file
185
internal/proxy/look_aside_balancer.go
Normal file
@ -0,0 +1,185 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
checkQueryNodeHealthInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
type LookAsideBalancer struct {
|
||||
clientMgr shardClientMgr
|
||||
|
||||
// query node -> workload latest metrics
|
||||
metricsMap *typeutil.ConcurrentMap[int64, *internalpb.CostAggregation]
|
||||
|
||||
// query node -> last update metrics ts
|
||||
metricsUpdateTs *typeutil.ConcurrentMap[int64, int64]
|
||||
|
||||
// query node -> total nq of requests which already send but response hasn't received
|
||||
executingTaskTotalNQ *typeutil.ConcurrentMap[int64, *atomic.Int64]
|
||||
|
||||
unreachableQueryNodes *typeutil.ConcurrentSet[int64]
|
||||
|
||||
closeCh chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer {
|
||||
balancer := &LookAsideBalancer{
|
||||
clientMgr: clientMgr,
|
||||
metricsMap: typeutil.NewConcurrentMap[int64, *internalpb.CostAggregation](),
|
||||
metricsUpdateTs: typeutil.NewConcurrentMap[int64, int64](),
|
||||
executingTaskTotalNQ: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
|
||||
unreachableQueryNodes: typeutil.NewConcurrentSet[int64](),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
balancer.wg.Add(1)
|
||||
go balancer.checkQueryNodeHealthLoop()
|
||||
return balancer
|
||||
}
|
||||
|
||||
func (b *LookAsideBalancer) Close() {
|
||||
b.closeOnce.Do(func() {
|
||||
close(b.closeCh)
|
||||
b.wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func (b *LookAsideBalancer) SelectNode(availableNodes []int64, cost int64) (int64, error) {
|
||||
targetNode := int64(-1)
|
||||
targetScore := float64(math.MaxFloat64)
|
||||
for _, node := range availableNodes {
|
||||
if b.unreachableQueryNodes.Contain(node) {
|
||||
continue
|
||||
}
|
||||
|
||||
cost, _ := b.metricsMap.Get(node)
|
||||
executingNQ, ok := b.executingTaskTotalNQ.Get(node)
|
||||
if !ok {
|
||||
executingNQ = atomic.NewInt64(0)
|
||||
b.executingTaskTotalNQ.Insert(node, executingNQ)
|
||||
}
|
||||
|
||||
score := b.calculateScore(cost, executingNQ.Load())
|
||||
if targetNode == -1 || score < targetScore {
|
||||
targetScore = score
|
||||
targetNode = node
|
||||
}
|
||||
}
|
||||
|
||||
// update executing task cost
|
||||
totalNQ, ok := b.executingTaskTotalNQ.Get(targetNode)
|
||||
if !ok {
|
||||
totalNQ = atomic.NewInt64(0)
|
||||
}
|
||||
totalNQ.Add(cost)
|
||||
|
||||
return targetNode, nil
|
||||
}
|
||||
|
||||
// when task canceled, should reduce executing total nq cost
|
||||
func (b *LookAsideBalancer) CancelWorkload(node int64, nq int64) {
|
||||
totalNQ, ok := b.executingTaskTotalNQ.Get(node)
|
||||
if ok {
|
||||
totalNQ.Sub(nq)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateCostMetrics used for cache some metrics of recent search/query cost
|
||||
func (b *LookAsideBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
|
||||
// cache the latest query node cost metrics for updating the score
|
||||
b.metricsMap.Insert(node, cost)
|
||||
b.metricsUpdateTs.Insert(node, time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
// calculateScore compute the query node's workload score
|
||||
// https://www.usenix.org/conference/nsdi15/technical-sessions/presentation/suresh
|
||||
func (b *LookAsideBalancer) calculateScore(cost *internalpb.CostAggregation, executingNQ int64) float64 {
|
||||
if cost == nil || cost.ResponseTime == 0 {
|
||||
return float64(executingNQ)
|
||||
}
|
||||
return float64(cost.ResponseTime) - float64(1)/float64(cost.ServiceTime) + math.Pow(float64(1+cost.TotalNQ+executingNQ), 3.0)/float64(cost.ServiceTime)
|
||||
}
|
||||
|
||||
func (b *LookAsideBalancer) checkQueryNodeHealthLoop() {
|
||||
defer b.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(checkQueryNodeHealthInterval)
|
||||
defer ticker.Stop()
|
||||
log.Info("Start check query node health loop")
|
||||
for {
|
||||
select {
|
||||
case <-b.closeCh:
|
||||
log.Info("check query node health loop exit")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
now := time.Now().UnixMilli()
|
||||
b.metricsUpdateTs.Range(func(node int64, lastUpdateTs int64) bool {
|
||||
if now-lastUpdateTs > checkQueryNodeHealthInterval.Milliseconds() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
checkHealthFailed := func(err error) bool {
|
||||
log.Warn("query node check health failed, add it to unreachable nodes list",
|
||||
zap.Int64("nodeID", node),
|
||||
zap.Error(err))
|
||||
b.unreachableQueryNodes.Insert(node)
|
||||
return true
|
||||
}
|
||||
|
||||
qn, err := b.clientMgr.GetClient(ctx, node)
|
||||
if err != nil {
|
||||
return checkHealthFailed(err)
|
||||
}
|
||||
|
||||
resp, err := qn.GetComponentStates(ctx)
|
||||
if err != nil {
|
||||
return checkHealthFailed(err)
|
||||
}
|
||||
|
||||
if resp.GetState().GetStateCode() != commonpb.StateCode_Healthy {
|
||||
return checkHealthFailed(merr.WrapErrNodeOffline(node))
|
||||
}
|
||||
|
||||
// check health successfully, update check health ts
|
||||
b.metricsUpdateTs.Insert(node, time.Now().Local().UnixMilli())
|
||||
b.unreachableQueryNodes.Remove(node)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
292
internal/proxy/look_aside_balancer_test.go
Normal file
292
internal/proxy/look_aside_balancer_test.go
Normal file
@ -0,0 +1,292 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type LookAsideBalancerSuite struct {
|
||||
suite.Suite
|
||||
|
||||
clientMgr *MockShardClientManager
|
||||
balancer *LookAsideBalancer
|
||||
}
|
||||
|
||||
func (suite *LookAsideBalancerSuite) SetupTest() {
|
||||
suite.clientMgr = NewMockShardClientManager(suite.T())
|
||||
suite.balancer = NewLookAsideBalancer(suite.clientMgr)
|
||||
|
||||
qn := types.NewMockQueryNode(suite.T())
|
||||
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(1)).Return(qn, nil).Maybe()
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, errors.New("fake error")).Maybe()
|
||||
}
|
||||
|
||||
func (suite *LookAsideBalancerSuite) TearDownTest() {
|
||||
suite.balancer.Close()
|
||||
}
|
||||
|
||||
func (suite *LookAsideBalancerSuite) TestUpdateMetrics() {
|
||||
costMetrics := &internalpb.CostAggregation{
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 1,
|
||||
}
|
||||
|
||||
suite.balancer.UpdateCostMetrics(1, costMetrics)
|
||||
|
||||
lastUpdateTs, ok := suite.balancer.metricsUpdateTs.Get(1)
|
||||
suite.True(ok)
|
||||
suite.True(time.Now().UnixMilli()-lastUpdateTs <= 5)
|
||||
}
|
||||
|
||||
func (suite *LookAsideBalancerSuite) TestCalculateScore() {
|
||||
costMetrics1 := &internalpb.CostAggregation{
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 1,
|
||||
}
|
||||
|
||||
costMetrics2 := &internalpb.CostAggregation{
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 2,
|
||||
TotalNQ: 1,
|
||||
}
|
||||
|
||||
costMetrics3 := &internalpb.CostAggregation{
|
||||
ResponseTime: 10,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 1,
|
||||
}
|
||||
|
||||
costMetrics4 := &internalpb.CostAggregation{
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
}
|
||||
|
||||
score1 := suite.balancer.calculateScore(costMetrics1, 0)
|
||||
score2 := suite.balancer.calculateScore(costMetrics2, 0)
|
||||
score3 := suite.balancer.calculateScore(costMetrics3, 0)
|
||||
score4 := suite.balancer.calculateScore(costMetrics4, 0)
|
||||
suite.Equal(float64(12), score1)
|
||||
suite.Equal(float64(8.5), score2)
|
||||
suite.Equal(float64(17), score3)
|
||||
suite.Equal(float64(5), score4)
|
||||
|
||||
score5 := suite.balancer.calculateScore(costMetrics1, 5)
|
||||
score6 := suite.balancer.calculateScore(costMetrics2, 5)
|
||||
score7 := suite.balancer.calculateScore(costMetrics3, 5)
|
||||
score8 := suite.balancer.calculateScore(costMetrics4, 5)
|
||||
suite.Equal(float64(347), score5)
|
||||
suite.Equal(float64(176), score6)
|
||||
suite.Equal(float64(352), score7)
|
||||
suite.Equal(float64(220), score8)
|
||||
}
|
||||
|
||||
func (suite *LookAsideBalancerSuite) TestSelectNode() {
|
||||
type testcase struct {
|
||||
name string
|
||||
costMetrics map[int64]*internalpb.CostAggregation
|
||||
executingNQ map[int64]int64
|
||||
requestCount int
|
||||
result map[int64]int64
|
||||
}
|
||||
|
||||
cases := []testcase{
|
||||
{
|
||||
name: "each qn has same cost metrics",
|
||||
costMetrics: map[int64]*internalpb.CostAggregation{
|
||||
1: {
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
2: {
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
|
||||
3: {
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
},
|
||||
|
||||
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0},
|
||||
requestCount: 100,
|
||||
result: map[int64]int64{1: 34, 2: 33, 3: 33},
|
||||
},
|
||||
{
|
||||
name: "each qn has different service time",
|
||||
costMetrics: map[int64]*internalpb.CostAggregation{
|
||||
1: {
|
||||
ResponseTime: 30,
|
||||
ServiceTime: 20,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
2: {
|
||||
ResponseTime: 50,
|
||||
ServiceTime: 40,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
|
||||
3: {
|
||||
ResponseTime: 70,
|
||||
ServiceTime: 60,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
},
|
||||
|
||||
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0},
|
||||
requestCount: 100,
|
||||
result: map[int64]int64{1: 27, 2: 34, 3: 39},
|
||||
},
|
||||
{
|
||||
name: "one qn has task in queue",
|
||||
costMetrics: map[int64]*internalpb.CostAggregation{
|
||||
1: {
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
2: {
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
|
||||
3: {
|
||||
ResponseTime: 100,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 20,
|
||||
},
|
||||
},
|
||||
|
||||
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0},
|
||||
requestCount: 100,
|
||||
result: map[int64]int64{1: 40, 2: 40, 3: 20},
|
||||
},
|
||||
|
||||
{
|
||||
name: "qn with executing task",
|
||||
costMetrics: map[int64]*internalpb.CostAggregation{
|
||||
1: {
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
2: {
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
|
||||
3: {
|
||||
ResponseTime: 5,
|
||||
ServiceTime: 1,
|
||||
TotalNQ: 0,
|
||||
},
|
||||
},
|
||||
|
||||
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 20},
|
||||
requestCount: 100,
|
||||
result: map[int64]int64{1: 40, 2: 40, 3: 20},
|
||||
},
|
||||
{
|
||||
name: "qn with empty metrics",
|
||||
costMetrics: map[int64]*internalpb.CostAggregation{
|
||||
1: {},
|
||||
2: {},
|
||||
3: {},
|
||||
},
|
||||
|
||||
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0},
|
||||
requestCount: 100,
|
||||
result: map[int64]int64{1: 34, 2: 33, 3: 33},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
suite.Run(c.name, func() {
|
||||
for node, cost := range c.costMetrics {
|
||||
suite.balancer.UpdateCostMetrics(node, cost)
|
||||
}
|
||||
|
||||
for node, executingNQ := range c.executingNQ {
|
||||
suite.balancer.executingTaskTotalNQ.Insert(node, atomic.NewInt64(executingNQ))
|
||||
}
|
||||
|
||||
counter := make(map[int64]int64)
|
||||
for i := 0; i < c.requestCount; i++ {
|
||||
node, err := suite.balancer.SelectNode([]int64{1, 2, 3}, 1)
|
||||
suite.NoError(err)
|
||||
counter[node]++
|
||||
}
|
||||
|
||||
for node, result := range c.result {
|
||||
suite.Equal(result, counter[node])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *LookAsideBalancerSuite) TestCancelWorkload() {
|
||||
node, err := suite.balancer.SelectNode([]int64{1, 2, 3}, 10)
|
||||
suite.NoError(err)
|
||||
suite.balancer.CancelWorkload(node, 10)
|
||||
|
||||
executingNQ, ok := suite.balancer.executingTaskTotalNQ.Get(node)
|
||||
suite.True(ok)
|
||||
suite.Equal(int64(0), executingNQ.Load())
|
||||
}
|
||||
|
||||
func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
|
||||
qn2 := types.NewMockQueryNode(suite.T())
|
||||
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil)
|
||||
qn2.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
|
||||
State: &milvuspb.ComponentInfo{
|
||||
StateCode: commonpb.StateCode_Healthy,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
suite.balancer.metricsUpdateTs.Insert(1, time.Now().UnixMilli())
|
||||
suite.balancer.metricsUpdateTs.Insert(2, time.Now().UnixMilli())
|
||||
suite.Eventually(func() bool {
|
||||
return suite.balancer.unreachableQueryNodes.Contain(1)
|
||||
}, 2*time.Second, 100*time.Millisecond)
|
||||
|
||||
suite.Eventually(func() bool {
|
||||
return !suite.balancer.unreachableQueryNodes.Contain(2)
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestLookAsideBalancerSuite(t *testing.T) {
|
||||
suite.Run(t, new(LookAsideBalancerSuite))
|
||||
}
|
@ -2,7 +2,10 @@
|
||||
|
||||
package proxy
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
import (
|
||||
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockLBBalancer is an autogenerated mock type for the LBBalancer type
|
||||
type MockLBBalancer struct {
|
||||
@ -17,6 +20,72 @@ func (_m *MockLBBalancer) EXPECT() *MockLBBalancer_Expecter {
|
||||
return &MockLBBalancer_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// CancelWorkload provides a mock function with given fields: node, nq
|
||||
func (_m *MockLBBalancer) CancelWorkload(node int64, nq int64) {
|
||||
_m.Called(node, nq)
|
||||
}
|
||||
|
||||
// MockLBBalancer_CancelWorkload_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CancelWorkload'
|
||||
type MockLBBalancer_CancelWorkload_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CancelWorkload is a helper method to define mock.On call
|
||||
// - node int64
|
||||
// - nq int64
|
||||
func (_e *MockLBBalancer_Expecter) CancelWorkload(node interface{}, nq interface{}) *MockLBBalancer_CancelWorkload_Call {
|
||||
return &MockLBBalancer_CancelWorkload_Call{Call: _e.mock.On("CancelWorkload", node, nq)}
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_CancelWorkload_Call) Run(run func(node int64, nq int64)) *MockLBBalancer_CancelWorkload_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64), args[1].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_CancelWorkload_Call) Return() *MockLBBalancer_CancelWorkload_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_CancelWorkload_Call) RunAndReturn(run func(int64, int64)) *MockLBBalancer_CancelWorkload_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *MockLBBalancer) Close() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockLBBalancer_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
|
||||
type MockLBBalancer_Close_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Close is a helper method to define mock.On call
|
||||
func (_e *MockLBBalancer_Expecter) Close() *MockLBBalancer_Close_Call {
|
||||
return &MockLBBalancer_Close_Call{Call: _e.mock.On("Close")}
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_Close_Call) Run(run func()) *MockLBBalancer_Close_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_Close_Call) Return() *MockLBBalancer_Close_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Close_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SelectNode provides a mock function with given fields: availableNodes, nq
|
||||
func (_m *MockLBBalancer) SelectNode(availableNodes []int64, nq int64) (int64, error) {
|
||||
ret := _m.Called(availableNodes, nq)
|
||||
@ -70,6 +139,40 @@ func (_c *MockLBBalancer_SelectNode_Call) RunAndReturn(run func([]int64, int64)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateCostMetrics provides a mock function with given fields: node, cost
|
||||
func (_m *MockLBBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
|
||||
_m.Called(node, cost)
|
||||
}
|
||||
|
||||
// MockLBBalancer_UpdateCostMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCostMetrics'
|
||||
type MockLBBalancer_UpdateCostMetrics_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateCostMetrics is a helper method to define mock.On call
|
||||
// - node int64
|
||||
// - cost *internalpb.CostAggregation
|
||||
func (_e *MockLBBalancer_Expecter) UpdateCostMetrics(node interface{}, cost interface{}) *MockLBBalancer_UpdateCostMetrics_Call {
|
||||
return &MockLBBalancer_UpdateCostMetrics_Call{Call: _e.mock.On("UpdateCostMetrics", node, cost)}
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_UpdateCostMetrics_Call) Run(run func(node int64, cost *internalpb.CostAggregation)) *MockLBBalancer_UpdateCostMetrics_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64), args[1].(*internalpb.CostAggregation))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_UpdateCostMetrics_Call) Return() *MockLBBalancer_UpdateCostMetrics_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_UpdateCostMetrics_Call) RunAndReturn(run func(int64, *internalpb.CostAggregation)) *MockLBBalancer_UpdateCostMetrics_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockLBBalancer interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
|
@ -5,6 +5,7 @@ package proxy
|
||||
import (
|
||||
context "context"
|
||||
|
||||
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
@ -64,13 +65,13 @@ func (_c *MockLBPolicy_Execute_Call) RunAndReturn(run func(context.Context, Coll
|
||||
return _c
|
||||
}
|
||||
|
||||
// ExecuteWithRetry provides a mock function with given fields: ctx, workload, retryTimes
|
||||
func (_m *MockLBPolicy) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload, retryTimes uint) error {
|
||||
ret := _m.Called(ctx, workload, retryTimes)
|
||||
// ExecuteWithRetry provides a mock function with given fields: ctx, workload
|
||||
func (_m *MockLBPolicy) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
|
||||
ret := _m.Called(ctx, workload)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, ChannelWorkload, uint) error); ok {
|
||||
r0 = rf(ctx, workload, retryTimes)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, ChannelWorkload) error); ok {
|
||||
r0 = rf(ctx, workload)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
@ -86,14 +87,13 @@ type MockLBPolicy_ExecuteWithRetry_Call struct {
|
||||
// ExecuteWithRetry is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - workload ChannelWorkload
|
||||
// - retryTimes uint
|
||||
func (_e *MockLBPolicy_Expecter) ExecuteWithRetry(ctx interface{}, workload interface{}, retryTimes interface{}) *MockLBPolicy_ExecuteWithRetry_Call {
|
||||
return &MockLBPolicy_ExecuteWithRetry_Call{Call: _e.mock.On("ExecuteWithRetry", ctx, workload, retryTimes)}
|
||||
func (_e *MockLBPolicy_Expecter) ExecuteWithRetry(ctx interface{}, workload interface{}) *MockLBPolicy_ExecuteWithRetry_Call {
|
||||
return &MockLBPolicy_ExecuteWithRetry_Call{Call: _e.mock.On("ExecuteWithRetry", ctx, workload)}
|
||||
}
|
||||
|
||||
func (_c *MockLBPolicy_ExecuteWithRetry_Call) Run(run func(ctx context.Context, workload ChannelWorkload, retryTimes uint)) *MockLBPolicy_ExecuteWithRetry_Call {
|
||||
func (_c *MockLBPolicy_ExecuteWithRetry_Call) Run(run func(ctx context.Context, workload ChannelWorkload)) *MockLBPolicy_ExecuteWithRetry_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(ChannelWorkload), args[2].(uint))
|
||||
run(args[0].(context.Context), args[1].(ChannelWorkload))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@ -103,7 +103,41 @@ func (_c *MockLBPolicy_ExecuteWithRetry_Call) Return(_a0 error) *MockLBPolicy_Ex
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBPolicy_ExecuteWithRetry_Call) RunAndReturn(run func(context.Context, ChannelWorkload, uint) error) *MockLBPolicy_ExecuteWithRetry_Call {
|
||||
func (_c *MockLBPolicy_ExecuteWithRetry_Call) RunAndReturn(run func(context.Context, ChannelWorkload) error) *MockLBPolicy_ExecuteWithRetry_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateCostMetrics provides a mock function with given fields: node, cost
|
||||
func (_m *MockLBPolicy) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
|
||||
_m.Called(node, cost)
|
||||
}
|
||||
|
||||
// MockLBPolicy_UpdateCostMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCostMetrics'
|
||||
type MockLBPolicy_UpdateCostMetrics_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateCostMetrics is a helper method to define mock.On call
|
||||
// - node int64
|
||||
// - cost *internalpb.CostAggregation
|
||||
func (_e *MockLBPolicy_Expecter) UpdateCostMetrics(node interface{}, cost interface{}) *MockLBPolicy_UpdateCostMetrics_Call {
|
||||
return &MockLBPolicy_UpdateCostMetrics_Call{Call: _e.mock.On("UpdateCostMetrics", node, cost)}
|
||||
}
|
||||
|
||||
func (_c *MockLBPolicy_UpdateCostMetrics_Call) Run(run func(node int64, cost *internalpb.CostAggregation)) *MockLBPolicy_UpdateCostMetrics_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64), args[1].(*internalpb.CostAggregation))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBPolicy_UpdateCostMetrics_Call) Return() *MockLBPolicy_UpdateCostMetrics_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBPolicy_UpdateCostMetrics_Call) RunAndReturn(run func(int64, *internalpb.CostAggregation)) *MockLBPolicy_UpdateCostMetrics_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
||||
searchResultCh: make(chan *internalpb.SearchResults, n),
|
||||
shardMgr: mgr,
|
||||
multiRateLimiter: NewMultiRateLimiter(),
|
||||
lbPolicy: NewLBPolicyImpl(NewRoundRobinBalancer(), mgr),
|
||||
lbPolicy: NewLBPolicyImpl(mgr),
|
||||
}
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
|
||||
@ -437,6 +437,10 @@ func (node *Proxy) Stop() error {
|
||||
node.chMgr.removeAllDMLStream()
|
||||
}
|
||||
|
||||
if node.lbPolicy != nil {
|
||||
node.lbPolicy.Close()
|
||||
}
|
||||
|
||||
// https://github.com/milvus-io/milvus/issues/12282
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
|
||||
|
@ -16,38 +16,56 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type RoundRobinBalancer struct {
|
||||
// request num send to each node
|
||||
mutex sync.RWMutex
|
||||
nodeWorkload map[int64]int64
|
||||
nodeWorkload *typeutil.ConcurrentMap[int64, *atomic.Int64]
|
||||
}
|
||||
|
||||
func NewRoundRobinBalancer() *RoundRobinBalancer {
|
||||
return &RoundRobinBalancer{
|
||||
nodeWorkload: make(map[int64]int64),
|
||||
nodeWorkload: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) SelectNode(availableNodes []int64, workload int64) (int64, error) {
|
||||
func (b *RoundRobinBalancer) SelectNode(availableNodes []int64, cost int64) (int64, error) {
|
||||
if len(availableNodes) == 0 {
|
||||
return -1, merr.ErrNoAvailableNode
|
||||
}
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
targetNode := int64(-1)
|
||||
targetNodeWorkload := int64(-1)
|
||||
var targetNodeWorkload *atomic.Int64
|
||||
for _, node := range availableNodes {
|
||||
if targetNodeWorkload == -1 || b.nodeWorkload[node] < targetNodeWorkload {
|
||||
workload, ok := b.nodeWorkload.Get(node)
|
||||
|
||||
if !ok {
|
||||
workload = atomic.NewInt64(0)
|
||||
b.nodeWorkload.Insert(node, workload)
|
||||
}
|
||||
|
||||
if targetNodeWorkload == nil || workload.Load() < targetNodeWorkload.Load() {
|
||||
targetNode = node
|
||||
targetNodeWorkload = b.nodeWorkload[node]
|
||||
targetNodeWorkload = workload
|
||||
}
|
||||
}
|
||||
|
||||
b.nodeWorkload[targetNode] += workload
|
||||
targetNodeWorkload.Add(cost)
|
||||
return targetNode, nil
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) CancelWorkload(node int64, nq int64) {
|
||||
load, ok := b.nodeWorkload.Get(node)
|
||||
|
||||
if ok {
|
||||
load.Sub(nq)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {}
|
||||
|
||||
func (b *RoundRobinBalancer) Close() {}
|
||||
|
@ -38,16 +38,24 @@ func (s *RoundRobinBalancerSuite) TestRoundRobin() {
|
||||
s.balancer.SelectNode(availableNodes, 1)
|
||||
s.balancer.SelectNode(availableNodes, 1)
|
||||
|
||||
s.Equal(int64(2), s.balancer.nodeWorkload[1])
|
||||
s.Equal(int64(2), s.balancer.nodeWorkload[2])
|
||||
workload, ok := s.balancer.nodeWorkload.Get(1)
|
||||
s.True(ok)
|
||||
s.Equal(int64(2), workload.Load())
|
||||
workload, ok = s.balancer.nodeWorkload.Get(1)
|
||||
s.True(ok)
|
||||
s.Equal(int64(2), workload.Load())
|
||||
|
||||
s.balancer.SelectNode(availableNodes, 3)
|
||||
s.balancer.SelectNode(availableNodes, 1)
|
||||
s.balancer.SelectNode(availableNodes, 1)
|
||||
s.balancer.SelectNode(availableNodes, 1)
|
||||
|
||||
s.Equal(int64(5), s.balancer.nodeWorkload[1])
|
||||
s.Equal(int64(5), s.balancer.nodeWorkload[2])
|
||||
workload, ok = s.balancer.nodeWorkload.Get(1)
|
||||
s.True(ok)
|
||||
s.Equal(int64(5), workload.Load())
|
||||
workload, ok = s.balancer.nodeWorkload.Get(1)
|
||||
s.True(ok)
|
||||
s.Equal(int64(5), workload.Load())
|
||||
}
|
||||
|
||||
func (s *RoundRobinBalancerSuite) TestNoAvailableNode() {
|
||||
@ -56,6 +64,17 @@ func (s *RoundRobinBalancerSuite) TestNoAvailableNode() {
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *RoundRobinBalancerSuite) TestCancelWorkload() {
|
||||
availableNodes := []int64{101}
|
||||
_, err := s.balancer.SelectNode(availableNodes, 5)
|
||||
s.NoError(err)
|
||||
workload, ok := s.balancer.nodeWorkload.Get(101)
|
||||
s.True(ok)
|
||||
s.Equal(int64(5), workload.Load())
|
||||
s.balancer.CancelWorkload(101, 5)
|
||||
s.Equal(int64(0), workload.Load())
|
||||
}
|
||||
|
||||
func TestRoundRobinBalancerSuite(t *testing.T) {
|
||||
suite.Run(t, new(RoundRobinBalancerSuite))
|
||||
}
|
||||
|
@ -491,6 +491,7 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
|
||||
|
||||
log.Debug("get query result")
|
||||
t.resultBuf.Insert(result)
|
||||
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -58,6 +58,8 @@ func TestQueryTask_all(t *testing.T) {
|
||||
hitNum = 10
|
||||
)
|
||||
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
|
||||
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
qc.EXPECT().Start().Return(nil)
|
||||
qc.EXPECT().Stop().Return(nil)
|
||||
@ -73,12 +75,10 @@ func TestQueryTask_all(t *testing.T) {
|
||||
},
|
||||
}, nil).Maybe()
|
||||
|
||||
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
|
||||
lb := NewLBPolicyImpl(NewRoundRobinBalancer(), mgr)
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
lb := NewLBPolicyImpl(mgr)
|
||||
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
@ -217,10 +217,12 @@ func TestQueryTask_all(t *testing.T) {
|
||||
task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField)
|
||||
task.ctx = ctx
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_NotShardLeader,
|
||||
@ -230,6 +232,7 @@ func TestQueryTask_all(t *testing.T) {
|
||||
assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error()))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
@ -238,6 +241,7 @@ func TestQueryTask_all(t *testing.T) {
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil)
|
||||
assert.NoError(t, task.Execute(ctx))
|
||||
|
||||
|
@ -521,6 +521,7 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
|
||||
return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
|
||||
}
|
||||
t.resultBuf.Insert(result)
|
||||
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1543,12 +1543,12 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||
)
|
||||
|
||||
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) {
|
||||
return qn, nil
|
||||
}
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
|
||||
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
|
||||
lb := NewLBPolicyImpl(NewRoundRobinBalancer(), mgr)
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
lb := NewLBPolicyImpl(mgr)
|
||||
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
@ -1661,6 +1661,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_NotShardLeader,
|
||||
@ -1670,6 +1671,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error()))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
@ -1678,6 +1680,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -77,11 +77,11 @@ func (s *StatisticTaskSuite) SetupTest() {
|
||||
s.rc.Start()
|
||||
s.qn = types.NewMockQueryNode(s.T())
|
||||
|
||||
mockCreator := func(ctx context.Context, addr string) (types.QueryNode, error) {
|
||||
return s.qn, nil
|
||||
}
|
||||
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
|
||||
s.lb = NewLBPolicyImpl(NewRoundRobinBalancer(), mgr)
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
|
||||
mgr := NewMockShardClientManager(s.T())
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe()
|
||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
s.lb = NewLBPolicyImpl(mgr)
|
||||
|
||||
err := InitMetaCache(context.Background(), s.rc, s.qc, mgr)
|
||||
s.NoError(err)
|
||||
|
@ -1017,6 +1017,12 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
|
||||
collector.Rate.Add(metricsinfo.NQPerSecond, 1)
|
||||
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
|
||||
}
|
||||
|
||||
if ret.CostAggregation != nil {
|
||||
// update channel's response time
|
||||
currentTotalNQ := node.scheduler.GetWaitingTaskTotalNQ()
|
||||
ret.CostAggregation.TotalNQ = currentTotalNQ
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user