Watch changeInfo in query node (#10045)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2021-10-20 17:54:43 +08:00 committed by GitHub
parent 033079269a
commit c18fa9b785
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 330 additions and 1 deletions

View File

@ -68,6 +68,13 @@ func (g *globalSealedSegmentManager) getGlobalSegmentIDsByPartitionIds(partition
return resIDs return resIDs
} }
func (g *globalSealedSegmentManager) hasGlobalSegment(segmentID UniqueID) bool {
g.mu.Lock()
defer g.mu.Unlock()
_, ok := g.globalSealedSegments[segmentID]
return ok
}
func (g *globalSealedSegmentManager) removeGlobalSegmentInfo(segmentID UniqueID) { func (g *globalSealedSegmentManager) removeGlobalSegmentInfo(segmentID UniqueID) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()

View File

@ -56,6 +56,9 @@ func TestGlobalSealedSegmentManager(t *testing.T) {
ids = manager.getGlobalSegmentIDs() ids = manager.getGlobalSegmentIDs()
assert.Len(t, ids, 0) assert.Len(t, ids, 0)
has := manager.hasGlobalSegment(defaultSegmentID)
assert.False(t, has)
segmentInfo.CollectionID = defaultCollectionID segmentInfo.CollectionID = defaultCollectionID
err = manager.addGlobalSegmentInfo(segmentInfo) err = manager.addGlobalSegmentInfo(segmentInfo)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
) )
@ -1248,6 +1249,33 @@ func consumeSimpleRetrieveResult(stream msgstream.MsgStream) (*msgstream.Retriev
return res.Msgs[0].(*msgstream.RetrieveResultMsg), nil return res.Msgs[0].(*msgstream.RetrieveResultMsg), nil
} }
func genSimpleChangeInfo() *querypb.SealedSegmentsChangeInfo {
return &querypb.SealedSegmentsChangeInfo{
Base: genCommonMsgBase(commonpb.MsgType_LoadBalanceSegments),
OnlineNodeID: Params.QueryNodeID,
OnlineSegments: []*querypb.SegmentInfo{
genSimpleSegmentInfo(),
},
OfflineNodeID: Params.QueryNodeID + 1,
OfflineSegments: []*querypb.SegmentInfo{
genSimpleSegmentInfo(),
},
}
}
func saveChangeInfo(key string, value string) error {
log.Debug(".. [query node unittest] Saving change info")
kv, err := genEtcdKV()
if err != nil {
return err
}
key = changeInfoMetaPrefix + "/" + key
return kv.Save(key, value)
}
// node // node
func genSimpleQueryNode(ctx context.Context) (*QueryNode, error) { func genSimpleQueryNode(ctx context.Context) (*QueryNode, error) {
fac, err := genFactory() fac, err := genFactory()
@ -1256,6 +1284,13 @@ func genSimpleQueryNode(ctx context.Context) (*QueryNode, error) {
} }
node := NewQueryNode(ctx, fac) node := NewQueryNode(ctx, fac)
etcdKV, err := genEtcdKV()
if err != nil {
return nil, err
}
node.etcdKV = etcdKV
streaming, err := genSimpleStreaming(ctx) streaming, err := genSimpleStreaming(ctx)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -27,25 +27,33 @@ import "C"
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"path/filepath"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe" "unsafe"
"github.com/golang/protobuf/proto"
"go.etcd.io/etcd/api/v3/mvccpb"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
) )
const changeInfoMetaPrefix = "query-changeInfo"
// make sure QueryNode implements types.QueryNode // make sure QueryNode implements types.QueryNode
var _ types.QueryNode = (*QueryNode)(nil) var _ types.QueryNode = (*QueryNode)(nil)
@ -212,6 +220,7 @@ func (node *QueryNode) Start() error {
// start services // start services
go node.historical.start() go node.historical.start()
go node.watchChangeInfo()
Params.CreatedTime = time.Now() Params.CreatedTime = time.Now()
Params.UpdatedTime = time.Now() Params.UpdatedTime = time.Now()
@ -256,3 +265,137 @@ func (node *QueryNode) SetIndexCoord(index types.IndexCoord) error {
node.indexCoord = index node.indexCoord = index
return nil return nil
} }
func (node *QueryNode) watchChangeInfo() {
log.Debug("query node watchChangeInfo start")
watchChan := node.etcdKV.WatchWithPrefix(changeInfoMetaPrefix)
for {
select {
case <-node.queryNodeLoopCtx.Done():
log.Debug("query node watchChangeInfo close")
return
case resp := <-watchChan:
for _, event := range resp.Events {
switch event.Type {
case mvccpb.PUT:
infoID, err := strconv.ParseInt(filepath.Base(string(event.Kv.Key)), 10, 64)
if err != nil {
log.Warn("Parse SealedSegmentsChangeInfo id failed", zap.Any("error", err.Error()))
continue
}
log.Debug("get SealedSegmentsChangeInfo from etcd",
zap.Any("infoID", infoID),
)
info := &querypb.SealedSegmentsChangeInfo{}
err = proto.Unmarshal(event.Kv.Value, info)
if err != nil {
log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error()))
continue
}
go func() {
err = node.adjustByChangeInfo(info)
if err != nil {
log.Warn("adjustByChangeInfo failed", zap.Any("error", err.Error()))
}
}()
default:
// do nothing
}
}
}
}
}
func (node *QueryNode) waitChangeInfo(info *querypb.SealedSegmentsChangeInfo) error {
fn := func() error {
canDoLoadBalance := true
// Check online segments:
for _, segmentInfo := range info.OnlineSegments {
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
if err != nil {
canDoLoadBalance = false
break
}
if info.OnlineNodeID == Params.QueryNodeID && !qc.globalSegmentManager.hasGlobalSegment(segmentInfo.SegmentID) {
canDoLoadBalance = false
break
}
}
}
// Check offline segments:
for _, segmentInfo := range info.OfflineSegments {
if node.queryService.hasQueryCollection(segmentInfo.CollectionID) {
qc, err := node.queryService.getQueryCollection(segmentInfo.CollectionID)
if err != nil {
canDoLoadBalance = false
break
}
if info.OfflineNodeID == Params.QueryNodeID && qc.globalSegmentManager.hasGlobalSegment(segmentInfo.SegmentID) {
canDoLoadBalance = false
break
}
}
}
if canDoLoadBalance {
return nil
}
return errors.New(fmt.Sprintln("waitChangeInfo failed, infoID = ", info.Base.GetMsgID()))
}
return retry.Do(context.TODO(), fn, retry.Attempts(10))
}
func (node *QueryNode) adjustByChangeInfo(info *querypb.SealedSegmentsChangeInfo) error {
err := node.waitChangeInfo(info)
if err != nil {
log.Error("waitChangeInfo failed", zap.Any("error", err.Error()))
return err
}
// For online segments:
for _, segmentInfo := range info.OnlineSegments {
// 1. update excluded segment, cluster have been loaded sealed segments,
// so we need to avoid getting growing segment from flow graph.
node.streaming.replica.addExcludedSegments(segmentInfo.CollectionID, []*datapb.SegmentInfo{
{
ID: segmentInfo.SegmentID,
CollectionID: segmentInfo.CollectionID,
PartitionID: segmentInfo.PartitionID,
InsertChannel: segmentInfo.ChannelID,
NumOfRows: segmentInfo.NumRows,
// TODO: add status, remove query pb segment status, use common pb segment status?
DmlPosition: &internalpb.MsgPosition{
// use max timestamp to filter out dm messages
Timestamp: typeutil.MaxTimestamp,
},
},
})
// 2. delete growing segment because these segments are loaded in historical.
hasGrowingSegment := node.streaming.replica.hasSegment(segmentInfo.SegmentID)
if hasGrowingSegment {
err := node.streaming.replica.removeSegment(segmentInfo.SegmentID)
if err != nil {
return err
}
log.Debug("remove growing segment in adjustByChangeInfo",
zap.Any("collectionID", segmentInfo.CollectionID),
zap.Any("segmentID", segmentInfo.SegmentID),
zap.Any("infoID", info.Base.GetMsgID()),
)
}
}
// For offline segments:
for _, segment := range info.OfflineSegments {
// 1. load balance or compaction, remove old sealed segments.
if info.OfflineNodeID == Params.QueryNodeID {
err := node.historical.replica.removeSegment(segment.SegmentID)
if err != nil {
return err
}
}
}
return nil
}

View File

@ -19,6 +19,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
@ -272,3 +273,138 @@ func TestQueryNode_init(t *testing.T) {
err = node.Init() err = node.Init()
assert.NoError(t, err) assert.NoError(t, err)
} }
func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, error) {
node, err := genSimpleQueryNode(ctx)
if err != nil {
return nil, err
}
err = node.queryService.addQueryCollection(defaultCollectionID)
if err != nil {
return nil, err
}
qc, err := node.queryService.getQueryCollection(defaultCollectionID)
if err != nil {
return nil, err
}
err = qc.globalSegmentManager.addGlobalSegmentInfo(genSimpleSegmentInfo())
if err != nil {
return nil, err
}
return node, nil
}
func TestQueryNode_waitChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
err = node.waitChangeInfo(genSimpleChangeInfo())
assert.NoError(t, err)
}
func TestQueryNode_adjustByChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("test adjustByChangeInfo", func(t *testing.T) {
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
err = node.adjustByChangeInfo(genSimpleChangeInfo())
assert.NoError(t, err)
})
t.Run("test adjustByChangeInfo no segment", func(t *testing.T) {
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
err = node.historical.replica.removeSegment(defaultSegmentID)
assert.NoError(t, err)
info := genSimpleChangeInfo()
info.OnlineSegments = nil
info.OfflineNodeID = Params.QueryNodeID
qc, err := node.queryService.getQueryCollection(defaultCollectionID)
assert.NoError(t, err)
qc.globalSegmentManager.removeGlobalSegmentInfo(defaultSegmentID)
err = node.adjustByChangeInfo(info)
assert.Error(t, err)
})
}
func TestQueryNode_watchChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("test watchChangeInfo", func(t *testing.T) {
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
go node.watchChangeInfo()
info := genSimpleSegmentInfo()
value, err := proto.Marshal(info)
assert.NoError(t, err)
err = saveChangeInfo("0", string(value))
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
})
t.Run("test watchChangeInfo key error", func(t *testing.T) {
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
go node.watchChangeInfo()
err = saveChangeInfo("*$&#%^^", "%EUY%&#^$%&@")
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
})
t.Run("test watchChangeInfo unmarshal error", func(t *testing.T) {
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
go node.watchChangeInfo()
err = saveChangeInfo("0", "$%^$*&%^#$&*")
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
})
t.Run("test watchChangeInfo adjustByChangeInfo error", func(t *testing.T) {
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
err = node.historical.replica.removeSegment(defaultSegmentID)
assert.NoError(t, err)
info := genSimpleChangeInfo()
info.OnlineSegments = nil
info.OfflineNodeID = Params.QueryNodeID
qc, err := node.queryService.getQueryCollection(defaultCollectionID)
assert.NoError(t, err)
qc.globalSegmentManager.removeGlobalSegmentInfo(defaultSegmentID)
go node.watchChangeInfo()
value, err := proto.Marshal(info)
assert.NoError(t, err)
err = saveChangeInfo("0", string(value))
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
})
}

View File

@ -11,7 +11,12 @@
package typeutil package typeutil
import "time" import (
"math"
"time"
)
const MaxTimestamp = math.MaxUint64
// ZeroTime is a zero time. // ZeroTime is a zero time.
var ZeroTime = time.Time{} var ZeroTime = time.Time{}