milvus/internal/querycoord/task_scheduler_test.go
wayblink 46e0e2658b
Move datacoord related methods from meta to globalMetaBroker (#17812)
Signed-off-by: wayblink <anyang.wang@zilliz.com>
2022-06-27 21:04:17 +08:00

617 lines
18 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoord
import (
"context"
"fmt"
"strconv"
"testing"
"time"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
)
type testTask struct {
baseTask
baseMsg *commonpb.MsgBase
cluster Cluster
meta Meta
nodeID int64
binlogSize int
}
func (tt *testTask) msgBase() *commonpb.MsgBase {
return tt.baseMsg
}
func (tt *testTask) marshal() ([]byte, error) {
return []byte{}, nil
}
func (tt *testTask) msgType() commonpb.MsgType {
return tt.baseMsg.MsgType
}
func (tt *testTask) timestamp() Timestamp {
return tt.baseMsg.Timestamp
}
func (tt *testTask) preExecute(ctx context.Context) error {
tt.setResultInfo(nil)
log.Debug("test task preExecute...")
return nil
}
func (tt *testTask) execute(ctx context.Context) error {
log.Debug("test task execute...")
switch tt.baseMsg.MsgType {
case commonpb.MsgType_LoadCollection:
binlogs := make([]*datapb.FieldBinlog, 0)
binlogs = append(binlogs, &datapb.FieldBinlog{
FieldID: 0,
Binlogs: []*datapb.Binlog{{LogPath: funcutil.RandomString(tt.binlogSize)}},
})
for id := 0; id < 10; id++ {
segmentInfo := &querypb.SegmentLoadInfo{
SegmentID: UniqueID(id),
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
BinlogPaths: binlogs,
}
segmentInfo.SegmentSize = estimateSegmentSize(segmentInfo)
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
},
Infos: []*querypb.SegmentLoadInfo{segmentInfo},
CollectionID: defaultCollectionID,
}
loadTask := &loadSegmentTask{
baseTask: &baseTask{
ctx: tt.ctx,
condition: newTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
},
LoadSegmentsRequest: req,
meta: tt.meta,
cluster: tt.cluster,
excludeNodeIDs: []int64{},
}
loadTask.setParentTask(tt)
tt.addChildTask(loadTask)
}
case commonpb.MsgType_LoadSegments:
childTask := &loadSegmentTask{
baseTask: &baseTask{
ctx: tt.ctx,
condition: newTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
},
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
},
DstNodeID: tt.nodeID,
CollectionID: defaultCollectionID,
},
meta: tt.meta,
cluster: tt.cluster,
excludeNodeIDs: []int64{},
}
tt.addChildTask(childTask)
case commonpb.MsgType_WatchDmChannels:
childTask := &watchDmChannelTask{
baseTask: &baseTask{
ctx: tt.ctx,
condition: newTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
},
WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
},
NodeID: tt.nodeID,
},
cluster: tt.cluster,
meta: tt.meta,
excludeNodeIDs: []int64{},
}
tt.addChildTask(childTask)
}
return nil
}
func (tt *testTask) postExecute(ctx context.Context) error {
log.Debug("test task postExecute...")
return nil
}
func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) {
refreshParams()
baseCtx := context.Background()
queryCoord, err := startQueryCoord(baseCtx)
assert.Nil(t, err)
activeTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix)
assert.Nil(t, err)
queryNode, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
nodeID := queryNode.queryNodeID
waitQueryNodeOnline(queryCoord.cluster, nodeID)
testTask := &testTask{
baseTask: baseTask{
ctx: baseCtx,
condition: newTaskCondition(baseCtx),
triggerCondition: querypb.TriggerCondition_GrpcRequest,
},
baseMsg: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
},
cluster: queryCoord.cluster,
meta: queryCoord.meta,
nodeID: nodeID,
}
queryCoord.scheduler.Enqueue(testTask)
queryNode.stop()
err = removeNodeSession(queryNode.queryNodeID)
assert.Nil(t, err)
for {
newActiveTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix)
assert.Nil(t, err)
if len(newActiveTaskIDKeys) == len(activeTaskIDKeys) {
break
}
}
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func TestUnMarshalTask(t *testing.T) {
refreshParams()
etcdCli, err := etcd.GetEtcdClient(&Params.EtcdCfg)
assert.Nil(t, err)
defer etcdCli.Close()
kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath)
baseCtx, cancel := context.WithCancel(context.Background())
dataCoord := &dataCoordMock{}
broker, err := newGlobalMetaBroker(baseCtx, nil, dataCoord, nil, nil)
assert.Nil(t, err)
taskScheduler := &TaskScheduler{
ctx: baseCtx,
cancel: cancel,
broker: broker,
}
t.Run("Test loadCollectionTask", func(t *testing.T) {
loadTask := &loadCollectionTask{
LoadCollectionRequest: &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
ReplicaNumber: 1,
},
}
blobs, err := loadTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadCollection", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadCollection")
value, err := kv.Load("testMarshalLoadCollection")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1000, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_LoadCollection)
})
t.Run("Test LoadPartitionsTask", func(t *testing.T) {
loadTask := &loadPartitionTask{
LoadPartitionsRequest: &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
},
}
blobs, err := loadTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadPartition", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadPartition")
value, err := kv.Load("testMarshalLoadPartition")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1001, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_LoadPartitions)
})
t.Run("Test releaseCollectionTask", func(t *testing.T) {
releaseTask := &releaseCollectionTask{
ReleaseCollectionRequest: &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseCollection,
},
},
}
blobs, err := releaseTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalReleaseCollection", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalReleaseCollection")
value, err := kv.Load("testMarshalReleaseCollection")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1002, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_ReleaseCollection)
})
t.Run("Test releasePartitionTask", func(t *testing.T) {
releaseTask := &releasePartitionTask{
ReleasePartitionsRequest: &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
},
}
blobs, err := releaseTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalReleasePartition", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalReleasePartition")
value, err := kv.Load("testMarshalReleasePartition")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1003, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_ReleasePartitions)
})
t.Run("Test loadSegmentTask", func(t *testing.T) {
loadTask := &loadSegmentTask{
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
},
},
}
blobs, err := loadTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadSegment", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadSegment")
value, err := kv.Load("testMarshalLoadSegment")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1004, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_LoadSegments)
})
t.Run("Test releaseSegmentTask", func(t *testing.T) {
releaseTask := &releaseSegmentTask{
ReleaseSegmentsRequest: &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseSegments,
},
},
}
blobs, err := releaseTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalReleaseSegment", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalReleaseSegment")
value, err := kv.Load("testMarshalReleaseSegment")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1005, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_ReleaseSegments)
})
t.Run("Test watchDmChannelTask", func(t *testing.T) {
watchTask := &watchDmChannelTask{
WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
},
},
}
blobs, err := watchTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalWatchDmChannel", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalWatchDmChannel")
value, err := kv.Load("testMarshalWatchDmChannel")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1006, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_WatchDmChannels)
dataCoord.returnError = true
defer func() {
dataCoord.returnError = false
}()
task2, err := taskScheduler.unmarshalTask(1006, value)
assert.Error(t, err)
assert.Nil(t, task2)
})
t.Run("Test watchDeltaChannelTask", func(t *testing.T) {
watchTask := &watchDeltaChannelTask{
WatchDeltaChannelsRequest: &querypb.WatchDeltaChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDeltaChannels,
},
},
}
blobs, err := watchTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalWatchDeltaChannel", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalWatchDeltaChannel")
value, err := kv.Load("testMarshalWatchDeltaChannel")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1007, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_WatchDeltaChannels)
})
t.Run("Test loadBalanceTask", func(t *testing.T) {
loadBalanceTask := &loadBalanceTask{
LoadBalanceRequest: &querypb.LoadBalanceRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
},
},
}
blobs, err := loadBalanceTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadBalanceTask", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadBalanceTask")
value, err := kv.Load("testMarshalLoadBalanceTask")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1009, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_LoadBalanceSegments)
})
t.Run("Test handoffTask", func(t *testing.T) {
handoffTask := &handoffTask{
HandoffSegmentsRequest: &querypb.HandoffSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_HandoffSegments,
},
},
}
blobs, err := handoffTask.marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalHandoffTask", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalHandoffTask")
value, err := kv.Load("testMarshalHandoffTask")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(1010, value)
assert.Nil(t, err)
assert.Equal(t, task.msgType(), commonpb.MsgType_HandoffSegments)
})
taskScheduler.Close()
}
func TestReloadTaskFromKV(t *testing.T) {
refreshParams()
etcdCli, err := etcd.GetEtcdClient(&Params.EtcdCfg)
assert.Nil(t, err)
defer etcdCli.Close()
kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath)
assert.Nil(t, err)
baseCtx, cancel := context.WithCancel(context.Background())
taskScheduler := &TaskScheduler{
ctx: baseCtx,
cancel: cancel,
client: kv,
triggerTaskQueue: newTaskQueue(),
}
kvs := make(map[string]string)
triggerTask := &loadCollectionTask{
LoadCollectionRequest: &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
Timestamp: 1,
MsgType: commonpb.MsgType_LoadCollection,
},
ReplicaNumber: 1,
},
}
triggerBlobs, err := triggerTask.marshal()
assert.Nil(t, err)
triggerTaskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, 100)
kvs[triggerTaskKey] = string(triggerBlobs)
activeTask := &loadSegmentTask{
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
Timestamp: 2,
MsgType: commonpb.MsgType_LoadSegments,
},
},
}
activeBlobs, err := activeTask.marshal()
assert.Nil(t, err)
activeTaskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, 101)
kvs[activeTaskKey] = string(activeBlobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, 100)
kvs[stateKey] = strconv.Itoa(int(taskDone))
err = kv.MultiSave(kvs)
assert.Nil(t, err)
taskScheduler.reloadFromKV()
task := taskScheduler.triggerTaskQueue.popTask()
assert.Equal(t, taskDone, task.getState())
assert.Equal(t, 1, len(task.getChildTask()))
}
func Test_saveInternalTaskToEtcd(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
defer queryCoord.Stop()
testTask := &testTask{
baseTask: baseTask{
ctx: ctx,
condition: newTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_GrpcRequest,
taskID: 100,
},
baseMsg: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
cluster: queryCoord.cluster,
meta: queryCoord.meta,
nodeID: defaultQueryNodeID,
}
t.Run("Test SaveEtcdFail", func(t *testing.T) {
// max send size limit of etcd is 2097152
testTask.binlogSize = 3000000
err = queryCoord.scheduler.processTask(testTask)
assert.NotNil(t, err)
})
t.Run("Test SaveEtcdSuccess", func(t *testing.T) {
testTask.childTasks = []task{}
testTask.binlogSize = 500000
err = queryCoord.scheduler.processTask(testTask)
assert.Nil(t, err)
})
}
func Test_generateDerivedInternalTasks(t *testing.T) {
refreshParams()
baseCtx := context.Background()
queryCoord, err := startQueryCoord(baseCtx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
vChannelInfos, _, err := queryCoord.broker.getRecoveryInfo(baseCtx, defaultCollectionID, defaultPartitionID)
assert.NoError(t, err)
deltaChannelInfos := make([]*datapb.VchannelInfo, len(vChannelInfos))
for i, info := range vChannelInfos {
deltaInfo, err := generateWatchDeltaChannelInfo(info)
assert.NoError(t, err)
deltaChannelInfos[i] = deltaInfo
}
queryCoord.meta.setDeltaChannel(defaultCollectionID, deltaChannelInfos)
loadCollectionTask := genLoadCollectionTask(baseCtx, queryCoord)
loadSegmentTask := genLoadSegmentTask(baseCtx, queryCoord, node1.queryNodeID)
loadCollectionTask.addChildTask(loadSegmentTask)
loadSegmentTask.setParentTask(loadCollectionTask)
watchDmChannelTask := genWatchDmChannelTask(baseCtx, queryCoord, node1.queryNodeID)
loadCollectionTask.addChildTask(watchDmChannelTask)
watchDmChannelTask.setParentTask(loadCollectionTask)
derivedTasks, err := generateDerivedInternalTasks(loadCollectionTask, queryCoord.meta, queryCoord.cluster)
assert.Nil(t, err)
assert.Equal(t, 1, len(derivedTasks))
for _, internalTask := range derivedTasks {
matchType := internalTask.msgType() == commonpb.MsgType_WatchDeltaChannels
assert.Equal(t, true, matchType)
if internalTask.msgType() == commonpb.MsgType_WatchDeltaChannels {
assert.Equal(t, node1.queryNodeID, internalTask.(*watchDeltaChannelTask).NodeID)
}
}
queryCoord.Stop()
err = removeAllSession()
assert.Nil(t, err)
}
func TestTaskScheduler_BindContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
s := &TaskScheduler{
ctx: ctx,
cancel: cancel,
}
t.Run("normal finish", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx, cancel = s.BindContext(ctx)
cancel() // normal finish
assert.Eventually(t, func() bool {
return ctx.Err() == context.Canceled
}, time.Second, time.Millisecond*10)
})
t.Run("input context canceled", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
nctx, ncancel := s.BindContext(ctx)
defer ncancel()
cancel() // input context cancel
assert.Eventually(t, func() bool {
return nctx.Err() == context.Canceled
}, time.Second, time.Millisecond*10)
})
t.Run("scheduler context cancel", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nctx, ncancel := s.BindContext(ctx)
defer ncancel()
s.cancel() // scheduler cancel
assert.Eventually(t, func() bool {
return nctx.Err() == context.Canceled
}, time.Second, time.Millisecond*10)
})
}