Dynamic add tSafe watcher, use sync.Cond instead of selectCase, add ref count (#8050)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2021-09-24 13:57:54 +08:00 committed by GitHub
parent 43432f47d7
commit 52126f2d5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 368 additions and 112 deletions

View File

@ -179,9 +179,15 @@ func (dsService *dataSyncService) removePartitionFlowGraph(partitionID UniqueID)
defer dsService.mu.Unlock()
if _, ok := dsService.partitionFlowGraphs[partitionID]; ok {
for _, nodeFG := range dsService.partitionFlowGraphs[partitionID] {
for channel, nodeFG := range dsService.partitionFlowGraphs[partitionID] {
// close flow graph
nodeFG.close()
// remove tSafe record
// no tSafe in tSafeReplica, don't return error
err := dsService.tSafeReplica.removeRecord(channel, partitionID)
if err != nil {
log.Warn(err.Error())
}
}
dsService.partitionFlowGraphs[partitionID] = nil
}

View File

@ -213,3 +213,25 @@ func TestDataSyncService_partitionFlowGraphs(t *testing.T) {
assert.Nil(t, fg)
assert.Error(t, err)
}
func TestDataSyncService_removePartitionFlowGraphs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("test no tSafe", func(t *testing.T) {
streaming, err := genSimpleStreaming(ctx)
assert.NoError(t, err)
fac, err := genFactory()
assert.NoError(t, err)
dataSyncService := newDataSyncService(ctx, streaming.replica, streaming.tSafeReplica, fac)
assert.NotNil(t, dataSyncService)
dataSyncService.addPartitionFlowGraph(defaultPartitionID, defaultPartitionID, []Channel{defaultVChannel})
err = dataSyncService.tSafeReplica.removeTSafe(defaultVChannel)
assert.NoError(t, err)
dataSyncService.removePartitionFlowGraph(defaultPartitionID)
})
}

View File

@ -65,7 +65,10 @@ func (stNode *serviceTimeNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
} else {
id = stNode.collectionID
}
stNode.tSafeReplica.setTSafe(stNode.vChannel, id, serviceTimeMsg.timeRange.timestampMax)
err := stNode.tSafeReplica.setTSafe(stNode.vChannel, id, serviceTimeMsg.timeRange.timestampMax)
if err != nil {
log.Warn(err.Error())
}
//log.Debug("update tSafe:",
// zap.Int64("tSafe", int64(serviceTimeMsg.timeRange.timestampMax)),
// zap.Any("collectionID", stNode.collectionID),

View File

@ -77,4 +77,18 @@ func TestServiceTimeNode_Operate(t *testing.T) {
in := []flowgraph.Msg{msg, msg}
node.Operate(in)
})
t.Run("test no tSafe", func(t *testing.T) {
node := genServiceTimeNode()
err := node.tSafeReplica.removeTSafe(defaultVChannel)
assert.NoError(t, err)
msg := &serviceTimeMsg{
timeRange: TimeRange{
timestampMin: 0,
timestampMax: 1000,
},
}
in := []flowgraph.Msg{msg, msg}
node.Operate(in)
})
}

View File

@ -106,12 +106,26 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
// add search collection
if !node.queryService.hasQueryCollection(collectionID) {
node.queryService.addQueryCollection(collectionID)
err := node.queryService.addQueryCollection(collectionID)
if err != nil {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
return status, err
}
log.Debug("add query collection", zap.Any("collectionID", collectionID))
}
// add request channel
sc := node.queryService.queryCollections[in.CollectionID]
sc, err := node.queryService.getQueryCollection(in.CollectionID)
if err != nil {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
return status, err
}
consumeChannels := []string{in.RequestChannelID}
//consumeSubName := Params.MsgChannelSubName
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())

View File

@ -17,13 +17,14 @@ import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/stretchr/testify/assert"
)
func TestImpl_GetComponentStates(t *testing.T) {
@ -108,6 +109,26 @@ func TestImpl_AddQueryChannel(t *testing.T) {
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
t.Run("test add query collection failed", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
err = node.streaming.replica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
}
status, err := node.AddQueryChannel(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
}
func TestImpl_RemoveQueryChannel(t *testing.T) {

View File

@ -14,9 +14,9 @@ package querynode
import (
"context"
"encoding/binary"
"errors"
"fmt"
"math"
"reflect"
"sync"
"unsafe"
@ -50,15 +50,16 @@ type queryCollection struct {
cancel context.CancelFunc
collectionID UniqueID
collection *Collection
historical *historical
streaming *streaming
unsolvedMsgMu sync.Mutex // guards unsolvedMsg
unsolvedMsg []queryMsg
tSafeWatchers map[Channel]*tSafeWatcher
watcherSelectCase []reflect.SelectCase
tSafeWatchersMu sync.Mutex // guards tSafeWatchers
tSafeWatchers map[Channel]*tSafeWatcher
tSafeUpdate bool
watcherCond *sync.Cond
serviceableTimeMutex sync.Mutex // guards serviceableTime
serviceableTime Timestamp
@ -83,25 +84,26 @@ func newQueryCollection(releaseCtx context.Context,
localChunkManager storage.ChunkManager,
remoteChunkManager storage.ChunkManager,
localCacheEnabled bool,
) *queryCollection {
) (*queryCollection, error) {
unsolvedMsg := make([]queryMsg, 0)
queryStream, _ := factory.NewQueryMsgStream(releaseCtx)
queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx)
collection, _ := streaming.replica.getCollectionByID(collectionID)
condMu := sync.Mutex{}
qc := &queryCollection{
releaseCtx: releaseCtx,
cancel: cancel,
collectionID: collectionID,
collection: collection,
historical: historical,
streaming: streaming,
tSafeWatchers: make(map[Channel]*tSafeWatcher),
tSafeUpdate: false,
watcherCond: sync.NewCond(&condMu),
unsolvedMsg: unsolvedMsg,
@ -113,8 +115,11 @@ func newQueryCollection(releaseCtx context.Context,
localCacheEnabled: localCacheEnabled,
}
qc.register()
return qc
err := qc.registerCollectionTSafe()
if err != nil {
return nil, err
}
return qc, nil
}
func (q *queryCollection) start() {
@ -133,26 +138,61 @@ func (q *queryCollection) close() {
}
}
func (q *queryCollection) register() {
// registerCollectionTSafe registers tSafe watcher if vChannels exists
func (q *queryCollection) registerCollectionTSafe() error {
collection, err := q.streaming.replica.getCollectionByID(q.collectionID)
if err != nil {
log.Warn(err.Error())
return
return err
}
//TODO:: can't add new vChannel to selectCase
q.watcherSelectCase = make([]reflect.SelectCase, 0)
log.Debug("register tSafe watcher and init watcher select case",
zap.Any("collectionID", collection.ID()),
zap.Any("dml channels", collection.getVChannels()),
)
for _, channel := range collection.getVChannels() {
q.tSafeWatchers[channel] = newTSafeWatcher()
q.streaming.tSafeReplica.registerTSafeWatcher(channel, q.tSafeWatchers[channel])
q.watcherSelectCase = append(q.watcherSelectCase, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(q.tSafeWatchers[channel].watcherChan()),
})
err = q.addTSafeWatcher(channel)
if err != nil {
return err
}
}
return nil
}
func (q *queryCollection) addTSafeWatcher(vChannel Channel) error {
q.tSafeWatchersMu.Lock()
defer q.tSafeWatchersMu.Unlock()
if _, ok := q.tSafeWatchers[vChannel]; ok {
err := errors.New(fmt.Sprintln("tSafeWatcher of queryCollection has been exists, ",
"collectionID = ", q.collectionID, ", ",
"channel = ", vChannel))
return err
}
q.tSafeWatchers[vChannel] = newTSafeWatcher()
err := q.streaming.tSafeReplica.registerTSafeWatcher(vChannel, q.tSafeWatchers[vChannel])
if err != nil {
return err
}
log.Debug("add tSafeWatcher to queryCollection",
zap.Any("collectionID", q.collectionID),
zap.Any("channel", vChannel),
)
go q.startWatcher(q.tSafeWatchers[vChannel].watcherChan())
return nil
}
// TODO: add stopWatcher(), add close() to tSafeWatcher
func (q *queryCollection) startWatcher(channel <-chan bool) {
for {
select {
case <-q.releaseCtx.Done():
return
case <-channel:
// TODO: check if channel is closed
q.watcherCond.L.Lock()
q.tSafeUpdate = true
q.watcherCond.Broadcast()
q.watcherCond.L.Unlock()
}
}
}
@ -171,22 +211,24 @@ func (q *queryCollection) popAllUnsolvedMsg() []queryMsg {
return ret
}
func (q *queryCollection) waitNewTSafe() Timestamp {
// block until any vChannel updating tSafe
_, _, recvOK := reflect.Select(q.watcherSelectCase)
if !recvOK {
//log.Warn("tSafe has been closed", zap.Any("collectionID", q.collectionID))
return Timestamp(math.MaxInt64)
func (q *queryCollection) waitNewTSafe() (Timestamp, error) {
q.watcherCond.L.Lock()
for !q.tSafeUpdate {
q.watcherCond.Wait()
}
q.watcherCond.L.Unlock()
//log.Debug("wait new tSafe", zap.Any("collectionID", s.collectionID))
t := Timestamp(math.MaxInt64)
for channel := range q.tSafeWatchers {
ts := q.streaming.tSafeReplica.getTSafe(channel)
ts, err := q.streaming.tSafeReplica.getTSafe(channel)
if err != nil {
return 0, err
}
if ts <= t {
t = ts
}
}
return t
return t, nil
}
func (q *queryCollection) getServiceableTime() Timestamp {
@ -397,7 +439,11 @@ func (q *queryCollection) doUnsolvedQueryMsg() {
return
default:
//time.Sleep(10 * time.Millisecond)
serviceTime := q.waitNewTSafe()
serviceTime, err := q.waitNewTSafe()
if err != nil {
log.Error(err.Error())
return
}
//st, _ := tsoutil.ParseTS(serviceTime)
//log.Debug("get tSafe from flow graph",
// zap.Int64("collectionID", q.collectionID),
@ -769,7 +815,12 @@ func (q *queryCollection) search(msg queryMsg) error {
searchTimestamp := searchMsg.BeginTs()
travelTimestamp := searchMsg.TravelTimestamp
schema, err := typeutil.CreateSchemaHelper(q.collection.schema)
collection, err := q.streaming.replica.getCollectionByID(searchMsg.CollectionID)
if err != nil {
return err
}
schema, err := typeutil.CreateSchemaHelper(collection.schema)
if err != nil {
return err
}
@ -777,13 +828,13 @@ func (q *queryCollection) search(msg queryMsg) error {
var plan *SearchPlan
if searchMsg.GetDslType() == commonpb.DslType_BoolExprV1 {
expr := searchMsg.SerializedExprPlan
plan, err = createSearchPlanByExpr(q.collection, expr)
plan, err = createSearchPlanByExpr(collection, expr)
if err != nil {
return err
}
} else {
dsl := searchMsg.Dsl
plan, err = createSearchPlan(q.collection, dsl)
plan, err = createSearchPlan(collection, dsl)
if err != nil {
return err
}
@ -821,13 +872,13 @@ func (q *queryCollection) search(msg queryMsg) error {
if len(searchMsg.PartitionIDs) > 0 {
globalSealedSegments = q.historical.getGlobalSegmentIDsByPartitionIds(searchMsg.PartitionIDs)
} else {
globalSealedSegments = q.historical.getGlobalSegmentIDsByCollectionID(q.collection.id)
globalSealedSegments = q.historical.getGlobalSegmentIDsByCollectionID(collection.id)
}
searchResults := make([]*SearchResult, 0)
// historical search
hisSearchResults, sealedSegmentSearched, err1 := q.historical.search(searchRequests, q.collection.id, searchMsg.PartitionIDs, plan, travelTimestamp)
hisSearchResults, sealedSegmentSearched, err1 := q.historical.search(searchRequests, collection.id, searchMsg.PartitionIDs, plan, travelTimestamp)
if err1 != nil {
log.Warn(err1.Error())
return err1
@ -837,9 +888,9 @@ func (q *queryCollection) search(msg queryMsg) error {
// streaming search
var err2 error
for _, channel := range q.collection.getVChannels() {
for _, channel := range collection.getVChannels() {
var strSearchResults []*SearchResult
strSearchResults, err2 = q.streaming.search(searchRequests, q.collection.id, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
strSearchResults, err2 = q.streaming.search(searchRequests, collection.id, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
if err2 != nil {
log.Warn(err2.Error())
return err2
@ -870,14 +921,14 @@ func (q *queryCollection) search(msg queryMsg) error {
SlicedOffset: 1,
SlicedNumCount: 1,
SealedSegmentIDsSearched: sealedSegmentSearched,
ChannelIDsSearched: q.collection.getVChannels(),
ChannelIDsSearched: collection.getVChannels(),
GlobalSealedSegmentIDs: globalSealedSegments,
},
}
log.Debug("QueryNode Empty SearchResultMsg",
zap.Any("collectionID", q.collection.id),
zap.Any("collectionID", collection.id),
zap.Any("msgID", searchMsg.ID()),
zap.Any("vChannels", q.collection.getVChannels()),
zap.Any("vChannels", collection.getVChannels()),
zap.Any("sealedSegmentSearched", sealedSegmentSearched),
)
err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID)
@ -962,14 +1013,14 @@ func (q *queryCollection) search(msg queryMsg) error {
SlicedOffset: 1,
SlicedNumCount: 1,
SealedSegmentIDsSearched: sealedSegmentSearched,
ChannelIDsSearched: q.collection.getVChannels(),
ChannelIDsSearched: collection.getVChannels(),
GlobalSealedSegmentIDs: globalSealedSegments,
},
}
log.Debug("QueryNode SearchResultMsg",
zap.Any("collectionID", q.collection.id),
zap.Any("collectionID", collection.id),
zap.Any("msgID", searchMsg.ID()),
zap.Any("vChannels", q.collection.getVChannels()),
zap.Any("vChannels", collection.getVChannels()),
zap.Any("sealedSegmentSearched", sealedSegmentSearched),
)

View File

@ -4,10 +4,8 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"math"
"math/rand"
"reflect"
"testing"
"time"
@ -52,7 +50,7 @@ func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*
return nil, err
}
queryCollection := newQueryCollection(ctx, cancel,
queryCollection, err := newQueryCollection(ctx, cancel,
defaultCollectionID,
historical,
streaming,
@ -60,22 +58,15 @@ func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*
localCM,
remoteCM,
false)
if queryCollection == nil {
return nil, errors.New("nil simple query collection")
}
return queryCollection, nil
return queryCollection, err
}
func updateTSafe(queryCollection *queryCollection, timestamp Timestamp) {
// register
queryCollection.watcherSelectCase = make([]reflect.SelectCase, 0)
queryCollection.tSafeWatchers[defaultVChannel] = newTSafeWatcher()
queryCollection.streaming.tSafeReplica.addTSafe(defaultVChannel)
queryCollection.streaming.tSafeReplica.registerTSafeWatcher(defaultVChannel, queryCollection.tSafeWatchers[defaultVChannel])
queryCollection.watcherSelectCase = append(queryCollection.watcherSelectCase, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(queryCollection.tSafeWatchers[defaultVChannel].watcherChan()),
})
queryCollection.addTSafeWatcher(defaultVChannel)
queryCollection.streaming.tSafeReplica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
}
@ -125,7 +116,8 @@ func TestQueryCollection_withoutVChannel(t *testing.T) {
assert.Nil(t, err)
ctx, cancel := context.WithCancel(context.Background())
queryCollection := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil, false)
queryCollection, err := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil, false)
assert.NoError(t, err)
producerChannels := []string{"testResultChannel"}
queryCollection.queryResultMsgStream.AsProducer(producerChannels)
@ -484,6 +476,15 @@ func TestQueryCollection_serviceableTime(t *testing.T) {
assert.Equal(t, st+gracefulTime, resST)
}
func TestQueryCollection_addTSafeWatcher(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
assert.NoError(t, err)
queryCollection.addTSafeWatcher(defaultVChannel)
}
func TestQueryCollection_waitNewTSafe(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
@ -493,7 +494,8 @@ func TestQueryCollection_waitNewTSafe(t *testing.T) {
timestamp := Timestamp(1000)
updateTSafe(queryCollection, timestamp)
resTimestamp := queryCollection.waitNewTSafe()
resTimestamp, err := queryCollection.waitNewTSafe()
assert.NoError(t, err)
assert.Equal(t, timestamp, resTimestamp)
}

View File

@ -14,7 +14,10 @@ package querynode
import "C"
import (
"context"
"errors"
"fmt"
"strconv"
"sync"
"go.uber.org/zap"
@ -31,7 +34,8 @@ type queryService struct {
historical *historical
streaming *streaming
queryCollections map[UniqueID]*queryCollection
queryCollectionMu sync.Mutex // guards queryCollections
queryCollections map[UniqueID]*queryCollection
factory msgstream.Factory
@ -94,17 +98,22 @@ func (q *queryService) close() {
for collectionID := range q.queryCollections {
q.stopQueryCollection(collectionID)
}
q.queryCollectionMu.Lock()
q.queryCollections = make(map[UniqueID]*queryCollection)
q.queryCollectionMu.Unlock()
q.cancel()
}
func (q *queryService) addQueryCollection(collectionID UniqueID) {
func (q *queryService) addQueryCollection(collectionID UniqueID) error {
q.queryCollectionMu.Lock()
defer q.queryCollectionMu.Unlock()
if _, ok := q.queryCollections[collectionID]; ok {
log.Warn("query collection already exists", zap.Any("collectionID", collectionID))
return
err := errors.New(fmt.Sprintln("query collection already exists, collectionID = ", collectionID))
return err
}
ctx1, cancel := context.WithCancel(q.ctx)
qc := newQueryCollection(ctx1,
qc, err := newQueryCollection(ctx1,
cancel,
collectionID,
q.historical,
@ -114,15 +123,33 @@ func (q *queryService) addQueryCollection(collectionID UniqueID) {
q.remoteChunkManager,
q.localCacheEnabled,
)
if err != nil {
return err
}
q.queryCollections[collectionID] = qc
return nil
}
func (q *queryService) hasQueryCollection(collectionID UniqueID) bool {
q.queryCollectionMu.Lock()
defer q.queryCollectionMu.Unlock()
_, ok := q.queryCollections[collectionID]
return ok
}
func (q *queryService) getQueryCollection(collectionID UniqueID) (*queryCollection, error) {
q.queryCollectionMu.Lock()
defer q.queryCollectionMu.Unlock()
_, ok := q.queryCollections[collectionID]
if ok {
return q.queryCollections[collectionID], nil
}
return nil, errors.New(fmt.Sprintln("queryCollection not exists, collectionID = ", collectionID))
}
func (q *queryService) stopQueryCollection(collectionID UniqueID) {
q.queryCollectionMu.Lock()
defer q.queryCollectionMu.Unlock()
sc, ok := q.queryCollections[collectionID]
if !ok {
log.Warn("stopQueryCollection failed, collection doesn't exist", zap.Int64("collectionID", collectionID))

View File

@ -154,7 +154,8 @@ func TestSearch_Search(t *testing.T) {
err = loadFields(segment, DIM, N)
assert.NoError(t, err)
node.queryService.addQueryCollection(collectionID)
err = node.queryService.addQueryCollection(collectionID)
assert.Error(t, err)
err = sendSearchRequest(node.queryNodeLoopCtx, DIM)
assert.NoError(t, err)
@ -184,7 +185,8 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
node.historical,
node.streaming,
msFactory)
node.queryService.addQueryCollection(collectionID)
err = node.queryService.addQueryCollection(collectionID)
assert.Error(t, err)
// load segments
err = node.historical.replica.addSegment(segmentID1, defaultPartitionID, collectionID, "", segmentTypeSealed, true)
@ -227,13 +229,16 @@ func TestQueryService_addQueryCollection(t *testing.T) {
qs := newQueryService(ctx, his, str, fac)
assert.NotNil(t, qs)
qs.addQueryCollection(defaultCollectionID)
err = qs.addQueryCollection(defaultCollectionID)
assert.NoError(t, err)
assert.Len(t, qs.queryCollections, 1)
qs.addQueryCollection(defaultCollectionID)
err = qs.addQueryCollection(defaultCollectionID)
assert.Error(t, err)
assert.Len(t, qs.queryCollections, 1)
const invalidCollectionID = 10000
qs.addQueryCollection(invalidCollectionID)
assert.Len(t, qs.queryCollections, 2)
err = qs.addQueryCollection(invalidCollectionID)
assert.Error(t, err)
assert.Len(t, qs.queryCollections, 1)
}

View File

@ -236,6 +236,18 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) error {
log.Debug("query node add collection flow graphs", zap.Any("channels", vChannels))
}
// add tSafe watcher if queryCollection exists
qc, err := w.node.queryService.getQueryCollection(collectionID)
if err == nil {
for _, channel := range vChannels {
err = qc.addTSafeWatcher(channel)
if err != nil {
// tSafe have been exist, not error
log.Warn(err.Error())
}
}
}
// channels as consumer
var nodeFGs map[Channel]*queryNodeFlowGraph
if loadPartition {
@ -467,7 +479,11 @@ func (r *releaseCollectionTask) Execute(ctx context.Context) error {
zap.Any("collectionID", r.req.CollectionID),
zap.Any("vChannel", channel),
)
r.node.streaming.tSafeReplica.removeTSafe(channel)
// no tSafe in tSafeReplica, don't return error
err = r.node.streaming.tSafeReplica.removeTSafe(channel)
if err != nil {
log.Warn(err.Error())
}
}
// remove excludedSegments record
@ -561,7 +577,11 @@ func (r *releasePartitionsTask) Execute(ctx context.Context) error {
zap.Any("partitionID", id),
zap.Any("vChannel", channel),
)
r.node.streaming.tSafeReplica.removeTSafe(channel)
// no tSafe in tSafeReplica, don't return error
err = r.node.streaming.tSafeReplica.removeTSafe(channel)
if err != nil {
log.Warn(err.Error())
}
}
}

View File

@ -16,6 +16,8 @@ import (
"math"
"sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
)
@ -51,6 +53,7 @@ type tSafer interface {
registerTSafeWatcher(t *tSafeWatcher)
start()
close()
removeRecord(partitionID UniqueID)
}
type tSafeMsg struct {
@ -89,7 +92,9 @@ func (ts *tSafe) start() {
for {
select {
case <-ts.ctx.Done():
log.Debug("tSafe context done")
log.Debug("tSafe context done",
zap.Any("channel", ts.channel),
)
return
case m := <-ts.tSafeChan:
ts.tSafeMu.Lock()
@ -116,6 +121,21 @@ func (ts *tSafe) start() {
}()
}
// removeRecord for deleting the old partition which has been released,
// if we don't delete this, tSafe would always be the old partition's timestamp
// (because we set tSafe to the minimum timestamp) from old partition
// flow graph which has been closed and would not update tSafe any more.
// removeRecord should be called when flow graph is been removed.
func (ts *tSafe) removeRecord(partitionID UniqueID) {
ts.tSafeMu.Lock()
defer ts.tSafeMu.Unlock()
log.Debug("remove tSafeRecord",
zap.Any("partitionID", partitionID),
)
delete(ts.tSafeRecord, partitionID)
}
func (ts *tSafe) registerTSafeWatcher(t *tSafeWatcher) {
ts.tSafeMu.Lock()
defer ts.tSafeMu.Unlock()

View File

@ -23,38 +23,48 @@ import (
// TSafeReplicaInterface is the interface wrapper of tSafeReplica
type TSafeReplicaInterface interface {
getTSafe(vChannel Channel) Timestamp
setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp)
getTSafe(vChannel Channel) (Timestamp, error)
setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) error
addTSafe(vChannel Channel)
removeTSafe(vChannel Channel)
registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher)
removeTSafe(vChannel Channel) error
registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) error
removeRecord(vChannel Channel, partitionID UniqueID) error
}
type tSafeRef struct {
tSafer tSafer
ref int
}
type tSafeReplica struct {
mu sync.Mutex // guards tSafes
tSafes map[string]tSafer // map[vChannel]tSafer
mu sync.Mutex // guards tSafes
tSafes map[Channel]*tSafeRef // map[vChannel]tSafeRef
}
func (t *tSafeReplica) getTSafe(vChannel Channel) Timestamp {
func (t *tSafeReplica) getTSafe(vChannel Channel) (Timestamp, error) {
t.mu.Lock()
defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel)
if err != nil {
log.Warn("get tSafe failed", zap.Error(err))
return 0
//log.Warn("get tSafe failed",
// zap.Any("channel", vChannel),
// zap.Error(err),
//)
return 0, err
}
return safer.get()
return safer.get(), nil
}
func (t *tSafeReplica) setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) {
func (t *tSafeReplica) setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) error {
t.mu.Lock()
defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel)
if err != nil {
log.Warn("set tSafe failed", zap.Error(err))
return
//log.Warn("set tSafe failed", zap.Error(err))
return err
}
safer.set(id, timestamp)
return nil
}
func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) {
@ -63,7 +73,7 @@ func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) {
//log.Warn(err.Error())
return nil, err
}
return t.tSafes[vChannel], nil
return t.tSafes[vChannel].tSafer, nil
}
func (t *tSafeReplica) addTSafe(vChannel Channel) {
@ -71,42 +81,74 @@ func (t *tSafeReplica) addTSafe(vChannel Channel) {
defer t.mu.Unlock()
ctx := context.Background()
if _, ok := t.tSafes[vChannel]; !ok {
t.tSafes[vChannel] = newTSafe(ctx, vChannel)
t.tSafes[vChannel].start()
log.Debug("add tSafe done", zap.Any("channel", vChannel))
t.tSafes[vChannel] = &tSafeRef{
tSafer: newTSafe(ctx, vChannel),
ref: 1,
}
t.tSafes[vChannel].tSafer.start()
log.Debug("add tSafe done",
zap.Any("channel", vChannel),
zap.Any("count", t.tSafes[vChannel].ref),
)
} else {
log.Warn("tSafe has been existed", zap.Any("channel", vChannel))
t.tSafes[vChannel].ref++
log.Debug("tSafe has been existed",
zap.Any("channel", vChannel),
zap.Any("count", t.tSafes[vChannel].ref),
)
}
}
func (t *tSafeReplica) removeTSafe(vChannel Channel) {
func (t *tSafeReplica) removeTSafe(vChannel Channel) error {
t.mu.Lock()
defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel)
if err != nil {
return
if _, ok := t.tSafes[vChannel]; !ok {
return errors.New("tSafe not exist, vChannel = " + vChannel)
}
log.Debug("remove tSafe replica",
t.tSafes[vChannel].ref--
log.Debug("reduce tSafe reference count",
zap.Any("vChannel", vChannel),
zap.Any("count", t.tSafes[vChannel].ref),
)
safer.close()
delete(t.tSafes, vChannel)
if t.tSafes[vChannel].ref == 0 {
safer, err := t.getTSaferPrivate(vChannel)
if err != nil {
return err
}
log.Debug("remove tSafe replica",
zap.Any("vChannel", vChannel),
)
safer.close()
delete(t.tSafes, vChannel)
}
return nil
}
func (t *tSafeReplica) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) {
func (t *tSafeReplica) removeRecord(vChannel Channel, partitionID UniqueID) error {
t.mu.Lock()
defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel)
if err != nil {
log.Warn("register tSafe watcher failed", zap.Error(err))
return
return err
}
safer.removeRecord(partitionID)
return nil
}
func (t *tSafeReplica) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) error {
t.mu.Lock()
defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel)
if err != nil {
return err
}
safer.registerTSafeWatcher(watcher)
return nil
}
func newTSafeReplica() TSafeReplicaInterface {
var replica TSafeReplicaInterface = &tSafeReplica{
tSafes: make(map[string]tSafer),
tSafes: make(map[string]*tSafeRef),
}
return replica
}

View File

@ -23,30 +23,39 @@ func TestTSafeReplica_valid(t *testing.T) {
replica.addTSafe(defaultVChannel)
watcher := newTSafeWatcher()
replica.registerTSafeWatcher(defaultVChannel, watcher)
err := replica.registerTSafeWatcher(defaultVChannel, watcher)
assert.NoError(t, err)
timestamp := Timestamp(1000)
replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
err = replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond)
resT := replica.getTSafe(defaultVChannel)
resT, err := replica.getTSafe(defaultVChannel)
assert.NoError(t, err)
assert.Equal(t, timestamp, resT)
replica.removeTSafe(defaultVChannel)
err = replica.removeTSafe(defaultVChannel)
assert.NoError(t, err)
}
func TestTSafeReplica_invalid(t *testing.T) {
replica := newTSafeReplica()
replica.addTSafe(defaultVChannel)
watcher := newTSafeWatcher()
replica.registerTSafeWatcher(defaultVChannel, watcher)
err := replica.registerTSafeWatcher(defaultVChannel, watcher)
assert.NoError(t, err)
timestamp := Timestamp(1000)
replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
err = replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond)
resT := replica.getTSafe(defaultVChannel)
assert.Equal(t, Timestamp(0), resT)
resT, err := replica.getTSafe(defaultVChannel)
assert.NoError(t, err)
assert.Equal(t, timestamp, resT)
replica.removeTSafe(defaultVChannel)
err = replica.removeTSafe(defaultVChannel)
assert.NoError(t, err)
replica.addTSafe(defaultVChannel)
replica.addTSafe(defaultVChannel)