enhance: refactor segment dist manager interface (#31073)

issue: #31091
This PR add `GetByFilter` interface in segment dist manager, instead of
all kind of get func

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2024-03-08 16:29:01 +08:00 committed by GitHub
parent ff80d2fd8c
commit efe8cecc88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 98 additions and 126 deletions

View File

@ -520,14 +520,14 @@ func (b *MultiTargetBalancer) genSegmentPlan(replica *meta.Replica) []SegmentAss
nodeSegments := make(map[int64][]*meta.Segment)
globalNodeSegments := make(map[int64][]*meta.Segment)
for _, node := range replica.Nodes {
dist := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), node)
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.CollectionID), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
segment.GetLevel() != datapb.SegmentLevel_L0
})
nodeSegments[node] = segments
globalNodeSegments[node] = b.dist.SegmentDistManager.GetByNode(node)
globalNodeSegments[node] = b.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(node))
}
return b.genPlanByDistributions(nodeSegments, globalNodeSegments)

View File

@ -110,7 +110,7 @@ func (b *RowCountBasedBalancer) convertToNodeItemsBySegment(nodeIDs []int64) []*
node := nodeInfo.ID()
// calculate sealed segment row count on node
segments := b.dist.SegmentDistManager.GetByNode(node)
segments := b.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(node))
rowcnt := 0
for _, s := range segments {
rowcnt += int(s.GetNumOfRows())
@ -204,7 +204,7 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) ([]Segment
func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
for _, nodeID := range offlineNodes {
dist := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
@ -227,7 +227,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
segmentDist := make(map[int64][]*meta.Segment)
totalRowCount := 0
for _, node := range onlineNodes {
dist := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), node)
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
@ -271,7 +271,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool {
// if the segment are redundant, skip it's balance for now
return len(b.dist.SegmentDistManager.Get(s.GetID())) == 1
return len(b.dist.SegmentDistManager.GetByFilter(meta.WithSegmentID(s.GetID()))) == 1
})
if len(nodesWithLessRow) == 0 || len(segmentsToMove) == 0 {

View File

@ -141,7 +141,7 @@ func (b *ScoreBasedBalancer) convertToNodeItems(collectionID int64, nodeIDs []in
func (b *ScoreBasedBalancer) calculateScore(collectionID, nodeID int64) int {
rowCount := 0
// calculate global sealed segment row count
globalSegments := b.dist.SegmentDistManager.GetByNode(nodeID)
globalSegments := b.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(nodeID))
for _, s := range globalSegments {
rowCount += int(s.GetNumOfRows())
}
@ -154,7 +154,7 @@ func (b *ScoreBasedBalancer) calculateScore(collectionID, nodeID int64) int {
collectionRowCount := 0
// calculate collection sealed segment row count
collectionSegments := b.dist.SegmentDistManager.GetByCollectionAndNode(collectionID, nodeID)
collectionSegments := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(collectionID), meta.WithNodeID(nodeID))
for _, s := range collectionSegments {
collectionRowCount += int(s.GetNumOfRows())
}
@ -235,7 +235,7 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAss
func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
for _, nodeID := range offlineNodes {
dist := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
@ -258,7 +258,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [
// list all segment which could be balanced, and calculate node's score
for _, node := range onlineNodes {
dist := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), node)
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil &&
b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil &&
@ -298,7 +298,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [
// if the segment are redundant, skip it's balance for now
segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool {
return len(b.dist.SegmentDistManager.Get(s.GetID())) == 1
return len(b.dist.SegmentDistManager.GetByFilter(meta.WithSegmentID(s.GetID()))) == 1
})
if len(segmentsToMove) == 0 {

View File

@ -159,7 +159,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
distInfo += fmt.Sprintf("[nodeID:%d, ", normalNodeID)
distInfo += "loaded-segments:["
nodeRowSum := int64(0)
normalNodeSegments := segmentDistMgr.GetByNode(normalNodeID)
normalNodeSegments := segmentDistMgr.GetByFilter(meta.WithNodeID(normalNodeID))
for _, normalNodeSegment := range normalNodeSegments {
nodeRowSum += normalNodeSegment.GetNumOfRows()
}

View File

@ -153,7 +153,7 @@ func (c *IndexChecker) checkSegment(ctx context.Context, segment *meta.Segment,
func (c *IndexChecker) getSealedSegmentsDist(replica *meta.Replica) []*meta.Segment {
var ret []*meta.Segment
for _, node := range replica.GetNodes() {
ret = append(ret, c.dist.SegmentDistManager.GetByCollectionAndNode(replica.CollectionID, node)...)
ret = append(ret, c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))...)
}
return ret
}

View File

@ -100,7 +100,7 @@ func (c *LeaderChecker) Check(ctx context.Context) []task.Task {
leaderViews := c.dist.LeaderViewManager.GetByCollectionAndNode(replica.GetCollectionID(), node)
for ch, leaderView := range leaderViews {
dist := c.dist.SegmentDistManager.GetByShardWithReplica(ch, replica)
dist := c.dist.SegmentDistManager.GetByFilter(meta.WithChannel(ch), meta.WithReplica(replica))
tasks = append(tasks, c.findNeedLoadedSegments(ctx, replica.ID, leaderView, dist)...)
tasks = append(tasks, c.findNeedRemovedSegments(ctx, replica.ID, leaderView, dist)...)
}

View File

@ -95,7 +95,7 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task {
}
// find already released segments which are not contained in target
segments := c.dist.SegmentDistManager.GetAll()
segments := c.dist.SegmentDistManager.GetByFilter(nil)
released := utils.FilterReleased(segments, collectionIDs)
reduceTasks := c.createSegmentReduceTasks(ctx, released, -1, querypb.DataScope_Historical)
task.SetReason("collection released", reduceTasks...)
@ -271,7 +271,7 @@ func (c *SegmentChecker) getSealedSegmentDiff(
func (c *SegmentChecker) getSealedSegmentsDist(replica *meta.Replica) []*meta.Segment {
ret := make([]*meta.Segment, 0)
for _, node := range replica.GetNodes() {
ret = append(ret, c.dist.SegmentDistManager.GetByCollectionAndNode(replica.CollectionID, node)...)
ret = append(ret, c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))...)
}
return ret
}

View File

@ -61,7 +61,7 @@ func (s *Server) checkAnyReplicaAvailable(collectionID int64) bool {
}
func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentInfo {
segments := s.dist.SegmentDistManager.GetByCollection(collection)
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(collection))
currentTargetSegmentsMap := s.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget)
infos := make(map[int64]*querypb.SegmentInfo)
for _, segment := range segments {
@ -107,7 +107,7 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe
toBalance := typeutil.NewSet[*meta.Segment]()
// Only balance segments in targets
segments := s.dist.SegmentDistManager.GetByCollectionAndNode(req.GetCollectionID(), srcNode)
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(srcNode))
segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool {
return s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
})
@ -321,7 +321,7 @@ func (s *Server) fillReplicaInfo(replica *meta.Replica, withShardNodes bool) (*m
}
var segments []*meta.Segment
if withShardNodes {
segments = s.dist.SegmentDistManager.GetByCollection(replica.GetCollectionID())
segments = s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()))
}
for _, channel := range channels {

View File

@ -42,7 +42,7 @@ func waitCollectionReleased(dist *meta.DistributionManager, checkerController *c
for {
var (
channels []*meta.DmChannel
segments []*meta.Segment = dist.SegmentDistManager.GetByCollection(collection)
segments []*meta.Segment = dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(collection))
)
if partitionSet.Len() > 0 {
segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool {

View File

@ -26,6 +26,38 @@ import (
. "github.com/milvus-io/milvus/pkg/util/typeutil"
)
type SegmentDistFilter func(s *Segment) bool
func WithSegmentID(segmentID int64) SegmentDistFilter {
return func(s *Segment) bool {
return s.GetID() == segmentID
}
}
func WithReplica(replica *Replica) SegmentDistFilter {
return func(s *Segment) bool {
return replica.GetCollectionID() == s.GetCollectionID() && replica.Contains(s.Node)
}
}
func WithNodeID(nodeID int64) SegmentDistFilter {
return func(s *Segment) bool {
return s.Node == nodeID
}
}
func WithCollectionID(collectionID UniqueID) SegmentDistFilter {
return func(s *Segment) bool {
return s.CollectionID == collectionID
}
}
func WithChannel(channelName string) SegmentDistFilter {
return func(s *Segment) bool {
return s.GetInsertChannel() == channelName
}
}
type Segment struct {
*datapb.SegmentInfo
Node int64 // Node the segment is in
@ -71,14 +103,21 @@ func (m *SegmentDistManager) Update(nodeID UniqueID, segments ...*Segment) {
m.segments[nodeID] = segments
}
func (m *SegmentDistManager) Get(id UniqueID) []*Segment {
// GetByFilter return segment list which match all given filters
func (m *SegmentDistManager) GetByFilter(filters ...SegmentDistFilter) []*Segment {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
ret := make([]*Segment, 0)
for _, segments := range m.segments {
for _, segment := range segments {
if segment.GetID() == id {
allMatch := true
for _, f := range filters {
if f != nil && !f(segment) {
allMatch = false
}
}
if allMatch {
ret = append(ret, segment)
}
}
@ -86,100 +125,7 @@ func (m *SegmentDistManager) Get(id UniqueID) []*Segment {
return ret
}
// GetAll returns all segments
func (m *SegmentDistManager) GetAll() []*Segment {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
ret := make([]*Segment, 0)
for _, segments := range m.segments {
ret = append(ret, segments...)
}
return ret
}
// func (m *SegmentDistManager) Remove(ids ...UniqueID) {
// m.rwmutex.Lock()
// defer m.rwmutex.Unlock()
// for _, id := range ids {
// delete(m.segments, id)
// }
// }
// GetByNode returns all segments of the given node.
func (m *SegmentDistManager) GetByNode(nodeID UniqueID) []*Segment {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return m.segments[nodeID]
}
// GetByCollection returns all segments of the given collection.
func (m *SegmentDistManager) GetByCollection(collectionID UniqueID) []*Segment {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
ret := make([]*Segment, 0)
for _, segments := range m.segments {
for _, segment := range segments {
if segment.CollectionID == collectionID {
ret = append(ret, segment)
}
}
}
return ret
}
// GetByShard returns all segments of the given collection.
func (m *SegmentDistManager) GetByShard(shard string) []*Segment {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
ret := make([]*Segment, 0)
for _, segments := range m.segments {
for _, segment := range segments {
if segment.GetInsertChannel() == shard {
ret = append(ret, segment)
}
}
}
return ret
}
// GetByShard returns all segments of the given collection.
func (m *SegmentDistManager) GetByShardWithReplica(shard string, replica *Replica) []*Segment {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
ret := make([]*Segment, 0)
for nodeID, segments := range m.segments {
if !replica.Contains(nodeID) {
continue
}
for _, segment := range segments {
if segment.GetInsertChannel() == shard {
ret = append(ret, segment)
}
}
}
return ret
}
// GetByCollectionAndNode returns all segments of the given collection and node.
func (m *SegmentDistManager) GetByCollectionAndNode(collectionID, nodeID UniqueID) []*Segment {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
ret := make([]*Segment, 0)
for _, segment := range m.segments[nodeID] {
if segment.CollectionID == collectionID {
ret = append(ret, segment)
}
}
return ret
}
// return node list which contains the given segmentID
func (m *SegmentDistManager) GetSegmentDist(segmentID int64) []int64 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

View File

@ -22,6 +22,8 @@ import (
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type SegmentDistManagerSuite struct {
@ -89,38 +91,62 @@ func (suite *SegmentDistManagerSuite) TestGetBy() {
dist := suite.dist
// Test GetByNode
for _, node := range suite.nodes {
segments := dist.GetByNode(node)
segments := dist.GetByFilter(WithNodeID(node))
suite.AssertNode(segments, node)
}
// Test GetByShard
for _, shard := range []string{"dmc0", "dmc1"} {
segments := dist.GetByShard(shard)
segments := dist.GetByFilter(WithChannel(shard))
suite.AssertShard(segments, shard)
}
// Test GetByCollection
segments := dist.GetByCollection(suite.collection)
segments := dist.GetByFilter(WithCollectionID(suite.collection))
suite.Len(segments, 8)
suite.AssertCollection(segments, suite.collection)
segments = dist.GetByCollection(-1)
segments = dist.GetByFilter(WithCollectionID(-1))
suite.Len(segments, 0)
// Test GetByNodeAndCollection
// 1. Valid node and valid collection
for _, node := range suite.nodes {
segments := dist.GetByCollectionAndNode(suite.collection, node)
segments := dist.GetByFilter(WithCollectionID(suite.collection), WithNodeID(node))
suite.AssertNode(segments, node)
suite.AssertCollection(segments, suite.collection)
}
// 2. Valid node and invalid collection
segments = dist.GetByCollectionAndNode(-1, suite.nodes[1])
segments = dist.GetByFilter(WithCollectionID(-1), WithNodeID(suite.nodes[1]))
suite.Len(segments, 0)
// 3. Invalid node and valid collection
segments = dist.GetByCollectionAndNode(suite.collection, -1)
segments = dist.GetByFilter(WithCollectionID(suite.collection), WithNodeID(-1))
suite.Len(segments, 0)
// Test GetBy With Wrong Replica
replica := &Replica{
Replica: &querypb.Replica{
ID: 1,
CollectionID: suite.collection + 1,
Nodes: []int64{suite.nodes[0]},
},
nodes: typeutil.NewUniqueSet(suite.nodes[0]),
}
segments = dist.GetByFilter(WithReplica(replica))
suite.Len(segments, 0)
// Test GetBy With Correct Replica
replica = &Replica{
Replica: &querypb.Replica{
ID: 1,
CollectionID: suite.collection,
Nodes: []int64{suite.nodes[0]},
},
nodes: typeutil.NewUniqueSet(suite.nodes[0]),
}
segments = dist.GetByFilter(WithReplica(replica))
suite.Len(segments, 2)
}
func (suite *SegmentDistManagerSuite) AssertIDs(segments []*Segment, ids ...int64) bool {

View File

@ -99,7 +99,7 @@ func (ob *ReplicaObserver) checkNodesInReplica() {
for node := range outboundNodes {
channels := ob.distMgr.ChannelDistManager.GetByCollectionAndNode(collectionID, node)
segments := ob.distMgr.SegmentDistManager.GetByCollectionAndNode(collectionID, node)
segments := ob.distMgr.SegmentDistManager.GetByFilter(meta.WithCollectionID(collectionID), meta.WithNodeID(node))
if len(channels) == 0 && len(segments) == 0 {
replica.RemoveNode(node)

View File

@ -501,7 +501,7 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
infos = s.getCollectionSegmentInfo(req.GetCollectionID())
} else {
for _, segmentID := range req.GetSegmentIDs() {
segments := s.dist.SegmentDistManager.Get(segmentID)
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithSegmentID(segmentID))
if len(segments) == 0 {
err := merr.WrapErrSegmentNotLoaded(segmentID)
msg := fmt.Sprintf("segment %v not found in any node", segmentID)

View File

@ -121,7 +121,7 @@ func (action *SegmentAction) IsFinished(distMgr *meta.DistributionManager) bool
// the leader should return a map of segment ID to list of nodes,
// now, we just always commit the release task to executor once.
// NOTE: DO NOT create a task containing release action and the action is not the last action
sealed := distMgr.SegmentDistManager.GetByNode(action.Node())
sealed := distMgr.SegmentDistManager.GetByFilter(meta.WithNodeID(action.Node()))
growing := distMgr.LeaderViewManager.GetSegmentByNode(action.Node())
segments := make([]int64, 0, len(sealed)+len(growing))
for _, segment := range sealed {