Add max number of retries for interTask in querycoord (#8215)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2021-10-11 09:54:37 +08:00 committed by GitHub
parent bd3a8ed3cf
commit 235d736a49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1737 additions and 804 deletions

View File

@ -374,6 +374,7 @@ func (c *queryNodeCluster) releasePartitions(ctx context.Context, nodeID int64,
log.Debug("ReleasePartitions: queryNode release partitions error", zap.String("error", err.Error()))
return err
}
for _, partitionID := range in.PartitionIDs {
err = c.clusterMeta.releasePartition(in.CollectionID, partitionID)
if err != nil {

View File

@ -141,21 +141,23 @@ func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadColle
return status, err
}
baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_grpcRequest)
loadCollectionTask := &LoadCollectionTask{
BaseTask: BaseTask{
ctx: qc.loopCtx,
Condition: NewTaskCondition(qc.loopCtx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
LoadCollectionRequest: req,
rootCoord: qc.rootCoordClient,
dataCoord: qc.dataCoordClient,
cluster: qc.cluster,
meta: qc.meta,
}
qc.scheduler.Enqueue([]task{loadCollectionTask})
err := qc.scheduler.Enqueue(loadCollectionTask)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
err := loadCollectionTask.WaitToFinish()
err = loadCollectionTask.WaitToFinish()
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
@ -188,20 +190,22 @@ func (qc *QueryCoord) ReleaseCollection(ctx context.Context, req *querypb.Releas
return status, nil
}
baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_grpcRequest)
releaseCollectionTask := &ReleaseCollectionTask{
BaseTask: BaseTask{
ctx: qc.loopCtx,
Condition: NewTaskCondition(qc.loopCtx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
ReleaseCollectionRequest: req,
cluster: qc.cluster,
meta: qc.meta,
rootCoord: qc.rootCoordClient,
}
qc.scheduler.Enqueue([]task{releaseCollectionTask})
err := qc.scheduler.Enqueue(releaseCollectionTask)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
err := releaseCollectionTask.WaitToFinish()
err = releaseCollectionTask.WaitToFinish()
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
@ -329,20 +333,22 @@ func (qc *QueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadParti
req.PartitionIDs = partitionIDsToLoad
}
baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_grpcRequest)
loadPartitionTask := &LoadPartitionTask{
BaseTask: BaseTask{
ctx: qc.loopCtx,
Condition: NewTaskCondition(qc.loopCtx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
LoadPartitionsRequest: req,
dataCoord: qc.dataCoordClient,
cluster: qc.cluster,
meta: qc.meta,
}
qc.scheduler.Enqueue([]task{loadPartitionTask})
err := qc.scheduler.Enqueue(loadPartitionTask)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
err := loadPartitionTask.WaitToFinish()
err = loadPartitionTask.WaitToFinish()
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
@ -398,18 +404,20 @@ func (qc *QueryCoord) ReleasePartitions(ctx context.Context, req *querypb.Releas
}
req.PartitionIDs = toReleasedPartitions
baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_grpcRequest)
releasePartitionTask := &ReleasePartitionTask{
BaseTask: BaseTask{
ctx: qc.loopCtx,
Condition: NewTaskCondition(qc.loopCtx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
ReleasePartitionsRequest: req,
cluster: qc.cluster,
}
qc.scheduler.Enqueue([]task{releasePartitionTask})
err := qc.scheduler.Enqueue(releasePartitionTask)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
err := releasePartitionTask.WaitToFinish()
err = releasePartitionTask.WaitToFinish()
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()

View File

@ -13,6 +13,7 @@ package querycoord
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
@ -328,6 +329,87 @@ func TestGrpcTask(t *testing.T) {
assert.Nil(t, err)
}
func TestGrpcTaskEnqueueFail(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
_, err = startQueryNodeServer(ctx)
assert.Nil(t, err)
taskIDAllocator := queryCoord.scheduler.taskIDAllocator
failedAllocator := func() (UniqueID, error) {
return 0, errors.New("scheduler failed to allocate ID")
}
queryCoord.scheduler.taskIDAllocator = failedAllocator
t.Run("Test LoadPartition", func(t *testing.T) {
status, err := queryCoord.LoadPartitions(ctx, &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
Schema: genCollectionSchema(defaultCollectionID, false),
})
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
assert.NotNil(t, err)
})
t.Run("Test LoadCollection", func(t *testing.T) {
status, err := queryCoord.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
})
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
assert.NotNil(t, err)
})
queryCoord.scheduler.taskIDAllocator = taskIDAllocator
status, err := queryCoord.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
})
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
assert.Nil(t, err)
queryCoord.scheduler.taskIDAllocator = failedAllocator
t.Run("Test ReleasePartition", func(t *testing.T) {
status, err := queryCoord.ReleasePartitions(ctx, &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
})
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
assert.NotNil(t, err)
})
t.Run("Test ReleaseCollection", func(t *testing.T) {
status, err := queryCoord.ReleaseCollection(ctx, &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseCollection,
},
CollectionID: defaultCollectionID,
})
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
assert.NotNil(t, err)
})
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func TestLoadBalanceTask(t *testing.T) {
refreshParams()
baseCtx := context.Background()
@ -371,7 +453,7 @@ func TestLoadBalanceTask(t *testing.T) {
}
loadBalanceTask := &LoadBalanceTask{
BaseTask: BaseTask{
BaseTask: &BaseTask{
ctx: baseCtx,
Condition: NewTaskCondition(baseCtx),
triggerCondition: querypb.TriggerCondition_nodeDown,
@ -382,7 +464,7 @@ func TestLoadBalanceTask(t *testing.T) {
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
queryCoord.scheduler.Enqueue([]task{loadBalanceTask})
queryCoord.scheduler.Enqueue(loadBalanceTask)
res, err = queryCoord.ReleaseCollection(baseCtx, &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
@ -400,6 +482,7 @@ func TestLoadBalanceTask(t *testing.T) {
}
func TestGrpcTaskBeforeHealthy(t *testing.T) {
refreshParams()
ctx := context.Background()
unHealthyCoord, err := startUnHealthyQueryCoord(ctx)
assert.Nil(t, err)

View File

@ -423,25 +423,23 @@ func (m *MetaReplica) releaseCollection(collectionID UniqueID) error {
defer m.Unlock()
delete(m.collectionInfos, collectionID)
var err error
for id, info := range m.segmentInfos {
if info.CollectionID == collectionID {
err := removeSegmentInfo(id, m.client)
err = removeSegmentInfo(id, m.client)
if err != nil {
log.Error("remove segmentInfo error", zap.Any("error", err.Error()), zap.Int64("segmentID", id))
return err
log.Warn("remove segmentInfo error", zap.Any("error", err.Error()), zap.Int64("segmentID", id))
}
delete(m.segmentInfos, id)
}
}
delete(m.queryChannelInfos, collectionID)
err := removeGlobalCollectionInfo(collectionID, m.client)
err = removeGlobalCollectionInfo(collectionID, m.client)
if err != nil {
log.Error("remove collectionInfo error", zap.Any("error", err.Error()), zap.Int64("collectionID", collectionID))
return err
log.Warn("remove collectionInfo error", zap.Any("error", err.Error()), zap.Int64("collectionID", collectionID))
}
return nil
return err
}
func (m *MetaReplica) releasePartition(collectionID UniqueID, partitionID UniqueID) error {

View File

@ -214,7 +214,9 @@ func (rc *rootCoordMock) createCollection(collectionID UniqueID) {
if _, ok := rc.Col2partition[collectionID]; !ok {
rc.CollectionIDs = append(rc.CollectionIDs, collectionID)
rc.Col2partition[collectionID] = make([]UniqueID, 0)
partitionIDs := make([]UniqueID, 0)
partitionIDs = append(partitionIDs, defaultPartitionID+1)
rc.Col2partition[collectionID] = partitionIDs
}
}
@ -222,13 +224,30 @@ func (rc *rootCoordMock) createPartition(collectionID UniqueID, partitionID Uniq
rc.Lock()
defer rc.Unlock()
if _, ok := rc.Col2partition[collectionID]; ok {
rc.Col2partition[collectionID] = append(rc.Col2partition[collectionID], partitionID)
if partitionIDs, ok := rc.Col2partition[collectionID]; ok {
partitionExist := false
for _, id := range partitionIDs {
if id == partitionID {
partitionExist = true
break
}
}
if !partitionExist {
rc.Col2partition[collectionID] = append(rc.Col2partition[collectionID], partitionID)
}
return nil
}
return errors.New("collection not exist")
}
func (rc *rootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) {
rc.createPartition(defaultCollectionID, defaultPartitionID)
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
func (rc *rootCoordMock) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
collectionID := in.CollectionID
status := &commonpb.Status{
@ -244,7 +263,6 @@ func (rc *rootCoordMock) ShowPartitions(ctx context.Context, in *milvuspb.ShowPa
}
rc.createCollection(collectionID)
rc.createPartition(collectionID, defaultPartitionID)
return &milvuspb.ShowPartitionsResponse{
Status: status,
@ -267,16 +285,17 @@ type dataCoordMock struct {
minioKV kv.BaseKV
collections []UniqueID
col2DmChannels map[UniqueID][]*datapb.VchannelInfo
partitionID2Segment map[UniqueID]UniqueID
Segment2Binlog map[UniqueID][]*datapb.SegmentBinlogs
assignedSegmentID UniqueID
partitionID2Segment map[UniqueID][]UniqueID
Segment2Binlog map[UniqueID]*datapb.SegmentBinlogs
baseSegmentID UniqueID
channelNumPerCol int
}
func newDataCoordMock(ctx context.Context) (*dataCoordMock, error) {
collectionIDs := make([]UniqueID, 0)
col2DmChannels := make(map[UniqueID][]*datapb.VchannelInfo)
partitionID2Segment := make(map[UniqueID]UniqueID)
segment2Binglog := make(map[UniqueID][]*datapb.SegmentBinlogs)
partitionID2Segments := make(map[UniqueID][]UniqueID)
segment2Binglog := make(map[UniqueID]*datapb.SegmentBinlogs)
// create minio client
option := &minioKV.Option{
@ -296,9 +315,10 @@ func newDataCoordMock(ctx context.Context) (*dataCoordMock, error) {
minioKV: kv,
collections: collectionIDs,
col2DmChannels: col2DmChannels,
partitionID2Segment: partitionID2Segment,
partitionID2Segment: partitionID2Segments,
Segment2Binlog: segment2Binglog,
assignedSegmentID: defaultSegmentID,
baseSegmentID: defaultSegmentID,
channelNumPerCol: 2,
}, nil
}
@ -306,28 +326,36 @@ func (data *dataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetR
collectionID := req.CollectionID
partitionID := req.PartitionID
if _, ok := data.col2DmChannels[collectionID]; !ok {
segmentID := data.assignedSegmentID
data.partitionID2Segment[partitionID] = segmentID
fieldID2Paths, err := generateInsertBinLog(collectionID, partitionID, segmentID, "queryCoorf-mockDataCoord", data.minioKV)
if err != nil {
return nil, err
}
fieldBinlogs := make([]*datapb.FieldBinlog, 0)
for fieldID, path := range fieldID2Paths {
fieldBinlog := &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []string{path},
if _, ok := data.partitionID2Segment[partitionID]; !ok {
segmentIDs := make([]UniqueID, 0)
for i := 0; i < data.channelNumPerCol; i++ {
segmentID := data.baseSegmentID
if _, ok := data.Segment2Binlog[segmentID]; !ok {
fieldID2Paths, err := generateInsertBinLog(collectionID, partitionID, segmentID, "queryCoorf-mockDataCoord", data.minioKV)
if err != nil {
return nil, err
}
fieldBinlogs := make([]*datapb.FieldBinlog, 0)
for fieldID, path := range fieldID2Paths {
fieldBinlog := &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []string{path},
}
fieldBinlogs = append(fieldBinlogs, fieldBinlog)
}
segmentBinlog := &datapb.SegmentBinlogs{
SegmentID: segmentID,
FieldBinlogs: fieldBinlogs,
}
data.Segment2Binlog[segmentID] = segmentBinlog
}
fieldBinlogs = append(fieldBinlogs, fieldBinlog)
segmentIDs = append(segmentIDs, segmentID)
data.baseSegmentID++
}
data.Segment2Binlog[segmentID] = make([]*datapb.SegmentBinlogs, 0)
segmentBinlog := &datapb.SegmentBinlogs{
SegmentID: segmentID,
FieldBinlogs: fieldBinlogs,
}
data.Segment2Binlog[segmentID] = append(data.Segment2Binlog[segmentID], segmentBinlog)
data.partitionID2Segment[partitionID] = segmentIDs
}
if _, ok := data.col2DmChannels[collectionID]; !ok {
channelInfos := make([]*datapb.VchannelInfo, 0)
data.collections = append(data.collections, collectionID)
collectionName := funcutil.RandomString(8)
@ -339,20 +367,24 @@ func (data *dataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetR
SeekPosition: &internalpb.MsgPosition{
ChannelName: vChannel,
},
FlushedSegments: []*datapb.SegmentInfo{{ID: segmentID}},
}
channelInfos = append(channelInfos, channelInfo)
}
data.col2DmChannels[collectionID] = channelInfos
}
segmentID := data.partitionID2Segment[partitionID]
binlogs := make([]*datapb.SegmentBinlogs, 0)
for _, segmentID := range data.partitionID2Segment[partitionID] {
if _, ok := data.Segment2Binlog[segmentID]; ok {
binlogs = append(binlogs, data.Segment2Binlog[segmentID])
}
}
return &datapb.GetRecoveryInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Channels: data.col2DmChannels[collectionID],
Binlogs: data.Segment2Binlog[segmentID],
Binlogs: binlogs,
}, nil
}

View File

@ -87,6 +87,8 @@ func (qc *QueryCoord) Register() error {
// Init function initializes the queryCoord's meta, cluster, etcdKV and task scheduler
func (qc *QueryCoord) Init() error {
log.Debug("query coordinator start init")
//connect etcd
connectEtcdFn := func() error {
etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
if err != nil {
@ -221,19 +223,17 @@ func (qc *QueryCoord) watchNodeLoop() {
SourceNodeIDs: offlineNodeIDs,
}
baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_nodeDown)
loadBalanceTask := &LoadBalanceTask{
BaseTask: BaseTask{
ctx: qc.loopCtx,
Condition: NewTaskCondition(qc.loopCtx),
triggerCondition: querypb.TriggerCondition_nodeDown,
},
BaseTask: baseTask,
LoadBalanceRequest: loadBalanceSegment,
rootCoord: qc.rootCoordClient,
dataCoord: qc.dataCoordClient,
cluster: qc.cluster,
meta: qc.meta,
}
qc.scheduler.Enqueue([]task{loadBalanceTask})
//TODO::deal enqueue error
qc.scheduler.Enqueue(loadBalanceTask)
log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask))
}
@ -271,21 +271,19 @@ func (qc *QueryCoord) watchNodeLoop() {
BalanceReason: querypb.TriggerCondition_nodeDown,
}
baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_nodeDown)
loadBalanceTask := &LoadBalanceTask{
BaseTask: BaseTask{
ctx: qc.loopCtx,
Condition: NewTaskCondition(qc.loopCtx),
triggerCondition: querypb.TriggerCondition_nodeDown,
},
BaseTask: baseTask,
LoadBalanceRequest: loadBalanceSegment,
rootCoord: qc.rootCoordClient,
dataCoord: qc.dataCoordClient,
cluster: qc.cluster,
meta: qc.meta,
}
qc.scheduler.Enqueue([]task{loadBalanceTask})
log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask))
qc.metricsCacheManager.InvalidateSystemInfoMetrics()
//TODO:: deal enqueue error
qc.scheduler.Enqueue(loadBalanceTask)
log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask))
}
}
}

View File

@ -93,6 +93,11 @@ func startQueryCoord(ctx context.Context) (*QueryCoord, error) {
return coord, nil
}
func createDefaultPartition(ctx context.Context, queryCoord *QueryCoord) error {
_, err := queryCoord.rootCoordClient.CreatePartition(ctx, nil)
return err
}
func startUnHealthyQueryCoord(ctx context.Context) (*QueryCoord, error) {
factory := msgstream.NewPmsFactory()

File diff suppressed because it is too large Load Diff

View File

@ -59,31 +59,29 @@ func (queue *TaskQueue) taskFull() bool {
return int64(queue.tasks.Len()) >= queue.maxTask
}
func (queue *TaskQueue) addTask(tasks []task) {
func (queue *TaskQueue) addTask(t task) {
queue.Lock()
defer queue.Unlock()
for _, t := range tasks {
if queue.tasks.Len() == 0 {
queue.taskChan <- 1
queue.tasks.PushBack(t)
if queue.tasks.Len() == 0 {
queue.taskChan <- 1
queue.tasks.PushBack(t)
return
}
for e := queue.tasks.Back(); e != nil; e = e.Prev() {
if t.TaskPriority() > e.Value.(task).TaskPriority() {
if e.Prev() == nil {
queue.taskChan <- 1
queue.tasks.InsertBefore(t, e)
break
}
continue
}
for e := queue.tasks.Back(); e != nil; e = e.Prev() {
if t.TaskPriority() > e.Value.(task).TaskPriority() {
if e.Prev() == nil {
queue.taskChan <- 1
queue.tasks.InsertBefore(t, e)
break
}
continue
}
//TODO:: take care of timestamp
queue.taskChan <- 1
queue.tasks.InsertAfter(t, e)
break
}
//TODO:: take care of timestamp
queue.taskChan <- 1
queue.tasks.InsertAfter(t, e)
break
}
}
@ -123,12 +121,13 @@ func NewTaskQueue() *TaskQueue {
// TaskScheduler controls the scheduling of trigger tasks and internal tasks
type TaskScheduler struct {
triggerTaskQueue *TaskQueue
activateTaskChan chan task
meta Meta
cluster Cluster
taskIDAllocator func() (UniqueID, error)
client *etcdkv.EtcdKV
triggerTaskQueue *TaskQueue
activateTaskChan chan task
meta Meta
cluster Cluster
taskIDAllocator func() (UniqueID, error)
client *etcdkv.EtcdKV
stopActivateTaskLoopChan chan int
rootCoord types.RootCoord
dataCoord types.DataCoord
@ -141,17 +140,20 @@ type TaskScheduler struct {
func NewTaskScheduler(ctx context.Context, meta Meta, cluster Cluster, kv *etcdkv.EtcdKV, rootCoord types.RootCoord, dataCoord types.DataCoord) (*TaskScheduler, error) {
ctx1, cancel := context.WithCancel(ctx)
taskChan := make(chan task, 1024)
stopTaskLoopChan := make(chan int, 1)
s := &TaskScheduler{
ctx: ctx1,
cancel: cancel,
meta: meta,
cluster: cluster,
activateTaskChan: taskChan,
client: kv,
rootCoord: rootCoord,
dataCoord: dataCoord,
ctx: ctx1,
cancel: cancel,
meta: meta,
cluster: cluster,
activateTaskChan: taskChan,
client: kv,
stopActivateTaskLoopChan: stopTaskLoopChan,
rootCoord: rootCoord,
dataCoord: dataCoord,
}
s.triggerTaskQueue = NewTaskQueue()
//init id allocator
etcdKV, err := tsoutil.NewTSOKVBase(Params.EtcdEndpoints, Params.KvRootPath, "queryCoordTaskID")
if err != nil {
return nil, err
@ -166,6 +168,7 @@ func NewTaskScheduler(ctx context.Context, meta Meta, cluster Cluster, kv *etcdk
}
err = s.reloadFromKV()
if err != nil {
log.Error("reload task from kv failed", zap.Error(err))
return nil, err
}
@ -192,7 +195,7 @@ func (scheduler *TaskScheduler) reloadFromKV() error {
if err != nil {
return err
}
t, err := scheduler.unmarshalTask(triggerTaskValues[index])
t, err := scheduler.unmarshalTask(taskID, triggerTaskValues[index])
if err != nil {
return err
}
@ -205,7 +208,7 @@ func (scheduler *TaskScheduler) reloadFromKV() error {
if err != nil {
return err
}
t, err := scheduler.unmarshalTask(activeTaskValues[index])
t, err := scheduler.unmarshalTask(taskID, activeTaskValues[index])
if err != nil {
return err
}
@ -232,15 +235,17 @@ func (scheduler *TaskScheduler) reloadFromKV() error {
}
var doneTriggerTask task = nil
for id, t := range triggerTasks {
if taskInfos[id] == taskDone {
for _, t := range triggerTasks {
if t.State() == taskDone {
doneTriggerTask = t
for _, childTask := range activeTasks {
childTask.SetParentTask(t) //replace child task after reScheduler
t.AddChildTask(childTask)
}
t.SetResultInfo(nil)
continue
}
scheduler.triggerTaskQueue.addTask([]task{t})
scheduler.triggerTaskQueue.addTask(t)
}
if doneTriggerTask != nil {
@ -250,26 +255,23 @@ func (scheduler *TaskScheduler) reloadFromKV() error {
return nil
}
func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) {
func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, error) {
header := commonpb.MsgHeader{}
err := proto.Unmarshal([]byte(t), &header)
if err != nil {
return nil, fmt.Errorf("Failed to unmarshal message header, err %s ", err.Error())
}
var newTask task
baseTask := newBaseTask(scheduler.ctx, querypb.TriggerCondition_grpcRequest)
switch header.Base.MsgType {
case commonpb.MsgType_LoadCollection:
loadReq := querypb.LoadCollectionRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
loadCollectionTask := &LoadCollectionTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
LoadCollectionRequest: &loadReq,
rootCoord: scheduler.rootCoord,
dataCoord: scheduler.dataCoord,
@ -281,14 +283,10 @@ func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) {
loadReq := querypb.LoadPartitionsRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
loadPartitionTask := &LoadPartitionTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
LoadPartitionsRequest: &loadReq,
dataCoord: scheduler.dataCoord,
cluster: scheduler.cluster,
@ -299,14 +297,10 @@ func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) {
loadReq := querypb.ReleaseCollectionRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
releaseCollectionTask := &ReleaseCollectionTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
ReleaseCollectionRequest: &loadReq,
cluster: scheduler.cluster,
meta: scheduler.meta,
@ -317,96 +311,79 @@ func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) {
loadReq := querypb.ReleasePartitionsRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
releasePartitionTask := &ReleasePartitionTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
ReleasePartitionsRequest: &loadReq,
cluster: scheduler.cluster,
}
newTask = releasePartitionTask
case commonpb.MsgType_LoadSegments:
//TODO::trigger condition may be different
loadReq := querypb.LoadSegmentsRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
loadSegmentTask := &LoadSegmentTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
LoadSegmentsRequest: &loadReq,
cluster: scheduler.cluster,
meta: scheduler.meta,
excludeNodeIDs: []int64{},
}
newTask = loadSegmentTask
case commonpb.MsgType_ReleaseSegments:
//TODO::trigger condition may be different
loadReq := querypb.ReleaseSegmentsRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
releaseSegmentTask := &ReleaseSegmentTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
ReleaseSegmentsRequest: &loadReq,
cluster: scheduler.cluster,
}
newTask = releaseSegmentTask
case commonpb.MsgType_WatchDmChannels:
//TODO::trigger condition may be different
loadReq := querypb.WatchDmChannelsRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
watchDmChannelTask := &WatchDmChannelTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
WatchDmChannelsRequest: &loadReq,
cluster: scheduler.cluster,
meta: scheduler.meta,
excludeNodeIDs: []int64{},
}
newTask = watchDmChannelTask
case commonpb.MsgType_WatchQueryChannels:
//TODO::trigger condition may be different
loadReq := querypb.AddQueryChannelRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
watchQueryChannelTask := &WatchQueryChannelTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
BaseTask: baseTask,
AddQueryChannelRequest: &loadReq,
cluster: scheduler.cluster,
}
newTask = watchQueryChannelTask
case commonpb.MsgType_LoadBalanceSegments:
//TODO::trigger condition may be different
loadReq := querypb.LoadBalanceRequest{}
err = proto.Unmarshal([]byte(t), &loadReq)
if err != nil {
log.Error(err.Error())
return nil, err
}
loadBalanceTask := &LoadBalanceTask{
BaseTask: BaseTask{
ctx: scheduler.ctx,
Condition: NewTaskCondition(scheduler.ctx),
triggerCondition: loadReq.BalanceReason,
},
BaseTask: baseTask,
LoadBalanceRequest: &loadReq,
rootCoord: scheduler.rootCoord,
dataCoord: scheduler.dataCoord,
@ -420,105 +397,115 @@ func (scheduler *TaskScheduler) unmarshalTask(t string) (task, error) {
return nil, err
}
newTask.SetID(taskID)
return newTask, nil
}
// Enqueue pushs a trigger task to triggerTaskQueue and assigns task id
func (scheduler *TaskScheduler) Enqueue(tasks []task) {
for _, t := range tasks {
id, err := scheduler.taskIDAllocator()
if err != nil {
log.Error(err.Error())
}
t.SetID(id)
kvs := make(map[string]string)
taskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, t.ID())
blobs, err := t.Marshal()
if err != nil {
log.Error("error when save marshal task", zap.Int64("taskID", t.ID()), zap.String("error", err.Error()))
}
kvs[taskKey] = string(blobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
kvs[stateKey] = strconv.Itoa(int(taskUndo))
err = scheduler.client.MultiSave(kvs)
if err != nil {
log.Error("error when save trigger task to etcd", zap.Int64("taskID", t.ID()), zap.String("error", err.Error()))
}
log.Debug("EnQueue a triggerTask and save to etcd", zap.Int64("taskID", t.ID()))
t.SetState(taskUndo)
func (scheduler *TaskScheduler) Enqueue(t task) error {
id, err := scheduler.taskIDAllocator()
if err != nil {
log.Error("allocator trigger taskID failed", zap.Error(err))
return err
}
t.SetID(id)
kvs := make(map[string]string)
taskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, t.ID())
blobs, err := t.Marshal()
if err != nil {
log.Error("error when save marshal task", zap.Int64("taskID", t.ID()), zap.Error(err))
return err
}
kvs[taskKey] = string(blobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
kvs[stateKey] = strconv.Itoa(int(taskUndo))
err = scheduler.client.MultiSave(kvs)
if err != nil {
//TODO::clean etcd meta
log.Error("error when save trigger task to etcd", zap.Int64("taskID", t.ID()), zap.Error(err))
return err
}
t.SetState(taskUndo)
scheduler.triggerTaskQueue.addTask(t)
log.Debug("EnQueue a triggerTask and save to etcd", zap.Int64("taskID", t.ID()))
scheduler.triggerTaskQueue.addTask(tasks)
return nil
}
func (scheduler *TaskScheduler) processTask(t task) error {
var taskInfoKey string
// assign taskID for childTask and update triggerTask's childTask to etcd
updateKVFn := func(parentTask task) error {
kvs := make(map[string]string)
kvs[taskInfoKey] = strconv.Itoa(int(taskDone))
for _, childTask := range parentTask.GetChildTask() {
id, err := scheduler.taskIDAllocator()
if err != nil {
return err
}
childTask.SetID(id)
childTaskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, childTask.ID())
blobs, err := childTask.Marshal()
if err != nil {
return err
}
kvs[childTaskKey] = string(blobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, childTask.ID())
kvs[stateKey] = strconv.Itoa(int(taskUndo))
}
err := scheduler.client.MultiSave(kvs)
if err != nil {
return err
}
return nil
}
span, ctx := trace.StartSpanFromContext(t.TraceCtx(),
opentracing.Tags{
"Type": t.Type(),
"ID": t.ID(),
})
var err error
defer span.Finish()
defer func() {
//task postExecute
span.LogFields(oplog.Int64("processTask: scheduler process PostExecute", t.ID()))
t.PostExecute(ctx)
}()
// task preExecute
span.LogFields(oplog.Int64("processTask: scheduler process PreExecute", t.ID()))
t.PreExecute(ctx)
key := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
err := scheduler.client.Save(key, strconv.Itoa(int(taskDoing)))
taskInfoKey = fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
err = scheduler.client.Save(taskInfoKey, strconv.Itoa(int(taskDoing)))
if err != nil {
log.Error("processTask: update task state err", zap.String("reason", err.Error()), zap.Int64("taskID", t.ID()))
trace.LogError(span, err)
t.SetResultInfo(err)
return err
}
t.SetState(taskDoing)
// task execute
span.LogFields(oplog.Int64("processTask: scheduler process Execute", t.ID()))
err = t.Execute(ctx)
if err != nil {
log.Debug("processTask: execute err", zap.String("reason", err.Error()), zap.Int64("taskID", t.ID()))
trace.LogError(span, err)
return err
}
for _, childTask := range t.GetChildTask() {
if childTask == nil {
log.Error("processTask: child task equal nil")
continue
}
id, err := scheduler.taskIDAllocator()
if err != nil {
return err
}
childTask.SetID(id)
kvs := make(map[string]string)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, childTask.ID())
blobs, err := childTask.Marshal()
if err != nil {
log.Error("processTask: marshal task err", zap.String("reason", err.Error()))
trace.LogError(span, err)
return err
}
kvs[taskKey] = string(blobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, childTask.ID())
kvs[stateKey] = strconv.Itoa(int(taskUndo))
err = scheduler.client.MultiSave(kvs)
if err != nil {
log.Error("processTask: save active task info err", zap.String("reason", err.Error()))
trace.LogError(span, err)
return err
}
log.Debug("processTask: save active task to etcd", zap.Int64("parent taskID", t.ID()), zap.Int64("child taskID", childTask.ID()))
}
err = scheduler.client.Save(key, strconv.Itoa(int(taskDone)))
err = updateKVFn(t)
if err != nil {
log.Error("processTask: update task state err", zap.String("reason", err.Error()), zap.Int64("taskID", t.ID()))
trace.LogError(span, err)
t.SetResultInfo(err)
return err
}
log.Debug("processTask: update etcd success", zap.Int64("parent taskID", t.ID()))
if t.Type() == commonpb.MsgType_LoadCollection || t.Type() == commonpb.MsgType_LoadPartitions {
t.Notify(nil)
}
span.LogFields(oplog.Int64("processTask: scheduler process PostExecute", t.ID()))
t.PostExecute(ctx)
t.SetState(taskDone)
t.UpdateTaskProcess()
return nil
}
@ -526,140 +513,258 @@ func (scheduler *TaskScheduler) processTask(t task) error {
func (scheduler *TaskScheduler) scheduleLoop() {
defer scheduler.wg.Done()
activeTaskWg := &sync.WaitGroup{}
var triggerTask task
for {
var err error = nil
select {
case <-scheduler.ctx.Done():
return
case <-scheduler.triggerTaskQueue.Chan():
t := scheduler.triggerTaskQueue.PopTask()
log.Debug("scheduleLoop: pop a triggerTask from triggerTaskQueue", zap.Int64("taskID", t.ID()))
if t.State() < taskDone {
err = scheduler.processTask(t)
if err != nil {
log.Error("scheduleLoop: process task error", zap.Any("error", err.Error()))
t.Notify(err)
t.PostExecute(scheduler.ctx)
}
if t.Type() == commonpb.MsgType_LoadCollection || t.Type() == commonpb.MsgType_LoadPartitions {
t.Notify(err)
}
}
log.Debug("scheduleLoop: num of child task", zap.Int("num child task", len(t.GetChildTask())))
for _, childTask := range t.GetChildTask() {
if childTask != nil {
log.Debug("scheduleLoop: add a activate task to activateChan", zap.Int64("taskID", childTask.ID()))
scheduler.activateTaskChan <- childTask
activeTaskWg.Add(1)
go scheduler.waitActivateTaskDone(activeTaskWg, childTask)
}
}
activeTaskWg.Wait()
if t.Type() == commonpb.MsgType_LoadCollection || t.Type() == commonpb.MsgType_LoadPartitions {
t.PostExecute(scheduler.ctx)
processInternalTaskFn := func(activateTasks []task, triggerTask task) {
log.Debug("scheduleLoop: num of child task", zap.Int("num child task", len(activateTasks)))
for _, childTask := range activateTasks {
if childTask != nil {
log.Debug("scheduleLoop: add a activate task to activateChan", zap.Int64("taskID", childTask.ID()))
scheduler.activateTaskChan <- childTask
activeTaskWg.Add(1)
go scheduler.waitActivateTaskDone(activeTaskWg, childTask, triggerTask)
}
}
activeTaskWg.Wait()
}
keys := make([]string, 0)
taskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, t.ID())
rollBackInterTaskFn := func(triggerTask task, originInternalTasks []task, rollBackTasks []task) error {
saves := make(map[string]string)
removes := make([]string, 0)
childTaskIDs := make([]int64, 0)
for _, t := range originInternalTasks {
childTaskIDs = append(childTaskIDs, t.ID())
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID())
removes = append(removes, taskKey)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
removes = append(removes, stateKey)
}
for _, t := range rollBackTasks {
id, err := scheduler.taskIDAllocator()
if err != nil {
return err
}
t.SetID(id)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID())
blobs, err := t.Marshal()
if err != nil {
return err
}
saves[taskKey] = string(blobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
saves[stateKey] = strconv.Itoa(int(taskUndo))
}
err := scheduler.client.MultiSaveAndRemove(saves, removes)
if err != nil {
return err
}
for _, taskID := range childTaskIDs {
triggerTask.RemoveChildTaskByID(taskID)
}
for _, t := range rollBackTasks {
triggerTask.AddChildTask(t)
}
return nil
}
removeTaskFromKVFn := func(triggerTask task) error {
keys := make([]string, 0)
taskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, triggerTask.ID())
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, triggerTask.ID())
keys = append(keys, taskKey)
keys = append(keys, stateKey)
childTasks := triggerTask.GetChildTask()
for _, t := range childTasks {
taskKey = fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID())
stateKey = fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
keys = append(keys, taskKey)
keys = append(keys, stateKey)
err = scheduler.client.MultiRemove(keys)
if err != nil {
log.Error("scheduleLoop: error when remove trigger task to etcd", zap.Int64("taskID", t.ID()))
t.Notify(err)
continue
}
err := scheduler.client.MultiRemove(keys)
if err != nil {
return err
}
return nil
}
for {
var err error
select {
case <-scheduler.ctx.Done():
scheduler.stopActivateTaskLoopChan <- 1
return
case <-scheduler.triggerTaskQueue.Chan():
triggerTask = scheduler.triggerTaskQueue.PopTask()
log.Debug("scheduleLoop: pop a triggerTask from triggerTaskQueue", zap.Int64("triggerTaskID", triggerTask.ID()))
alreadyNotify := true
if triggerTask.State() == taskUndo || triggerTask.State() == taskDoing {
err = scheduler.processTask(triggerTask)
if err != nil {
log.Debug("scheduleLoop: process triggerTask failed", zap.Int64("triggerTaskID", triggerTask.ID()), zap.Error(err))
alreadyNotify = false
}
}
if triggerTask.Type() != commonpb.MsgType_LoadCollection && triggerTask.Type() != commonpb.MsgType_LoadPartitions {
alreadyNotify = false
}
childTasks := triggerTask.GetChildTask()
if len(childTasks) != 0 {
activateTasks := make([]task, len(childTasks))
copy(activateTasks, childTasks)
processInternalTaskFn(activateTasks, triggerTask)
resultStatus := triggerTask.GetResultInfo()
if resultStatus.ErrorCode != commonpb.ErrorCode_Success {
rollBackTasks := triggerTask.RollBack(scheduler.ctx)
log.Debug("scheduleLoop: start rollBack after triggerTask failed",
zap.Int64("triggerTaskID", triggerTask.ID()),
zap.Any("rollBackTasks", rollBackTasks))
err = rollBackInterTaskFn(triggerTask, childTasks, rollBackTasks)
if err != nil {
log.Error("scheduleLoop: rollBackInternalTask error",
zap.Int64("triggerTaskID", triggerTask.ID()),
zap.Error(err))
triggerTask.SetResultInfo(err)
} else {
processInternalTaskFn(rollBackTasks, triggerTask)
}
}
}
err = removeTaskFromKVFn(triggerTask)
if err != nil {
log.Error("scheduleLoop: error when remove trigger and internal tasks from etcd", zap.Int64("triggerTaskID", triggerTask.ID()), zap.Error(err))
triggerTask.SetResultInfo(err)
} else {
log.Debug("scheduleLoop: trigger task done and delete from etcd", zap.Int64("triggerTaskID", triggerTask.ID()))
}
resultStatus := triggerTask.GetResultInfo()
if resultStatus.ErrorCode != commonpb.ErrorCode_Success {
triggerTask.SetState(taskFailed)
if !alreadyNotify {
triggerTask.Notify(errors.New(resultStatus.Reason))
}
} else {
triggerTask.UpdateTaskProcess()
triggerTask.SetState(taskExpired)
if !alreadyNotify {
triggerTask.Notify(nil)
}
}
log.Debug("scheduleLoop: trigger task done and delete from etcd", zap.Int64("taskID", t.ID()))
t.Notify(err)
}
}
}
func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task) {
func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task, triggerTask task) {
defer wg.Done()
err := t.WaitToFinish()
var err error
redoFunc1 := func() {
if !t.IsValid() || !t.IsRetryable() {
log.Debug("waitActivateTaskDone: reSchedule the activate task",
zap.Int64("taskID", t.ID()),
zap.Int64("triggerTaskID", triggerTask.ID()))
reScheduledTasks, err := t.Reschedule(scheduler.ctx)
if err != nil {
log.Error("waitActivateTaskDone: reschedule task error",
zap.Int64("taskID", t.ID()),
zap.Int64("triggerTaskID", triggerTask.ID()),
zap.Error(err))
triggerTask.SetResultInfo(err)
return
}
removes := make([]string, 0)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID())
removes = append(removes, taskKey)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
removes = append(removes, stateKey)
saves := make(map[string]string)
for _, rt := range reScheduledTasks {
if rt != nil {
id, err := scheduler.taskIDAllocator()
if err != nil {
log.Error("waitActivateTaskDone: allocate id error",
zap.Int64("triggerTaskID", triggerTask.ID()),
zap.Error(err))
triggerTask.SetResultInfo(err)
return
}
rt.SetID(id)
log.Debug("waitActivateTaskDone: reScheduler set id", zap.Int64("id", rt.ID()))
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, rt.ID())
blobs, err := rt.Marshal()
if err != nil {
log.Error("waitActivateTaskDone: error when marshal active task",
zap.Int64("triggerTaskID", triggerTask.ID()),
zap.Error(err))
triggerTask.SetResultInfo(err)
return
}
saves[taskKey] = string(blobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, rt.ID())
saves[stateKey] = strconv.Itoa(int(taskUndo))
}
}
//TODO::queryNode auto watch queryChannel, then update etcd use same id directly
err = scheduler.client.MultiSaveAndRemove(saves, removes)
if err != nil {
log.Error("waitActivateTaskDone: error when save and remove task from etcd", zap.Int64("triggerTaskID", triggerTask.ID()))
triggerTask.SetResultInfo(err)
return
}
triggerTask.RemoveChildTaskByID(t.ID())
log.Debug("waitActivateTaskDone: delete failed active task and save reScheduled task to etcd",
zap.Int64("triggerTaskID", triggerTask.ID()),
zap.Int64("failed taskID", t.ID()),
zap.Any("reScheduled tasks", reScheduledTasks))
for _, rt := range reScheduledTasks {
if rt != nil {
triggerTask.AddChildTask(rt)
log.Debug("waitActivateTaskDone: add a reScheduled active task to activateChan", zap.Int64("taskID", rt.ID()))
scheduler.activateTaskChan <- rt
wg.Add(1)
go scheduler.waitActivateTaskDone(wg, rt, triggerTask)
}
}
//delete task from etcd
} else {
log.Debug("waitActivateTaskDone: retry the active task",
zap.Int64("taskID", t.ID()),
zap.Int64("triggerTaskID", triggerTask.ID()))
scheduler.activateTaskChan <- t
wg.Add(1)
go scheduler.waitActivateTaskDone(wg, t, triggerTask)
}
}
redoFunc2 := func(err error) {
if t.IsValid() {
if !t.IsRetryable() {
log.Error("waitActivateTaskDone: activate task failed after retry",
zap.Int64("taskID", t.ID()),
zap.Int64("triggerTaskID", triggerTask.ID()))
triggerTask.SetResultInfo(err)
return
}
log.Debug("waitActivateTaskDone: retry the active task",
zap.Int64("taskID", t.ID()),
zap.Int64("triggerTaskID", triggerTask.ID()))
scheduler.activateTaskChan <- t
wg.Add(1)
go scheduler.waitActivateTaskDone(wg, t, triggerTask)
}
}
err = t.WaitToFinish()
if err != nil {
log.Debug("waitActivateTaskDone: activate task return err", zap.Any("error", err.Error()), zap.Int64("taskID", t.ID()))
redoFunc1 := func() {
if !t.IsValid() {
reScheduledTasks, err := t.Reschedule()
if err != nil {
log.Error(err.Error())
return
}
removes := make([]string, 0)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID())
removes = append(removes, taskKey)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
removes = append(removes, stateKey)
saves := make(map[string]string)
reSchedID := make([]int64, 0)
for _, rt := range reScheduledTasks {
if rt != nil {
id, err := scheduler.taskIDAllocator()
if err != nil {
log.Error(err.Error())
continue
}
rt.SetID(id)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, rt.ID())
blobs, err := rt.Marshal()
if err != nil {
log.Error("waitActivateTaskDone: error when marshal active task")
continue
//TODO::xige-16 deal error when marshal task failed
}
saves[taskKey] = string(blobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, rt.ID())
saves[stateKey] = strconv.Itoa(int(taskUndo))
reSchedID = append(reSchedID, rt.ID())
}
}
err = scheduler.client.MultiSaveAndRemove(saves, removes)
if err != nil {
log.Error("waitActivateTaskDone: error when save and remove task from etcd")
//TODO::xige-16 deal error when save meta failed
}
log.Debug("waitActivateTaskDone: delete failed active task and save reScheduled task to etcd", zap.Int64("failed taskID", t.ID()), zap.Int64s("reScheduled taskIDs", reSchedID))
for _, rt := range reScheduledTasks {
if rt != nil {
log.Debug("waitActivateTaskDone: add a reScheduled active task to activateChan", zap.Int64("taskID", rt.ID()))
scheduler.activateTaskChan <- rt
wg.Add(1)
go scheduler.waitActivateTaskDone(wg, rt)
}
}
//delete task from etcd
} else {
log.Debug("waitActivateTaskDone: retry the active task", zap.Int64("taskID", t.ID()))
scheduler.activateTaskChan <- t
wg.Add(1)
go scheduler.waitActivateTaskDone(wg, t)
}
}
redoFunc2 := func() {
if t.IsValid() {
log.Debug("waitActivateTaskDone: retry the active task", zap.Int64("taskID", t.ID()))
scheduler.activateTaskChan <- t
wg.Add(1)
go scheduler.waitActivateTaskDone(wg, t)
} else {
removes := make([]string, 0)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID())
removes = append(removes, taskKey)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
removes = append(removes, stateKey)
err = scheduler.client.MultiRemove(removes)
if err != nil {
log.Error("waitActivateTaskDone: error when remove task from etcd", zap.Int64("taskID", t.ID()))
}
}
}
log.Debug("waitActivateTaskDone: activate task return err",
zap.Int64("taskID", t.ID()),
zap.Int64("triggerTaskID", triggerTask.ID()),
zap.Error(err))
switch t.Type() {
case commonpb.MsgType_LoadSegments:
@ -667,48 +772,37 @@ func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task)
case commonpb.MsgType_WatchDmChannels:
redoFunc1()
case commonpb.MsgType_WatchQueryChannels:
redoFunc2()
redoFunc2(err)
case commonpb.MsgType_ReleaseSegments:
redoFunc2()
redoFunc2(err)
case commonpb.MsgType_ReleaseCollection:
redoFunc2()
redoFunc2(err)
case commonpb.MsgType_ReleasePartitions:
redoFunc2()
redoFunc2(err)
default:
//TODO:: case commonpb.MsgType_RemoveDmChannels:
}
} else {
keys := make([]string, 0)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID())
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
keys = append(keys, taskKey)
keys = append(keys, stateKey)
err = scheduler.client.MultiRemove(keys)
if err != nil {
log.Error("waitActivateTaskDone: error when remove task from etcd", zap.Int64("taskID", t.ID()))
}
log.Debug("waitActivateTaskDone: delete activate task from etcd", zap.Int64("taskID", t.ID()))
log.Debug("waitActivateTaskDone: one activate task done",
zap.Int64("taskID", t.ID()),
zap.Int64("triggerTaskID", triggerTask.ID()))
}
log.Debug("waitActivateTaskDone: one activate task done", zap.Int64("taskID", t.ID()))
}
func (scheduler *TaskScheduler) processActivateTaskLoop() {
defer scheduler.wg.Done()
for {
select {
case <-scheduler.ctx.Done():
case <-scheduler.stopActivateTaskLoopChan:
log.Debug("processActivateTaskLoop, ctx done")
return
case t := <-scheduler.activateTaskChan:
if t == nil {
log.Error("processActivateTaskLoop: pop a nil active task", zap.Int64("taskID", t.ID()))
continue
}
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
err := scheduler.client.Save(stateKey, strconv.Itoa(int(taskDoing)))
if err != nil {
t.Notify(err)
continue
}
log.Debug("processActivateTaskLoop: pop a active task from activateChan", zap.Int64("taskID", t.ID()))
go func() {
err := scheduler.processTask(t)

View File

@ -49,6 +49,7 @@ func (tt *testTask) Timestamp() Timestamp {
}
func (tt *testTask) PreExecute(ctx context.Context) error {
tt.SetResultInfo(nil)
log.Debug("test task preExecute...")
return nil
}
@ -59,7 +60,7 @@ func (tt *testTask) Execute(ctx context.Context) error {
switch tt.baseMsg.MsgType {
case commonpb.MsgType_LoadSegments:
childTask := &LoadSegmentTask{
BaseTask: BaseTask{
BaseTask: &BaseTask{
ctx: tt.ctx,
Condition: NewTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
@ -70,13 +71,14 @@ func (tt *testTask) Execute(ctx context.Context) error {
},
NodeID: tt.nodeID,
},
meta: tt.meta,
cluster: tt.cluster,
meta: tt.meta,
cluster: tt.cluster,
excludeNodeIDs: []int64{},
}
tt.AddChildTask(childTask)
case commonpb.MsgType_WatchDmChannels:
childTask := &WatchDmChannelTask{
BaseTask: BaseTask{
BaseTask: &BaseTask{
ctx: tt.ctx,
Condition: NewTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
@ -87,13 +89,14 @@ func (tt *testTask) Execute(ctx context.Context) error {
},
NodeID: tt.nodeID,
},
cluster: tt.cluster,
meta: tt.meta,
cluster: tt.cluster,
meta: tt.meta,
excludeNodeIDs: []int64{},
}
tt.AddChildTask(childTask)
case commonpb.MsgType_WatchQueryChannels:
childTask := &WatchQueryChannelTask{
BaseTask: BaseTask{
BaseTask: &BaseTask{
ctx: tt.ctx,
Condition: NewTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
@ -129,12 +132,7 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) {
queryNode.addQueryChannels = returnFailedResult
nodeID := queryNode.queryNodeID
for {
_, err = queryCoord.cluster.getNodeByID(nodeID)
if err == nil {
break
}
}
waitQueryNodeOnline(queryCoord.cluster, nodeID)
testTask := &testTask{
BaseTask: BaseTask{
ctx: baseCtx,
@ -148,7 +146,7 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) {
meta: queryCoord.meta,
nodeID: nodeID,
}
queryCoord.scheduler.Enqueue([]task{testTask})
queryCoord.scheduler.Enqueue(testTask)
queryNode.stop()
err = removeNodeSession(queryNode.queryNodeID)
@ -169,7 +167,11 @@ func TestUnMarshalTask(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
taskScheduler := &TaskScheduler{}
baseCtx, cancel := context.WithCancel(context.Background())
taskScheduler := &TaskScheduler{
ctx: baseCtx,
cancel: cancel,
}
t.Run("Test LoadCollectionTask", func(t *testing.T) {
loadTask := &LoadCollectionTask{
@ -187,7 +189,7 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalLoadCollection")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1000, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadCollection)
})
@ -208,7 +210,7 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalLoadPartition")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1001, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadPartitions)
})
@ -229,7 +231,7 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalReleaseCollection")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1002, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_ReleaseCollection)
})
@ -250,7 +252,7 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalReleasePartition")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1003, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_ReleasePartitions)
})
@ -271,7 +273,7 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalLoadSegment")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1004, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadSegments)
})
@ -292,7 +294,7 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalReleaseSegment")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1005, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_ReleaseSegments)
})
@ -313,7 +315,7 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalWatchDmChannel")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1006, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_WatchDmChannels)
})
@ -334,7 +336,7 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalWatchQueryChannel")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1007, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_WatchQueryChannels)
})
@ -356,17 +358,22 @@ func TestUnMarshalTask(t *testing.T) {
value, err := kv.Load("testMarshalLoadBalanceTask")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
task, err := taskScheduler.unmarshalTask(1008, value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadBalanceSegments)
})
taskScheduler.Close()
}
func TestReloadTaskFromKV(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
baseCtx, cancel := context.WithCancel(context.Background())
taskScheduler := &TaskScheduler{
ctx: baseCtx,
cancel: cancel,
client: kv,
triggerTaskQueue: NewTaskQueue(),
}

View File

@ -17,9 +17,215 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
func genLoadCollectionTask(ctx context.Context, queryCoord *QueryCoord) *LoadCollectionTask {
req := &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
loadCollectionTask := &LoadCollectionTask{
BaseTask: baseTask,
LoadCollectionRequest: req,
rootCoord: queryCoord.rootCoordClient,
dataCoord: queryCoord.dataCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
return loadCollectionTask
}
func genLoadPartitionTask(ctx context.Context, queryCoord *QueryCoord) *LoadPartitionTask {
req := &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
loadPartitionTask := &LoadPartitionTask{
BaseTask: baseTask,
LoadPartitionsRequest: req,
dataCoord: queryCoord.dataCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
return loadPartitionTask
}
func genReleaseCollectionTask(ctx context.Context, queryCoord *QueryCoord) *ReleaseCollectionTask {
req := &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseCollection,
},
CollectionID: defaultCollectionID,
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
releaseCollectionTask := &ReleaseCollectionTask{
BaseTask: baseTask,
ReleaseCollectionRequest: req,
rootCoord: queryCoord.rootCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
return releaseCollectionTask
}
func genReleasePartitionTask(ctx context.Context, queryCoord *QueryCoord) *ReleasePartitionTask {
req := &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
releasePartitionTask := &ReleasePartitionTask{
BaseTask: baseTask,
ReleasePartitionsRequest: req,
cluster: queryCoord.cluster,
}
return releasePartitionTask
}
func genReleaseSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *ReleaseSegmentTask {
req := &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseSegments,
},
NodeID: nodeID,
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID},
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
releaseSegmentTask := &ReleaseSegmentTask{
BaseTask: baseTask,
ReleaseSegmentsRequest: req,
cluster: queryCoord.cluster,
}
return releaseSegmentTask
}
func genWatchDmChannelTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *WatchDmChannelTask {
schema := genCollectionSchema(defaultCollectionID, false)
vChannelInfo := &datapb.VchannelInfo{
CollectionID: defaultCollectionID,
ChannelName: "testDmChannel",
}
req := &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
},
NodeID: nodeID,
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
Schema: schema,
Infos: []*datapb.VchannelInfo{vChannelInfo},
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
baseTask.taskID = 100
watchDmChannelTask := &WatchDmChannelTask{
BaseTask: baseTask,
WatchDmChannelsRequest: req,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
excludeNodeIDs: []int64{},
}
parentReq := &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
}
baseParentTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
baseParentTask.taskID = 10
parentTask := &LoadCollectionTask{
BaseTask: baseParentTask,
LoadCollectionRequest: parentReq,
rootCoord: queryCoord.rootCoordClient,
dataCoord: queryCoord.dataCoordClient,
meta: queryCoord.meta,
cluster: queryCoord.cluster,
}
parentTask.SetState(taskDone)
parentTask.SetResultInfo(nil)
parentTask.AddChildTask(watchDmChannelTask)
watchDmChannelTask.SetParentTask(parentTask)
queryCoord.meta.addCollection(defaultCollectionID, schema)
return watchDmChannelTask
}
func genLoadSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *LoadSegmentTask {
schema := genCollectionSchema(defaultCollectionID, false)
segmentInfo := &querypb.SegmentLoadInfo{
SegmentID: defaultSegmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
}
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
},
NodeID: nodeID,
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{segmentInfo},
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
baseTask.taskID = 100
loadSegmentTask := &LoadSegmentTask{
BaseTask: baseTask,
LoadSegmentsRequest: req,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
excludeNodeIDs: []int64{},
}
parentReq := &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
}
baseParentTask := newBaseTask(ctx, querypb.TriggerCondition_grpcRequest)
baseParentTask.taskID = 10
parentTask := &LoadCollectionTask{
BaseTask: baseParentTask,
LoadCollectionRequest: parentReq,
rootCoord: queryCoord.rootCoordClient,
dataCoord: queryCoord.dataCoordClient,
meta: queryCoord.meta,
cluster: queryCoord.cluster,
}
parentTask.SetState(taskDone)
parentTask.SetResultInfo(nil)
parentTask.AddChildTask(loadSegmentTask)
loadSegmentTask.SetParentTask(parentTask)
queryCoord.meta.addCollection(defaultCollectionID, schema)
return loadSegmentTask
}
func waitTaskFinalState(t task, state taskState) {
for {
if t.State() == state {
break
}
}
}
func TestTriggerTask(t *testing.T) {
refreshParams()
ctx := context.Background()
@ -28,98 +234,32 @@ func TestTriggerTask(t *testing.T) {
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
t.Run("Test LoadCollection", func(t *testing.T) {
req := &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
}
loadCollectionTask := &LoadCollectionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
LoadCollectionRequest: req,
rootCoord: queryCoord.rootCoordClient,
dataCoord: queryCoord.dataCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.processTask(loadCollectionTask)
assert.Nil(t, err)
})
t.Run("Test ReleaseCollection", func(t *testing.T) {
req := &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseCollection,
},
CollectionID: defaultCollectionID,
}
loadCollectionTask := &ReleaseCollectionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
ReleaseCollectionRequest: req,
rootCoord: queryCoord.rootCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
err = queryCoord.scheduler.processTask(loadCollectionTask)
releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.processTask(releaseCollectionTask)
assert.Nil(t, err)
})
t.Run("Test LoadPartition", func(t *testing.T) {
req := &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
loadCollectionTask := &LoadPartitionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
LoadPartitionsRequest: req,
dataCoord: queryCoord.dataCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
loadPartitionTask := genLoadPartitionTask(ctx, queryCoord)
err = queryCoord.scheduler.processTask(loadCollectionTask)
err = queryCoord.scheduler.processTask(loadPartitionTask)
assert.Nil(t, err)
})
t.Run("Test ReleasePartition", func(t *testing.T) {
req := &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
loadCollectionTask := &ReleasePartitionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
ReleasePartitionsRequest: req,
cluster: queryCoord.cluster,
}
releasePartitionTask := genReleaseCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.processTask(loadCollectionTask)
err = queryCoord.scheduler.processTask(releasePartitionTask)
assert.Nil(t, err)
})
@ -128,3 +268,388 @@ func TestTriggerTask(t *testing.T) {
err = removeAllSession()
assert.Nil(t, err)
}
func Test_LoadCollectionAfterLoadPartition(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
loadPartitionTask := genLoadPartitionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadPartitionTask)
assert.Nil(t, err)
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask)
assert.Nil(t, err)
releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(releaseCollectionTask)
assert.Nil(t, err)
err = releaseCollectionTask.WaitToFinish()
assert.Nil(t, err)
node.stop()
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_RepeatLoadCollection(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
loadCollectionTask1 := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask1)
assert.Nil(t, err)
createDefaultPartition(ctx, queryCoord)
loadCollectionTask2 := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask2)
assert.Nil(t, err)
releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(releaseCollectionTask)
assert.Nil(t, err)
err = releaseCollectionTask.WaitToFinish()
assert.Nil(t, err)
node.stop()
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_LoadCollectionAssignTaskFail(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask)
assert.Nil(t, err)
err = loadCollectionTask.WaitToFinish()
assert.NotNil(t, err)
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_LoadCollectionExecuteFail(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node.loadSegment = returnFailedResult
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask)
assert.Nil(t, err)
waitTaskFinalState(loadCollectionTask, taskFailed)
node.stop()
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_LoadPartitionAssignTaskFail(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
loadPartitionTask := genLoadPartitionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadPartitionTask)
assert.Nil(t, err)
err = loadPartitionTask.WaitToFinish()
assert.NotNil(t, err)
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_LoadPartitionExecuteFail(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node.loadSegment = returnFailedResult
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
loadPartitionTask := genLoadPartitionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadPartitionTask)
assert.Nil(t, err)
waitTaskFinalState(loadPartitionTask, taskFailed)
node.stop()
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_LoadPartitionExecuteFailAfterLoadCollection(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask)
assert.Nil(t, err)
waitTaskFinalState(loadCollectionTask, taskExpired)
createDefaultPartition(ctx, queryCoord)
node.watchDmChannels = returnFailedResult
loadPartitionTask := genLoadPartitionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadPartitionTask)
assert.Nil(t, err)
waitTaskFinalState(loadPartitionTask, taskFailed)
node.stop()
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_ReleaseCollectionExecuteFail(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node.releaseCollection = returnFailedResult
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(releaseCollectionTask)
assert.Nil(t, err)
waitTaskFinalState(releaseCollectionTask, taskFailed)
node.stop()
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_LoadSegmentReschedule(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node1.loadSegment = returnFailedResult
node2, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID)
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask)
assert.Nil(t, err)
waitTaskFinalState(loadCollectionTask, taskExpired)
node1.stop()
node2.stop()
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_WatchDmChannelReschedule(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node1.watchDmChannels = returnFailedResult
node2, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID)
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
err = queryCoord.scheduler.Enqueue(loadCollectionTask)
assert.Nil(t, err)
waitTaskFinalState(loadCollectionTask, taskExpired)
node1.stop()
node2.stop()
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_ReleaseSegmentTask(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
releaseSegmentTask := genReleaseSegmentTask(ctx, queryCoord, node1.queryNodeID)
queryCoord.scheduler.activateTaskChan <- releaseSegmentTask
waitTaskFinalState(releaseSegmentTask, taskDone)
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_RescheduleDmChannelWithWatchQueryChannel(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node2, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID)
node1.watchDmChannels = returnFailedResult
watchDmChannelTask := genWatchDmChannelTask(ctx, queryCoord, node1.queryNodeID)
loadCollectionTask := watchDmChannelTask.parentTask
queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask)
waitTaskFinalState(loadCollectionTask, taskExpired)
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_RescheduleSegmentWithWatchQueryChannel(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node2, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID)
node1.loadSegment = returnFailedResult
loadSegmentTask := genLoadSegmentTask(ctx, queryCoord, node1.queryNodeID)
loadCollectionTask := loadSegmentTask.parentTask
queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask)
waitTaskFinalState(loadCollectionTask, taskExpired)
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_RescheduleSegmentEndWithFail(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node1.loadSegment = returnFailedResult
node2, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node2.loadSegment = returnFailedResult
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID)
loadSegmentTask := genLoadSegmentTask(ctx, queryCoord, node1.queryNodeID)
loadCollectionTask := loadSegmentTask.parentTask
queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask)
waitTaskFinalState(loadCollectionTask, taskFailed)
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func Test_RescheduleDmChannelsEndWithFail(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node1.watchDmChannels = returnFailedResult
node2, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
node2.watchDmChannels = returnFailedResult
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID)
watchDmChannelTask := genWatchDmChannelTask(ctx, queryCoord, node1.queryNodeID)
loadCollectionTask := watchDmChannelTask.parentTask
queryCoord.scheduler.triggerTaskQueue.addTask(loadCollectionTask)
waitTaskFinalState(loadCollectionTask, taskFailed)
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}