Add segment reference count and handles change info in ShardCluster (#16620)

Resolves #16619
Add reference count for each search/query request
For SegmentChangeInfo
- Wait all segments in OnlineList to be loaded
- Add handoff event into pending list
- Wait all segments in OfflineList is not used (reference count = 0)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2022-04-25 11:51:46 +08:00 committed by GitHub
parent 4ef2df8cb9
commit b99b65c26e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 605 additions and 51 deletions

View File

@ -20,10 +20,23 @@ type queryChannel struct {
streaming *streaming streaming *streaming
queryMsgStream msgstream.MsgStream queryMsgStream msgstream.MsgStream
shardCluster *ShardClusterService
asConsumeOnce sync.Once asConsumeOnce sync.Once
closeOnce sync.Once closeOnce sync.Once
} }
// NewQueryChannel create a query channel with provided shardCluster, query msgstream and collection id
func NewQueryChannel(collectionID int64, scs *ShardClusterService, qms msgstream.MsgStream, streaming *streaming) *queryChannel {
return &queryChannel{
closeCh: make(chan struct{}),
collectionID: collectionID,
streaming: streaming,
queryMsgStream: qms,
shardCluster: scs,
}
}
// AsConsumer do AsConsumer for query msgstream and seek if position is not nil // AsConsumer do AsConsumer for query msgstream and seek if position is not nil
func (qc *queryChannel) AsConsumer(channelName string, subName string, position *internalpb.MsgPosition) error { func (qc *queryChannel) AsConsumer(channelName string, subName string, position *internalpb.MsgPosition) error {
var err error var err error
@ -101,7 +114,9 @@ func (qc *queryChannel) adjustByChangeInfo(msg *msgstream.SealedSegmentsChangeIn
} }
} }
// should handle segment change in shardCluster // process change in shard cluster
qc.shardCluster.HandoffSegments(qc.collectionID, info)
// for OnlineSegments: // for OnlineSegments:
for _, segment := range info.OnlineSegments { for _, segment := range info.OnlineSegments {
/* /*
@ -124,12 +139,6 @@ func (qc *queryChannel) adjustByChangeInfo(msg *msgstream.SealedSegmentsChangeIn
}, },
}) })
} }
/*
// for OfflineSegments:
for _, segment := range info.OfflineSegments {
// 1. update global sealed segments
q.globalSegmentManager.removeGlobalSealedSegmentInfo(segment.SegmentID)
}*/
log.Info("Successfully changed global sealed segment info ", log.Info("Successfully changed global sealed segment info ",
zap.Int64("collection ", qc.collectionID), zap.Int64("collection ", qc.collectionID),

View File

@ -97,13 +97,7 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
mqs := &mockQueryMsgStream{} mqs := &mockQueryMsgStream{}
mqs.On("Close").Return() mqs.On("Close").Return()
qc := &queryChannel{ qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
closeCh: make(chan struct{}),
collectionID: defaultCollectionID,
streaming: nil,
queryMsgStream: mqs,
}
mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return() mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return()
@ -122,13 +116,7 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
mqs := &mockQueryMsgStream{} mqs := &mockQueryMsgStream{}
mqs.On("Close").Return() mqs.On("Close").Return()
qc := &queryChannel{ qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
closeCh: make(chan struct{}),
collectionID: defaultCollectionID,
streaming: nil,
queryMsgStream: mqs,
}
mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return() mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return()
@ -146,13 +134,8 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
mqs := &mockQueryMsgStream{} mqs := &mockQueryMsgStream{}
mqs.On("Close").Return() mqs.On("Close").Return()
qc := &queryChannel{ qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
closeCh: make(chan struct{}),
collectionID: defaultCollectionID,
streaming: nil,
queryMsgStream: mqs,
}
msgID := make([]byte, 8) msgID := make([]byte, 8)
rand.Read(msgID) rand.Read(msgID)
pos := &internalpb.MsgPosition{MsgID: msgID} pos := &internalpb.MsgPosition{MsgID: msgID}

View File

@ -138,12 +138,7 @@ func (q *queryShardService) getQueryChannel(collectionID int64) *queryChannel {
qc, ok := q.queryChannels[collectionID] qc, ok := q.queryChannels[collectionID]
if !ok { if !ok {
queryStream, _ := q.factory.NewQueryMsgStream(q.ctx) queryStream, _ := q.factory.NewQueryMsgStream(q.ctx)
qc = &queryChannel{ qc = NewQueryChannel(collectionID, q.shardClusterService, queryStream, q.streaming)
closeCh: make(chan struct{}),
collectionID: collectionID,
queryMsgStream: queryStream,
streaming: q.streaming,
}
q.queryChannels[collectionID] = qc q.queryChannels[collectionID] = qc
} }

View File

@ -92,6 +92,7 @@ type shardSegmentInfo struct {
partitionID int64 partitionID int64
nodeID int64 nodeID int64
state segmentState state segmentState
inUse int32
} }
// ShardNodeDetector provides method to detect node events // ShardNodeDetector provides method to detect node events
@ -122,6 +123,10 @@ type ShardCluster struct {
mut sync.RWMutex mut sync.RWMutex
nodes map[int64]*shardNode // online nodes nodes map[int64]*shardNode // online nodes
segments map[int64]*shardSegmentInfo // shard segments segments map[int64]*shardSegmentInfo // shard segments
handoffs map[int32]*querypb.SegmentChangeInfo // current pending handoff
lastToken *atomic.Int32 // last token used for segment change info
segmentCond *sync.Cond // segment state change condition
rcCond *sync.Cond // segment rc change condition
closeOnce sync.Once closeOnce sync.Once
closeCh chan struct{} closeCh chan struct{}
@ -143,10 +148,17 @@ func NewShardCluster(collectionID int64, replicaID int64, vchannelName string,
nodes: make(map[int64]*shardNode), nodes: make(map[int64]*shardNode),
segments: make(map[int64]*shardSegmentInfo), segments: make(map[int64]*shardSegmentInfo),
handoffs: make(map[int32]*querypb.SegmentChangeInfo),
lastToken: atomic.NewInt32(0),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
} }
m := sync.Mutex{}
sc.segmentCond = sync.NewCond(&m)
m2 := sync.Mutex{}
sc.rcCond = sync.NewCond(&m2)
sc.init() sc.init()
return sc return sc
@ -205,8 +217,15 @@ func (sc *ShardCluster) removeNode(evt nodeEvent) {
// updateSegment apply segment change to shard cluster // updateSegment apply segment change to shard cluster
func (sc *ShardCluster) updateSegment(evt segmentEvent) { func (sc *ShardCluster) updateSegment(evt segmentEvent) {
log.Debug("ShardCluster update segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state))) log.Debug("ShardCluster update segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
// notify handoff wait online if any
defer func() {
sc.segmentCond.L.Lock()
sc.segmentCond.Broadcast()
sc.segmentCond.L.Unlock()
}()
sc.mut.Lock() sc.mut.Lock()
defer sc.mut.Unlock() defer sc.mut.Unlock()
@ -255,6 +274,8 @@ func (sc *ShardCluster) transferSegment(old *shardSegmentInfo, evt segmentEvent)
// removeSegment removes segment from cluster // removeSegment removes segment from cluster
// should only applied in hand-off or load balance procedure // should only applied in hand-off or load balance procedure
func (sc *ShardCluster) removeSegment(evt segmentEvent) { func (sc *ShardCluster) removeSegment(evt segmentEvent) {
log.Debug("ShardCluster remove segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
sc.mut.Lock() sc.mut.Lock()
defer sc.mut.Unlock() defer sc.mut.Unlock()
@ -269,7 +290,6 @@ func (sc *ShardCluster) removeSegment(evt segmentEvent) {
return return
} }
//TODO check handoff / load balance
delete(sc.segments, evt.segmentID) delete(sc.segments, evt.segmentID)
} }
@ -331,7 +351,6 @@ func (sc *ShardCluster) watchSegments(evtCh <-chan segmentEvent) {
for { for {
select { select {
case evt, ok := <-evtCh: case evt, ok := <-evtCh:
log.Debug("segment event", zap.Any("evt", evt))
if !ok { if !ok {
log.Warn("ShardCluster segment channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID)) log.Warn("ShardCluster segment channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
return return
@ -381,20 +400,166 @@ func (sc *ShardCluster) getSegment(segmentID int64) (*shardSegmentInfo, bool) {
} }
// segmentAllocations returns node to segments mappings. // segmentAllocations returns node to segments mappings.
// calling this function also increases the reference count of related segments.
func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) map[int64][]int64 { func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) map[int64][]int64 {
result := make(map[int64][]int64) // nodeID => segmentIDs result := make(map[int64][]int64) // nodeID => segmentIDs
sc.mut.RLock() sc.mut.Lock()
defer sc.mut.RUnlock() defer sc.mut.Unlock()
for _, segment := range sc.segments { for _, segment := range sc.segments {
if len(partitionIDs) > 0 && !inList(partitionIDs, segment.partitionID) { if len(partitionIDs) > 0 && !inList(partitionIDs, segment.partitionID) {
continue continue
} }
if sc.inHandoffOffline(segment.segmentID) {
log.Debug("segment ignore in pending offline list", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.Int64("segmentID", segment.segmentID))
continue
}
// reference count ++
segment.inUse++
result[segment.nodeID] = append(result[segment.nodeID], segment.segmentID) result[segment.nodeID] = append(result[segment.nodeID], segment.segmentID)
} }
return result return result
} }
// inHandoffOffline checks whether segment is pending handoff offline list
// Note that sc.mut Lock is assumed to be hold outside of this function!
func (sc *ShardCluster) inHandoffOffline(segmentID int64) bool {
for _, handoff := range sc.handoffs {
for _, offlineSegment := range handoff.OfflineSegments {
if segmentID == offlineSegment.GetSegmentID() {
return true
}
}
}
return false
}
// finishUsage decreases the inUse count of provided segments
func (sc *ShardCluster) finishUsage(allocs map[int64][]int64) {
defer func() {
sc.rcCond.L.Lock()
sc.rcCond.Broadcast()
sc.rcCond.L.Unlock()
}()
sc.mut.Lock()
defer sc.mut.Unlock()
for _, segments := range allocs {
for _, segmentID := range segments {
segment, ok := sc.segments[segmentID]
if !ok || segment == nil {
// this shall not happen, since removing segment without decreasing rc to zero is illegal
log.Error("finishUsage with non-existing segment", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.String("vchannel", sc.vchannelName), zap.Int64("segmentID", segmentID))
continue
}
// decrease the reference count
segment.inUse--
}
}
}
// HandoffSegments processes the handoff/load balance segments update procedure.
func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
// wait for all OnlineSegment is loaded
onlineSegments := make([]int64, 0, len(info.OnlineSegments))
for _, seg := range info.OnlineSegments {
// filter out segments not maintained in this cluster
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
continue
}
onlineSegments = append(onlineSegments, seg.GetSegmentID())
}
sc.waitSegmentsOnline(onlineSegments)
// add segmentChangeInfo to pending list
token := sc.appendHandoff(info)
// wait for all OfflineSegments is not in use
offlineSegments := make([]int64, 0, len(info.OfflineSegments))
for _, seg := range info.OfflineSegments {
offlineSegments = append(offlineSegments, seg.GetSegmentID())
}
sc.waitSegmentsNotInUse(offlineSegments)
// remove offline segments record
for _, seg := range info.OfflineSegments {
// filter out segments not maintained in this cluster
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
continue
}
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
}
// finish handoff and remove it from pending list
sc.finishHandoff(token)
return nil
}
// appendHandoff adds the change info into pending list and returns the token.
func (sc *ShardCluster) appendHandoff(info *querypb.SegmentChangeInfo) int32 {
sc.mut.Lock()
defer sc.mut.Unlock()
token := sc.lastToken.Add(1)
sc.handoffs[token] = info
return token
}
// finishHandoff removes the handoff related to the token.
func (sc *ShardCluster) finishHandoff(token int32) {
sc.mut.Lock()
defer sc.mut.Unlock()
delete(sc.handoffs, token)
}
// waitSegmentsOnline waits until all provided segments is loaded.
func (sc *ShardCluster) waitSegmentsOnline(segments []int64) {
sc.segmentCond.L.Lock()
for !sc.segmentsOnline(segments) {
sc.segmentCond.Wait()
}
sc.segmentCond.L.Unlock()
}
// waitSegmentsNotInUse waits until all provided segments is not in use.
func (sc *ShardCluster) waitSegmentsNotInUse(segments []int64) {
sc.rcCond.L.Lock()
for sc.segmentsInUse(segments) {
sc.rcCond.Wait()
}
sc.rcCond.L.Unlock()
}
// checkOnline checks whether all segment ids provided in online state.
func (sc *ShardCluster) segmentsOnline(segmentIDs []int64) bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
for _, segID := range segmentIDs {
segment, ok := sc.segments[segID]
if !ok || segment.state != segmentStateLoaded {
return false
}
}
return true
}
// segmentsInUse checks whether all segment ids provided still in use.
func (sc *ShardCluster) segmentsInUse(segmentIDs []int64) bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
for _, segID := range segmentIDs {
segment, ok := sc.segments[segID]
if !ok {
// ignore missing segments, since they might be in streaming
continue
}
if segment.inUse > 0 {
return true
}
}
return false
}
// Search preforms search operation on shard cluster. // Search preforms search operation on shard cluster.
func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) { func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
if sc.state.Load() != int32(available) { if sc.state.Load() != int32(available) {
@ -405,16 +570,14 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest)
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel()) return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
} }
//req.GetReq().GetPartitionIDs() // get node allocation and maintains the inUse reference count
// get node allocation
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs()) segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
defer sc.finishUsage(segAllocs)
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs))) log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)))
for nodeID, segmentIDs := range segAllocs { for nodeID, segmentIDs := range segAllocs {
log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs)) log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs))
} }
// TODO dispatch to local queryShardService query dml channel growing segments
// concurrent visiting nodes // concurrent visiting nodes
var wg sync.WaitGroup var wg sync.WaitGroup
@ -430,7 +593,7 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest)
nodeReq.SegmentIDs = segments nodeReq.SegmentIDs = segments
node, ok := sc.getNode(nodeID) node, ok := sc.getNode(nodeID)
if !ok { // meta dismatch, report error if !ok { // meta dismatch, report error
return nil, fmt.Errorf("SharcCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID) return nil, fmt.Errorf("ShardCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
} }
wg.Add(1) wg.Add(1)
go func() { go func() {
@ -466,10 +629,9 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest) ([
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel()) return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
} }
// get node allocation // get node allocation and maintains the inUse reference count
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs()) segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
defer sc.finishUsage(segAllocs)
// TODO dispatch to local queryShardService query dml channel growing segments
// concurrent visiting nodes // concurrent visiting nodes
var wg sync.WaitGroup var wg sync.WaitGroup

View File

@ -8,6 +8,7 @@ import (
"sync" "sync"
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
@ -107,3 +108,20 @@ func (s *ShardClusterService) releaseCollection(collectionID int64) {
return true return true
}) })
} }
// HandoffSegments dispatch segmentChangeInfo to related shardClusters
func (s *ShardClusterService) HandoffSegments(collectionID int64, info *querypb.SegmentChangeInfo) {
var wg sync.WaitGroup
s.clusters.Range(func(k, v interface{}) bool {
cs := v.(*ShardCluster)
if cs.collectionID == collectionID {
wg.Add(1)
go func() {
defer wg.Done()
cs.HandoffSegments(info)
}()
}
return true
})
wg.Wait()
}

View File

@ -4,8 +4,10 @@ import (
"context" "context"
"testing" "testing"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/etcdserver/api/v3client" "go.etcd.io/etcd/server/v3/etcdserver/api/v3client"
) )
@ -32,3 +34,19 @@ func TestShardClusterService(t *testing.T) {
err = clusterService.releaseShardCluster("non-exist-channel") err = clusterService.releaseShardCluster("non-exist-channel")
assert.Error(t, err) assert.Error(t, err)
} }
func TestShardClusterService_HandoffSegments(t *testing.T) {
qn, err := genSimpleQueryNode(context.Background())
require.NoError(t, err)
client := v3client.New(embedetcdServer.Server)
defer client.Close()
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
clusterService := newShardClusterService(client, session, qn)
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
//TODO change shardCluster to interface to mock test behavior
assert.NotPanics(t, func() {
clusterService.HandoffSegments(defaultCollectionID, &querypb.SegmentChangeInfo{})
})
}

View File

@ -1023,3 +1023,372 @@ func TestShardCluster_Query(t *testing.T) {
}) })
} }
func TestShardCluster_ReferenceCount(t *testing.T) {
collectionID := int64(1)
vchannelName := "dml_1_1_v0"
replicaID := int64(0)
// ctx := context.Background()
t.Run("normal alloc & finish", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
defer sc.Close()
allocs := sc.segmentAllocations(nil)
sc.mut.RLock()
for _, segment := range sc.segments {
assert.Greater(t, segment.inUse, int32(0))
}
sc.mut.RUnlock()
assert.True(t, sc.segmentsInUse([]int64{1, 2}))
assert.True(t, sc.segmentsInUse([]int64{1, 2, -1}))
sc.finishUsage(allocs)
sc.mut.RLock()
for _, segment := range sc.segments {
assert.EqualValues(t, segment.inUse, 0)
}
sc.mut.RUnlock()
assert.False(t, sc.segmentsInUse([]int64{1, 2}))
assert.False(t, sc.segmentsInUse([]int64{1, 2, -1}))
})
t.Run("alloc & finish with modified alloc", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
defer sc.Close()
allocs := sc.segmentAllocations(nil)
sc.mut.RLock()
for _, segment := range sc.segments {
assert.Greater(t, segment.inUse, int32(0))
}
sc.mut.RUnlock()
for node, segments := range allocs {
segments = append(segments, -1) // add non-exist segment
// shall be ignored in finishUsage
allocs[node] = segments
}
assert.NotPanics(t, func() {
sc.finishUsage(allocs)
})
sc.mut.RLock()
for _, segment := range sc.segments {
assert.EqualValues(t, segment.inUse, 0)
}
sc.mut.RUnlock()
})
t.Run("wait segments online", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
assert.True(t, sc.segmentsOnline([]int64{1, 2}))
assert.False(t, sc.segmentsOnline([]int64{1, 2, 3}))
sig := make(chan struct{})
go func() {
sc.waitSegmentsOnline([]int64{1, 2, 3})
close(sig)
}()
evtCh <- segmentEvent{
eventType: segmentAdd,
segmentID: 3,
nodeID: 1,
state: segmentStateLoaded,
}
<-sig
assert.True(t, sc.segmentsOnline([]int64{1, 2, 3}))
})
}
func TestShardCluster_HandoffSegments(t *testing.T) {
collectionID := int64(1)
otherCollectionID := int64(2)
vchannelName := "dml_1_1_v0"
otherVchannelName := "dml_1_2_v0"
replicaID := int64(0)
t.Run("handoff without using segments", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
defer sc.Close()
sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
},
})
sc.mut.RLock()
_, has := sc.segments[1]
sc.mut.RUnlock()
assert.False(t, has)
})
t.Run("handoff with growing segment(segment not recorded)", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
defer sc.Close()
sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
},
})
sc.mut.RLock()
_, has := sc.segments[3]
sc.mut.RUnlock()
assert.False(t, has)
})
t.Run("handoff wait online and usage", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeID: 1,
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeID: 2,
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
// add rc to all segments
allocs := sc.segmentAllocations(nil)
sig := make(chan struct{})
go func() {
sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
},
})
close(sig)
}()
sc.mut.RLock()
// still waiting online
assert.Equal(t, 0, len(sc.handoffs))
sc.mut.RUnlock()
evtCh <- segmentEvent{
eventType: segmentAdd,
segmentID: 3,
nodeID: 1,
state: segmentStateLoaded,
}
// wait for handoff appended into list
assert.Eventually(t, func() bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
return len(sc.handoffs) > 0
}, time.Second, time.Millisecond*10)
tmpAllocs := sc.segmentAllocations(nil)
found := false
for _, segments := range tmpAllocs {
if inList(segments, int64(1)) {
found = true
break
}
}
// segment 1 shall not be allocated again!
assert.False(t, found)
sc.finishUsage(tmpAllocs)
// rc shall be 0 now
sc.finishUsage(allocs)
// wait handoff finished
<-sig
sc.mut.RLock()
_, has := sc.segments[1]
sc.mut.RUnlock()
assert.False(t, has)
})
}