mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
Fix handling segment change logic (#16695)
Dispatch segmentChangeInfo to ShardCluster leader Hold segment remove before search is done Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
bb6cd4b484
commit
3a6db2faeb
@ -58,7 +58,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
@ -452,12 +451,7 @@ func (node *QueryNode) watchChangeInfo() {
|
||||
log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error()))
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
err = node.removeSegments(info)
|
||||
if err != nil {
|
||||
log.Warn("cleanup segments failed", zap.Any("error", err.Error()))
|
||||
}
|
||||
}()
|
||||
go node.handleSealedSegmentsChangeInfo(info)
|
||||
default:
|
||||
// do nothing
|
||||
}
|
||||
@ -466,58 +460,47 @@ func (node *QueryNode) watchChangeInfo() {
|
||||
}
|
||||
}
|
||||
|
||||
func (node *QueryNode) waitChangeInfo(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error {
|
||||
fn := func() error {
|
||||
/*
|
||||
for _, info := range segmentChangeInfos.Infos {
|
||||
canDoLoadBalance := true
|
||||
// make sure all query channel already received segment location changes
|
||||
// Check online segments:
|
||||
for _, segmentInfo := range info.OnlineSegments {
|
||||
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
|
||||
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
|
||||
if err != nil {
|
||||
canDoLoadBalance = false
|
||||
break
|
||||
}
|
||||
if info.OnlineNodeID == Params.QueryNodeCfg.QueryNodeID && !qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) {
|
||||
canDoLoadBalance = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check offline segments:
|
||||
for _, segmentInfo := range info.OfflineSegments {
|
||||
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
|
||||
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
|
||||
if err != nil {
|
||||
canDoLoadBalance = false
|
||||
break
|
||||
}
|
||||
if info.OfflineNodeID == Params.QueryNodeCfg.QueryNodeID && qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) {
|
||||
canDoLoadBalance = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if canDoLoadBalance {
|
||||
return nil
|
||||
}
|
||||
return errors.New(fmt.Sprintln("waitChangeInfo failed, infoID = ", segmentChangeInfos.Base.GetMsgID()))
|
||||
}
|
||||
*/
|
||||
return nil
|
||||
func (node *QueryNode) handleSealedSegmentsChangeInfo(info *querypb.SealedSegmentsChangeInfo) {
|
||||
for _, line := range info.GetInfos() {
|
||||
vchannel, err := validateChangeChannel(line)
|
||||
if err != nil {
|
||||
log.Warn("failed to validate vchannel for SegmentChangeInfo", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
node.ShardClusterService.HandoffVChannelSegments(vchannel, line)
|
||||
}
|
||||
}
|
||||
|
||||
func validateChangeChannel(info *querypb.SegmentChangeInfo) (string, error) {
|
||||
if len(info.GetOnlineSegments()) == 0 && len(info.GetOfflineSegments()) == 0 {
|
||||
return "", errors.New("SegmentChangeInfo with no segments info")
|
||||
}
|
||||
|
||||
return retry.Do(node.queryNodeLoopCtx, fn, retry.Attempts(50))
|
||||
var channelName string
|
||||
|
||||
for _, segment := range info.GetOnlineSegments() {
|
||||
if channelName == "" {
|
||||
channelName = segment.GetDmChannel()
|
||||
}
|
||||
if segment.GetDmChannel() != channelName {
|
||||
return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel())
|
||||
}
|
||||
}
|
||||
for _, segment := range info.GetOfflineSegments() {
|
||||
if channelName == "" {
|
||||
channelName = segment.GetDmChannel()
|
||||
}
|
||||
if segment.GetDmChannel() != channelName {
|
||||
return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel())
|
||||
}
|
||||
}
|
||||
|
||||
return channelName, nil
|
||||
}
|
||||
|
||||
// remove the segments since it's already compacted or balanced to other QueryNodes
|
||||
func (node *QueryNode) removeSegments(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error {
|
||||
err := node.waitChangeInfo(segmentChangeInfos)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node.streaming.replica.queryLock()
|
||||
node.historical.replica.queryLock()
|
||||
|
@ -31,6 +31,7 @@ import (
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
@ -329,17 +330,6 @@ func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, e
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func TestQueryNode_waitChangeInfo(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = node.waitChangeInfo(genSimpleChangeInfo())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestQueryNode_adjustByChangeInfo(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@ -534,3 +524,139 @@ func TestQueryNode_watchService(t *testing.T) {
|
||||
assert.True(t, flag)
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryNode_validateChangeChannel(t *testing.T) {
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
info *querypb.SegmentChangeInfo
|
||||
expectedError bool
|
||||
expectedChannelName string
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "empty info",
|
||||
info: &querypb.SegmentChangeInfo{},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "normal segment change info",
|
||||
info: &querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
},
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
expectedChannelName: defaultDMLChannel,
|
||||
},
|
||||
{
|
||||
name: "empty offline change info",
|
||||
info: &querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
expectedChannelName: defaultDMLChannel,
|
||||
},
|
||||
{
|
||||
name: "empty online change info",
|
||||
info: &querypb.SegmentChangeInfo{
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
expectedChannelName: defaultDMLChannel,
|
||||
},
|
||||
{
|
||||
name: "different channel in online",
|
||||
info: &querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
{DmChannel: "other_channel"},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "different channel in offline",
|
||||
info: &querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
},
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: "other_channel"},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
channelName, err := validateChangeChannel(tc.info)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedChannelName, channelName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryNode_handleSealedSegmentsChangeInfo(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
qn, err := genSimpleQueryNode(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("empty info", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{})
|
||||
})
|
||||
assert.NotPanics(t, func() {
|
||||
qn.handleSealedSegmentsChangeInfo(nil)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("normal segment change info", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
|
||||
Infos: []*querypb.SegmentChangeInfo{
|
||||
{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
},
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("bad change info", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
|
||||
Infos: []*querypb.SegmentChangeInfo{
|
||||
{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: defaultDMLChannel},
|
||||
},
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{DmChannel: "other_channel"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/errorutil"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@ -78,6 +79,7 @@ type segmentEvent struct {
|
||||
type shardQueryNode interface {
|
||||
Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)
|
||||
Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
|
||||
ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
|
||||
Stop() error
|
||||
}
|
||||
|
||||
@ -363,7 +365,6 @@ func (sc *ShardCluster) watchNodes(evtCh <-chan nodeEvent) {
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-evtCh:
|
||||
log.Debug("node event", zap.Any("evt", evt))
|
||||
if !ok {
|
||||
log.Warn("ShardCluster node channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
|
||||
return
|
||||
@ -514,6 +515,8 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
|
||||
offlineSegments = append(offlineSegments, seg.GetSegmentID())
|
||||
}
|
||||
sc.waitSegmentsNotInUse(offlineSegments)
|
||||
|
||||
removes := make(map[int64][]int64) // nodeID => []segmentIDs
|
||||
// remove offline segments record
|
||||
for _, seg := range info.OfflineSegments {
|
||||
// filter out segments not maintained in this cluster
|
||||
@ -521,11 +524,39 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
|
||||
continue
|
||||
}
|
||||
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
|
||||
|
||||
removes[seg.GetNodeID()] = append(removes[seg.GetNodeID()], seg.SegmentID)
|
||||
}
|
||||
|
||||
var errs errorutil.ErrorList
|
||||
|
||||
// notify querynode(s) to release segments
|
||||
for nodeID, segmentIDs := range removes {
|
||||
node, ok := sc.getNode(nodeID)
|
||||
if !ok {
|
||||
log.Warn("node not in cluster", zap.Int64("nodeID", nodeID), zap.Int64("collectionID", sc.collectionID), zap.String("vchannel", sc.vchannelName))
|
||||
errs = append(errs, fmt.Errorf("node not in cluster nodeID %d", nodeID))
|
||||
continue
|
||||
}
|
||||
state, err := node.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
|
||||
CollectionID: sc.collectionID,
|
||||
SegmentIDs: segmentIDs,
|
||||
})
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
if state.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
errs = append(errs, fmt.Errorf("Release segments failed with reason: %s", state.GetReason()))
|
||||
}
|
||||
}
|
||||
|
||||
// finish handoff and remove it from pending list
|
||||
sc.finishHandoff(token)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -137,3 +137,14 @@ func (s *ShardClusterService) SyncReplicaSegments(vchannelName string, distribut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandoffVChannelSegments dispatches SegmentChangeInfo to related ShardCluster with VChannel
|
||||
func (s *ShardClusterService) HandoffVChannelSegments(vchannel string, info *querypb.SegmentChangeInfo) error {
|
||||
raw, ok := s.clusters.Load(vchannel)
|
||||
if !ok {
|
||||
// not leader for this channel, ignore without error
|
||||
return nil
|
||||
}
|
||||
sc := raw.(*ShardCluster)
|
||||
return sc.HandoffSegments(info)
|
||||
}
|
||||
|
@ -88,3 +88,23 @@ func TestShardClusterService_SyncReplicaSegments(t *testing.T) {
|
||||
assert.Equal(t, segmentStateLoaded, segment.state)
|
||||
})
|
||||
}
|
||||
|
||||
func TestShardClusterService_HandoffVChannelSegments(t *testing.T) {
|
||||
qn, err := genSimpleQueryNode(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
client := v3client.New(embedetcdServer.Server)
|
||||
defer client.Close()
|
||||
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
|
||||
clusterService := newShardClusterService(client, session, qn)
|
||||
|
||||
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
|
||||
//TODO change shardCluster to interface to mock test behavior
|
||||
assert.NotPanics(t, func() {
|
||||
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -47,10 +48,12 @@ func (m *mockSegmentDetector) watchSegments(collectionID int64, replicaID int64,
|
||||
}
|
||||
|
||||
type mockShardQueryNode struct {
|
||||
searchResult *internalpb.SearchResults
|
||||
searchErr error
|
||||
queryResult *internalpb.RetrieveResults
|
||||
queryErr error
|
||||
searchResult *internalpb.SearchResults
|
||||
searchErr error
|
||||
queryResult *internalpb.RetrieveResults
|
||||
queryErr error
|
||||
releaseSegmentsResult *commonpb.Status
|
||||
releaseSegmentsErr error
|
||||
}
|
||||
|
||||
func (m *mockShardQueryNode) Search(_ context.Context, _ *querypb.SearchRequest) (*internalpb.SearchResults, error) {
|
||||
@ -61,6 +64,10 @@ func (m *mockShardQueryNode) Query(_ context.Context, _ *querypb.QueryRequest) (
|
||||
return m.queryResult, m.queryErr
|
||||
}
|
||||
|
||||
func (m *mockShardQueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
|
||||
return m.releaseSegmentsResult, m.releaseSegmentsErr
|
||||
}
|
||||
|
||||
func (m *mockShardQueryNode) Stop() error {
|
||||
return nil
|
||||
}
|
||||
@ -1336,7 +1343,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
@ -1344,6 +1351,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
sc.mut.RLock()
|
||||
_, has := sc.segments[1]
|
||||
@ -1383,7 +1394,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
||||
@ -1393,6 +1404,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
sc.mut.RLock()
|
||||
_, has := sc.segments[3]
|
||||
@ -1439,7 +1451,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||
|
||||
sig := make(chan struct{})
|
||||
go func() {
|
||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
@ -1448,6 +1460,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
close(sig)
|
||||
}()
|
||||
|
||||
@ -1493,4 +1506,113 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||
|
||||
assert.False(t, has)
|
||||
})
|
||||
|
||||
t.Run("handoff from non-exist node", func(t *testing.T) {
|
||||
nodeEvents := []nodeEvent{
|
||||
{
|
||||
nodeID: 1,
|
||||
nodeAddr: "addr_1",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
nodeAddr: "addr_2",
|
||||
},
|
||||
}
|
||||
|
||||
segmentEvents := []segmentEvent{
|
||||
{
|
||||
segmentID: 1,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
{
|
||||
segmentID: 2,
|
||||
nodeID: 2,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
}
|
||||
evtCh := make(chan segmentEvent, 10)
|
||||
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||
&mockNodeDetector{
|
||||
initNodes: nodeEvents,
|
||||
}, &mockSegmentDetector{
|
||||
initSegments: segmentEvents,
|
||||
evtCh: evtCh,
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 1, NodeID: 3, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("release failed", func(t *testing.T) {
|
||||
nodeEvents := []nodeEvent{
|
||||
{
|
||||
nodeID: 1,
|
||||
nodeAddr: "addr_1",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
nodeAddr: "addr_2",
|
||||
},
|
||||
}
|
||||
|
||||
segmentEvents := []segmentEvent{
|
||||
{
|
||||
segmentID: 1,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
{
|
||||
segmentID: 2,
|
||||
nodeID: 2,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
}
|
||||
evtCh := make(chan segmentEvent, 10)
|
||||
mqn := &mockShardQueryNode{}
|
||||
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||
&mockNodeDetector{
|
||||
initNodes: nodeEvents,
|
||||
}, &mockSegmentDetector{
|
||||
initSegments: segmentEvents,
|
||||
evtCh: evtCh,
|
||||
}, func(nodeID int64, addr string) shardQueryNode {
|
||||
return mqn
|
||||
})
|
||||
defer sc.Close()
|
||||
|
||||
mqn.releaseSegmentsErr = errors.New("mocked error")
|
||||
|
||||
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
|
||||
mqn.releaseSegmentsErr = nil
|
||||
mqn.releaseSegmentsResult = &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mocked error",
|
||||
}
|
||||
|
||||
err = sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user