mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
Fix handling segment change logic (#16695)
Dispatch segmentChangeInfo to ShardCluster leader Hold segment remove before search is done Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
bb6cd4b484
commit
3a6db2faeb
@ -58,7 +58,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/util"
|
"github.com/milvus-io/milvus/internal/util"
|
||||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/internal/util/retry"
|
|
||||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
)
|
)
|
||||||
@ -452,12 +451,7 @@ func (node *QueryNode) watchChangeInfo() {
|
|||||||
log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error()))
|
log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go func() {
|
go node.handleSealedSegmentsChangeInfo(info)
|
||||||
err = node.removeSegments(info)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("cleanup segments failed", zap.Any("error", err.Error()))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
default:
|
default:
|
||||||
// do nothing
|
// do nothing
|
||||||
}
|
}
|
||||||
@ -466,58 +460,47 @@ func (node *QueryNode) watchChangeInfo() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *QueryNode) waitChangeInfo(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error {
|
func (node *QueryNode) handleSealedSegmentsChangeInfo(info *querypb.SealedSegmentsChangeInfo) {
|
||||||
fn := func() error {
|
for _, line := range info.GetInfos() {
|
||||||
/*
|
vchannel, err := validateChangeChannel(line)
|
||||||
for _, info := range segmentChangeInfos.Infos {
|
if err != nil {
|
||||||
canDoLoadBalance := true
|
log.Warn("failed to validate vchannel for SegmentChangeInfo", zap.Error(err))
|
||||||
// make sure all query channel already received segment location changes
|
continue
|
||||||
// Check online segments:
|
}
|
||||||
for _, segmentInfo := range info.OnlineSegments {
|
|
||||||
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
|
node.ShardClusterService.HandoffVChannelSegments(vchannel, line)
|
||||||
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
|
}
|
||||||
if err != nil {
|
}
|
||||||
canDoLoadBalance = false
|
|
||||||
break
|
func validateChangeChannel(info *querypb.SegmentChangeInfo) (string, error) {
|
||||||
}
|
if len(info.GetOnlineSegments()) == 0 && len(info.GetOfflineSegments()) == 0 {
|
||||||
if info.OnlineNodeID == Params.QueryNodeCfg.QueryNodeID && !qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) {
|
return "", errors.New("SegmentChangeInfo with no segments info")
|
||||||
canDoLoadBalance = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Check offline segments:
|
|
||||||
for _, segmentInfo := range info.OfflineSegments {
|
|
||||||
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
|
|
||||||
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
|
|
||||||
if err != nil {
|
|
||||||
canDoLoadBalance = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if info.OfflineNodeID == Params.QueryNodeCfg.QueryNodeID && qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) {
|
|
||||||
canDoLoadBalance = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if canDoLoadBalance {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New(fmt.Sprintln("waitChangeInfo failed, infoID = ", segmentChangeInfos.Base.GetMsgID()))
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return retry.Do(node.queryNodeLoopCtx, fn, retry.Attempts(50))
|
var channelName string
|
||||||
|
|
||||||
|
for _, segment := range info.GetOnlineSegments() {
|
||||||
|
if channelName == "" {
|
||||||
|
channelName = segment.GetDmChannel()
|
||||||
|
}
|
||||||
|
if segment.GetDmChannel() != channelName {
|
||||||
|
return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, segment := range info.GetOfflineSegments() {
|
||||||
|
if channelName == "" {
|
||||||
|
channelName = segment.GetDmChannel()
|
||||||
|
}
|
||||||
|
if segment.GetDmChannel() != channelName {
|
||||||
|
return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return channelName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove the segments since it's already compacted or balanced to other QueryNodes
|
// remove the segments since it's already compacted or balanced to other QueryNodes
|
||||||
func (node *QueryNode) removeSegments(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error {
|
func (node *QueryNode) removeSegments(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error {
|
||||||
err := node.waitChangeInfo(segmentChangeInfos)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
node.streaming.replica.queryLock()
|
node.streaming.replica.queryLock()
|
||||||
node.historical.replica.queryLock()
|
node.historical.replica.queryLock()
|
||||||
|
@ -31,6 +31,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"go.etcd.io/etcd/server/v3/embed"
|
"go.etcd.io/etcd/server/v3/embed"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||||
@ -329,17 +330,6 @@ func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, e
|
|||||||
return node, nil
|
return node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryNode_waitChangeInfo(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = node.waitChangeInfo(genSimpleChangeInfo())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryNode_adjustByChangeInfo(t *testing.T) {
|
func TestQueryNode_adjustByChangeInfo(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@ -534,3 +524,139 @@ func TestQueryNode_watchService(t *testing.T) {
|
|||||||
assert.True(t, flag)
|
assert.True(t, flag)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryNode_validateChangeChannel(t *testing.T) {
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
info *querypb.SegmentChangeInfo
|
||||||
|
expectedError bool
|
||||||
|
expectedChannelName string
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []testCase{
|
||||||
|
{
|
||||||
|
name: "empty info",
|
||||||
|
info: &querypb.SegmentChangeInfo{},
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal segment change info",
|
||||||
|
info: &querypb.SegmentChangeInfo{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
},
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
expectedChannelName: defaultDMLChannel,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty offline change info",
|
||||||
|
info: &querypb.SegmentChangeInfo{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
expectedChannelName: defaultDMLChannel,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty online change info",
|
||||||
|
info: &querypb.SegmentChangeInfo{
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
expectedChannelName: defaultDMLChannel,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different channel in online",
|
||||||
|
info: &querypb.SegmentChangeInfo{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
{DmChannel: "other_channel"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different channel in offline",
|
||||||
|
info: &querypb.SegmentChangeInfo{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
},
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: "other_channel"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
channelName, err := validateChangeChannel(tc.info)
|
||||||
|
if tc.expectedError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expectedChannelName, channelName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryNode_handleSealedSegmentsChangeInfo(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
qn, err := genSimpleQueryNode(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("empty info", func(t *testing.T) {
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{})
|
||||||
|
})
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
qn.handleSealedSegmentsChangeInfo(nil)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal segment change info", func(t *testing.T) {
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
|
||||||
|
Infos: []*querypb.SegmentChangeInfo{
|
||||||
|
{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
},
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad change info", func(t *testing.T) {
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
|
||||||
|
Infos: []*querypb.SegmentChangeInfo{
|
||||||
|
{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: defaultDMLChannel},
|
||||||
|
},
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{DmChannel: "other_channel"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -27,6 +27,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/errorutil"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@ -78,6 +79,7 @@ type segmentEvent struct {
|
|||||||
type shardQueryNode interface {
|
type shardQueryNode interface {
|
||||||
Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)
|
Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)
|
||||||
Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
|
Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
|
||||||
|
ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
|
||||||
Stop() error
|
Stop() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -363,7 +365,6 @@ func (sc *ShardCluster) watchNodes(evtCh <-chan nodeEvent) {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case evt, ok := <-evtCh:
|
case evt, ok := <-evtCh:
|
||||||
log.Debug("node event", zap.Any("evt", evt))
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warn("ShardCluster node channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
|
log.Warn("ShardCluster node channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
|
||||||
return
|
return
|
||||||
@ -514,6 +515,8 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
|
|||||||
offlineSegments = append(offlineSegments, seg.GetSegmentID())
|
offlineSegments = append(offlineSegments, seg.GetSegmentID())
|
||||||
}
|
}
|
||||||
sc.waitSegmentsNotInUse(offlineSegments)
|
sc.waitSegmentsNotInUse(offlineSegments)
|
||||||
|
|
||||||
|
removes := make(map[int64][]int64) // nodeID => []segmentIDs
|
||||||
// remove offline segments record
|
// remove offline segments record
|
||||||
for _, seg := range info.OfflineSegments {
|
for _, seg := range info.OfflineSegments {
|
||||||
// filter out segments not maintained in this cluster
|
// filter out segments not maintained in this cluster
|
||||||
@ -521,11 +524,39 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
|
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
|
||||||
|
|
||||||
|
removes[seg.GetNodeID()] = append(removes[seg.GetNodeID()], seg.SegmentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs errorutil.ErrorList
|
||||||
|
|
||||||
|
// notify querynode(s) to release segments
|
||||||
|
for nodeID, segmentIDs := range removes {
|
||||||
|
node, ok := sc.getNode(nodeID)
|
||||||
|
if !ok {
|
||||||
|
log.Warn("node not in cluster", zap.Int64("nodeID", nodeID), zap.Int64("collectionID", sc.collectionID), zap.String("vchannel", sc.vchannelName))
|
||||||
|
errs = append(errs, fmt.Errorf("node not in cluster nodeID %d", nodeID))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
state, err := node.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
|
||||||
|
CollectionID: sc.collectionID,
|
||||||
|
SegmentIDs: segmentIDs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if state.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
|
errs = append(errs, fmt.Errorf("Release segments failed with reason: %s", state.GetReason()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// finish handoff and remove it from pending list
|
// finish handoff and remove it from pending list
|
||||||
sc.finishHandoff(token)
|
sc.finishHandoff(token)
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errs
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,3 +137,14 @@ func (s *ShardClusterService) SyncReplicaSegments(vchannelName string, distribut
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandoffVChannelSegments dispatches SegmentChangeInfo to related ShardCluster with VChannel
|
||||||
|
func (s *ShardClusterService) HandoffVChannelSegments(vchannel string, info *querypb.SegmentChangeInfo) error {
|
||||||
|
raw, ok := s.clusters.Load(vchannel)
|
||||||
|
if !ok {
|
||||||
|
// not leader for this channel, ignore without error
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sc := raw.(*ShardCluster)
|
||||||
|
return sc.HandoffSegments(info)
|
||||||
|
}
|
||||||
|
@ -88,3 +88,23 @@ func TestShardClusterService_SyncReplicaSegments(t *testing.T) {
|
|||||||
assert.Equal(t, segmentStateLoaded, segment.state)
|
assert.Equal(t, segmentStateLoaded, segment.state)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShardClusterService_HandoffVChannelSegments(t *testing.T) {
|
||||||
|
qn, err := genSimpleQueryNode(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
client := v3client.New(embedetcdServer.Server)
|
||||||
|
defer client.Close()
|
||||||
|
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
|
||||||
|
clusterService := newShardClusterService(client, session, qn)
|
||||||
|
|
||||||
|
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
|
||||||
|
//TODO change shardCluster to interface to mock test behavior
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -47,10 +48,12 @@ func (m *mockSegmentDetector) watchSegments(collectionID int64, replicaID int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockShardQueryNode struct {
|
type mockShardQueryNode struct {
|
||||||
searchResult *internalpb.SearchResults
|
searchResult *internalpb.SearchResults
|
||||||
searchErr error
|
searchErr error
|
||||||
queryResult *internalpb.RetrieveResults
|
queryResult *internalpb.RetrieveResults
|
||||||
queryErr error
|
queryErr error
|
||||||
|
releaseSegmentsResult *commonpb.Status
|
||||||
|
releaseSegmentsErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockShardQueryNode) Search(_ context.Context, _ *querypb.SearchRequest) (*internalpb.SearchResults, error) {
|
func (m *mockShardQueryNode) Search(_ context.Context, _ *querypb.SearchRequest) (*internalpb.SearchResults, error) {
|
||||||
@ -61,6 +64,10 @@ func (m *mockShardQueryNode) Query(_ context.Context, _ *querypb.QueryRequest) (
|
|||||||
return m.queryResult, m.queryErr
|
return m.queryResult, m.queryErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockShardQueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
|
||||||
|
return m.releaseSegmentsResult, m.releaseSegmentsErr
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockShardQueryNode) Stop() error {
|
func (m *mockShardQueryNode) Stop() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1336,7 +1343,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
|||||||
}, buildMockQueryNode)
|
}, buildMockQueryNode)
|
||||||
defer sc.Close()
|
defer sc.Close()
|
||||||
|
|
||||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
OnlineSegments: []*querypb.SegmentInfo{
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
},
|
},
|
||||||
@ -1344,6 +1351,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
|||||||
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err.Error())
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
sc.mut.RLock()
|
sc.mut.RLock()
|
||||||
_, has := sc.segments[1]
|
_, has := sc.segments[1]
|
||||||
@ -1383,7 +1394,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
|||||||
}, buildMockQueryNode)
|
}, buildMockQueryNode)
|
||||||
defer sc.Close()
|
defer sc.Close()
|
||||||
|
|
||||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
OnlineSegments: []*querypb.SegmentInfo{
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
||||||
@ -1393,6 +1404,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
|||||||
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
sc.mut.RLock()
|
sc.mut.RLock()
|
||||||
_, has := sc.segments[3]
|
_, has := sc.segments[3]
|
||||||
@ -1439,7 +1451,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
|||||||
|
|
||||||
sig := make(chan struct{})
|
sig := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
OnlineSegments: []*querypb.SegmentInfo{
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
},
|
},
|
||||||
@ -1448,6 +1460,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
close(sig)
|
close(sig)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -1493,4 +1506,113 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
|
|||||||
|
|
||||||
assert.False(t, has)
|
assert.False(t, has)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("handoff from non-exist node", func(t *testing.T) {
|
||||||
|
nodeEvents := []nodeEvent{
|
||||||
|
{
|
||||||
|
nodeID: 1,
|
||||||
|
nodeAddr: "addr_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nodeID: 2,
|
||||||
|
nodeAddr: "addr_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
segmentEvents := []segmentEvent{
|
||||||
|
{
|
||||||
|
segmentID: 1,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
segmentID: 2,
|
||||||
|
nodeID: 2,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
evtCh := make(chan segmentEvent, 10)
|
||||||
|
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||||
|
&mockNodeDetector{
|
||||||
|
initNodes: nodeEvents,
|
||||||
|
}, &mockSegmentDetector{
|
||||||
|
initSegments: segmentEvents,
|
||||||
|
evtCh: evtCh,
|
||||||
|
}, buildMockQueryNode)
|
||||||
|
defer sc.Close()
|
||||||
|
|
||||||
|
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
},
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 1, NodeID: 3, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("release failed", func(t *testing.T) {
|
||||||
|
nodeEvents := []nodeEvent{
|
||||||
|
{
|
||||||
|
nodeID: 1,
|
||||||
|
nodeAddr: "addr_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nodeID: 2,
|
||||||
|
nodeAddr: "addr_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
segmentEvents := []segmentEvent{
|
||||||
|
{
|
||||||
|
segmentID: 1,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
segmentID: 2,
|
||||||
|
nodeID: 2,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
evtCh := make(chan segmentEvent, 10)
|
||||||
|
mqn := &mockShardQueryNode{}
|
||||||
|
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||||
|
&mockNodeDetector{
|
||||||
|
initNodes: nodeEvents,
|
||||||
|
}, &mockSegmentDetector{
|
||||||
|
initSegments: segmentEvents,
|
||||||
|
evtCh: evtCh,
|
||||||
|
}, func(nodeID int64, addr string) shardQueryNode {
|
||||||
|
return mqn
|
||||||
|
})
|
||||||
|
defer sc.Close()
|
||||||
|
|
||||||
|
mqn.releaseSegmentsErr = errors.New("mocked error")
|
||||||
|
|
||||||
|
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
mqn.releaseSegmentsErr = nil
|
||||||
|
mqn.releaseSegmentsResult = &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "mocked error",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user