mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 10:59:32 +08:00
Dynamic add tSafe watcher, use sync.Cond instead of selectCase, add ref count (#8050)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
parent
43432f47d7
commit
52126f2d5a
@ -179,9 +179,15 @@ func (dsService *dataSyncService) removePartitionFlowGraph(partitionID UniqueID)
|
||||
defer dsService.mu.Unlock()
|
||||
|
||||
if _, ok := dsService.partitionFlowGraphs[partitionID]; ok {
|
||||
for _, nodeFG := range dsService.partitionFlowGraphs[partitionID] {
|
||||
for channel, nodeFG := range dsService.partitionFlowGraphs[partitionID] {
|
||||
// close flow graph
|
||||
nodeFG.close()
|
||||
// remove tSafe record
|
||||
// no tSafe in tSafeReplica, don't return error
|
||||
err := dsService.tSafeReplica.removeRecord(channel, partitionID)
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
}
|
||||
}
|
||||
dsService.partitionFlowGraphs[partitionID] = nil
|
||||
}
|
||||
|
@ -213,3 +213,25 @@ func TestDataSyncService_partitionFlowGraphs(t *testing.T) {
|
||||
assert.Nil(t, fg)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDataSyncService_removePartitionFlowGraphs(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
t.Run("test no tSafe", func(t *testing.T) {
|
||||
streaming, err := genSimpleStreaming(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fac, err := genFactory()
|
||||
assert.NoError(t, err)
|
||||
|
||||
dataSyncService := newDataSyncService(ctx, streaming.replica, streaming.tSafeReplica, fac)
|
||||
assert.NotNil(t, dataSyncService)
|
||||
|
||||
dataSyncService.addPartitionFlowGraph(defaultPartitionID, defaultPartitionID, []Channel{defaultVChannel})
|
||||
|
||||
err = dataSyncService.tSafeReplica.removeTSafe(defaultVChannel)
|
||||
assert.NoError(t, err)
|
||||
dataSyncService.removePartitionFlowGraph(defaultPartitionID)
|
||||
})
|
||||
}
|
||||
|
@ -65,7 +65,10 @@ func (stNode *serviceTimeNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
|
||||
} else {
|
||||
id = stNode.collectionID
|
||||
}
|
||||
stNode.tSafeReplica.setTSafe(stNode.vChannel, id, serviceTimeMsg.timeRange.timestampMax)
|
||||
err := stNode.tSafeReplica.setTSafe(stNode.vChannel, id, serviceTimeMsg.timeRange.timestampMax)
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
}
|
||||
//log.Debug("update tSafe:",
|
||||
// zap.Int64("tSafe", int64(serviceTimeMsg.timeRange.timestampMax)),
|
||||
// zap.Any("collectionID", stNode.collectionID),
|
||||
|
@ -77,4 +77,18 @@ func TestServiceTimeNode_Operate(t *testing.T) {
|
||||
in := []flowgraph.Msg{msg, msg}
|
||||
node.Operate(in)
|
||||
})
|
||||
|
||||
t.Run("test no tSafe", func(t *testing.T) {
|
||||
node := genServiceTimeNode()
|
||||
err := node.tSafeReplica.removeTSafe(defaultVChannel)
|
||||
assert.NoError(t, err)
|
||||
msg := &serviceTimeMsg{
|
||||
timeRange: TimeRange{
|
||||
timestampMin: 0,
|
||||
timestampMax: 1000,
|
||||
},
|
||||
}
|
||||
in := []flowgraph.Msg{msg, msg}
|
||||
node.Operate(in)
|
||||
})
|
||||
}
|
||||
|
@ -106,12 +106,26 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
|
||||
|
||||
// add search collection
|
||||
if !node.queryService.hasQueryCollection(collectionID) {
|
||||
node.queryService.addQueryCollection(collectionID)
|
||||
err := node.queryService.addQueryCollection(collectionID)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
log.Debug("add query collection", zap.Any("collectionID", collectionID))
|
||||
}
|
||||
|
||||
// add request channel
|
||||
sc := node.queryService.queryCollections[in.CollectionID]
|
||||
sc, err := node.queryService.getQueryCollection(in.CollectionID)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
consumeChannels := []string{in.RequestChannelID}
|
||||
//consumeSubName := Params.MsgChannelSubName
|
||||
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())
|
||||
|
@ -17,13 +17,14 @@ import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestImpl_GetComponentStates(t *testing.T) {
|
||||
@ -108,6 +109,26 @@ func TestImpl_AddQueryChannel(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("test add query collection failed", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = node.streaming.replica.removeCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
}
|
||||
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestImpl_RemoveQueryChannel(t *testing.T) {
|
||||
|
@ -14,9 +14,9 @@ package querynode
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
@ -50,15 +50,16 @@ type queryCollection struct {
|
||||
cancel context.CancelFunc
|
||||
|
||||
collectionID UniqueID
|
||||
collection *Collection
|
||||
historical *historical
|
||||
streaming *streaming
|
||||
|
||||
unsolvedMsgMu sync.Mutex // guards unsolvedMsg
|
||||
unsolvedMsg []queryMsg
|
||||
|
||||
tSafeWatchers map[Channel]*tSafeWatcher
|
||||
watcherSelectCase []reflect.SelectCase
|
||||
tSafeWatchersMu sync.Mutex // guards tSafeWatchers
|
||||
tSafeWatchers map[Channel]*tSafeWatcher
|
||||
tSafeUpdate bool
|
||||
watcherCond *sync.Cond
|
||||
|
||||
serviceableTimeMutex sync.Mutex // guards serviceableTime
|
||||
serviceableTime Timestamp
|
||||
@ -83,25 +84,26 @@ func newQueryCollection(releaseCtx context.Context,
|
||||
localChunkManager storage.ChunkManager,
|
||||
remoteChunkManager storage.ChunkManager,
|
||||
localCacheEnabled bool,
|
||||
) *queryCollection {
|
||||
) (*queryCollection, error) {
|
||||
|
||||
unsolvedMsg := make([]queryMsg, 0)
|
||||
|
||||
queryStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
|
||||
collection, _ := streaming.replica.getCollectionByID(collectionID)
|
||||
condMu := sync.Mutex{}
|
||||
|
||||
qc := &queryCollection{
|
||||
releaseCtx: releaseCtx,
|
||||
cancel: cancel,
|
||||
|
||||
collectionID: collectionID,
|
||||
collection: collection,
|
||||
historical: historical,
|
||||
streaming: streaming,
|
||||
|
||||
tSafeWatchers: make(map[Channel]*tSafeWatcher),
|
||||
tSafeUpdate: false,
|
||||
watcherCond: sync.NewCond(&condMu),
|
||||
|
||||
unsolvedMsg: unsolvedMsg,
|
||||
|
||||
@ -113,8 +115,11 @@ func newQueryCollection(releaseCtx context.Context,
|
||||
localCacheEnabled: localCacheEnabled,
|
||||
}
|
||||
|
||||
qc.register()
|
||||
return qc
|
||||
err := qc.registerCollectionTSafe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return qc, nil
|
||||
}
|
||||
|
||||
func (q *queryCollection) start() {
|
||||
@ -133,26 +138,61 @@ func (q *queryCollection) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queryCollection) register() {
|
||||
// registerCollectionTSafe registers tSafe watcher if vChannels exists
|
||||
func (q *queryCollection) registerCollectionTSafe() error {
|
||||
collection, err := q.streaming.replica.getCollectionByID(q.collectionID)
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
//TODO:: can't add new vChannel to selectCase
|
||||
q.watcherSelectCase = make([]reflect.SelectCase, 0)
|
||||
log.Debug("register tSafe watcher and init watcher select case",
|
||||
zap.Any("collectionID", collection.ID()),
|
||||
zap.Any("dml channels", collection.getVChannels()),
|
||||
)
|
||||
for _, channel := range collection.getVChannels() {
|
||||
q.tSafeWatchers[channel] = newTSafeWatcher()
|
||||
q.streaming.tSafeReplica.registerTSafeWatcher(channel, q.tSafeWatchers[channel])
|
||||
q.watcherSelectCase = append(q.watcherSelectCase, reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(q.tSafeWatchers[channel].watcherChan()),
|
||||
})
|
||||
err = q.addTSafeWatcher(channel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *queryCollection) addTSafeWatcher(vChannel Channel) error {
|
||||
q.tSafeWatchersMu.Lock()
|
||||
defer q.tSafeWatchersMu.Unlock()
|
||||
if _, ok := q.tSafeWatchers[vChannel]; ok {
|
||||
err := errors.New(fmt.Sprintln("tSafeWatcher of queryCollection has been exists, ",
|
||||
"collectionID = ", q.collectionID, ", ",
|
||||
"channel = ", vChannel))
|
||||
return err
|
||||
}
|
||||
q.tSafeWatchers[vChannel] = newTSafeWatcher()
|
||||
err := q.streaming.tSafeReplica.registerTSafeWatcher(vChannel, q.tSafeWatchers[vChannel])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug("add tSafeWatcher to queryCollection",
|
||||
zap.Any("collectionID", q.collectionID),
|
||||
zap.Any("channel", vChannel),
|
||||
)
|
||||
go q.startWatcher(q.tSafeWatchers[vChannel].watcherChan())
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: add stopWatcher(), add close() to tSafeWatcher
|
||||
func (q *queryCollection) startWatcher(channel <-chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-q.releaseCtx.Done():
|
||||
return
|
||||
case <-channel:
|
||||
// TODO: check if channel is closed
|
||||
q.watcherCond.L.Lock()
|
||||
q.tSafeUpdate = true
|
||||
q.watcherCond.Broadcast()
|
||||
q.watcherCond.L.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,22 +211,24 @@ func (q *queryCollection) popAllUnsolvedMsg() []queryMsg {
|
||||
return ret
|
||||
}
|
||||
|
||||
func (q *queryCollection) waitNewTSafe() Timestamp {
|
||||
// block until any vChannel updating tSafe
|
||||
_, _, recvOK := reflect.Select(q.watcherSelectCase)
|
||||
if !recvOK {
|
||||
//log.Warn("tSafe has been closed", zap.Any("collectionID", q.collectionID))
|
||||
return Timestamp(math.MaxInt64)
|
||||
func (q *queryCollection) waitNewTSafe() (Timestamp, error) {
|
||||
q.watcherCond.L.Lock()
|
||||
for !q.tSafeUpdate {
|
||||
q.watcherCond.Wait()
|
||||
}
|
||||
q.watcherCond.L.Unlock()
|
||||
//log.Debug("wait new tSafe", zap.Any("collectionID", s.collectionID))
|
||||
t := Timestamp(math.MaxInt64)
|
||||
for channel := range q.tSafeWatchers {
|
||||
ts := q.streaming.tSafeReplica.getTSafe(channel)
|
||||
ts, err := q.streaming.tSafeReplica.getTSafe(channel)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if ts <= t {
|
||||
t = ts
|
||||
}
|
||||
}
|
||||
return t
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (q *queryCollection) getServiceableTime() Timestamp {
|
||||
@ -397,7 +439,11 @@ func (q *queryCollection) doUnsolvedQueryMsg() {
|
||||
return
|
||||
default:
|
||||
//time.Sleep(10 * time.Millisecond)
|
||||
serviceTime := q.waitNewTSafe()
|
||||
serviceTime, err := q.waitNewTSafe()
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
//st, _ := tsoutil.ParseTS(serviceTime)
|
||||
//log.Debug("get tSafe from flow graph",
|
||||
// zap.Int64("collectionID", q.collectionID),
|
||||
@ -769,7 +815,12 @@ func (q *queryCollection) search(msg queryMsg) error {
|
||||
searchTimestamp := searchMsg.BeginTs()
|
||||
travelTimestamp := searchMsg.TravelTimestamp
|
||||
|
||||
schema, err := typeutil.CreateSchemaHelper(q.collection.schema)
|
||||
collection, err := q.streaming.replica.getCollectionByID(searchMsg.CollectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
schema, err := typeutil.CreateSchemaHelper(collection.schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -777,13 +828,13 @@ func (q *queryCollection) search(msg queryMsg) error {
|
||||
var plan *SearchPlan
|
||||
if searchMsg.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
expr := searchMsg.SerializedExprPlan
|
||||
plan, err = createSearchPlanByExpr(q.collection, expr)
|
||||
plan, err = createSearchPlanByExpr(collection, expr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
dsl := searchMsg.Dsl
|
||||
plan, err = createSearchPlan(q.collection, dsl)
|
||||
plan, err = createSearchPlan(collection, dsl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -821,13 +872,13 @@ func (q *queryCollection) search(msg queryMsg) error {
|
||||
if len(searchMsg.PartitionIDs) > 0 {
|
||||
globalSealedSegments = q.historical.getGlobalSegmentIDsByPartitionIds(searchMsg.PartitionIDs)
|
||||
} else {
|
||||
globalSealedSegments = q.historical.getGlobalSegmentIDsByCollectionID(q.collection.id)
|
||||
globalSealedSegments = q.historical.getGlobalSegmentIDsByCollectionID(collection.id)
|
||||
}
|
||||
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
|
||||
// historical search
|
||||
hisSearchResults, sealedSegmentSearched, err1 := q.historical.search(searchRequests, q.collection.id, searchMsg.PartitionIDs, plan, travelTimestamp)
|
||||
hisSearchResults, sealedSegmentSearched, err1 := q.historical.search(searchRequests, collection.id, searchMsg.PartitionIDs, plan, travelTimestamp)
|
||||
if err1 != nil {
|
||||
log.Warn(err1.Error())
|
||||
return err1
|
||||
@ -837,9 +888,9 @@ func (q *queryCollection) search(msg queryMsg) error {
|
||||
|
||||
// streaming search
|
||||
var err2 error
|
||||
for _, channel := range q.collection.getVChannels() {
|
||||
for _, channel := range collection.getVChannels() {
|
||||
var strSearchResults []*SearchResult
|
||||
strSearchResults, err2 = q.streaming.search(searchRequests, q.collection.id, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
|
||||
strSearchResults, err2 = q.streaming.search(searchRequests, collection.id, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
|
||||
if err2 != nil {
|
||||
log.Warn(err2.Error())
|
||||
return err2
|
||||
@ -870,14 +921,14 @@ func (q *queryCollection) search(msg queryMsg) error {
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
SealedSegmentIDsSearched: sealedSegmentSearched,
|
||||
ChannelIDsSearched: q.collection.getVChannels(),
|
||||
ChannelIDsSearched: collection.getVChannels(),
|
||||
GlobalSealedSegmentIDs: globalSealedSegments,
|
||||
},
|
||||
}
|
||||
log.Debug("QueryNode Empty SearchResultMsg",
|
||||
zap.Any("collectionID", q.collection.id),
|
||||
zap.Any("collectionID", collection.id),
|
||||
zap.Any("msgID", searchMsg.ID()),
|
||||
zap.Any("vChannels", q.collection.getVChannels()),
|
||||
zap.Any("vChannels", collection.getVChannels()),
|
||||
zap.Any("sealedSegmentSearched", sealedSegmentSearched),
|
||||
)
|
||||
err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID)
|
||||
@ -962,14 +1013,14 @@ func (q *queryCollection) search(msg queryMsg) error {
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
SealedSegmentIDsSearched: sealedSegmentSearched,
|
||||
ChannelIDsSearched: q.collection.getVChannels(),
|
||||
ChannelIDsSearched: collection.getVChannels(),
|
||||
GlobalSealedSegmentIDs: globalSealedSegments,
|
||||
},
|
||||
}
|
||||
log.Debug("QueryNode SearchResultMsg",
|
||||
zap.Any("collectionID", q.collection.id),
|
||||
zap.Any("collectionID", collection.id),
|
||||
zap.Any("msgID", searchMsg.ID()),
|
||||
zap.Any("vChannels", q.collection.getVChannels()),
|
||||
zap.Any("vChannels", collection.getVChannels()),
|
||||
zap.Any("sealedSegmentSearched", sealedSegmentSearched),
|
||||
)
|
||||
|
||||
|
@ -4,10 +4,8 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -52,7 +50,7 @@ func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryCollection := newQueryCollection(ctx, cancel,
|
||||
queryCollection, err := newQueryCollection(ctx, cancel,
|
||||
defaultCollectionID,
|
||||
historical,
|
||||
streaming,
|
||||
@ -60,22 +58,15 @@ func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*
|
||||
localCM,
|
||||
remoteCM,
|
||||
false)
|
||||
if queryCollection == nil {
|
||||
return nil, errors.New("nil simple query collection")
|
||||
}
|
||||
return queryCollection, nil
|
||||
return queryCollection, err
|
||||
}
|
||||
|
||||
func updateTSafe(queryCollection *queryCollection, timestamp Timestamp) {
|
||||
// register
|
||||
queryCollection.watcherSelectCase = make([]reflect.SelectCase, 0)
|
||||
queryCollection.tSafeWatchers[defaultVChannel] = newTSafeWatcher()
|
||||
queryCollection.streaming.tSafeReplica.addTSafe(defaultVChannel)
|
||||
queryCollection.streaming.tSafeReplica.registerTSafeWatcher(defaultVChannel, queryCollection.tSafeWatchers[defaultVChannel])
|
||||
queryCollection.watcherSelectCase = append(queryCollection.watcherSelectCase, reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(queryCollection.tSafeWatchers[defaultVChannel].watcherChan()),
|
||||
})
|
||||
queryCollection.addTSafeWatcher(defaultVChannel)
|
||||
|
||||
queryCollection.streaming.tSafeReplica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
|
||||
}
|
||||
@ -125,7 +116,8 @@ func TestQueryCollection_withoutVChannel(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
queryCollection := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil, false)
|
||||
queryCollection, err := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
producerChannels := []string{"testResultChannel"}
|
||||
queryCollection.queryResultMsgStream.AsProducer(producerChannels)
|
||||
@ -484,6 +476,15 @@ func TestQueryCollection_serviceableTime(t *testing.T) {
|
||||
assert.Equal(t, st+gracefulTime, resST)
|
||||
}
|
||||
|
||||
func TestQueryCollection_addTSafeWatcher(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
||||
assert.NoError(t, err)
|
||||
|
||||
queryCollection.addTSafeWatcher(defaultVChannel)
|
||||
}
|
||||
|
||||
func TestQueryCollection_waitNewTSafe(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
@ -493,7 +494,8 @@ func TestQueryCollection_waitNewTSafe(t *testing.T) {
|
||||
timestamp := Timestamp(1000)
|
||||
updateTSafe(queryCollection, timestamp)
|
||||
|
||||
resTimestamp := queryCollection.waitNewTSafe()
|
||||
resTimestamp, err := queryCollection.waitNewTSafe()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, timestamp, resTimestamp)
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,10 @@ package querynode
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
@ -31,7 +34,8 @@ type queryService struct {
|
||||
historical *historical
|
||||
streaming *streaming
|
||||
|
||||
queryCollections map[UniqueID]*queryCollection
|
||||
queryCollectionMu sync.Mutex // guards queryCollections
|
||||
queryCollections map[UniqueID]*queryCollection
|
||||
|
||||
factory msgstream.Factory
|
||||
|
||||
@ -94,17 +98,22 @@ func (q *queryService) close() {
|
||||
for collectionID := range q.queryCollections {
|
||||
q.stopQueryCollection(collectionID)
|
||||
}
|
||||
q.queryCollectionMu.Lock()
|
||||
q.queryCollections = make(map[UniqueID]*queryCollection)
|
||||
q.queryCollectionMu.Unlock()
|
||||
q.cancel()
|
||||
}
|
||||
|
||||
func (q *queryService) addQueryCollection(collectionID UniqueID) {
|
||||
func (q *queryService) addQueryCollection(collectionID UniqueID) error {
|
||||
q.queryCollectionMu.Lock()
|
||||
defer q.queryCollectionMu.Unlock()
|
||||
if _, ok := q.queryCollections[collectionID]; ok {
|
||||
log.Warn("query collection already exists", zap.Any("collectionID", collectionID))
|
||||
return
|
||||
err := errors.New(fmt.Sprintln("query collection already exists, collectionID = ", collectionID))
|
||||
return err
|
||||
}
|
||||
ctx1, cancel := context.WithCancel(q.ctx)
|
||||
qc := newQueryCollection(ctx1,
|
||||
qc, err := newQueryCollection(ctx1,
|
||||
cancel,
|
||||
collectionID,
|
||||
q.historical,
|
||||
@ -114,15 +123,33 @@ func (q *queryService) addQueryCollection(collectionID UniqueID) {
|
||||
q.remoteChunkManager,
|
||||
q.localCacheEnabled,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q.queryCollections[collectionID] = qc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *queryService) hasQueryCollection(collectionID UniqueID) bool {
|
||||
q.queryCollectionMu.Lock()
|
||||
defer q.queryCollectionMu.Unlock()
|
||||
_, ok := q.queryCollections[collectionID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (q *queryService) getQueryCollection(collectionID UniqueID) (*queryCollection, error) {
|
||||
q.queryCollectionMu.Lock()
|
||||
defer q.queryCollectionMu.Unlock()
|
||||
_, ok := q.queryCollections[collectionID]
|
||||
if ok {
|
||||
return q.queryCollections[collectionID], nil
|
||||
}
|
||||
return nil, errors.New(fmt.Sprintln("queryCollection not exists, collectionID = ", collectionID))
|
||||
}
|
||||
|
||||
func (q *queryService) stopQueryCollection(collectionID UniqueID) {
|
||||
q.queryCollectionMu.Lock()
|
||||
defer q.queryCollectionMu.Unlock()
|
||||
sc, ok := q.queryCollections[collectionID]
|
||||
if !ok {
|
||||
log.Warn("stopQueryCollection failed, collection doesn't exist", zap.Int64("collectionID", collectionID))
|
||||
|
@ -154,7 +154,8 @@ func TestSearch_Search(t *testing.T) {
|
||||
err = loadFields(segment, DIM, N)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.queryService.addQueryCollection(collectionID)
|
||||
err = node.queryService.addQueryCollection(collectionID)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = sendSearchRequest(node.queryNodeLoopCtx, DIM)
|
||||
assert.NoError(t, err)
|
||||
@ -184,7 +185,8 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
|
||||
node.historical,
|
||||
node.streaming,
|
||||
msFactory)
|
||||
node.queryService.addQueryCollection(collectionID)
|
||||
err = node.queryService.addQueryCollection(collectionID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// load segments
|
||||
err = node.historical.replica.addSegment(segmentID1, defaultPartitionID, collectionID, "", segmentTypeSealed, true)
|
||||
@ -227,13 +229,16 @@ func TestQueryService_addQueryCollection(t *testing.T) {
|
||||
qs := newQueryService(ctx, his, str, fac)
|
||||
assert.NotNil(t, qs)
|
||||
|
||||
qs.addQueryCollection(defaultCollectionID)
|
||||
err = qs.addQueryCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, qs.queryCollections, 1)
|
||||
|
||||
qs.addQueryCollection(defaultCollectionID)
|
||||
err = qs.addQueryCollection(defaultCollectionID)
|
||||
assert.Error(t, err)
|
||||
assert.Len(t, qs.queryCollections, 1)
|
||||
|
||||
const invalidCollectionID = 10000
|
||||
qs.addQueryCollection(invalidCollectionID)
|
||||
assert.Len(t, qs.queryCollections, 2)
|
||||
err = qs.addQueryCollection(invalidCollectionID)
|
||||
assert.Error(t, err)
|
||||
assert.Len(t, qs.queryCollections, 1)
|
||||
}
|
||||
|
@ -236,6 +236,18 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) error {
|
||||
log.Debug("query node add collection flow graphs", zap.Any("channels", vChannels))
|
||||
}
|
||||
|
||||
// add tSafe watcher if queryCollection exists
|
||||
qc, err := w.node.queryService.getQueryCollection(collectionID)
|
||||
if err == nil {
|
||||
for _, channel := range vChannels {
|
||||
err = qc.addTSafeWatcher(channel)
|
||||
if err != nil {
|
||||
// tSafe have been exist, not error
|
||||
log.Warn(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// channels as consumer
|
||||
var nodeFGs map[Channel]*queryNodeFlowGraph
|
||||
if loadPartition {
|
||||
@ -467,7 +479,11 @@ func (r *releaseCollectionTask) Execute(ctx context.Context) error {
|
||||
zap.Any("collectionID", r.req.CollectionID),
|
||||
zap.Any("vChannel", channel),
|
||||
)
|
||||
r.node.streaming.tSafeReplica.removeTSafe(channel)
|
||||
// no tSafe in tSafeReplica, don't return error
|
||||
err = r.node.streaming.tSafeReplica.removeTSafe(channel)
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// remove excludedSegments record
|
||||
@ -561,7 +577,11 @@ func (r *releasePartitionsTask) Execute(ctx context.Context) error {
|
||||
zap.Any("partitionID", id),
|
||||
zap.Any("vChannel", channel),
|
||||
)
|
||||
r.node.streaming.tSafeReplica.removeTSafe(channel)
|
||||
// no tSafe in tSafeReplica, don't return error
|
||||
err = r.node.streaming.tSafeReplica.removeTSafe(channel)
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,8 @@ import (
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
)
|
||||
|
||||
@ -51,6 +53,7 @@ type tSafer interface {
|
||||
registerTSafeWatcher(t *tSafeWatcher)
|
||||
start()
|
||||
close()
|
||||
removeRecord(partitionID UniqueID)
|
||||
}
|
||||
|
||||
type tSafeMsg struct {
|
||||
@ -89,7 +92,9 @@ func (ts *tSafe) start() {
|
||||
for {
|
||||
select {
|
||||
case <-ts.ctx.Done():
|
||||
log.Debug("tSafe context done")
|
||||
log.Debug("tSafe context done",
|
||||
zap.Any("channel", ts.channel),
|
||||
)
|
||||
return
|
||||
case m := <-ts.tSafeChan:
|
||||
ts.tSafeMu.Lock()
|
||||
@ -116,6 +121,21 @@ func (ts *tSafe) start() {
|
||||
}()
|
||||
}
|
||||
|
||||
// removeRecord for deleting the old partition which has been released,
|
||||
// if we don't delete this, tSafe would always be the old partition's timestamp
|
||||
// (because we set tSafe to the minimum timestamp) from old partition
|
||||
// flow graph which has been closed and would not update tSafe any more.
|
||||
// removeRecord should be called when flow graph is been removed.
|
||||
func (ts *tSafe) removeRecord(partitionID UniqueID) {
|
||||
ts.tSafeMu.Lock()
|
||||
defer ts.tSafeMu.Unlock()
|
||||
|
||||
log.Debug("remove tSafeRecord",
|
||||
zap.Any("partitionID", partitionID),
|
||||
)
|
||||
delete(ts.tSafeRecord, partitionID)
|
||||
}
|
||||
|
||||
func (ts *tSafe) registerTSafeWatcher(t *tSafeWatcher) {
|
||||
ts.tSafeMu.Lock()
|
||||
defer ts.tSafeMu.Unlock()
|
||||
|
@ -23,38 +23,48 @@ import (
|
||||
|
||||
// TSafeReplicaInterface is the interface wrapper of tSafeReplica
|
||||
type TSafeReplicaInterface interface {
|
||||
getTSafe(vChannel Channel) Timestamp
|
||||
setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp)
|
||||
getTSafe(vChannel Channel) (Timestamp, error)
|
||||
setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) error
|
||||
addTSafe(vChannel Channel)
|
||||
removeTSafe(vChannel Channel)
|
||||
registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher)
|
||||
removeTSafe(vChannel Channel) error
|
||||
registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) error
|
||||
removeRecord(vChannel Channel, partitionID UniqueID) error
|
||||
}
|
||||
|
||||
type tSafeRef struct {
|
||||
tSafer tSafer
|
||||
ref int
|
||||
}
|
||||
|
||||
type tSafeReplica struct {
|
||||
mu sync.Mutex // guards tSafes
|
||||
tSafes map[string]tSafer // map[vChannel]tSafer
|
||||
mu sync.Mutex // guards tSafes
|
||||
tSafes map[Channel]*tSafeRef // map[vChannel]tSafeRef
|
||||
}
|
||||
|
||||
func (t *tSafeReplica) getTSafe(vChannel Channel) Timestamp {
|
||||
func (t *tSafeReplica) getTSafe(vChannel Channel) (Timestamp, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
safer, err := t.getTSaferPrivate(vChannel)
|
||||
if err != nil {
|
||||
log.Warn("get tSafe failed", zap.Error(err))
|
||||
return 0
|
||||
//log.Warn("get tSafe failed",
|
||||
// zap.Any("channel", vChannel),
|
||||
// zap.Error(err),
|
||||
//)
|
||||
return 0, err
|
||||
}
|
||||
return safer.get()
|
||||
return safer.get(), nil
|
||||
}
|
||||
|
||||
func (t *tSafeReplica) setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) {
|
||||
func (t *tSafeReplica) setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
safer, err := t.getTSaferPrivate(vChannel)
|
||||
if err != nil {
|
||||
log.Warn("set tSafe failed", zap.Error(err))
|
||||
return
|
||||
//log.Warn("set tSafe failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
safer.set(id, timestamp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) {
|
||||
@ -63,7 +73,7 @@ func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) {
|
||||
//log.Warn(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return t.tSafes[vChannel], nil
|
||||
return t.tSafes[vChannel].tSafer, nil
|
||||
}
|
||||
|
||||
func (t *tSafeReplica) addTSafe(vChannel Channel) {
|
||||
@ -71,42 +81,74 @@ func (t *tSafeReplica) addTSafe(vChannel Channel) {
|
||||
defer t.mu.Unlock()
|
||||
ctx := context.Background()
|
||||
if _, ok := t.tSafes[vChannel]; !ok {
|
||||
t.tSafes[vChannel] = newTSafe(ctx, vChannel)
|
||||
t.tSafes[vChannel].start()
|
||||
log.Debug("add tSafe done", zap.Any("channel", vChannel))
|
||||
t.tSafes[vChannel] = &tSafeRef{
|
||||
tSafer: newTSafe(ctx, vChannel),
|
||||
ref: 1,
|
||||
}
|
||||
t.tSafes[vChannel].tSafer.start()
|
||||
log.Debug("add tSafe done",
|
||||
zap.Any("channel", vChannel),
|
||||
zap.Any("count", t.tSafes[vChannel].ref),
|
||||
)
|
||||
} else {
|
||||
log.Warn("tSafe has been existed", zap.Any("channel", vChannel))
|
||||
t.tSafes[vChannel].ref++
|
||||
log.Debug("tSafe has been existed",
|
||||
zap.Any("channel", vChannel),
|
||||
zap.Any("count", t.tSafes[vChannel].ref),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tSafeReplica) removeTSafe(vChannel Channel) {
|
||||
func (t *tSafeReplica) removeTSafe(vChannel Channel) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
safer, err := t.getTSaferPrivate(vChannel)
|
||||
if err != nil {
|
||||
return
|
||||
if _, ok := t.tSafes[vChannel]; !ok {
|
||||
return errors.New("tSafe not exist, vChannel = " + vChannel)
|
||||
}
|
||||
log.Debug("remove tSafe replica",
|
||||
t.tSafes[vChannel].ref--
|
||||
log.Debug("reduce tSafe reference count",
|
||||
zap.Any("vChannel", vChannel),
|
||||
zap.Any("count", t.tSafes[vChannel].ref),
|
||||
)
|
||||
safer.close()
|
||||
delete(t.tSafes, vChannel)
|
||||
if t.tSafes[vChannel].ref == 0 {
|
||||
safer, err := t.getTSaferPrivate(vChannel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug("remove tSafe replica",
|
||||
zap.Any("vChannel", vChannel),
|
||||
)
|
||||
safer.close()
|
||||
delete(t.tSafes, vChannel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tSafeReplica) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) {
|
||||
func (t *tSafeReplica) removeRecord(vChannel Channel, partitionID UniqueID) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
safer, err := t.getTSaferPrivate(vChannel)
|
||||
if err != nil {
|
||||
log.Warn("register tSafe watcher failed", zap.Error(err))
|
||||
return
|
||||
return err
|
||||
}
|
||||
safer.removeRecord(partitionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tSafeReplica) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
safer, err := t.getTSaferPrivate(vChannel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
safer.registerTSafeWatcher(watcher)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTSafeReplica() TSafeReplicaInterface {
|
||||
var replica TSafeReplicaInterface = &tSafeReplica{
|
||||
tSafes: make(map[string]tSafer),
|
||||
tSafes: make(map[string]*tSafeRef),
|
||||
}
|
||||
return replica
|
||||
}
|
||||
|
@ -23,30 +23,39 @@ func TestTSafeReplica_valid(t *testing.T) {
|
||||
replica.addTSafe(defaultVChannel)
|
||||
|
||||
watcher := newTSafeWatcher()
|
||||
replica.registerTSafeWatcher(defaultVChannel, watcher)
|
||||
err := replica.registerTSafeWatcher(defaultVChannel, watcher)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timestamp := Timestamp(1000)
|
||||
replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
|
||||
err = replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
resT := replica.getTSafe(defaultVChannel)
|
||||
resT, err := replica.getTSafe(defaultVChannel)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, timestamp, resT)
|
||||
|
||||
replica.removeTSafe(defaultVChannel)
|
||||
err = replica.removeTSafe(defaultVChannel)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTSafeReplica_invalid(t *testing.T) {
|
||||
replica := newTSafeReplica()
|
||||
replica.addTSafe(defaultVChannel)
|
||||
|
||||
watcher := newTSafeWatcher()
|
||||
replica.registerTSafeWatcher(defaultVChannel, watcher)
|
||||
err := replica.registerTSafeWatcher(defaultVChannel, watcher)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timestamp := Timestamp(1000)
|
||||
replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
|
||||
err = replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
resT := replica.getTSafe(defaultVChannel)
|
||||
assert.Equal(t, Timestamp(0), resT)
|
||||
resT, err := replica.getTSafe(defaultVChannel)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, timestamp, resT)
|
||||
|
||||
replica.removeTSafe(defaultVChannel)
|
||||
err = replica.removeTSafe(defaultVChannel)
|
||||
assert.NoError(t, err)
|
||||
|
||||
replica.addTSafe(defaultVChannel)
|
||||
replica.addTSafe(defaultVChannel)
|
||||
|
Loading…
Reference in New Issue
Block a user