Fix handling segment change logic (#16695)

Dispatch segmentChangeInfo to ShardCluster leader
Hold segment remove before search is done

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2022-04-27 22:23:46 +08:00 committed by GitHub
parent bb6cd4b484
commit 3a6db2faeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 365 additions and 72 deletions

View File

@ -58,7 +58,6 @@ import (
"github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
) )
@ -452,12 +451,7 @@ func (node *QueryNode) watchChangeInfo() {
log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error())) log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error()))
continue continue
} }
go func() { go node.handleSealedSegmentsChangeInfo(info)
err = node.removeSegments(info)
if err != nil {
log.Warn("cleanup segments failed", zap.Any("error", err.Error()))
}
}()
default: default:
// do nothing // do nothing
} }
@ -466,58 +460,47 @@ func (node *QueryNode) watchChangeInfo() {
} }
} }
func (node *QueryNode) waitChangeInfo(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error { func (node *QueryNode) handleSealedSegmentsChangeInfo(info *querypb.SealedSegmentsChangeInfo) {
fn := func() error { for _, line := range info.GetInfos() {
/* vchannel, err := validateChangeChannel(line)
for _, info := range segmentChangeInfos.Infos { if err != nil {
canDoLoadBalance := true log.Warn("failed to validate vchannel for SegmentChangeInfo", zap.Error(err))
// make sure all query channel already received segment location changes continue
// Check online segments: }
for _, segmentInfo := range info.OnlineSegments {
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) { node.ShardClusterService.HandoffVChannelSegments(vchannel, line)
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID) }
if err != nil { }
canDoLoadBalance = false
break func validateChangeChannel(info *querypb.SegmentChangeInfo) (string, error) {
} if len(info.GetOnlineSegments()) == 0 && len(info.GetOfflineSegments()) == 0 {
if info.OnlineNodeID == Params.QueryNodeCfg.QueryNodeID && !qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) { return "", errors.New("SegmentChangeInfo with no segments info")
canDoLoadBalance = false
break
}
}
}
// Check offline segments:
for _, segmentInfo := range info.OfflineSegments {
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
if err != nil {
canDoLoadBalance = false
break
}
if info.OfflineNodeID == Params.QueryNodeCfg.QueryNodeID && qc.globalSegmentManager.hasGlobalSealedSegment(segmentInfo.SegmentID) {
canDoLoadBalance = false
break
}
}
}
if canDoLoadBalance {
return nil
}
return errors.New(fmt.Sprintln("waitChangeInfo failed, infoID = ", segmentChangeInfos.Base.GetMsgID()))
}
*/
return nil
} }
return retry.Do(node.queryNodeLoopCtx, fn, retry.Attempts(50)) var channelName string
for _, segment := range info.GetOnlineSegments() {
if channelName == "" {
channelName = segment.GetDmChannel()
}
if segment.GetDmChannel() != channelName {
return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel())
}
}
for _, segment := range info.GetOfflineSegments() {
if channelName == "" {
channelName = segment.GetDmChannel()
}
if segment.GetDmChannel() != channelName {
return "", fmt.Errorf("found multilple channel name in one SegmentChangeInfo, channel1: %s, channel 2:%s", channelName, segment.GetDmChannel())
}
}
return channelName, nil
} }
// remove the segments since it's already compacted or balanced to other QueryNodes // remove the segments since it's already compacted or balanced to other QueryNodes
func (node *QueryNode) removeSegments(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error { func (node *QueryNode) removeSegments(segmentChangeInfos *querypb.SealedSegmentsChangeInfo) error {
err := node.waitChangeInfo(segmentChangeInfos)
if err != nil {
return err
}
node.streaming.replica.queryLock() node.streaming.replica.queryLock()
node.historical.replica.queryLock() node.historical.replica.queryLock()

View File

@ -31,6 +31,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed" "go.etcd.io/etcd/server/v3/embed"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
@ -329,17 +330,6 @@ func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, e
return node, nil return node, nil
} }
func TestQueryNode_waitChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
err = node.waitChangeInfo(genSimpleChangeInfo())
assert.NoError(t, err)
}
func TestQueryNode_adjustByChangeInfo(t *testing.T) { func TestQueryNode_adjustByChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -534,3 +524,139 @@ func TestQueryNode_watchService(t *testing.T) {
assert.True(t, flag) assert.True(t, flag)
}) })
} }
func TestQueryNode_validateChangeChannel(t *testing.T) {
type testCase struct {
name string
info *querypb.SegmentChangeInfo
expectedError bool
expectedChannelName string
}
cases := []testCase{
{
name: "empty info",
info: &querypb.SegmentChangeInfo{},
expectedError: true,
},
{
name: "normal segment change info",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedError: false,
expectedChannelName: defaultDMLChannel,
},
{
name: "empty offline change info",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedError: false,
expectedChannelName: defaultDMLChannel,
},
{
name: "empty online change info",
info: &querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedError: false,
expectedChannelName: defaultDMLChannel,
},
{
name: "different channel in online",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
{DmChannel: "other_channel"},
},
},
expectedError: true,
},
{
name: "different channel in offline",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: "other_channel"},
},
},
expectedError: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
channelName, err := validateChangeChannel(tc.info)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expectedChannelName, channelName)
}
})
}
}
func TestQueryNode_handleSealedSegmentsChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
qn, err := genSimpleQueryNode(ctx)
require.NoError(t, err)
t.Run("empty info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{})
})
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(nil)
})
})
t.Run("normal segment change info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
Infos: []*querypb.SegmentChangeInfo{
{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
},
})
})
})
t.Run("bad change info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
Infos: []*querypb.SegmentChangeInfo{
{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: "other_channel"},
},
},
},
})
})
})
}

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/errorutil"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -78,6 +79,7 @@ type segmentEvent struct {
type shardQueryNode interface { type shardQueryNode interface {
Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error) Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)
Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error) Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
Stop() error Stop() error
} }
@ -363,7 +365,6 @@ func (sc *ShardCluster) watchNodes(evtCh <-chan nodeEvent) {
for { for {
select { select {
case evt, ok := <-evtCh: case evt, ok := <-evtCh:
log.Debug("node event", zap.Any("evt", evt))
if !ok { if !ok {
log.Warn("ShardCluster node channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID)) log.Warn("ShardCluster node channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
return return
@ -514,6 +515,8 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
offlineSegments = append(offlineSegments, seg.GetSegmentID()) offlineSegments = append(offlineSegments, seg.GetSegmentID())
} }
sc.waitSegmentsNotInUse(offlineSegments) sc.waitSegmentsNotInUse(offlineSegments)
removes := make(map[int64][]int64) // nodeID => []segmentIDs
// remove offline segments record // remove offline segments record
for _, seg := range info.OfflineSegments { for _, seg := range info.OfflineSegments {
// filter out segments not maintained in this cluster // filter out segments not maintained in this cluster
@ -521,11 +524,39 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
continue continue
} }
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()}) sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
removes[seg.GetNodeID()] = append(removes[seg.GetNodeID()], seg.SegmentID)
}
var errs errorutil.ErrorList
// notify querynode(s) to release segments
for nodeID, segmentIDs := range removes {
node, ok := sc.getNode(nodeID)
if !ok {
log.Warn("node not in cluster", zap.Int64("nodeID", nodeID), zap.Int64("collectionID", sc.collectionID), zap.String("vchannel", sc.vchannelName))
errs = append(errs, fmt.Errorf("node not in cluster nodeID %d", nodeID))
continue
}
state, err := node.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
CollectionID: sc.collectionID,
SegmentIDs: segmentIDs,
})
if err != nil {
errs = append(errs, err)
continue
}
if state.GetErrorCode() != commonpb.ErrorCode_Success {
errs = append(errs, fmt.Errorf("Release segments failed with reason: %s", state.GetReason()))
}
} }
// finish handoff and remove it from pending list // finish handoff and remove it from pending list
sc.finishHandoff(token) sc.finishHandoff(token)
if len(errs) > 0 {
return errs
}
return nil return nil
} }

View File

@ -137,3 +137,14 @@ func (s *ShardClusterService) SyncReplicaSegments(vchannelName string, distribut
return nil return nil
} }
// HandoffVChannelSegments dispatches SegmentChangeInfo to related ShardCluster with VChannel
func (s *ShardClusterService) HandoffVChannelSegments(vchannel string, info *querypb.SegmentChangeInfo) error {
raw, ok := s.clusters.Load(vchannel)
if !ok {
// not leader for this channel, ignore without error
return nil
}
sc := raw.(*ShardCluster)
return sc.HandoffSegments(info)
}

View File

@ -88,3 +88,23 @@ func TestShardClusterService_SyncReplicaSegments(t *testing.T) {
assert.Equal(t, segmentStateLoaded, segment.state) assert.Equal(t, segmentStateLoaded, segment.state)
}) })
} }
func TestShardClusterService_HandoffVChannelSegments(t *testing.T) {
qn, err := genSimpleQueryNode(context.Background())
require.NoError(t, err)
client := v3client.New(embedetcdServer.Server)
defer client.Close()
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
clusterService := newShardClusterService(client, session, qn)
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
assert.NoError(t, err)
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
//TODO change shardCluster to interface to mock test behavior
assert.NotPanics(t, func() {
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
assert.NoError(t, err)
})
}

View File

@ -22,6 +22,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -47,10 +48,12 @@ func (m *mockSegmentDetector) watchSegments(collectionID int64, replicaID int64,
} }
type mockShardQueryNode struct { type mockShardQueryNode struct {
searchResult *internalpb.SearchResults searchResult *internalpb.SearchResults
searchErr error searchErr error
queryResult *internalpb.RetrieveResults queryResult *internalpb.RetrieveResults
queryErr error queryErr error
releaseSegmentsResult *commonpb.Status
releaseSegmentsErr error
} }
func (m *mockShardQueryNode) Search(_ context.Context, _ *querypb.SearchRequest) (*internalpb.SearchResults, error) { func (m *mockShardQueryNode) Search(_ context.Context, _ *querypb.SearchRequest) (*internalpb.SearchResults, error) {
@ -61,6 +64,10 @@ func (m *mockShardQueryNode) Query(_ context.Context, _ *querypb.QueryRequest) (
return m.queryResult, m.queryErr return m.queryResult, m.queryErr
} }
func (m *mockShardQueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
return m.releaseSegmentsResult, m.releaseSegmentsErr
}
func (m *mockShardQueryNode) Stop() error { func (m *mockShardQueryNode) Stop() error {
return nil return nil
} }
@ -1336,7 +1343,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, buildMockQueryNode) }, buildMockQueryNode)
defer sc.Close() defer sc.Close()
sc.HandoffSegments(&querypb.SegmentChangeInfo{ err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{ OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName}, {SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
}, },
@ -1344,6 +1351,10 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName}, {SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
}, },
}) })
if err != nil {
t.Log(err.Error())
}
assert.NoError(t, err)
sc.mut.RLock() sc.mut.RLock()
_, has := sc.segments[1] _, has := sc.segments[1]
@ -1383,7 +1394,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, buildMockQueryNode) }, buildMockQueryNode)
defer sc.Close() defer sc.Close()
sc.HandoffSegments(&querypb.SegmentChangeInfo{ err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{ OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName}, {SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName}, {SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
@ -1393,6 +1404,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName}, {SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
}, },
}) })
assert.NoError(t, err)
sc.mut.RLock() sc.mut.RLock()
_, has := sc.segments[3] _, has := sc.segments[3]
@ -1439,7 +1451,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
sig := make(chan struct{}) sig := make(chan struct{})
go func() { go func() {
sc.HandoffSegments(&querypb.SegmentChangeInfo{ err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{ OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName}, {SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
}, },
@ -1448,6 +1460,7 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
}, },
}) })
assert.NoError(t, err)
close(sig) close(sig)
}() }()
@ -1493,4 +1506,113 @@ func TestShardCluster_HandoffSegments(t *testing.T) {
assert.False(t, has) assert.False(t, has)
}) })
t.Run("handoff from non-exist node", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 3, CollectionID: collectionID, DmChannel: vchannelName},
},
})
assert.Error(t, err)
})
t.Run("release failed", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
mqn := &mockShardQueryNode{}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, func(nodeID int64, addr string) shardQueryNode {
return mqn
})
defer sc.Close()
mqn.releaseSegmentsErr = errors.New("mocked error")
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
},
})
assert.Error(t, err)
mqn.releaseSegmentsErr = nil
mqn.releaseSegmentsResult = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mocked error",
}
err = sc.HandoffSegments(&querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
},
})
assert.Error(t, err)
})
} }