diff --git a/Makefile b/Makefile index 43b22c65b9..890a86d884 100644 --- a/Makefile +++ b/Makefile @@ -362,6 +362,6 @@ generate-mockery: getdeps $(PWD)/bin/mockery --dir=internal/datacoord --name=compactionPlanContext --filename=mock_compaction_plan_context.go --output=internal/datacoord --structname=MockCompactionPlanContext --with-expecter --inpackage $(PWD)/bin/mockery --dir=internal/datacoord --name=Handler --filename=mock_handler.go --output=internal/datacoord --structname=NMockHandler --with-expecter --inpackage #internal/proxy - $(PWD)/bin/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy + $(PWD)/bin/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy --inpackage $(PWD)/bin/mockery --name=LBBalancer --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_balancer.go --structname=MockLBBalancer --with-expecter --outpkg=proxy --inpackage $(PWD)/bin/mockery --name=shardClientMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_shardclient_manager.go --structname=MockShardClientManager --with-expecter --outpkg=proxy --inpackage \ No newline at end of file diff --git a/internal/proxy/lb_balancer.go b/internal/proxy/lb_balancer.go index 15b397d187..8747df9bb9 100644 --- a/internal/proxy/lb_balancer.go +++ b/internal/proxy/lb_balancer.go @@ -16,6 +16,11 @@ package proxy +import "github.com/milvus-io/milvus/internal/proto/internalpb" + type LBBalancer interface { SelectNode(availableNodes []int64, nq int64) (int64, error) + CancelWorkload(node int64, nq int64) + UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) + Close() } diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index e2d092c199..d5e9fa2b79 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -18,6 +18,7 @@ package proxy import ( "context" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" @@ -48,6 +49,8 @@ type CollectionWorkLoad struct { type LBPolicy interface { Execute(ctx context.Context, workload CollectionWorkLoad) error ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error + UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) + Close() } type LBPolicyImpl struct { @@ -55,7 +58,8 @@ type LBPolicyImpl struct { clientMgr shardClientMgr } -func NewLBPolicyImpl(balancer LBBalancer, clientMgr shardClientMgr) *LBPolicyImpl { +func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl { + balancer := NewLookAsideBalancer(clientMgr) return &LBPolicyImpl{ balancer: balancer, clientMgr: clientMgr, @@ -135,6 +139,9 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo zap.Int64("nodeID", targetNode), zap.Error(err)) excludeNodes.Insert(targetNode) + + // cancel work load which assign to the target node + lb.balancer.CancelWorkload(targetNode, workload.nq) return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error()) } @@ -144,8 +151,11 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo zap.Int64("nodeID", targetNode), zap.Error(err)) excludeNodes.Insert(targetNode) + lb.balancer.CancelWorkload(targetNode, workload.nq) return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error()) } + + lb.balancer.CancelWorkload(targetNode, workload.nq) return nil }, retry.Attempts(workload.retryTimes)) @@ -179,3 +189,11 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad err = wg.Wait() return err } + +func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { + lb.balancer.UpdateCostMetrics(node, cost) +} + +func (lb *LBPolicyImpl) Close() { + lb.balancer.Close() +} diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index 34a051419b..982bcedaf9 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -24,6 +24,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" @@ -91,11 +92,13 @@ func (s *LBPolicySuite) SetupTest() { s.qn = types.NewMockQueryNode(s.T()) s.qn.EXPECT().GetAddress().Return("localhost").Maybe() + s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() s.mgr = NewMockShardClientManager(s.T()) s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() s.lbBalancer = NewMockLBBalancer(s.T()) - s.lbPolicy = NewLBPolicyImpl(s.lbBalancer, s.mgr) + s.lbPolicy = NewLBPolicyImpl(s.mgr) + s.lbPolicy.balancer = s.lbBalancer err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr) s.NoError(err) @@ -223,6 +226,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.lbBalancer.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) + s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ collection: s.collection, channel: s.channels[0], @@ -255,6 +259,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) + s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ collection: s.collection, channel: s.channels[0], @@ -270,6 +275,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) + s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ collection: s.collection, channel: s.channels[0], @@ -287,6 +293,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) + s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) counter := 0 err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ collection: s.collection, @@ -310,6 +317,7 @@ func (s *LBPolicySuite) TestExecute() { // test all channel success s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) + s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{ collection: s.collection, nq: 1, @@ -348,6 +356,11 @@ func (s *LBPolicySuite) TestExecute() { s.Error(err) } +func (s *LBPolicySuite) TestUpdateCostMetrics() { + s.lbBalancer.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything) + s.lbPolicy.UpdateCostMetrics(1, &internalpb.CostAggregation{}) +} + func TestLBPolicySuite(t *testing.T) { suite.Run(t, new(LBPolicySuite)) } diff --git a/internal/proxy/look_aside_balancer.go b/internal/proxy/look_aside_balancer.go new file mode 100644 index 0000000000..899b1c2e67 --- /dev/null +++ b/internal/proxy/look_aside_balancer.go @@ -0,0 +1,185 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "math" + "sync" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +var ( + checkQueryNodeHealthInterval = 500 * time.Millisecond +) + +type LookAsideBalancer struct { + clientMgr shardClientMgr + + // query node -> workload latest metrics + metricsMap *typeutil.ConcurrentMap[int64, *internalpb.CostAggregation] + + // query node -> last update metrics ts + metricsUpdateTs *typeutil.ConcurrentMap[int64, int64] + + // query node -> total nq of requests which already send but response hasn't received + executingTaskTotalNQ *typeutil.ConcurrentMap[int64, *atomic.Int64] + + unreachableQueryNodes *typeutil.ConcurrentSet[int64] + + closeCh chan struct{} + closeOnce sync.Once + wg sync.WaitGroup +} + +func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer { + balancer := &LookAsideBalancer{ + clientMgr: clientMgr, + metricsMap: typeutil.NewConcurrentMap[int64, *internalpb.CostAggregation](), + metricsUpdateTs: typeutil.NewConcurrentMap[int64, int64](), + executingTaskTotalNQ: typeutil.NewConcurrentMap[int64, *atomic.Int64](), + unreachableQueryNodes: typeutil.NewConcurrentSet[int64](), + closeCh: make(chan struct{}), + } + + balancer.wg.Add(1) + go balancer.checkQueryNodeHealthLoop() + return balancer +} + +func (b *LookAsideBalancer) Close() { + b.closeOnce.Do(func() { + close(b.closeCh) + b.wg.Wait() + }) +} + +func (b *LookAsideBalancer) SelectNode(availableNodes []int64, cost int64) (int64, error) { + targetNode := int64(-1) + targetScore := float64(math.MaxFloat64) + for _, node := range availableNodes { + if b.unreachableQueryNodes.Contain(node) { + continue + } + + cost, _ := b.metricsMap.Get(node) + executingNQ, ok := b.executingTaskTotalNQ.Get(node) + if !ok { + executingNQ = atomic.NewInt64(0) + b.executingTaskTotalNQ.Insert(node, executingNQ) + } + + score := b.calculateScore(cost, executingNQ.Load()) + if targetNode == -1 || score < targetScore { + targetScore = score + targetNode = node + } + } + + // update executing task cost + totalNQ, ok := b.executingTaskTotalNQ.Get(targetNode) + if !ok { + totalNQ = atomic.NewInt64(0) + } + totalNQ.Add(cost) + + return targetNode, nil +} + +// when task canceled, should reduce executing total nq cost +func (b *LookAsideBalancer) CancelWorkload(node int64, nq int64) { + totalNQ, ok := b.executingTaskTotalNQ.Get(node) + if ok { + totalNQ.Sub(nq) + } +} + +// UpdateCostMetrics used for cache some metrics of recent search/query cost +func (b *LookAsideBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { + // cache the latest query node cost metrics for updating the score + b.metricsMap.Insert(node, cost) + b.metricsUpdateTs.Insert(node, time.Now().UnixMilli()) +} + +// calculateScore compute the query node's workload score +// https://www.usenix.org/conference/nsdi15/technical-sessions/presentation/suresh +func (b *LookAsideBalancer) calculateScore(cost *internalpb.CostAggregation, executingNQ int64) float64 { + if cost == nil || cost.ResponseTime == 0 { + return float64(executingNQ) + } + return float64(cost.ResponseTime) - float64(1)/float64(cost.ServiceTime) + math.Pow(float64(1+cost.TotalNQ+executingNQ), 3.0)/float64(cost.ServiceTime) +} + +func (b *LookAsideBalancer) checkQueryNodeHealthLoop() { + defer b.wg.Done() + + ticker := time.NewTicker(checkQueryNodeHealthInterval) + defer ticker.Stop() + log.Info("Start check query node health loop") + for { + select { + case <-b.closeCh: + log.Info("check query node health loop exit") + return + + case <-ticker.C: + now := time.Now().UnixMilli() + b.metricsUpdateTs.Range(func(node int64, lastUpdateTs int64) bool { + if now-lastUpdateTs > checkQueryNodeHealthInterval.Milliseconds() { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + checkHealthFailed := func(err error) bool { + log.Warn("query node check health failed, add it to unreachable nodes list", + zap.Int64("nodeID", node), + zap.Error(err)) + b.unreachableQueryNodes.Insert(node) + return true + } + + qn, err := b.clientMgr.GetClient(ctx, node) + if err != nil { + return checkHealthFailed(err) + } + + resp, err := qn.GetComponentStates(ctx) + if err != nil { + return checkHealthFailed(err) + } + + if resp.GetState().GetStateCode() != commonpb.StateCode_Healthy { + return checkHealthFailed(merr.WrapErrNodeOffline(node)) + } + + // check health successfully, update check health ts + b.metricsUpdateTs.Insert(node, time.Now().Local().UnixMilli()) + b.unreachableQueryNodes.Remove(node) + } + + return true + }) + } + } +} diff --git a/internal/proxy/look_aside_balancer_test.go b/internal/proxy/look_aside_balancer_test.go new file mode 100644 index 0000000000..07d4658f18 --- /dev/null +++ b/internal/proxy/look_aside_balancer_test.go @@ -0,0 +1,292 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.uber.org/atomic" +) + +type LookAsideBalancerSuite struct { + suite.Suite + + clientMgr *MockShardClientManager + balancer *LookAsideBalancer +} + +func (suite *LookAsideBalancerSuite) SetupTest() { + suite.clientMgr = NewMockShardClientManager(suite.T()) + suite.balancer = NewLookAsideBalancer(suite.clientMgr) + + qn := types.NewMockQueryNode(suite.T()) + suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(1)).Return(qn, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, errors.New("fake error")).Maybe() +} + +func (suite *LookAsideBalancerSuite) TearDownTest() { + suite.balancer.Close() +} + +func (suite *LookAsideBalancerSuite) TestUpdateMetrics() { + costMetrics := &internalpb.CostAggregation{ + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 1, + } + + suite.balancer.UpdateCostMetrics(1, costMetrics) + + lastUpdateTs, ok := suite.balancer.metricsUpdateTs.Get(1) + suite.True(ok) + suite.True(time.Now().UnixMilli()-lastUpdateTs <= 5) +} + +func (suite *LookAsideBalancerSuite) TestCalculateScore() { + costMetrics1 := &internalpb.CostAggregation{ + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 1, + } + + costMetrics2 := &internalpb.CostAggregation{ + ResponseTime: 5, + ServiceTime: 2, + TotalNQ: 1, + } + + costMetrics3 := &internalpb.CostAggregation{ + ResponseTime: 10, + ServiceTime: 1, + TotalNQ: 1, + } + + costMetrics4 := &internalpb.CostAggregation{ + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + } + + score1 := suite.balancer.calculateScore(costMetrics1, 0) + score2 := suite.balancer.calculateScore(costMetrics2, 0) + score3 := suite.balancer.calculateScore(costMetrics3, 0) + score4 := suite.balancer.calculateScore(costMetrics4, 0) + suite.Equal(float64(12), score1) + suite.Equal(float64(8.5), score2) + suite.Equal(float64(17), score3) + suite.Equal(float64(5), score4) + + score5 := suite.balancer.calculateScore(costMetrics1, 5) + score6 := suite.balancer.calculateScore(costMetrics2, 5) + score7 := suite.balancer.calculateScore(costMetrics3, 5) + score8 := suite.balancer.calculateScore(costMetrics4, 5) + suite.Equal(float64(347), score5) + suite.Equal(float64(176), score6) + suite.Equal(float64(352), score7) + suite.Equal(float64(220), score8) +} + +func (suite *LookAsideBalancerSuite) TestSelectNode() { + type testcase struct { + name string + costMetrics map[int64]*internalpb.CostAggregation + executingNQ map[int64]int64 + requestCount int + result map[int64]int64 + } + + cases := []testcase{ + { + name: "each qn has same cost metrics", + costMetrics: map[int64]*internalpb.CostAggregation{ + 1: { + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + }, + 2: { + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + }, + + 3: { + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + }, + }, + + executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0}, + requestCount: 100, + result: map[int64]int64{1: 34, 2: 33, 3: 33}, + }, + { + name: "each qn has different service time", + costMetrics: map[int64]*internalpb.CostAggregation{ + 1: { + ResponseTime: 30, + ServiceTime: 20, + TotalNQ: 0, + }, + 2: { + ResponseTime: 50, + ServiceTime: 40, + TotalNQ: 0, + }, + + 3: { + ResponseTime: 70, + ServiceTime: 60, + TotalNQ: 0, + }, + }, + + executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0}, + requestCount: 100, + result: map[int64]int64{1: 27, 2: 34, 3: 39}, + }, + { + name: "one qn has task in queue", + costMetrics: map[int64]*internalpb.CostAggregation{ + 1: { + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + }, + 2: { + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + }, + + 3: { + ResponseTime: 100, + ServiceTime: 1, + TotalNQ: 20, + }, + }, + + executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0}, + requestCount: 100, + result: map[int64]int64{1: 40, 2: 40, 3: 20}, + }, + + { + name: "qn with executing task", + costMetrics: map[int64]*internalpb.CostAggregation{ + 1: { + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + }, + 2: { + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + }, + + 3: { + ResponseTime: 5, + ServiceTime: 1, + TotalNQ: 0, + }, + }, + + executingNQ: map[int64]int64{1: 0, 2: 0, 3: 20}, + requestCount: 100, + result: map[int64]int64{1: 40, 2: 40, 3: 20}, + }, + { + name: "qn with empty metrics", + costMetrics: map[int64]*internalpb.CostAggregation{ + 1: {}, + 2: {}, + 3: {}, + }, + + executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0}, + requestCount: 100, + result: map[int64]int64{1: 34, 2: 33, 3: 33}, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + for node, cost := range c.costMetrics { + suite.balancer.UpdateCostMetrics(node, cost) + } + + for node, executingNQ := range c.executingNQ { + suite.balancer.executingTaskTotalNQ.Insert(node, atomic.NewInt64(executingNQ)) + } + + counter := make(map[int64]int64) + for i := 0; i < c.requestCount; i++ { + node, err := suite.balancer.SelectNode([]int64{1, 2, 3}, 1) + suite.NoError(err) + counter[node]++ + } + + for node, result := range c.result { + suite.Equal(result, counter[node]) + } + }) + } +} + +func (suite *LookAsideBalancerSuite) TestCancelWorkload() { + node, err := suite.balancer.SelectNode([]int64{1, 2, 3}, 10) + suite.NoError(err) + suite.balancer.CancelWorkload(node, 10) + + executingNQ, ok := suite.balancer.executingTaskTotalNQ.Get(node) + suite.True(ok) + suite.Equal(int64(0), executingNQ.Load()) +} + +func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { + qn2 := types.NewMockQueryNode(suite.T()) + suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil) + qn2.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + StateCode: commonpb.StateCode_Healthy, + }, + }, nil) + + suite.balancer.metricsUpdateTs.Insert(1, time.Now().UnixMilli()) + suite.balancer.metricsUpdateTs.Insert(2, time.Now().UnixMilli()) + suite.Eventually(func() bool { + return suite.balancer.unreachableQueryNodes.Contain(1) + }, 2*time.Second, 100*time.Millisecond) + + suite.Eventually(func() bool { + return !suite.balancer.unreachableQueryNodes.Contain(2) + }, 3*time.Second, 100*time.Millisecond) +} + +func TestLookAsideBalancerSuite(t *testing.T) { + suite.Run(t, new(LookAsideBalancerSuite)) +} diff --git a/internal/proxy/mock_lb_balancer.go b/internal/proxy/mock_lb_balancer.go index 95b964ec98..0a0550b48b 100644 --- a/internal/proxy/mock_lb_balancer.go +++ b/internal/proxy/mock_lb_balancer.go @@ -2,7 +2,10 @@ package proxy -import mock "github.com/stretchr/testify/mock" +import ( + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + mock "github.com/stretchr/testify/mock" +) // MockLBBalancer is an autogenerated mock type for the LBBalancer type type MockLBBalancer struct { @@ -17,6 +20,72 @@ func (_m *MockLBBalancer) EXPECT() *MockLBBalancer_Expecter { return &MockLBBalancer_Expecter{mock: &_m.Mock} } +// CancelWorkload provides a mock function with given fields: node, nq +func (_m *MockLBBalancer) CancelWorkload(node int64, nq int64) { + _m.Called(node, nq) +} + +// MockLBBalancer_CancelWorkload_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CancelWorkload' +type MockLBBalancer_CancelWorkload_Call struct { + *mock.Call +} + +// CancelWorkload is a helper method to define mock.On call +// - node int64 +// - nq int64 +func (_e *MockLBBalancer_Expecter) CancelWorkload(node interface{}, nq interface{}) *MockLBBalancer_CancelWorkload_Call { + return &MockLBBalancer_CancelWorkload_Call{Call: _e.mock.On("CancelWorkload", node, nq)} +} + +func (_c *MockLBBalancer_CancelWorkload_Call) Run(run func(node int64, nq int64)) *MockLBBalancer_CancelWorkload_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int64)) + }) + return _c +} + +func (_c *MockLBBalancer_CancelWorkload_Call) Return() *MockLBBalancer_CancelWorkload_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBBalancer_CancelWorkload_Call) RunAndReturn(run func(int64, int64)) *MockLBBalancer_CancelWorkload_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockLBBalancer) Close() { + _m.Called() +} + +// MockLBBalancer_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockLBBalancer_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockLBBalancer_Expecter) Close() *MockLBBalancer_Close_Call { + return &MockLBBalancer_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockLBBalancer_Close_Call) Run(run func()) *MockLBBalancer_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockLBBalancer_Close_Call) Return() *MockLBBalancer_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Close_Call { + _c.Call.Return(run) + return _c +} + // SelectNode provides a mock function with given fields: availableNodes, nq func (_m *MockLBBalancer) SelectNode(availableNodes []int64, nq int64) (int64, error) { ret := _m.Called(availableNodes, nq) @@ -70,6 +139,40 @@ func (_c *MockLBBalancer_SelectNode_Call) RunAndReturn(run func([]int64, int64) return _c } +// UpdateCostMetrics provides a mock function with given fields: node, cost +func (_m *MockLBBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { + _m.Called(node, cost) +} + +// MockLBBalancer_UpdateCostMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCostMetrics' +type MockLBBalancer_UpdateCostMetrics_Call struct { + *mock.Call +} + +// UpdateCostMetrics is a helper method to define mock.On call +// - node int64 +// - cost *internalpb.CostAggregation +func (_e *MockLBBalancer_Expecter) UpdateCostMetrics(node interface{}, cost interface{}) *MockLBBalancer_UpdateCostMetrics_Call { + return &MockLBBalancer_UpdateCostMetrics_Call{Call: _e.mock.On("UpdateCostMetrics", node, cost)} +} + +func (_c *MockLBBalancer_UpdateCostMetrics_Call) Run(run func(node int64, cost *internalpb.CostAggregation)) *MockLBBalancer_UpdateCostMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(*internalpb.CostAggregation)) + }) + return _c +} + +func (_c *MockLBBalancer_UpdateCostMetrics_Call) Return() *MockLBBalancer_UpdateCostMetrics_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBBalancer_UpdateCostMetrics_Call) RunAndReturn(run func(int64, *internalpb.CostAggregation)) *MockLBBalancer_UpdateCostMetrics_Call { + _c.Call.Return(run) + return _c +} + type mockConstructorTestingTNewMockLBBalancer interface { mock.TestingT Cleanup(func()) diff --git a/internal/proxy/mock_lb_policy.go b/internal/proxy/mock_lb_policy.go index 62647d86ed..c6d907635f 100644 --- a/internal/proxy/mock_lb_policy.go +++ b/internal/proxy/mock_lb_policy.go @@ -5,6 +5,7 @@ package proxy import ( context "context" + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" mock "github.com/stretchr/testify/mock" ) @@ -64,13 +65,13 @@ func (_c *MockLBPolicy_Execute_Call) RunAndReturn(run func(context.Context, Coll return _c } -// ExecuteWithRetry provides a mock function with given fields: ctx, workload, retryTimes -func (_m *MockLBPolicy) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload, retryTimes uint) error { - ret := _m.Called(ctx, workload, retryTimes) +// ExecuteWithRetry provides a mock function with given fields: ctx, workload +func (_m *MockLBPolicy) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error { + ret := _m.Called(ctx, workload) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, ChannelWorkload, uint) error); ok { - r0 = rf(ctx, workload, retryTimes) + if rf, ok := ret.Get(0).(func(context.Context, ChannelWorkload) error); ok { + r0 = rf(ctx, workload) } else { r0 = ret.Error(0) } @@ -86,14 +87,13 @@ type MockLBPolicy_ExecuteWithRetry_Call struct { // ExecuteWithRetry is a helper method to define mock.On call // - ctx context.Context // - workload ChannelWorkload -// - retryTimes uint -func (_e *MockLBPolicy_Expecter) ExecuteWithRetry(ctx interface{}, workload interface{}, retryTimes interface{}) *MockLBPolicy_ExecuteWithRetry_Call { - return &MockLBPolicy_ExecuteWithRetry_Call{Call: _e.mock.On("ExecuteWithRetry", ctx, workload, retryTimes)} +func (_e *MockLBPolicy_Expecter) ExecuteWithRetry(ctx interface{}, workload interface{}) *MockLBPolicy_ExecuteWithRetry_Call { + return &MockLBPolicy_ExecuteWithRetry_Call{Call: _e.mock.On("ExecuteWithRetry", ctx, workload)} } -func (_c *MockLBPolicy_ExecuteWithRetry_Call) Run(run func(ctx context.Context, workload ChannelWorkload, retryTimes uint)) *MockLBPolicy_ExecuteWithRetry_Call { +func (_c *MockLBPolicy_ExecuteWithRetry_Call) Run(run func(ctx context.Context, workload ChannelWorkload)) *MockLBPolicy_ExecuteWithRetry_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(ChannelWorkload), args[2].(uint)) + run(args[0].(context.Context), args[1].(ChannelWorkload)) }) return _c } @@ -103,7 +103,41 @@ func (_c *MockLBPolicy_ExecuteWithRetry_Call) Return(_a0 error) *MockLBPolicy_Ex return _c } -func (_c *MockLBPolicy_ExecuteWithRetry_Call) RunAndReturn(run func(context.Context, ChannelWorkload, uint) error) *MockLBPolicy_ExecuteWithRetry_Call { +func (_c *MockLBPolicy_ExecuteWithRetry_Call) RunAndReturn(run func(context.Context, ChannelWorkload) error) *MockLBPolicy_ExecuteWithRetry_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCostMetrics provides a mock function with given fields: node, cost +func (_m *MockLBPolicy) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { + _m.Called(node, cost) +} + +// MockLBPolicy_UpdateCostMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCostMetrics' +type MockLBPolicy_UpdateCostMetrics_Call struct { + *mock.Call +} + +// UpdateCostMetrics is a helper method to define mock.On call +// - node int64 +// - cost *internalpb.CostAggregation +func (_e *MockLBPolicy_Expecter) UpdateCostMetrics(node interface{}, cost interface{}) *MockLBPolicy_UpdateCostMetrics_Call { + return &MockLBPolicy_UpdateCostMetrics_Call{Call: _e.mock.On("UpdateCostMetrics", node, cost)} +} + +func (_c *MockLBPolicy_UpdateCostMetrics_Call) Run(run func(node int64, cost *internalpb.CostAggregation)) *MockLBPolicy_UpdateCostMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(*internalpb.CostAggregation)) + }) + return _c +} + +func (_c *MockLBPolicy_UpdateCostMetrics_Call) Return() *MockLBPolicy_UpdateCostMetrics_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBPolicy_UpdateCostMetrics_Call) RunAndReturn(run func(int64, *internalpb.CostAggregation)) *MockLBPolicy_UpdateCostMetrics_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 649248eb68..a5c91c7a4b 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -127,7 +127,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { searchResultCh: make(chan *internalpb.SearchResults, n), shardMgr: mgr, multiRateLimiter: NewMultiRateLimiter(), - lbPolicy: NewLBPolicyImpl(NewRoundRobinBalancer(), mgr), + lbPolicy: NewLBPolicyImpl(mgr), } node.UpdateStateCode(commonpb.StateCode_Abnormal) logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load())) @@ -437,6 +437,10 @@ func (node *Proxy) Stop() error { node.chMgr.removeAllDMLStream() } + if node.lbPolicy != nil { + node.lbPolicy.Close() + } + // https://github.com/milvus-io/milvus/issues/12282 node.UpdateStateCode(commonpb.StateCode_Abnormal) diff --git a/internal/proxy/roundrobin_balancer.go b/internal/proxy/roundrobin_balancer.go index 41ea1427be..3ab0d78231 100644 --- a/internal/proxy/roundrobin_balancer.go +++ b/internal/proxy/roundrobin_balancer.go @@ -16,38 +16,56 @@ package proxy import ( - "sync" - + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "go.uber.org/atomic" ) type RoundRobinBalancer struct { // request num send to each node - mutex sync.RWMutex - nodeWorkload map[int64]int64 + nodeWorkload *typeutil.ConcurrentMap[int64, *atomic.Int64] } func NewRoundRobinBalancer() *RoundRobinBalancer { return &RoundRobinBalancer{ - nodeWorkload: make(map[int64]int64), + nodeWorkload: typeutil.NewConcurrentMap[int64, *atomic.Int64](), } } -func (b *RoundRobinBalancer) SelectNode(availableNodes []int64, workload int64) (int64, error) { +func (b *RoundRobinBalancer) SelectNode(availableNodes []int64, cost int64) (int64, error) { if len(availableNodes) == 0 { return -1, merr.ErrNoAvailableNode } - b.mutex.Lock() - defer b.mutex.Unlock() + targetNode := int64(-1) - targetNodeWorkload := int64(-1) + var targetNodeWorkload *atomic.Int64 for _, node := range availableNodes { - if targetNodeWorkload == -1 || b.nodeWorkload[node] < targetNodeWorkload { + workload, ok := b.nodeWorkload.Get(node) + + if !ok { + workload = atomic.NewInt64(0) + b.nodeWorkload.Insert(node, workload) + } + + if targetNodeWorkload == nil || workload.Load() < targetNodeWorkload.Load() { targetNode = node - targetNodeWorkload = b.nodeWorkload[node] + targetNodeWorkload = workload } } - b.nodeWorkload[targetNode] += workload + targetNodeWorkload.Add(cost) return targetNode, nil } + +func (b *RoundRobinBalancer) CancelWorkload(node int64, nq int64) { + load, ok := b.nodeWorkload.Get(node) + + if ok { + load.Sub(nq) + } +} + +func (b *RoundRobinBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {} + +func (b *RoundRobinBalancer) Close() {} diff --git a/internal/proxy/roundrobin_balancer_test.go b/internal/proxy/roundrobin_balancer_test.go index 2c37fca9c4..71a7b7ac42 100644 --- a/internal/proxy/roundrobin_balancer_test.go +++ b/internal/proxy/roundrobin_balancer_test.go @@ -38,16 +38,24 @@ func (s *RoundRobinBalancerSuite) TestRoundRobin() { s.balancer.SelectNode(availableNodes, 1) s.balancer.SelectNode(availableNodes, 1) - s.Equal(int64(2), s.balancer.nodeWorkload[1]) - s.Equal(int64(2), s.balancer.nodeWorkload[2]) + workload, ok := s.balancer.nodeWorkload.Get(1) + s.True(ok) + s.Equal(int64(2), workload.Load()) + workload, ok = s.balancer.nodeWorkload.Get(1) + s.True(ok) + s.Equal(int64(2), workload.Load()) s.balancer.SelectNode(availableNodes, 3) s.balancer.SelectNode(availableNodes, 1) s.balancer.SelectNode(availableNodes, 1) s.balancer.SelectNode(availableNodes, 1) - s.Equal(int64(5), s.balancer.nodeWorkload[1]) - s.Equal(int64(5), s.balancer.nodeWorkload[2]) + workload, ok = s.balancer.nodeWorkload.Get(1) + s.True(ok) + s.Equal(int64(5), workload.Load()) + workload, ok = s.balancer.nodeWorkload.Get(1) + s.True(ok) + s.Equal(int64(5), workload.Load()) } func (s *RoundRobinBalancerSuite) TestNoAvailableNode() { @@ -56,6 +64,17 @@ func (s *RoundRobinBalancerSuite) TestNoAvailableNode() { s.Error(err) } +func (s *RoundRobinBalancerSuite) TestCancelWorkload() { + availableNodes := []int64{101} + _, err := s.balancer.SelectNode(availableNodes, 5) + s.NoError(err) + workload, ok := s.balancer.nodeWorkload.Get(101) + s.True(ok) + s.Equal(int64(5), workload.Load()) + s.balancer.CancelWorkload(101, 5) + s.Equal(int64(0), workload.Load()) +} + func TestRoundRobinBalancerSuite(t *testing.T) { suite.Run(t, new(RoundRobinBalancerSuite)) } diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index d9ad59b5c6..06db426e37 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -491,6 +491,7 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query log.Debug("get query result") t.resultBuf.Insert(result) + t.lb.UpdateCostMetrics(nodeID, result.CostAggregation) return nil } diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 38e0d2c8ce..3168b17a38 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -58,6 +58,8 @@ func TestQueryTask_all(t *testing.T) { hitNum = 10 ) + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc.EXPECT().Start().Return(nil) qc.EXPECT().Stop().Return(nil) @@ -73,12 +75,10 @@ func TestQueryTask_all(t *testing.T) { }, }, nil).Maybe() - mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) { - return qn, nil - } - - mgr := newShardClientMgr(withShardClientCreator(mockCreator)) - lb := NewLBPolicyImpl(NewRoundRobinBalancer(), mgr) + mgr := NewMockShardClientManager(t) + mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() + mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() + lb := NewLBPolicyImpl(mgr) rc.Start() defer rc.Stop() @@ -217,10 +217,12 @@ func TestQueryTask_all(t *testing.T) { task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField) task.ctx = ctx qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_NotShardLeader, @@ -230,6 +232,7 @@ func TestQueryTask_all(t *testing.T) { assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -238,6 +241,7 @@ func TestQueryTask_all(t *testing.T) { assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil) assert.NoError(t, task.Execute(ctx)) diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index d5f96ce932..5d0c065af0 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -521,6 +521,7 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) } t.resultBuf.Insert(result) + t.lb.UpdateCostMetrics(nodeID, result.CostAggregation) return nil } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 1bc571e8c4..3525626d1d 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1543,12 +1543,12 @@ func TestSearchTask_ErrExecute(t *testing.T) { collectionName = t.Name() + funcutil.GenRandomStr() ) - mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) { - return qn, nil - } + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() - mgr := newShardClientMgr(withShardClientCreator(mockCreator)) - lb := NewLBPolicyImpl(NewRoundRobinBalancer(), mgr) + mgr := NewMockShardClientManager(t) + mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() + mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() + lb := NewLBPolicyImpl(mgr) rc.Start() defer rc.Stop() @@ -1661,6 +1661,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_NotShardLeader, @@ -1670,6 +1671,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -1678,6 +1680,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index 994ce00cd8..26793ed94c 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -77,11 +77,11 @@ func (s *StatisticTaskSuite) SetupTest() { s.rc.Start() s.qn = types.NewMockQueryNode(s.T()) - mockCreator := func(ctx context.Context, addr string) (types.QueryNode, error) { - return s.qn, nil - } - mgr := newShardClientMgr(withShardClientCreator(mockCreator)) - s.lb = NewLBPolicyImpl(NewRoundRobinBalancer(), mgr) + s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + mgr := NewMockShardClientManager(s.T()) + mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe() + mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() + s.lb = NewLBPolicyImpl(mgr) err := InitMetaCache(context.Background(), s.rc, s.qc, mgr) s.NoError(err) diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index a9c01254bc..ff0fb9d084 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -1017,6 +1017,12 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i collector.Rate.Add(metricsinfo.NQPerSecond, 1) metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req))) } + + if ret.CostAggregation != nil { + // update channel's response time + currentTotalNQ := node.scheduler.GetWaitingTaskTotalNQ() + ret.CostAggregation.TotalNQ = currentTotalNQ + } return ret, nil }