milvus/internal/querycoordv2/services_test.go
yah01 11e4445ef7
Check whether segments are fully loaded while fetching shard leaders (#20991)
Signed-off-by: yah01 <yang.cen@zilliz.com>

Signed-off-by: yah01 <yang.cen@zilliz.com>
2022-12-06 18:05:18 +08:00

1184 lines
37 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoordv2
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/balance"
"github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type ServiceSuite struct {
suite.Suite
// Data
collections []int64
partitions map[int64][]int64
channels map[int64][]string
segments map[int64]map[int64][]int64 // CollectionID, PartitionID -> Segments
loadTypes map[int64]querypb.LoadType
replicaNumber map[int64]int32
nodes []int64
// Dependencies
kv kv.MetaKv
store meta.Store
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
broker *meta.MockBroker
cluster *session.MockCluster
nodeMgr *session.NodeManager
jobScheduler *job.Scheduler
taskScheduler *task.MockScheduler
balancer balance.Balance
// Test object
server *Server
}
func (suite *ServiceSuite) SetupSuite() {
Params.Init()
suite.collections = []int64{1000, 1001}
suite.partitions = map[int64][]int64{
1000: {100, 101},
1001: {102, 103},
}
suite.channels = map[int64][]string{
1000: {"1000-dmc0", "1000-dmc1"},
1001: {"1001-dmc0", "1001-dmc1"},
}
suite.segments = map[int64]map[int64][]int64{
1000: {
100: {1, 2},
101: {3, 4},
},
1001: {
102: {5, 6},
103: {7, 8},
},
}
suite.loadTypes = map[int64]querypb.LoadType{
1000: querypb.LoadType_LoadCollection,
1001: querypb.LoadType_LoadPartition,
}
suite.replicaNumber = map[int64]int32{
1000: 1,
1001: 3,
}
suite.nodes = []int64{1, 2, 3, 4, 5,
101, 102, 103, 104, 105}
}
func (suite *ServiceSuite) SetupTest() {
config := params.GenerateEtcdConfig()
cli, err := etcd.GetEtcdClient(
config.UseEmbedEtcd.GetAsBool(),
config.EtcdUseSSL.GetAsBool(),
config.Endpoints.GetAsStrings(),
config.EtcdTLSCert.GetValue(),
config.EtcdTLSKey.GetValue(),
config.EtcdTLSCACert.GetValue(),
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.store = meta.NewMetaStore(suite.kv)
suite.dist = meta.NewDistributionManager()
suite.meta = meta.NewMeta(params.RandomIncrementIDAllocator(), suite.store)
suite.broker = meta.NewMockBroker(suite.T())
suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta)
suite.nodeMgr = session.NewNodeManager()
for _, node := range suite.nodes {
suite.nodeMgr.Add(session.NewNodeInfo(node, "localhost"))
}
suite.cluster = session.NewMockCluster(suite.T())
suite.jobScheduler = job.NewScheduler()
suite.taskScheduler = task.NewMockScheduler(suite.T())
suite.jobScheduler.Start(context.Background())
suite.balancer = balance.NewRowCountBasedBalancer(
suite.taskScheduler,
suite.nodeMgr,
suite.dist,
suite.meta,
suite.targetMgr,
)
suite.server = &Server{
kv: suite.kv,
store: suite.store,
session: sessionutil.NewSession(context.Background(), Params.EtcdCfg.MetaRootPath.GetValue(), cli),
metricsCacheManager: metricsinfo.NewMetricsCacheManager(),
dist: suite.dist,
meta: suite.meta,
targetMgr: suite.targetMgr,
broker: suite.broker,
nodeMgr: suite.nodeMgr,
cluster: suite.cluster,
jobScheduler: suite.jobScheduler,
taskScheduler: suite.taskScheduler,
balancer: suite.balancer,
}
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
}
func (suite *ServiceSuite) TestShowCollections() {
suite.loadAll()
ctx := context.Background()
server := suite.server
collectionNum := len(suite.collections)
// Test get all collections
req := &querypb.ShowCollectionsRequest{}
resp, err := server.ShowCollections(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.Len(resp.CollectionIDs, collectionNum)
for _, collection := range suite.collections {
suite.Contains(resp.CollectionIDs, collection)
}
// Test get 1 collection
collection := suite.collections[0]
req.CollectionIDs = []int64{collection}
resp, err = server.ShowCollections(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.Len(resp.CollectionIDs, 1)
suite.Equal(collection, resp.CollectionIDs[0])
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
resp, err = server.ShowCollections(ctx, req)
suite.NoError(err)
suite.Contains(resp.Status.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestShowPartitions() {
suite.loadAll()
ctx := context.Background()
server := suite.server
for _, collection := range suite.collections {
partitions := suite.partitions[collection]
partitionNum := len(partitions)
// Test get all partitions
req := &querypb.ShowPartitionsRequest{
CollectionID: collection,
}
resp, err := server.ShowPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.Len(resp.PartitionIDs, partitionNum)
for _, partition := range partitions {
suite.Contains(resp.PartitionIDs, partition)
}
// Test get 1 partition
req = &querypb.ShowPartitionsRequest{
CollectionID: collection,
PartitionIDs: partitions[0:1],
}
resp, err = server.ShowPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.Len(resp.PartitionIDs, 1)
for _, partition := range partitions[0:1] {
suite.Contains(resp.PartitionIDs, partition)
}
}
// Test when server is not healthy
req := &querypb.ShowPartitionsRequest{
CollectionID: suite.collections[0],
}
server.UpdateStateCode(commonpb.StateCode_Initializing)
resp, err := server.ShowPartitions(ctx, req)
suite.NoError(err)
suite.Contains(resp.Status.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestLoadCollection() {
ctx := context.Background()
server := suite.server
// Test load all collections
for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
suite.expectGetRecoverInfo(collection)
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
}
resp, err := server.LoadCollection(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.assertLoaded(collection)
}
// Test load again
for _, collection := range suite.collections {
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
}
resp, err := server.LoadCollection(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.LoadCollectionRequest{
CollectionID: suite.collections[0],
}
resp, err := server.LoadCollection(ctx, req)
suite.NoError(err)
suite.Contains(resp.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestLoadCollectionFailed() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test load with different replica number
for _, collection := range suite.collections {
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[collection] + 1,
}
resp, err := server.LoadCollection(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
// Test load with partitions loaded
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
}
resp, err := server.LoadCollection(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
}
func (suite *ServiceSuite) TestLoadPartition() {
ctx := context.Background()
server := suite.server
// Test load all partitions
for _, collection := range suite.collections {
suite.expectGetRecoverInfo(collection)
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.assertLoaded(collection)
}
// Test load again
for _, collection := range suite.collections {
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.LoadPartitionsRequest{
CollectionID: suite.collections[0],
PartitionIDs: suite.partitions[suite.collections[0]],
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Contains(resp.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestLoadPartitionFailed() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test load with different replica number
for _, collection := range suite.collections {
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: suite.replicaNumber[collection] + 1,
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
// Test load with collection loaded
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
// Test load with more partitions
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: append(suite.partitions[collection], 999),
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
}
func (suite *ServiceSuite) TestReleaseCollection() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test release all collections
for _, collection := range suite.collections {
req := &querypb.ReleaseCollectionRequest{
CollectionID: collection,
}
resp, err := server.ReleaseCollection(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.assertReleased(collection)
}
// Test release again
for _, collection := range suite.collections {
req := &querypb.ReleaseCollectionRequest{
CollectionID: collection,
}
resp, err := server.ReleaseCollection(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.ReleaseCollectionRequest{
CollectionID: suite.collections[0],
}
resp, err := server.ReleaseCollection(ctx, req)
suite.NoError(err)
suite.Contains(resp.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestReleasePartition() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test release all partitions
for _, collection := range suite.collections {
req := &querypb.ReleasePartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection][0:1],
}
resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
} else {
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...)
}
// Test release again
for _, collection := range suite.collections {
req := &querypb.ReleasePartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection][0:1],
}
resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
} else {
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...)
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.ReleasePartitionsRequest{
CollectionID: suite.collections[0],
PartitionIDs: suite.partitions[suite.collections[0]][0:1],
}
resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err)
suite.Contains(resp.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestGetPartitionStates() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test get partitions' state
for _, collection := range suite.collections {
req := &querypb.GetPartitionStatesRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := server.GetPartitionStates(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.Len(resp.PartitionDescriptions, len(suite.partitions[collection]))
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.GetPartitionStatesRequest{
CollectionID: suite.collections[0],
}
resp, err := server.GetPartitionStates(ctx, req)
suite.NoError(err)
suite.Contains(resp.Status.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestGetSegmentInfo() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test get all segments
for i, collection := range suite.collections {
suite.updateSegmentDist(collection, int64(i))
req := &querypb.GetSegmentInfoRequest{
CollectionID: collection,
}
resp, err := server.GetSegmentInfo(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.assertSegments(collection, resp.GetInfos())
}
// Test get given segments
for _, collection := range suite.collections {
req := &querypb.GetSegmentInfoRequest{
CollectionID: collection,
SegmentIDs: suite.getAllSegments(collection),
}
resp, err := server.GetSegmentInfo(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.assertSegments(collection, resp.GetInfos())
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.GetSegmentInfoRequest{
CollectionID: suite.collections[0],
}
resp, err := server.GetSegmentInfo(ctx, req)
suite.NoError(err)
suite.Contains(resp.Status.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestLoadBalance() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test get balance first segment
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
srcNode := replicas[0].GetNodes()[0]
dstNode := replicas[0].GetNodes()[1]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateSegmentDist(collection, srcNode)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
CollectionID: collection,
SourceNodeIDs: []int64{srcNode},
DstNodeIDs: []int64{dstNode},
SealedSegmentIDs: segments,
}
suite.taskScheduler.ExpectedCalls = make([]*mock.Call, 0)
suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(task task.Task) {
actions := task.Actions()
suite.Len(actions, 2)
growAction, reduceAction := actions[0], actions[1]
suite.Equal(dstNode, growAction.Node())
suite.Equal(srcNode, reduceAction.Node())
task.Cancel()
}).Return(nil)
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.taskScheduler.AssertExpectations(suite.T())
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.LoadBalanceRequest{
CollectionID: suite.collections[0],
SourceNodeIDs: []int64{1},
DstNodeIDs: []int64{100 + 1},
}
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Contains(resp.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() {
suite.loadAll()
ctx := context.Background()
server := suite.server
srcNode := int64(1001)
dstNode := int64(1002)
metaSegments := make([]*meta.Segment, 0)
segmentOnCollection := make(map[int64][]int64)
// update two collection's dist
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas[0].AddNode(srcNode)
replicas[0].AddNode(dstNode)
defer replicas[0].RemoveNode(srcNode)
defer replicas[0].RemoveNode(dstNode)
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
for partition, segments := range suite.segments[collection] {
for _, segment := range segments {
metaSegments = append(metaSegments,
utils.CreateTestSegment(collection, partition, segment, srcNode, 1, "test-channel"))
if segmentOnCollection[collection] == nil {
segmentOnCollection[collection] = make([]int64, 0)
}
segmentOnCollection[collection] = append(segmentOnCollection[collection], segment)
}
}
}
suite.dist.SegmentDistManager.Update(srcNode, metaSegments...)
// expect each collection can only trigger its own segment's balance
for _, collection := range suite.collections {
req := &querypb.LoadBalanceRequest{
CollectionID: collection,
SourceNodeIDs: []int64{srcNode},
DstNodeIDs: []int64{dstNode},
}
suite.taskScheduler.ExpectedCalls = make([]*mock.Call, 0)
suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(t task.Task) {
actions := t.Actions()
suite.Len(actions, 2)
growAction := actions[0].(*task.SegmentAction)
reduceAction := actions[1].(*task.SegmentAction)
suite.True(lo.Contains(segmentOnCollection[collection], growAction.SegmentID()))
suite.True(lo.Contains(segmentOnCollection[collection], reduceAction.SegmentID()))
suite.Equal(dstNode, growAction.Node())
suite.Equal(srcNode, reduceAction.Node())
t.Cancel()
}).Return(nil)
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.taskScheduler.AssertExpectations(suite.T())
}
}
func (suite *ServiceSuite) TestLoadBalanceFailed() {
suite.loadAll()
ctx := context.Background()
server := suite.server
// Test load balance without source node
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
dstNode := replicas[0].GetNodes()[1]
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
CollectionID: collection,
DstNodeIDs: []int64{dstNode},
SealedSegmentIDs: segments,
}
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
suite.Contains(resp.Reason, "source nodes can only contain 1 node")
}
// Test load balance with not fully loaded
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
srcNode := replicas[0].GetNodes()[0]
dstNode := replicas[0].GetNodes()[1]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loading)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
CollectionID: collection,
SourceNodeIDs: []int64{srcNode},
DstNodeIDs: []int64{dstNode},
SealedSegmentIDs: segments,
}
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
suite.Contains(resp.Reason, "can't balance segments of not fully loaded collection")
}
// Test load balance with source node and dest node not in the same replica
for _, collection := range suite.collections {
if suite.replicaNumber[collection] <= 1 {
continue
}
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
srcNode := replicas[0].GetNodes()[0]
dstNode := replicas[1].GetNodes()[0]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateSegmentDist(collection, srcNode)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
CollectionID: collection,
SourceNodeIDs: []int64{srcNode},
DstNodeIDs: []int64{dstNode},
SealedSegmentIDs: segments,
}
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
suite.Contains(resp.Reason, "destination nodes have to be in the same replica of source node")
}
// Test balance task failed
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
srcNode := replicas[0].GetNodes()[0]
dstNode := replicas[0].GetNodes()[1]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateSegmentDist(collection, srcNode)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
CollectionID: collection,
SourceNodeIDs: []int64{srcNode},
DstNodeIDs: []int64{dstNode},
SealedSegmentIDs: segments,
}
suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(balanceTask task.Task) {
balanceTask.SetErr(task.ErrTaskCanceled)
balanceTask.Cancel()
}).Return(nil)
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
suite.Contains(resp.Reason, "failed to balance segments")
suite.Contains(resp.Reason, task.ErrTaskCanceled.Error())
}
}
func (suite *ServiceSuite) TestShowConfigurations() {
ctx := context.Background()
server := suite.server
req := &internalpb.ShowConfigurationsRequest{
Pattern: "Port",
}
resp, err := server.ShowConfigurations(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.Len(resp.Configuations, 1)
suite.Equal("querycoord.port", resp.Configuations[0].Key)
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req = &internalpb.ShowConfigurationsRequest{
Pattern: "Port",
}
resp, err = server.ShowConfigurations(ctx, req)
suite.NoError(err)
suite.Contains(resp.Status.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestGetMetrics() {
ctx := context.Background()
server := suite.server
for _, node := range suite.nodes {
suite.cluster.EXPECT().GetMetrics(ctx, node, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: successStatus,
ComponentName: "QueryNode",
}, nil)
}
metricReq := make(map[string]string)
metricReq[metricsinfo.MetricTypeKey] = "system_info"
req, err := json.Marshal(metricReq)
suite.NoError(err)
resp, err := server.GetMetrics(ctx, &milvuspb.GetMetricsRequest{
Base: &commonpb.MsgBase{},
Request: string(req),
})
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
resp, err = server.GetMetrics(ctx, &milvuspb.GetMetricsRequest{
Base: &commonpb.MsgBase{},
Request: string(req),
})
suite.NoError(err)
suite.Contains(resp.Status.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestGetReplicas() {
suite.loadAll()
ctx := context.Background()
server := suite.server
for _, collection := range suite.collections {
suite.updateChannelDist(collection)
req := &milvuspb.GetReplicasRequest{
CollectionID: collection,
}
resp, err := server.GetReplicas(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.EqualValues(suite.replicaNumber[collection], len(resp.Replicas))
}
// Test get with shard nodes
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
for _, replica := range replicas {
suite.updateSegmentDist(collection, replica.GetNodes()[0])
}
suite.updateChannelDist(collection)
req := &milvuspb.GetReplicasRequest{
CollectionID: collection,
WithShardNodes: true,
}
resp, err := server.GetReplicas(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.EqualValues(suite.replicaNumber[collection], len(resp.Replicas))
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &milvuspb.GetReplicasRequest{
CollectionID: suite.collections[0],
}
resp, err := server.GetReplicas(ctx, req)
suite.NoError(err)
suite.Contains(resp.Status.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestCheckHealth() {
ctx := context.Background()
server := suite.server
// Test for server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
resp, err := server.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
suite.NoError(err)
suite.Equal(resp.IsHealthy, false)
suite.NotEmpty(resp.Reasons)
// Test for components state fail
for _, node := range suite.nodes {
suite.cluster.EXPECT().GetComponentStates(mock.Anything, node).Return(
&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Abnormal},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
},
nil).Once()
}
server.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err = server.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
suite.NoError(err)
suite.Equal(resp.IsHealthy, false)
suite.NotEmpty(resp.Reasons)
// Test for server is healthy
for _, node := range suite.nodes {
suite.cluster.EXPECT().GetComponentStates(mock.Anything, node).Return(
&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
},
nil).Once()
}
resp, err = server.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
suite.NoError(err)
suite.Equal(resp.IsHealthy, true)
suite.Empty(resp.Reasons)
}
func (suite *ServiceSuite) TestGetShardLeaders() {
suite.loadAll()
ctx := context.Background()
server := suite.server
for _, collection := range suite.collections {
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateChannelDist(collection)
req := &querypb.GetShardLeadersRequest{
CollectionID: collection,
}
suite.fetchHeartbeats(time.Now())
resp, err := server.GetShardLeaders(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
suite.Len(resp.Shards, len(suite.channels[collection]))
for _, shard := range resp.Shards {
suite.Len(shard.NodeIds, int(suite.replicaNumber[collection]))
}
}
// Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing)
req := &querypb.GetShardLeadersRequest{
CollectionID: suite.collections[0],
}
resp, err := server.GetShardLeaders(ctx, req)
suite.NoError(err)
suite.Contains(resp.Status.Reason, ErrNotHealthy.Error())
}
func (suite *ServiceSuite) TestGetShardLeadersFailed() {
suite.loadAll()
ctx := context.Background()
server := suite.server
for _, collection := range suite.collections {
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateChannelDist(collection)
req := &querypb.GetShardLeadersRequest{
CollectionID: collection,
}
// Last heartbeat response time too old
suite.fetchHeartbeats(time.Now().Add(-Params.QueryCoordCfg.HeartbeatAvailableInterval - 1))
resp, err := server.GetShardLeaders(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NoReplicaAvailable, resp.Status.ErrorCode)
// Segment not fully loaded
suite.updateChannelDistWithoutSegment(collection)
suite.fetchHeartbeats(time.Now())
resp, err = server.GetShardLeaders(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NoReplicaAvailable, resp.Status.ErrorCode)
}
}
func (suite *ServiceSuite) loadAll() {
ctx := context.Background()
for _, collection := range suite.collections {
suite.expectGetRecoverInfo(collection)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[collection],
}
job := job.NewLoadCollectionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.targetMgr,
suite.broker,
suite.nodeMgr,
)
suite.jobScheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
} else {
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: suite.replicaNumber[collection],
}
job := job.NewLoadPartitionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.targetMgr,
suite.broker,
suite.nodeMgr,
)
suite.jobScheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
}
}
}
func (suite *ServiceSuite) assertLoaded(collection int64) {
suite.True(suite.meta.Exist(collection))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.NextTarget))
}
for _, partitions := range suite.segments[collection] {
for _, segment := range partitions {
suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.NextTarget))
}
}
}
func (suite *ServiceSuite) assertPartitionLoaded(collection int64, partitions ...int64) {
suite.True(suite.meta.Exist(collection))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
}
partitionSet := typeutil.NewUniqueSet(partitions...)
for partition, segments := range suite.segments[collection] {
if !partitionSet.Contain(partition) {
continue
}
for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget))
}
}
}
func (suite *ServiceSuite) assertReleased(collection int64) {
suite.False(suite.meta.Exist(collection))
for _, channel := range suite.channels[collection] {
suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
}
for _, partitions := range suite.segments[collection] {
for _, segment := range partitions {
suite.Nil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget))
suite.Nil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.NextTarget))
}
}
}
func (suite *ServiceSuite) assertSegments(collection int64, segments []*querypb.SegmentInfo) bool {
segmentSet := typeutil.NewUniqueSet(
suite.getAllSegments(collection)...)
if !suite.Len(segments, segmentSet.Len()) {
return false
}
for _, segment := range segments {
if !suite.Contains(segmentSet, segment.GetSegmentID()) {
return false
}
}
return true
}
func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) {
vChannels := []*datapb.VchannelInfo{}
for _, channel := range suite.channels[collection] {
vChannels = append(vChannels, &datapb.VchannelInfo{
CollectionID: collection,
ChannelName: channel,
})
}
for partition, segments := range suite.segments[collection] {
segmentBinlogs := []*datapb.SegmentBinlogs{}
for _, segment := range segments {
segmentBinlogs = append(segmentBinlogs, &datapb.SegmentBinlogs{
SegmentID: segment,
InsertChannel: suite.channels[collection][segment%2],
})
}
suite.broker.EXPECT().
GetRecoveryInfo(mock.Anything, collection, partition).
Return(vChannels, segmentBinlogs, nil)
}
}
func (suite *ServiceSuite) getAllSegments(collection int64) []int64 {
allSegments := make([]int64, 0)
for _, segments := range suite.segments[collection] {
allSegments = append(allSegments, segments...)
}
return allSegments
}
func (suite *ServiceSuite) updateSegmentDist(collection, node int64) {
metaSegments := make([]*meta.Segment, 0)
for partition, segments := range suite.segments[collection] {
for _, segment := range segments {
metaSegments = append(metaSegments,
utils.CreateTestSegment(collection, partition, segment, node, 1, "test-channel"))
}
}
suite.dist.SegmentDistManager.Update(node, metaSegments...)
}
func (suite *ServiceSuite) updateChannelDist(collection int64) {
channels := suite.channels[collection]
segments := lo.Flatten(lo.Values(suite.segments[collection]))
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
for _, replica := range replicas {
i := 0
for _, node := range replica.GetNodes() {
suite.dist.ChannelDistManager.Update(node, meta.DmChannelFromVChannel(&datapb.VchannelInfo{
CollectionID: collection,
ChannelName: channels[i],
}))
suite.dist.LeaderViewManager.Update(node, &meta.LeaderView{
ID: node,
CollectionID: collection,
Channel: channels[i],
Segments: lo.SliceToMap(segments, func(segment int64) (int64, *querypb.SegmentDist) {
return segment, &querypb.SegmentDist{
NodeID: node,
Version: time.Now().Unix(),
}
}),
})
i++
if i >= len(channels) {
break
}
}
}
}
func (suite *ServiceSuite) updateChannelDistWithoutSegment(collection int64) {
channels := suite.channels[collection]
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
for _, replica := range replicas {
i := 0
for _, node := range replica.GetNodes() {
suite.dist.ChannelDistManager.Update(node, meta.DmChannelFromVChannel(&datapb.VchannelInfo{
CollectionID: collection,
ChannelName: channels[i],
}))
suite.dist.LeaderViewManager.Update(node, &meta.LeaderView{
ID: node,
CollectionID: collection,
Channel: channels[i],
})
i++
if i >= len(channels) {
break
}
}
}
}
func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status querypb.LoadStatus) {
collection := suite.meta.GetCollection(collectionID)
if collection != nil {
collection := collection.Clone()
collection.LoadPercentage = 0
if status == querypb.LoadStatus_Loaded {
collection.LoadPercentage = 100
}
collection.CollectionLoadInfo.Status = status
suite.meta.UpdateCollection(collection)
} else {
partitions := suite.meta.GetPartitionsByCollection(collectionID)
for _, partition := range partitions {
partition := partition.Clone()
partition.LoadPercentage = 0
if status == querypb.LoadStatus_Loaded {
partition.LoadPercentage = 100
}
partition.PartitionLoadInfo.Status = status
suite.meta.UpdatePartition(partition)
}
}
}
func (suite *ServiceSuite) fetchHeartbeats(time time.Time) {
for _, node := range suite.nodes {
node := suite.nodeMgr.Get(node)
node.SetLastHeartbeat(time)
}
}
func TestService(t *testing.T) {
suite.Run(t, new(ServiceSuite))
}