diff --git a/internal/querycoordv2/meta/segment_dist_manager.go b/internal/querycoordv2/meta/segment_dist_manager.go index d105ec7808..1cf7b46f19 100644 --- a/internal/querycoordv2/meta/segment_dist_manager.go +++ b/internal/querycoordv2/meta/segment_dist_manager.go @@ -20,42 +20,76 @@ import ( "sync" "github.com/golang/protobuf/proto" + "github.com/samber/lo" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" - . "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type SegmentDistFilter func(s *Segment) bool +type SegmentDistFilter interface { + Match(s *Segment) bool + NodeIDs() ([]int64, bool) +} -func WithSegmentID(segmentID int64) SegmentDistFilter { - return func(s *Segment) bool { - return s.GetID() == segmentID - } +type SegmentDistFilterFunc func(s *Segment) bool + +func (f SegmentDistFilterFunc) Match(s *Segment) bool { + return f(s) +} + +func (f SegmentDistFilterFunc) NodeIDs() ([]int64, bool) { + return nil, false +} + +type ReplicaSegDistFilter struct { + *Replica +} + +func (f *ReplicaSegDistFilter) Match(s *Segment) bool { + return f.GetCollectionID() == s.GetCollectionID() && f.Contains(s.Node) +} + +func (f *ReplicaSegDistFilter) NodeIDs() ([]int64, bool) { + return f.GetNodes(), true } func WithReplica(replica *Replica) SegmentDistFilter { - return func(s *Segment) bool { - return replica.GetCollectionID() == s.GetCollectionID() && replica.Contains(s.Node) + return &ReplicaSegDistFilter{ + Replica: replica, } } +type NodeSegDistFilter int64 + +func (f NodeSegDistFilter) Match(s *Segment) bool { + return s.Node == int64(f) +} + +func (f NodeSegDistFilter) NodeIDs() ([]int64, bool) { + return []int64{int64(f)}, true +} + func WithNodeID(nodeID int64) SegmentDistFilter { - return func(s *Segment) bool { - return s.Node == nodeID - } + return NodeSegDistFilter(nodeID) } -func WithCollectionID(collectionID UniqueID) SegmentDistFilter { - return func(s *Segment) bool { +func WithSegmentID(segmentID int64) SegmentDistFilter { + return SegmentDistFilterFunc(func(s *Segment) bool { + return s.GetID() == segmentID + }) +} + +func WithCollectionID(collectionID typeutil.UniqueID) SegmentDistFilter { + return SegmentDistFilterFunc(func(s *Segment) bool { return s.CollectionID == collectionID - } + }) } func WithChannel(channelName string) SegmentDistFilter { - return func(s *Segment) bool { + return SegmentDistFilterFunc(func(s *Segment) bool { return s.GetInsertChannel() == channelName - } + }) } type Segment struct { @@ -84,16 +118,16 @@ type SegmentDistManager struct { rwmutex sync.RWMutex // nodeID -> []*Segment - segments map[UniqueID][]*Segment + segments map[typeutil.UniqueID][]*Segment } func NewSegmentDistManager() *SegmentDistManager { return &SegmentDistManager{ - segments: make(map[UniqueID][]*Segment), + segments: make(map[typeutil.UniqueID][]*Segment), } } -func (m *SegmentDistManager) Update(nodeID UniqueID, segments ...*Segment) { +func (m *SegmentDistManager) Update(nodeID typeutil.UniqueID, segments ...*Segment) { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -108,17 +142,36 @@ func (m *SegmentDistManager) GetByFilter(filters ...SegmentDistFilter) []*Segmen m.rwmutex.RLock() defer m.rwmutex.RUnlock() + nodes := make(typeutil.Set[int64]) + var hasNodeIDs bool + + for _, filter := range filters { + if ids, ok := filter.NodeIDs(); ok { + nodes.Insert(ids...) + hasNodeIDs = true + } + } + mergedFilters := func(s *Segment) bool { for _, f := range filters { - if f != nil && !f(s) { + if f != nil && !f.Match(s) { return false } } return true } + var candidates [][]*Segment + if hasNodeIDs { + candidates = lo.Map(nodes.Collect(), func(nodeID int64, _ int) []*Segment { + return m.segments[nodeID] + }) + } else { + candidates = lo.Values(m.segments) + } + ret := make([]*Segment, 0) - for _, segments := range m.segments { + for _, segments := range candidates { for _, segment := range segments { if mergedFilters(segment) { ret = append(ret, segment)