mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
Add segment reference count and handles change info in ShardCluster (#16620)
Resolves #16619 Add reference count for each search/query request For SegmentChangeInfo - Wait all segments in OnlineList to be loaded - Add handoff event into pending list - Wait all segments in OfflineList is not used (reference count = 0) Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
4ef2df8cb9
commit
b99b65c26e
@ -20,10 +20,23 @@ type queryChannel struct {
|
|||||||
|
|
||||||
streaming *streaming
|
streaming *streaming
|
||||||
queryMsgStream msgstream.MsgStream
|
queryMsgStream msgstream.MsgStream
|
||||||
|
shardCluster *ShardClusterService
|
||||||
asConsumeOnce sync.Once
|
asConsumeOnce sync.Once
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewQueryChannel create a query channel with provided shardCluster, query msgstream and collection id
|
||||||
|
func NewQueryChannel(collectionID int64, scs *ShardClusterService, qms msgstream.MsgStream, streaming *streaming) *queryChannel {
|
||||||
|
return &queryChannel{
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
collectionID: collectionID,
|
||||||
|
|
||||||
|
streaming: streaming,
|
||||||
|
queryMsgStream: qms,
|
||||||
|
shardCluster: scs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// AsConsumer do AsConsumer for query msgstream and seek if position is not nil
|
// AsConsumer do AsConsumer for query msgstream and seek if position is not nil
|
||||||
func (qc *queryChannel) AsConsumer(channelName string, subName string, position *internalpb.MsgPosition) error {
|
func (qc *queryChannel) AsConsumer(channelName string, subName string, position *internalpb.MsgPosition) error {
|
||||||
var err error
|
var err error
|
||||||
@ -101,7 +114,9 @@ func (qc *queryChannel) adjustByChangeInfo(msg *msgstream.SealedSegmentsChangeIn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// should handle segment change in shardCluster
|
// process change in shard cluster
|
||||||
|
qc.shardCluster.HandoffSegments(qc.collectionID, info)
|
||||||
|
|
||||||
// for OnlineSegments:
|
// for OnlineSegments:
|
||||||
for _, segment := range info.OnlineSegments {
|
for _, segment := range info.OnlineSegments {
|
||||||
/*
|
/*
|
||||||
@ -124,12 +139,6 @@ func (qc *queryChannel) adjustByChangeInfo(msg *msgstream.SealedSegmentsChangeIn
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
/*
|
|
||||||
// for OfflineSegments:
|
|
||||||
for _, segment := range info.OfflineSegments {
|
|
||||||
// 1. update global sealed segments
|
|
||||||
q.globalSegmentManager.removeGlobalSealedSegmentInfo(segment.SegmentID)
|
|
||||||
}*/
|
|
||||||
|
|
||||||
log.Info("Successfully changed global sealed segment info ",
|
log.Info("Successfully changed global sealed segment info ",
|
||||||
zap.Int64("collection ", qc.collectionID),
|
zap.Int64("collection ", qc.collectionID),
|
||||||
|
@ -97,13 +97,7 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
|
|||||||
mqs := &mockQueryMsgStream{}
|
mqs := &mockQueryMsgStream{}
|
||||||
mqs.On("Close").Return()
|
mqs.On("Close").Return()
|
||||||
|
|
||||||
qc := &queryChannel{
|
qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
|
||||||
closeCh: make(chan struct{}),
|
|
||||||
collectionID: defaultCollectionID,
|
|
||||||
|
|
||||||
streaming: nil,
|
|
||||||
queryMsgStream: mqs,
|
|
||||||
}
|
|
||||||
|
|
||||||
mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return()
|
mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return()
|
||||||
|
|
||||||
@ -122,13 +116,7 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
|
|||||||
mqs := &mockQueryMsgStream{}
|
mqs := &mockQueryMsgStream{}
|
||||||
mqs.On("Close").Return()
|
mqs.On("Close").Return()
|
||||||
|
|
||||||
qc := &queryChannel{
|
qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
|
||||||
closeCh: make(chan struct{}),
|
|
||||||
collectionID: defaultCollectionID,
|
|
||||||
|
|
||||||
streaming: nil,
|
|
||||||
queryMsgStream: mqs,
|
|
||||||
}
|
|
||||||
|
|
||||||
mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return()
|
mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return()
|
||||||
|
|
||||||
@ -146,13 +134,8 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
|
|||||||
mqs := &mockQueryMsgStream{}
|
mqs := &mockQueryMsgStream{}
|
||||||
mqs.On("Close").Return()
|
mqs.On("Close").Return()
|
||||||
|
|
||||||
qc := &queryChannel{
|
qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
|
||||||
closeCh: make(chan struct{}),
|
|
||||||
collectionID: defaultCollectionID,
|
|
||||||
|
|
||||||
streaming: nil,
|
|
||||||
queryMsgStream: mqs,
|
|
||||||
}
|
|
||||||
msgID := make([]byte, 8)
|
msgID := make([]byte, 8)
|
||||||
rand.Read(msgID)
|
rand.Read(msgID)
|
||||||
pos := &internalpb.MsgPosition{MsgID: msgID}
|
pos := &internalpb.MsgPosition{MsgID: msgID}
|
||||||
|
@ -138,12 +138,7 @@ func (q *queryShardService) getQueryChannel(collectionID int64) *queryChannel {
|
|||||||
qc, ok := q.queryChannels[collectionID]
|
qc, ok := q.queryChannels[collectionID]
|
||||||
if !ok {
|
if !ok {
|
||||||
queryStream, _ := q.factory.NewQueryMsgStream(q.ctx)
|
queryStream, _ := q.factory.NewQueryMsgStream(q.ctx)
|
||||||
qc = &queryChannel{
|
qc = NewQueryChannel(collectionID, q.shardClusterService, queryStream, q.streaming)
|
||||||
closeCh: make(chan struct{}),
|
|
||||||
collectionID: collectionID,
|
|
||||||
queryMsgStream: queryStream,
|
|
||||||
streaming: q.streaming,
|
|
||||||
}
|
|
||||||
q.queryChannels[collectionID] = qc
|
q.queryChannels[collectionID] = qc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,6 +92,7 @@ type shardSegmentInfo struct {
|
|||||||
partitionID int64
|
partitionID int64
|
||||||
nodeID int64
|
nodeID int64
|
||||||
state segmentState
|
state segmentState
|
||||||
|
inUse int32
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShardNodeDetector provides method to detect node events
|
// ShardNodeDetector provides method to detect node events
|
||||||
@ -119,9 +120,13 @@ type ShardCluster struct {
|
|||||||
segmentDetector ShardSegmentDetector
|
segmentDetector ShardSegmentDetector
|
||||||
nodeBuilder ShardNodeBuilder
|
nodeBuilder ShardNodeBuilder
|
||||||
|
|
||||||
mut sync.RWMutex
|
mut sync.RWMutex
|
||||||
nodes map[int64]*shardNode // online nodes
|
nodes map[int64]*shardNode // online nodes
|
||||||
segments map[int64]*shardSegmentInfo // shard segments
|
segments map[int64]*shardSegmentInfo // shard segments
|
||||||
|
handoffs map[int32]*querypb.SegmentChangeInfo // current pending handoff
|
||||||
|
lastToken *atomic.Int32 // last token used for segment change info
|
||||||
|
segmentCond *sync.Cond // segment state change condition
|
||||||
|
rcCond *sync.Cond // segment rc change condition
|
||||||
|
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
closeCh chan struct{}
|
closeCh chan struct{}
|
||||||
@ -141,12 +146,19 @@ func NewShardCluster(collectionID int64, replicaID int64, vchannelName string,
|
|||||||
segmentDetector: segmentDetector,
|
segmentDetector: segmentDetector,
|
||||||
nodeBuilder: nodeBuilder,
|
nodeBuilder: nodeBuilder,
|
||||||
|
|
||||||
nodes: make(map[int64]*shardNode),
|
nodes: make(map[int64]*shardNode),
|
||||||
segments: make(map[int64]*shardSegmentInfo),
|
segments: make(map[int64]*shardSegmentInfo),
|
||||||
|
handoffs: make(map[int32]*querypb.SegmentChangeInfo),
|
||||||
|
lastToken: atomic.NewInt32(0),
|
||||||
|
|
||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m := sync.Mutex{}
|
||||||
|
sc.segmentCond = sync.NewCond(&m)
|
||||||
|
m2 := sync.Mutex{}
|
||||||
|
sc.rcCond = sync.NewCond(&m2)
|
||||||
|
|
||||||
sc.init()
|
sc.init()
|
||||||
|
|
||||||
return sc
|
return sc
|
||||||
@ -205,8 +217,15 @@ func (sc *ShardCluster) removeNode(evt nodeEvent) {
|
|||||||
|
|
||||||
// updateSegment apply segment change to shard cluster
|
// updateSegment apply segment change to shard cluster
|
||||||
func (sc *ShardCluster) updateSegment(evt segmentEvent) {
|
func (sc *ShardCluster) updateSegment(evt segmentEvent) {
|
||||||
|
|
||||||
log.Debug("ShardCluster update segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
|
log.Debug("ShardCluster update segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
|
||||||
|
|
||||||
|
// notify handoff wait online if any
|
||||||
|
defer func() {
|
||||||
|
sc.segmentCond.L.Lock()
|
||||||
|
sc.segmentCond.Broadcast()
|
||||||
|
sc.segmentCond.L.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
sc.mut.Lock()
|
sc.mut.Lock()
|
||||||
defer sc.mut.Unlock()
|
defer sc.mut.Unlock()
|
||||||
|
|
||||||
@ -255,6 +274,8 @@ func (sc *ShardCluster) transferSegment(old *shardSegmentInfo, evt segmentEvent)
|
|||||||
// removeSegment removes segment from cluster
|
// removeSegment removes segment from cluster
|
||||||
// should only applied in hand-off or load balance procedure
|
// should only applied in hand-off or load balance procedure
|
||||||
func (sc *ShardCluster) removeSegment(evt segmentEvent) {
|
func (sc *ShardCluster) removeSegment(evt segmentEvent) {
|
||||||
|
log.Debug("ShardCluster remove segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
|
||||||
|
|
||||||
sc.mut.Lock()
|
sc.mut.Lock()
|
||||||
defer sc.mut.Unlock()
|
defer sc.mut.Unlock()
|
||||||
|
|
||||||
@ -269,7 +290,6 @@ func (sc *ShardCluster) removeSegment(evt segmentEvent) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO check handoff / load balance
|
|
||||||
delete(sc.segments, evt.segmentID)
|
delete(sc.segments, evt.segmentID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -331,7 +351,6 @@ func (sc *ShardCluster) watchSegments(evtCh <-chan segmentEvent) {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case evt, ok := <-evtCh:
|
case evt, ok := <-evtCh:
|
||||||
log.Debug("segment event", zap.Any("evt", evt))
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warn("ShardCluster segment channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
|
log.Warn("ShardCluster segment channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
|
||||||
return
|
return
|
||||||
@ -381,20 +400,166 @@ func (sc *ShardCluster) getSegment(segmentID int64) (*shardSegmentInfo, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// segmentAllocations returns node to segments mappings.
|
// segmentAllocations returns node to segments mappings.
|
||||||
|
// calling this function also increases the reference count of related segments.
|
||||||
func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) map[int64][]int64 {
|
func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) map[int64][]int64 {
|
||||||
result := make(map[int64][]int64) // nodeID => segmentIDs
|
result := make(map[int64][]int64) // nodeID => segmentIDs
|
||||||
sc.mut.RLock()
|
sc.mut.Lock()
|
||||||
defer sc.mut.RUnlock()
|
defer sc.mut.Unlock()
|
||||||
|
|
||||||
for _, segment := range sc.segments {
|
for _, segment := range sc.segments {
|
||||||
if len(partitionIDs) > 0 && !inList(partitionIDs, segment.partitionID) {
|
if len(partitionIDs) > 0 && !inList(partitionIDs, segment.partitionID) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if sc.inHandoffOffline(segment.segmentID) {
|
||||||
|
log.Debug("segment ignore in pending offline list", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.Int64("segmentID", segment.segmentID))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// reference count ++
|
||||||
|
segment.inUse++
|
||||||
result[segment.nodeID] = append(result[segment.nodeID], segment.segmentID)
|
result[segment.nodeID] = append(result[segment.nodeID], segment.segmentID)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inHandoffOffline checks whether segment is pending handoff offline list
|
||||||
|
// Note that sc.mut Lock is assumed to be hold outside of this function!
|
||||||
|
func (sc *ShardCluster) inHandoffOffline(segmentID int64) bool {
|
||||||
|
for _, handoff := range sc.handoffs {
|
||||||
|
for _, offlineSegment := range handoff.OfflineSegments {
|
||||||
|
if segmentID == offlineSegment.GetSegmentID() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// finishUsage decreases the inUse count of provided segments
|
||||||
|
func (sc *ShardCluster) finishUsage(allocs map[int64][]int64) {
|
||||||
|
defer func() {
|
||||||
|
sc.rcCond.L.Lock()
|
||||||
|
sc.rcCond.Broadcast()
|
||||||
|
sc.rcCond.L.Unlock()
|
||||||
|
}()
|
||||||
|
sc.mut.Lock()
|
||||||
|
defer sc.mut.Unlock()
|
||||||
|
for _, segments := range allocs {
|
||||||
|
for _, segmentID := range segments {
|
||||||
|
segment, ok := sc.segments[segmentID]
|
||||||
|
if !ok || segment == nil {
|
||||||
|
// this shall not happen, since removing segment without decreasing rc to zero is illegal
|
||||||
|
log.Error("finishUsage with non-existing segment", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.String("vchannel", sc.vchannelName), zap.Int64("segmentID", segmentID))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// decrease the reference count
|
||||||
|
segment.inUse--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandoffSegments processes the handoff/load balance segments update procedure.
|
||||||
|
func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
|
||||||
|
// wait for all OnlineSegment is loaded
|
||||||
|
onlineSegments := make([]int64, 0, len(info.OnlineSegments))
|
||||||
|
for _, seg := range info.OnlineSegments {
|
||||||
|
// filter out segments not maintained in this cluster
|
||||||
|
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
onlineSegments = append(onlineSegments, seg.GetSegmentID())
|
||||||
|
}
|
||||||
|
sc.waitSegmentsOnline(onlineSegments)
|
||||||
|
|
||||||
|
// add segmentChangeInfo to pending list
|
||||||
|
token := sc.appendHandoff(info)
|
||||||
|
|
||||||
|
// wait for all OfflineSegments is not in use
|
||||||
|
offlineSegments := make([]int64, 0, len(info.OfflineSegments))
|
||||||
|
for _, seg := range info.OfflineSegments {
|
||||||
|
offlineSegments = append(offlineSegments, seg.GetSegmentID())
|
||||||
|
}
|
||||||
|
sc.waitSegmentsNotInUse(offlineSegments)
|
||||||
|
// remove offline segments record
|
||||||
|
for _, seg := range info.OfflineSegments {
|
||||||
|
// filter out segments not maintained in this cluster
|
||||||
|
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
|
||||||
|
}
|
||||||
|
|
||||||
|
// finish handoff and remove it from pending list
|
||||||
|
sc.finishHandoff(token)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendHandoff adds the change info into pending list and returns the token.
|
||||||
|
func (sc *ShardCluster) appendHandoff(info *querypb.SegmentChangeInfo) int32 {
|
||||||
|
sc.mut.Lock()
|
||||||
|
defer sc.mut.Unlock()
|
||||||
|
|
||||||
|
token := sc.lastToken.Add(1)
|
||||||
|
sc.handoffs[token] = info
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
// finishHandoff removes the handoff related to the token.
|
||||||
|
func (sc *ShardCluster) finishHandoff(token int32) {
|
||||||
|
sc.mut.Lock()
|
||||||
|
defer sc.mut.Unlock()
|
||||||
|
|
||||||
|
delete(sc.handoffs, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitSegmentsOnline waits until all provided segments is loaded.
|
||||||
|
func (sc *ShardCluster) waitSegmentsOnline(segments []int64) {
|
||||||
|
sc.segmentCond.L.Lock()
|
||||||
|
for !sc.segmentsOnline(segments) {
|
||||||
|
sc.segmentCond.Wait()
|
||||||
|
}
|
||||||
|
sc.segmentCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitSegmentsNotInUse waits until all provided segments is not in use.
|
||||||
|
func (sc *ShardCluster) waitSegmentsNotInUse(segments []int64) {
|
||||||
|
sc.rcCond.L.Lock()
|
||||||
|
for sc.segmentsInUse(segments) {
|
||||||
|
sc.rcCond.Wait()
|
||||||
|
}
|
||||||
|
sc.rcCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkOnline checks whether all segment ids provided in online state.
|
||||||
|
func (sc *ShardCluster) segmentsOnline(segmentIDs []int64) bool {
|
||||||
|
sc.mut.RLock()
|
||||||
|
defer sc.mut.RUnlock()
|
||||||
|
for _, segID := range segmentIDs {
|
||||||
|
segment, ok := sc.segments[segID]
|
||||||
|
if !ok || segment.state != segmentStateLoaded {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// segmentsInUse checks whether all segment ids provided still in use.
|
||||||
|
func (sc *ShardCluster) segmentsInUse(segmentIDs []int64) bool {
|
||||||
|
sc.mut.RLock()
|
||||||
|
defer sc.mut.RUnlock()
|
||||||
|
for _, segID := range segmentIDs {
|
||||||
|
segment, ok := sc.segments[segID]
|
||||||
|
if !ok {
|
||||||
|
// ignore missing segments, since they might be in streaming
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if segment.inUse > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Search preforms search operation on shard cluster.
|
// Search preforms search operation on shard cluster.
|
||||||
func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
||||||
if sc.state.Load() != int32(available) {
|
if sc.state.Load() != int32(available) {
|
||||||
@ -405,16 +570,14 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest)
|
|||||||
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
|
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
|
||||||
}
|
}
|
||||||
|
|
||||||
//req.GetReq().GetPartitionIDs()
|
// get node allocation and maintains the inUse reference count
|
||||||
|
|
||||||
// get node allocation
|
|
||||||
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
||||||
|
defer sc.finishUsage(segAllocs)
|
||||||
|
|
||||||
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)))
|
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)))
|
||||||
for nodeID, segmentIDs := range segAllocs {
|
for nodeID, segmentIDs := range segAllocs {
|
||||||
log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs))
|
log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs))
|
||||||
}
|
}
|
||||||
// TODO dispatch to local queryShardService query dml channel growing segments
|
|
||||||
|
|
||||||
// concurrent visiting nodes
|
// concurrent visiting nodes
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@ -430,7 +593,7 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest)
|
|||||||
nodeReq.SegmentIDs = segments
|
nodeReq.SegmentIDs = segments
|
||||||
node, ok := sc.getNode(nodeID)
|
node, ok := sc.getNode(nodeID)
|
||||||
if !ok { // meta dismatch, report error
|
if !ok { // meta dismatch, report error
|
||||||
return nil, fmt.Errorf("SharcCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
|
return nil, fmt.Errorf("ShardCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
|
||||||
}
|
}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@ -466,10 +629,9 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest) ([
|
|||||||
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
|
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
|
||||||
}
|
}
|
||||||
|
|
||||||
// get node allocation
|
// get node allocation and maintains the inUse reference count
|
||||||
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
||||||
|
defer sc.finishUsage(segAllocs)
|
||||||
// TODO dispatch to local queryShardService query dml channel growing segments
|
|
||||||
|
|
||||||
// concurrent visiting nodes
|
// concurrent visiting nodes
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
|
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/util"
|
"github.com/milvus-io/milvus/internal/util"
|
||||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
@ -107,3 +108,20 @@ func (s *ShardClusterService) releaseCollection(collectionID int64) {
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandoffSegments dispatch segmentChangeInfo to related shardClusters
|
||||||
|
func (s *ShardClusterService) HandoffSegments(collectionID int64, info *querypb.SegmentChangeInfo) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
s.clusters.Range(func(k, v interface{}) bool {
|
||||||
|
cs := v.(*ShardCluster)
|
||||||
|
if cs.collectionID == collectionID {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
cs.HandoffSegments(info)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
@ -4,8 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"go.etcd.io/etcd/server/v3/etcdserver/api/v3client"
|
"go.etcd.io/etcd/server/v3/etcdserver/api/v3client"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,3 +34,19 @@ func TestShardClusterService(t *testing.T) {
|
|||||||
err = clusterService.releaseShardCluster("non-exist-channel")
|
err = clusterService.releaseShardCluster("non-exist-channel")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShardClusterService_HandoffSegments(t *testing.T) {
|
||||||
|
qn, err := genSimpleQueryNode(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
client := v3client.New(embedetcdServer.Server)
|
||||||
|
defer client.Close()
|
||||||
|
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
|
||||||
|
clusterService := newShardClusterService(client, session, qn)
|
||||||
|
|
||||||
|
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
|
||||||
|
//TODO change shardCluster to interface to mock test behavior
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
clusterService.HandoffSegments(defaultCollectionID, &querypb.SegmentChangeInfo{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -1023,3 +1023,372 @@ func TestShardCluster_Query(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShardCluster_ReferenceCount(t *testing.T) {
|
||||||
|
collectionID := int64(1)
|
||||||
|
vchannelName := "dml_1_1_v0"
|
||||||
|
replicaID := int64(0)
|
||||||
|
// ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("normal alloc & finish", func(t *testing.T) {
|
||||||
|
nodeEvents := []nodeEvent{
|
||||||
|
{
|
||||||
|
nodeID: 1,
|
||||||
|
nodeAddr: "addr_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nodeID: 2,
|
||||||
|
nodeAddr: "addr_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
segmentEvents := []segmentEvent{
|
||||||
|
{
|
||||||
|
segmentID: 1,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
segmentID: 2,
|
||||||
|
nodeID: 2,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||||
|
&mockNodeDetector{
|
||||||
|
initNodes: nodeEvents,
|
||||||
|
}, &mockSegmentDetector{
|
||||||
|
initSegments: segmentEvents,
|
||||||
|
}, buildMockQueryNode)
|
||||||
|
defer sc.Close()
|
||||||
|
|
||||||
|
allocs := sc.segmentAllocations(nil)
|
||||||
|
|
||||||
|
sc.mut.RLock()
|
||||||
|
for _, segment := range sc.segments {
|
||||||
|
assert.Greater(t, segment.inUse, int32(0))
|
||||||
|
}
|
||||||
|
sc.mut.RUnlock()
|
||||||
|
|
||||||
|
assert.True(t, sc.segmentsInUse([]int64{1, 2}))
|
||||||
|
assert.True(t, sc.segmentsInUse([]int64{1, 2, -1}))
|
||||||
|
|
||||||
|
sc.finishUsage(allocs)
|
||||||
|
sc.mut.RLock()
|
||||||
|
for _, segment := range sc.segments {
|
||||||
|
assert.EqualValues(t, segment.inUse, 0)
|
||||||
|
}
|
||||||
|
sc.mut.RUnlock()
|
||||||
|
|
||||||
|
assert.False(t, sc.segmentsInUse([]int64{1, 2}))
|
||||||
|
assert.False(t, sc.segmentsInUse([]int64{1, 2, -1}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("alloc & finish with modified alloc", func(t *testing.T) {
|
||||||
|
nodeEvents := []nodeEvent{
|
||||||
|
{
|
||||||
|
nodeID: 1,
|
||||||
|
nodeAddr: "addr_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nodeID: 2,
|
||||||
|
nodeAddr: "addr_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
segmentEvents := []segmentEvent{
|
||||||
|
{
|
||||||
|
segmentID: 1,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
segmentID: 2,
|
||||||
|
nodeID: 2,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||||
|
&mockNodeDetector{
|
||||||
|
initNodes: nodeEvents,
|
||||||
|
}, &mockSegmentDetector{
|
||||||
|
initSegments: segmentEvents,
|
||||||
|
}, buildMockQueryNode)
|
||||||
|
defer sc.Close()
|
||||||
|
|
||||||
|
allocs := sc.segmentAllocations(nil)
|
||||||
|
|
||||||
|
sc.mut.RLock()
|
||||||
|
for _, segment := range sc.segments {
|
||||||
|
assert.Greater(t, segment.inUse, int32(0))
|
||||||
|
}
|
||||||
|
sc.mut.RUnlock()
|
||||||
|
|
||||||
|
for node, segments := range allocs {
|
||||||
|
segments = append(segments, -1) // add non-exist segment
|
||||||
|
// shall be ignored in finishUsage
|
||||||
|
allocs[node] = segments
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
sc.finishUsage(allocs)
|
||||||
|
})
|
||||||
|
sc.mut.RLock()
|
||||||
|
for _, segment := range sc.segments {
|
||||||
|
assert.EqualValues(t, segment.inUse, 0)
|
||||||
|
}
|
||||||
|
sc.mut.RUnlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("wait segments online", func(t *testing.T) {
|
||||||
|
nodeEvents := []nodeEvent{
|
||||||
|
{
|
||||||
|
nodeID: 1,
|
||||||
|
nodeAddr: "addr_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nodeID: 2,
|
||||||
|
nodeAddr: "addr_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
segmentEvents := []segmentEvent{
|
||||||
|
{
|
||||||
|
segmentID: 1,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
segmentID: 2,
|
||||||
|
nodeID: 2,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
evtCh := make(chan segmentEvent, 10)
|
||||||
|
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||||
|
&mockNodeDetector{
|
||||||
|
initNodes: nodeEvents,
|
||||||
|
}, &mockSegmentDetector{
|
||||||
|
initSegments: segmentEvents,
|
||||||
|
evtCh: evtCh,
|
||||||
|
}, buildMockQueryNode)
|
||||||
|
defer sc.Close()
|
||||||
|
|
||||||
|
assert.True(t, sc.segmentsOnline([]int64{1, 2}))
|
||||||
|
assert.False(t, sc.segmentsOnline([]int64{1, 2, 3}))
|
||||||
|
|
||||||
|
sig := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
sc.waitSegmentsOnline([]int64{1, 2, 3})
|
||||||
|
close(sig)
|
||||||
|
}()
|
||||||
|
|
||||||
|
evtCh <- segmentEvent{
|
||||||
|
eventType: segmentAdd,
|
||||||
|
segmentID: 3,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
}
|
||||||
|
|
||||||
|
<-sig
|
||||||
|
assert.True(t, sc.segmentsOnline([]int64{1, 2, 3}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||||
|
collectionID := int64(1)
|
||||||
|
otherCollectionID := int64(2)
|
||||||
|
vchannelName := "dml_1_1_v0"
|
||||||
|
otherVchannelName := "dml_1_2_v0"
|
||||||
|
replicaID := int64(0)
|
||||||
|
|
||||||
|
t.Run("handoff without using segments", func(t *testing.T) {
|
||||||
|
nodeEvents := []nodeEvent{
|
||||||
|
{
|
||||||
|
nodeID: 1,
|
||||||
|
nodeAddr: "addr_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nodeID: 2,
|
||||||
|
nodeAddr: "addr_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
segmentEvents := []segmentEvent{
|
||||||
|
{
|
||||||
|
segmentID: 1,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
segmentID: 2,
|
||||||
|
nodeID: 2,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||||
|
&mockNodeDetector{
|
||||||
|
initNodes: nodeEvents,
|
||||||
|
}, &mockSegmentDetector{
|
||||||
|
initSegments: segmentEvents,
|
||||||
|
}, buildMockQueryNode)
|
||||||
|
defer sc.Close()
|
||||||
|
|
||||||
|
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
},
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
sc.mut.RLock()
|
||||||
|
_, has := sc.segments[1]
|
||||||
|
sc.mut.RUnlock()
|
||||||
|
|
||||||
|
assert.False(t, has)
|
||||||
|
})
|
||||||
|
t.Run("handoff with growing segment(segment not recorded)", func(t *testing.T) {
|
||||||
|
nodeEvents := []nodeEvent{
|
||||||
|
{
|
||||||
|
nodeID: 1,
|
||||||
|
nodeAddr: "addr_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nodeID: 2,
|
||||||
|
nodeAddr: "addr_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
segmentEvents := []segmentEvent{
|
||||||
|
{
|
||||||
|
segmentID: 1,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
segmentID: 2,
|
||||||
|
nodeID: 2,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||||
|
&mockNodeDetector{
|
||||||
|
initNodes: nodeEvents,
|
||||||
|
}, &mockSegmentDetector{
|
||||||
|
initSegments: segmentEvents,
|
||||||
|
}, buildMockQueryNode)
|
||||||
|
defer sc.Close()
|
||||||
|
|
||||||
|
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
||||||
|
},
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
sc.mut.RLock()
|
||||||
|
_, has := sc.segments[3]
|
||||||
|
sc.mut.RUnlock()
|
||||||
|
|
||||||
|
assert.False(t, has)
|
||||||
|
})
|
||||||
|
t.Run("handoff wait online and usage", func(t *testing.T) {
|
||||||
|
nodeEvents := []nodeEvent{
|
||||||
|
{
|
||||||
|
nodeID: 1,
|
||||||
|
nodeAddr: "addr_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nodeID: 2,
|
||||||
|
nodeAddr: "addr_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
segmentEvents := []segmentEvent{
|
||||||
|
{
|
||||||
|
segmentID: 1,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
segmentID: 2,
|
||||||
|
nodeID: 2,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
evtCh := make(chan segmentEvent, 10)
|
||||||
|
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||||
|
&mockNodeDetector{
|
||||||
|
initNodes: nodeEvents,
|
||||||
|
}, &mockSegmentDetector{
|
||||||
|
initSegments: segmentEvents,
|
||||||
|
evtCh: evtCh,
|
||||||
|
}, buildMockQueryNode)
|
||||||
|
defer sc.Close()
|
||||||
|
|
||||||
|
// add rc to all segments
|
||||||
|
allocs := sc.segmentAllocations(nil)
|
||||||
|
|
||||||
|
sig := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||||
|
OnlineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
},
|
||||||
|
OfflineSegments: []*querypb.SegmentInfo{
|
||||||
|
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
close(sig)
|
||||||
|
}()
|
||||||
|
|
||||||
|
sc.mut.RLock()
|
||||||
|
// still waiting online
|
||||||
|
assert.Equal(t, 0, len(sc.handoffs))
|
||||||
|
sc.mut.RUnlock()
|
||||||
|
|
||||||
|
evtCh <- segmentEvent{
|
||||||
|
eventType: segmentAdd,
|
||||||
|
segmentID: 3,
|
||||||
|
nodeID: 1,
|
||||||
|
state: segmentStateLoaded,
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for handoff appended into list
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
sc.mut.RLock()
|
||||||
|
defer sc.mut.RUnlock()
|
||||||
|
return len(sc.handoffs) > 0
|
||||||
|
}, time.Second, time.Millisecond*10)
|
||||||
|
|
||||||
|
tmpAllocs := sc.segmentAllocations(nil)
|
||||||
|
found := false
|
||||||
|
for _, segments := range tmpAllocs {
|
||||||
|
if inList(segments, int64(1)) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// segment 1 shall not be allocated again!
|
||||||
|
assert.False(t, found)
|
||||||
|
sc.finishUsage(tmpAllocs)
|
||||||
|
// rc shall be 0 now
|
||||||
|
sc.finishUsage(allocs)
|
||||||
|
|
||||||
|
// wait handoff finished
|
||||||
|
<-sig
|
||||||
|
|
||||||
|
sc.mut.RLock()
|
||||||
|
_, has := sc.segments[1]
|
||||||
|
sc.mut.RUnlock()
|
||||||
|
|
||||||
|
assert.False(t, has)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user