mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
Add segment reference count and handles change info in ShardCluster (#16620)
Resolves #16619 Add reference count for each search/query request For SegmentChangeInfo - Wait all segments in OnlineList to be loaded - Add handoff event into pending list - Wait all segments in OfflineList is not used (reference count = 0) Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
4ef2df8cb9
commit
b99b65c26e
@ -20,10 +20,23 @@ type queryChannel struct {
|
||||
|
||||
streaming *streaming
|
||||
queryMsgStream msgstream.MsgStream
|
||||
shardCluster *ShardClusterService
|
||||
asConsumeOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewQueryChannel create a query channel with provided shardCluster, query msgstream and collection id
|
||||
func NewQueryChannel(collectionID int64, scs *ShardClusterService, qms msgstream.MsgStream, streaming *streaming) *queryChannel {
|
||||
return &queryChannel{
|
||||
closeCh: make(chan struct{}),
|
||||
collectionID: collectionID,
|
||||
|
||||
streaming: streaming,
|
||||
queryMsgStream: qms,
|
||||
shardCluster: scs,
|
||||
}
|
||||
}
|
||||
|
||||
// AsConsumer do AsConsumer for query msgstream and seek if position is not nil
|
||||
func (qc *queryChannel) AsConsumer(channelName string, subName string, position *internalpb.MsgPosition) error {
|
||||
var err error
|
||||
@ -101,7 +114,9 @@ func (qc *queryChannel) adjustByChangeInfo(msg *msgstream.SealedSegmentsChangeIn
|
||||
}
|
||||
}
|
||||
|
||||
// should handle segment change in shardCluster
|
||||
// process change in shard cluster
|
||||
qc.shardCluster.HandoffSegments(qc.collectionID, info)
|
||||
|
||||
// for OnlineSegments:
|
||||
for _, segment := range info.OnlineSegments {
|
||||
/*
|
||||
@ -124,12 +139,6 @@ func (qc *queryChannel) adjustByChangeInfo(msg *msgstream.SealedSegmentsChangeIn
|
||||
},
|
||||
})
|
||||
}
|
||||
/*
|
||||
// for OfflineSegments:
|
||||
for _, segment := range info.OfflineSegments {
|
||||
// 1. update global sealed segments
|
||||
q.globalSegmentManager.removeGlobalSealedSegmentInfo(segment.SegmentID)
|
||||
}*/
|
||||
|
||||
log.Info("Successfully changed global sealed segment info ",
|
||||
zap.Int64("collection ", qc.collectionID),
|
||||
|
@ -97,13 +97,7 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
|
||||
mqs := &mockQueryMsgStream{}
|
||||
mqs.On("Close").Return()
|
||||
|
||||
qc := &queryChannel{
|
||||
closeCh: make(chan struct{}),
|
||||
collectionID: defaultCollectionID,
|
||||
|
||||
streaming: nil,
|
||||
queryMsgStream: mqs,
|
||||
}
|
||||
qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
|
||||
|
||||
mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return()
|
||||
|
||||
@ -122,13 +116,7 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
|
||||
mqs := &mockQueryMsgStream{}
|
||||
mqs.On("Close").Return()
|
||||
|
||||
qc := &queryChannel{
|
||||
closeCh: make(chan struct{}),
|
||||
collectionID: defaultCollectionID,
|
||||
|
||||
streaming: nil,
|
||||
queryMsgStream: mqs,
|
||||
}
|
||||
qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
|
||||
|
||||
mqs.On("AsConsumer", []string{defaultDMLChannel}, defaultSubName).Return()
|
||||
|
||||
@ -146,13 +134,8 @@ func TestQueryChannel_AsConsumer(t *testing.T) {
|
||||
mqs := &mockQueryMsgStream{}
|
||||
mqs.On("Close").Return()
|
||||
|
||||
qc := &queryChannel{
|
||||
closeCh: make(chan struct{}),
|
||||
collectionID: defaultCollectionID,
|
||||
qc := NewQueryChannel(defaultCollectionID, nil, mqs, nil)
|
||||
|
||||
streaming: nil,
|
||||
queryMsgStream: mqs,
|
||||
}
|
||||
msgID := make([]byte, 8)
|
||||
rand.Read(msgID)
|
||||
pos := &internalpb.MsgPosition{MsgID: msgID}
|
||||
|
@ -138,12 +138,7 @@ func (q *queryShardService) getQueryChannel(collectionID int64) *queryChannel {
|
||||
qc, ok := q.queryChannels[collectionID]
|
||||
if !ok {
|
||||
queryStream, _ := q.factory.NewQueryMsgStream(q.ctx)
|
||||
qc = &queryChannel{
|
||||
closeCh: make(chan struct{}),
|
||||
collectionID: collectionID,
|
||||
queryMsgStream: queryStream,
|
||||
streaming: q.streaming,
|
||||
}
|
||||
qc = NewQueryChannel(collectionID, q.shardClusterService, queryStream, q.streaming)
|
||||
q.queryChannels[collectionID] = qc
|
||||
}
|
||||
|
||||
|
@ -92,6 +92,7 @@ type shardSegmentInfo struct {
|
||||
partitionID int64
|
||||
nodeID int64
|
||||
state segmentState
|
||||
inUse int32
|
||||
}
|
||||
|
||||
// ShardNodeDetector provides method to detect node events
|
||||
@ -119,9 +120,13 @@ type ShardCluster struct {
|
||||
segmentDetector ShardSegmentDetector
|
||||
nodeBuilder ShardNodeBuilder
|
||||
|
||||
mut sync.RWMutex
|
||||
nodes map[int64]*shardNode // online nodes
|
||||
segments map[int64]*shardSegmentInfo // shard segments
|
||||
mut sync.RWMutex
|
||||
nodes map[int64]*shardNode // online nodes
|
||||
segments map[int64]*shardSegmentInfo // shard segments
|
||||
handoffs map[int32]*querypb.SegmentChangeInfo // current pending handoff
|
||||
lastToken *atomic.Int32 // last token used for segment change info
|
||||
segmentCond *sync.Cond // segment state change condition
|
||||
rcCond *sync.Cond // segment rc change condition
|
||||
|
||||
closeOnce sync.Once
|
||||
closeCh chan struct{}
|
||||
@ -141,12 +146,19 @@ func NewShardCluster(collectionID int64, replicaID int64, vchannelName string,
|
||||
segmentDetector: segmentDetector,
|
||||
nodeBuilder: nodeBuilder,
|
||||
|
||||
nodes: make(map[int64]*shardNode),
|
||||
segments: make(map[int64]*shardSegmentInfo),
|
||||
nodes: make(map[int64]*shardNode),
|
||||
segments: make(map[int64]*shardSegmentInfo),
|
||||
handoffs: make(map[int32]*querypb.SegmentChangeInfo),
|
||||
lastToken: atomic.NewInt32(0),
|
||||
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
m := sync.Mutex{}
|
||||
sc.segmentCond = sync.NewCond(&m)
|
||||
m2 := sync.Mutex{}
|
||||
sc.rcCond = sync.NewCond(&m2)
|
||||
|
||||
sc.init()
|
||||
|
||||
return sc
|
||||
@ -205,8 +217,15 @@ func (sc *ShardCluster) removeNode(evt nodeEvent) {
|
||||
|
||||
// updateSegment apply segment change to shard cluster
|
||||
func (sc *ShardCluster) updateSegment(evt segmentEvent) {
|
||||
|
||||
log.Debug("ShardCluster update segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
|
||||
|
||||
// notify handoff wait online if any
|
||||
defer func() {
|
||||
sc.segmentCond.L.Lock()
|
||||
sc.segmentCond.Broadcast()
|
||||
sc.segmentCond.L.Unlock()
|
||||
}()
|
||||
|
||||
sc.mut.Lock()
|
||||
defer sc.mut.Unlock()
|
||||
|
||||
@ -255,6 +274,8 @@ func (sc *ShardCluster) transferSegment(old *shardSegmentInfo, evt segmentEvent)
|
||||
// removeSegment removes segment from cluster
|
||||
// should only applied in hand-off or load balance procedure
|
||||
func (sc *ShardCluster) removeSegment(evt segmentEvent) {
|
||||
log.Debug("ShardCluster remove segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
|
||||
|
||||
sc.mut.Lock()
|
||||
defer sc.mut.Unlock()
|
||||
|
||||
@ -269,7 +290,6 @@ func (sc *ShardCluster) removeSegment(evt segmentEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
//TODO check handoff / load balance
|
||||
delete(sc.segments, evt.segmentID)
|
||||
}
|
||||
|
||||
@ -331,7 +351,6 @@ func (sc *ShardCluster) watchSegments(evtCh <-chan segmentEvent) {
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-evtCh:
|
||||
log.Debug("segment event", zap.Any("evt", evt))
|
||||
if !ok {
|
||||
log.Warn("ShardCluster segment channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
|
||||
return
|
||||
@ -381,20 +400,166 @@ func (sc *ShardCluster) getSegment(segmentID int64) (*shardSegmentInfo, bool) {
|
||||
}
|
||||
|
||||
// segmentAllocations returns node to segments mappings.
|
||||
// calling this function also increases the reference count of related segments.
|
||||
func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) map[int64][]int64 {
|
||||
result := make(map[int64][]int64) // nodeID => segmentIDs
|
||||
sc.mut.RLock()
|
||||
defer sc.mut.RUnlock()
|
||||
sc.mut.Lock()
|
||||
defer sc.mut.Unlock()
|
||||
|
||||
for _, segment := range sc.segments {
|
||||
if len(partitionIDs) > 0 && !inList(partitionIDs, segment.partitionID) {
|
||||
continue
|
||||
}
|
||||
if sc.inHandoffOffline(segment.segmentID) {
|
||||
log.Debug("segment ignore in pending offline list", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.Int64("segmentID", segment.segmentID))
|
||||
continue
|
||||
}
|
||||
// reference count ++
|
||||
segment.inUse++
|
||||
result[segment.nodeID] = append(result[segment.nodeID], segment.segmentID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// inHandoffOffline checks whether segment is pending handoff offline list
|
||||
// Note that sc.mut Lock is assumed to be hold outside of this function!
|
||||
func (sc *ShardCluster) inHandoffOffline(segmentID int64) bool {
|
||||
for _, handoff := range sc.handoffs {
|
||||
for _, offlineSegment := range handoff.OfflineSegments {
|
||||
if segmentID == offlineSegment.GetSegmentID() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// finishUsage decreases the inUse count of provided segments
|
||||
func (sc *ShardCluster) finishUsage(allocs map[int64][]int64) {
|
||||
defer func() {
|
||||
sc.rcCond.L.Lock()
|
||||
sc.rcCond.Broadcast()
|
||||
sc.rcCond.L.Unlock()
|
||||
}()
|
||||
sc.mut.Lock()
|
||||
defer sc.mut.Unlock()
|
||||
for _, segments := range allocs {
|
||||
for _, segmentID := range segments {
|
||||
segment, ok := sc.segments[segmentID]
|
||||
if !ok || segment == nil {
|
||||
// this shall not happen, since removing segment without decreasing rc to zero is illegal
|
||||
log.Error("finishUsage with non-existing segment", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.String("vchannel", sc.vchannelName), zap.Int64("segmentID", segmentID))
|
||||
continue
|
||||
}
|
||||
// decrease the reference count
|
||||
segment.inUse--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandoffSegments processes the handoff/load balance segments update procedure.
|
||||
func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
|
||||
// wait for all OnlineSegment is loaded
|
||||
onlineSegments := make([]int64, 0, len(info.OnlineSegments))
|
||||
for _, seg := range info.OnlineSegments {
|
||||
// filter out segments not maintained in this cluster
|
||||
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
|
||||
continue
|
||||
}
|
||||
onlineSegments = append(onlineSegments, seg.GetSegmentID())
|
||||
}
|
||||
sc.waitSegmentsOnline(onlineSegments)
|
||||
|
||||
// add segmentChangeInfo to pending list
|
||||
token := sc.appendHandoff(info)
|
||||
|
||||
// wait for all OfflineSegments is not in use
|
||||
offlineSegments := make([]int64, 0, len(info.OfflineSegments))
|
||||
for _, seg := range info.OfflineSegments {
|
||||
offlineSegments = append(offlineSegments, seg.GetSegmentID())
|
||||
}
|
||||
sc.waitSegmentsNotInUse(offlineSegments)
|
||||
// remove offline segments record
|
||||
for _, seg := range info.OfflineSegments {
|
||||
// filter out segments not maintained in this cluster
|
||||
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
|
||||
continue
|
||||
}
|
||||
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
|
||||
}
|
||||
|
||||
// finish handoff and remove it from pending list
|
||||
sc.finishHandoff(token)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// appendHandoff adds the change info into pending list and returns the token.
|
||||
func (sc *ShardCluster) appendHandoff(info *querypb.SegmentChangeInfo) int32 {
|
||||
sc.mut.Lock()
|
||||
defer sc.mut.Unlock()
|
||||
|
||||
token := sc.lastToken.Add(1)
|
||||
sc.handoffs[token] = info
|
||||
return token
|
||||
}
|
||||
|
||||
// finishHandoff removes the handoff related to the token.
|
||||
func (sc *ShardCluster) finishHandoff(token int32) {
|
||||
sc.mut.Lock()
|
||||
defer sc.mut.Unlock()
|
||||
|
||||
delete(sc.handoffs, token)
|
||||
}
|
||||
|
||||
// waitSegmentsOnline waits until all provided segments is loaded.
|
||||
func (sc *ShardCluster) waitSegmentsOnline(segments []int64) {
|
||||
sc.segmentCond.L.Lock()
|
||||
for !sc.segmentsOnline(segments) {
|
||||
sc.segmentCond.Wait()
|
||||
}
|
||||
sc.segmentCond.L.Unlock()
|
||||
}
|
||||
|
||||
// waitSegmentsNotInUse waits until all provided segments is not in use.
|
||||
func (sc *ShardCluster) waitSegmentsNotInUse(segments []int64) {
|
||||
sc.rcCond.L.Lock()
|
||||
for sc.segmentsInUse(segments) {
|
||||
sc.rcCond.Wait()
|
||||
}
|
||||
sc.rcCond.L.Unlock()
|
||||
}
|
||||
|
||||
// checkOnline checks whether all segment ids provided in online state.
|
||||
func (sc *ShardCluster) segmentsOnline(segmentIDs []int64) bool {
|
||||
sc.mut.RLock()
|
||||
defer sc.mut.RUnlock()
|
||||
for _, segID := range segmentIDs {
|
||||
segment, ok := sc.segments[segID]
|
||||
if !ok || segment.state != segmentStateLoaded {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// segmentsInUse checks whether all segment ids provided still in use.
|
||||
func (sc *ShardCluster) segmentsInUse(segmentIDs []int64) bool {
|
||||
sc.mut.RLock()
|
||||
defer sc.mut.RUnlock()
|
||||
for _, segID := range segmentIDs {
|
||||
segment, ok := sc.segments[segID]
|
||||
if !ok {
|
||||
// ignore missing segments, since they might be in streaming
|
||||
continue
|
||||
}
|
||||
if segment.inUse > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Search preforms search operation on shard cluster.
|
||||
func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
||||
if sc.state.Load() != int32(available) {
|
||||
@ -405,16 +570,14 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest)
|
||||
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
|
||||
}
|
||||
|
||||
//req.GetReq().GetPartitionIDs()
|
||||
|
||||
// get node allocation
|
||||
// get node allocation and maintains the inUse reference count
|
||||
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
||||
defer sc.finishUsage(segAllocs)
|
||||
|
||||
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)))
|
||||
for nodeID, segmentIDs := range segAllocs {
|
||||
log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs))
|
||||
}
|
||||
// TODO dispatch to local queryShardService query dml channel growing segments
|
||||
|
||||
// concurrent visiting nodes
|
||||
var wg sync.WaitGroup
|
||||
@ -430,7 +593,7 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest)
|
||||
nodeReq.SegmentIDs = segments
|
||||
node, ok := sc.getNode(nodeID)
|
||||
if !ok { // meta dismatch, report error
|
||||
return nil, fmt.Errorf("SharcCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
|
||||
return nil, fmt.Errorf("ShardCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@ -466,10 +629,9 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest) ([
|
||||
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
|
||||
}
|
||||
|
||||
// get node allocation
|
||||
// get node allocation and maintains the inUse reference count
|
||||
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
||||
|
||||
// TODO dispatch to local queryShardService query dml channel growing segments
|
||||
defer sc.finishUsage(segAllocs)
|
||||
|
||||
// concurrent visiting nodes
|
||||
var wg sync.WaitGroup
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
|
||||
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
@ -107,3 +108,20 @@ func (s *ShardClusterService) releaseCollection(collectionID int64) {
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// HandoffSegments dispatch segmentChangeInfo to related shardClusters
|
||||
func (s *ShardClusterService) HandoffSegments(collectionID int64, info *querypb.SegmentChangeInfo) {
|
||||
var wg sync.WaitGroup
|
||||
s.clusters.Range(func(k, v interface{}) bool {
|
||||
cs := v.(*ShardCluster)
|
||||
if cs.collectionID == collectionID {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cs.HandoffSegments(info)
|
||||
}()
|
||||
}
|
||||
return true
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
|
@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/server/v3/etcdserver/api/v3client"
|
||||
)
|
||||
|
||||
@ -32,3 +34,19 @@ func TestShardClusterService(t *testing.T) {
|
||||
err = clusterService.releaseShardCluster("non-exist-channel")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestShardClusterService_HandoffSegments(t *testing.T) {
|
||||
qn, err := genSimpleQueryNode(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
client := v3client.New(embedetcdServer.Server)
|
||||
defer client.Close()
|
||||
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
|
||||
clusterService := newShardClusterService(client, session, qn)
|
||||
|
||||
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
|
||||
//TODO change shardCluster to interface to mock test behavior
|
||||
assert.NotPanics(t, func() {
|
||||
clusterService.HandoffSegments(defaultCollectionID, &querypb.SegmentChangeInfo{})
|
||||
})
|
||||
}
|
||||
|
@ -1023,3 +1023,372 @@ func TestShardCluster_Query(t *testing.T) {
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestShardCluster_ReferenceCount(t *testing.T) {
|
||||
collectionID := int64(1)
|
||||
vchannelName := "dml_1_1_v0"
|
||||
replicaID := int64(0)
|
||||
// ctx := context.Background()
|
||||
|
||||
t.Run("normal alloc & finish", func(t *testing.T) {
|
||||
nodeEvents := []nodeEvent{
|
||||
{
|
||||
nodeID: 1,
|
||||
nodeAddr: "addr_1",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
nodeAddr: "addr_2",
|
||||
},
|
||||
}
|
||||
|
||||
segmentEvents := []segmentEvent{
|
||||
{
|
||||
segmentID: 1,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
{
|
||||
segmentID: 2,
|
||||
nodeID: 2,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
}
|
||||
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||
&mockNodeDetector{
|
||||
initNodes: nodeEvents,
|
||||
}, &mockSegmentDetector{
|
||||
initSegments: segmentEvents,
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
allocs := sc.segmentAllocations(nil)
|
||||
|
||||
sc.mut.RLock()
|
||||
for _, segment := range sc.segments {
|
||||
assert.Greater(t, segment.inUse, int32(0))
|
||||
}
|
||||
sc.mut.RUnlock()
|
||||
|
||||
assert.True(t, sc.segmentsInUse([]int64{1, 2}))
|
||||
assert.True(t, sc.segmentsInUse([]int64{1, 2, -1}))
|
||||
|
||||
sc.finishUsage(allocs)
|
||||
sc.mut.RLock()
|
||||
for _, segment := range sc.segments {
|
||||
assert.EqualValues(t, segment.inUse, 0)
|
||||
}
|
||||
sc.mut.RUnlock()
|
||||
|
||||
assert.False(t, sc.segmentsInUse([]int64{1, 2}))
|
||||
assert.False(t, sc.segmentsInUse([]int64{1, 2, -1}))
|
||||
})
|
||||
|
||||
t.Run("alloc & finish with modified alloc", func(t *testing.T) {
|
||||
nodeEvents := []nodeEvent{
|
||||
{
|
||||
nodeID: 1,
|
||||
nodeAddr: "addr_1",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
nodeAddr: "addr_2",
|
||||
},
|
||||
}
|
||||
|
||||
segmentEvents := []segmentEvent{
|
||||
{
|
||||
segmentID: 1,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
{
|
||||
segmentID: 2,
|
||||
nodeID: 2,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
}
|
||||
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||
&mockNodeDetector{
|
||||
initNodes: nodeEvents,
|
||||
}, &mockSegmentDetector{
|
||||
initSegments: segmentEvents,
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
allocs := sc.segmentAllocations(nil)
|
||||
|
||||
sc.mut.RLock()
|
||||
for _, segment := range sc.segments {
|
||||
assert.Greater(t, segment.inUse, int32(0))
|
||||
}
|
||||
sc.mut.RUnlock()
|
||||
|
||||
for node, segments := range allocs {
|
||||
segments = append(segments, -1) // add non-exist segment
|
||||
// shall be ignored in finishUsage
|
||||
allocs[node] = segments
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
sc.finishUsage(allocs)
|
||||
})
|
||||
sc.mut.RLock()
|
||||
for _, segment := range sc.segments {
|
||||
assert.EqualValues(t, segment.inUse, 0)
|
||||
}
|
||||
sc.mut.RUnlock()
|
||||
})
|
||||
|
||||
t.Run("wait segments online", func(t *testing.T) {
|
||||
nodeEvents := []nodeEvent{
|
||||
{
|
||||
nodeID: 1,
|
||||
nodeAddr: "addr_1",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
nodeAddr: "addr_2",
|
||||
},
|
||||
}
|
||||
|
||||
segmentEvents := []segmentEvent{
|
||||
{
|
||||
segmentID: 1,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
{
|
||||
segmentID: 2,
|
||||
nodeID: 2,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
}
|
||||
evtCh := make(chan segmentEvent, 10)
|
||||
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||
&mockNodeDetector{
|
||||
initNodes: nodeEvents,
|
||||
}, &mockSegmentDetector{
|
||||
initSegments: segmentEvents,
|
||||
evtCh: evtCh,
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
assert.True(t, sc.segmentsOnline([]int64{1, 2}))
|
||||
assert.False(t, sc.segmentsOnline([]int64{1, 2, 3}))
|
||||
|
||||
sig := make(chan struct{})
|
||||
go func() {
|
||||
sc.waitSegmentsOnline([]int64{1, 2, 3})
|
||||
close(sig)
|
||||
}()
|
||||
|
||||
evtCh <- segmentEvent{
|
||||
eventType: segmentAdd,
|
||||
segmentID: 3,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
}
|
||||
|
||||
<-sig
|
||||
assert.True(t, sc.segmentsOnline([]int64{1, 2, 3}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestShardCluster_HandoffSegments(t *testing.T) {
|
||||
collectionID := int64(1)
|
||||
otherCollectionID := int64(2)
|
||||
vchannelName := "dml_1_1_v0"
|
||||
otherVchannelName := "dml_1_2_v0"
|
||||
replicaID := int64(0)
|
||||
|
||||
t.Run("handoff without using segments", func(t *testing.T) {
|
||||
nodeEvents := []nodeEvent{
|
||||
{
|
||||
nodeID: 1,
|
||||
nodeAddr: "addr_1",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
nodeAddr: "addr_2",
|
||||
},
|
||||
}
|
||||
|
||||
segmentEvents := []segmentEvent{
|
||||
{
|
||||
segmentID: 1,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
{
|
||||
segmentID: 2,
|
||||
nodeID: 2,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
}
|
||||
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||
&mockNodeDetector{
|
||||
initNodes: nodeEvents,
|
||||
}, &mockSegmentDetector{
|
||||
initSegments: segmentEvents,
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
})
|
||||
|
||||
sc.mut.RLock()
|
||||
_, has := sc.segments[1]
|
||||
sc.mut.RUnlock()
|
||||
|
||||
assert.False(t, has)
|
||||
})
|
||||
t.Run("handoff with growing segment(segment not recorded)", func(t *testing.T) {
|
||||
nodeEvents := []nodeEvent{
|
||||
{
|
||||
nodeID: 1,
|
||||
nodeAddr: "addr_1",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
nodeAddr: "addr_2",
|
||||
},
|
||||
}
|
||||
|
||||
segmentEvents := []segmentEvent{
|
||||
{
|
||||
segmentID: 1,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
{
|
||||
segmentID: 2,
|
||||
nodeID: 2,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
}
|
||||
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||
&mockNodeDetector{
|
||||
initNodes: nodeEvents,
|
||||
}, &mockSegmentDetector{
|
||||
initSegments: segmentEvents,
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
||||
},
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName},
|
||||
},
|
||||
})
|
||||
|
||||
sc.mut.RLock()
|
||||
_, has := sc.segments[3]
|
||||
sc.mut.RUnlock()
|
||||
|
||||
assert.False(t, has)
|
||||
})
|
||||
t.Run("handoff wait online and usage", func(t *testing.T) {
|
||||
nodeEvents := []nodeEvent{
|
||||
{
|
||||
nodeID: 1,
|
||||
nodeAddr: "addr_1",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
nodeAddr: "addr_2",
|
||||
},
|
||||
}
|
||||
|
||||
segmentEvents := []segmentEvent{
|
||||
{
|
||||
segmentID: 1,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
{
|
||||
segmentID: 2,
|
||||
nodeID: 2,
|
||||
state: segmentStateLoaded,
|
||||
},
|
||||
}
|
||||
evtCh := make(chan segmentEvent, 10)
|
||||
sc := NewShardCluster(collectionID, replicaID, vchannelName,
|
||||
&mockNodeDetector{
|
||||
initNodes: nodeEvents,
|
||||
}, &mockSegmentDetector{
|
||||
initSegments: segmentEvents,
|
||||
evtCh: evtCh,
|
||||
}, buildMockQueryNode)
|
||||
defer sc.Close()
|
||||
|
||||
// add rc to all segments
|
||||
allocs := sc.segmentAllocations(nil)
|
||||
|
||||
sig := make(chan struct{})
|
||||
go func() {
|
||||
sc.HandoffSegments(&querypb.SegmentChangeInfo{
|
||||
OnlineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
OfflineSegments: []*querypb.SegmentInfo{
|
||||
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName},
|
||||
},
|
||||
})
|
||||
|
||||
close(sig)
|
||||
}()
|
||||
|
||||
sc.mut.RLock()
|
||||
// still waiting online
|
||||
assert.Equal(t, 0, len(sc.handoffs))
|
||||
sc.mut.RUnlock()
|
||||
|
||||
evtCh <- segmentEvent{
|
||||
eventType: segmentAdd,
|
||||
segmentID: 3,
|
||||
nodeID: 1,
|
||||
state: segmentStateLoaded,
|
||||
}
|
||||
|
||||
// wait for handoff appended into list
|
||||
assert.Eventually(t, func() bool {
|
||||
sc.mut.RLock()
|
||||
defer sc.mut.RUnlock()
|
||||
return len(sc.handoffs) > 0
|
||||
}, time.Second, time.Millisecond*10)
|
||||
|
||||
tmpAllocs := sc.segmentAllocations(nil)
|
||||
found := false
|
||||
for _, segments := range tmpAllocs {
|
||||
if inList(segments, int64(1)) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// segment 1 shall not be allocated again!
|
||||
assert.False(t, found)
|
||||
sc.finishUsage(tmpAllocs)
|
||||
// rc shall be 0 now
|
||||
sc.finishUsage(allocs)
|
||||
|
||||
// wait handoff finished
|
||||
<-sig
|
||||
|
||||
sc.mut.RLock()
|
||||
_, has := sc.segments[1]
|
||||
sc.mut.RUnlock()
|
||||
|
||||
assert.False(t, has)
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user