Check load request both in impl and task (#18276)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2022-07-15 10:24:28 +08:00 committed by GitHub
parent 73aef14820
commit d278213e4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 285 additions and 116 deletions

View File

@ -203,6 +203,42 @@ func (qc *QueryCoord) ShowCollections(ctx context.Context, req *querypb.ShowColl
}, nil
}
func handleLoadError(err error, loadType querypb.LoadType, msgID, collectionID UniqueID, partitionIDs []UniqueID) (*commonpb.Status, error) {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
if errors.Is(err, ErrCollectionLoaded) {
log.Info("collection or partitions has already been loaded, return load success directly",
zap.String("loadType", loadType.String()),
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", partitionIDs),
zap.Int64("msgID", msgID))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return status, nil
} else if errors.Is(err, ErrLoadParametersMismatch) {
status.ErrorCode = commonpb.ErrorCode_IllegalArgument
status.Reason = err.Error()
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return status, nil
} else {
log.Error("load collection or partitions to query nodes failed",
zap.String("loadType", loadType.String()),
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", partitionIDs),
zap.Int64("msgID", msgID),
zap.Error(err))
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return status, nil
}
}
// LoadCollection loads all the sealed segments of this collection to queryNodes, and assigns watchDmChannelRequest to queryNodes
func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
metrics.QueryCoordLoadCount.WithLabelValues(metrics.TotalLabel).Inc()
@ -235,6 +271,19 @@ func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadColle
cluster: qc.cluster,
meta: qc.meta,
}
LastTaskType := qc.scheduler.triggerTaskQueue.willLoadOrRelease(req.GetCollectionID())
if LastTaskType == commonpb.MsgType_LoadCollection {
// collection will be loaded, remove idempotent loadCollection task, return success directly
return status, nil
}
if LastTaskType != commonpb.MsgType_ReleaseCollection {
err := checkLoadCollection(req, qc.meta)
if err != nil {
return handleLoadError(err, querypb.LoadType_LoadCollection, req.GetBase().GetMsgID(), req.GetCollectionID(), nil)
}
}
err := qc.scheduler.Enqueue(loadCollectionTask)
if err != nil {
log.Error("loadCollectionRequest failed to add execute task to scheduler",
@ -251,32 +300,7 @@ func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadColle
err = loadCollectionTask.waitToFinish()
if err != nil {
if errors.Is(err, ErrCollectionLoaded) {
log.Info("collection has already been loaded, return load success directly",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", req.Base.MsgID))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return status, nil
} else if errors.Is(err, ErrLoadParametersMismatch) {
status.ErrorCode = commonpb.ErrorCode_IllegalArgument
status.Reason = err.Error()
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return status, nil
} else {
log.Error("load collection to query nodes failed",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", req.Base.MsgID),
zap.Error(err))
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return status, nil
}
return handleLoadError(err, querypb.LoadType_LoadCollection, req.GetBase().GetMsgID(), req.GetCollectionID(), nil)
}
log.Info("loadCollectionRequest completed",
@ -508,6 +532,19 @@ func (qc *QueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadParti
cluster: qc.cluster,
meta: qc.meta,
}
LastTaskType := qc.scheduler.triggerTaskQueue.willLoadOrRelease(req.GetCollectionID())
if LastTaskType == commonpb.MsgType_LoadPartitions {
// partitions will be loaded, remove idempotent loadPartition task, return success directly
return status, nil
}
if LastTaskType != commonpb.MsgType_ReleasePartitions {
err := checkLoadPartition(req, qc.meta)
if err != nil {
return handleLoadError(err, querypb.LoadType_LoadPartition, req.GetBase().GetMsgID(), req.GetCollectionID(), req.GetPartitionIDs())
}
}
err := qc.scheduler.Enqueue(loadPartitionTask)
if err != nil {
log.Error("loadPartitionRequest failed to add execute task to scheduler",
@ -525,34 +562,7 @@ func (qc *QueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadParti
err = loadPartitionTask.waitToFinish()
if err != nil {
if errors.Is(err, ErrCollectionLoaded) {
log.Info("loadPartitionRequest completed, all partitions to load have already been loaded into memory",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", req.CollectionID),
zap.Int64s("partitionIDs", req.PartitionIDs),
zap.Int64("msgID", req.Base.MsgID))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return status, nil
} else if errors.Is(err, ErrLoadParametersMismatch) {
status.ErrorCode = commonpb.ErrorCode_IllegalArgument
status.Reason = err.Error()
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return status, nil
} else {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
log.Error("loadPartitionRequest failed",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", req.CollectionID),
zap.Int64s("partitionIDs", partitionIDs),
zap.Int64("msgID", req.Base.MsgID),
zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return status, nil
}
return handleLoadError(err, querypb.LoadType_LoadPartition, req.GetBase().GetMsgID(), req.GetCollectionID(), req.GetPartitionIDs())
}
log.Info("loadPartitionRequest completed",

View File

@ -345,37 +345,22 @@ func (lct *loadCollectionTask) updateTaskProcess() {
//this function shall just calculate intermediate progress
}
func (lct *loadCollectionTask) preExecute(ctx context.Context) error {
if lct.ReplicaNumber < 1 {
log.Warn("replicaNumber is less than 1 for load collection request, will set it to 1",
zap.Int32("replicaNumber", lct.ReplicaNumber))
lct.ReplicaNumber = 1
}
collectionID := lct.CollectionID
schema := lct.Schema
lct.setResultInfo(nil)
collectionInfo, err := lct.meta.getCollectionInfoByID(collectionID)
func checkLoadCollection(req *querypb.LoadCollectionRequest, meta Meta) error {
collectionID := req.CollectionID
collectionInfo, err := meta.getCollectionInfoByID(collectionID)
if err == nil {
// if collection has been loaded by load collection request, return success
if collectionInfo.LoadType == querypb.LoadType_LoadCollection {
if collectionInfo.ReplicaNumber != lct.ReplicaNumber {
if collectionInfo.ReplicaNumber != req.ReplicaNumber {
msg := fmt.Sprintf("collection has already been loaded, and the number of replicas %v is not same as the request's %v. Should release first then reload with the new number of replicas",
collectionInfo.ReplicaNumber,
lct.ReplicaNumber)
req.ReplicaNumber)
log.Warn(msg,
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", lct.Base.MsgID),
zap.Int64("msgID", req.Base.MsgID),
zap.Int32("collectionReplicaNumber", collectionInfo.ReplicaNumber),
zap.Int32("requestReplicaNumber", lct.ReplicaNumber))
lct.result = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_IllegalArgument,
Reason: msg,
}
zap.Int32("requestReplicaNumber", req.ReplicaNumber))
return fmt.Errorf(msg+" [%w]", ErrLoadParametersMismatch)
}
@ -390,15 +375,35 @@ func (lct *loadCollectionTask) preExecute(ctx context.Context) error {
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64s("loaded partitionIDs", collectionInfo.PartitionIDs),
zap.Int64("msgID", lct.Base.MsgID),
zap.Int64("msgID", req.Base.MsgID),
zap.Error(err))
lct.setResultInfo(err)
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return fmt.Errorf(err.Error()+" [%w]", ErrLoadParametersMismatch)
}
}
return nil
}
func (lct *loadCollectionTask) preExecute(ctx context.Context) error {
if lct.ReplicaNumber < 1 {
log.Warn("replicaNumber is less than 1 for load collection request, will set it to 1",
zap.Int32("replicaNumber", lct.ReplicaNumber))
lct.ReplicaNumber = 1
}
collectionID := lct.CollectionID
schema := lct.Schema
lct.setResultInfo(nil)
err := checkLoadCollection(lct.LoadCollectionRequest, lct.meta)
if err != nil {
lct.setResultInfo(err)
return err
}
log.Info("start do loadCollectionTask",
zap.Int64("taskID", lct.getTaskID()),
zap.Int64("msgID", lct.GetBase().GetMsgID()),
@ -837,44 +842,31 @@ func (lpt *loadPartitionTask) updateTaskProcess() {
//this function shall just calculate intermediate progress
}
func (lpt *loadPartitionTask) preExecute(context.Context) error {
if lpt.ReplicaNumber < 1 {
log.Warn("replicaNumber is less than 1 for load partitions request, will set it to 1",
zap.Int32("replicaNumber", lpt.ReplicaNumber))
lpt.ReplicaNumber = 1
}
lpt.setResultInfo(nil)
collectionID := lpt.CollectionID
collectionInfo, err := lpt.meta.getCollectionInfoByID(collectionID)
func checkLoadPartition(req *querypb.LoadPartitionsRequest, meta Meta) error {
collectionID := req.CollectionID
collectionInfo, err := meta.getCollectionInfoByID(collectionID)
if err == nil {
// if the collection has been loaded into memory by load collection request, return error
// should release collection first, then load partitions again
if collectionInfo.LoadType == querypb.LoadType_LoadCollection {
err = fmt.Errorf("collection %d has been loaded into QueryNode, please release collection firstly", collectionID)
lpt.setResultInfo(err)
return fmt.Errorf(err.Error()+" [%w]", ErrLoadParametersMismatch)
} else if collectionInfo.LoadType == querypb.LoadType_LoadPartition {
if collectionInfo.ReplicaNumber != lpt.ReplicaNumber {
if collectionInfo.ReplicaNumber != req.ReplicaNumber {
msg := fmt.Sprintf("partitions has already been loaded, and the number of replicas %v is not same as the request's %v. Should release first then reload with the new number of replicas",
collectionInfo.ReplicaNumber,
lpt.ReplicaNumber)
req.ReplicaNumber)
log.Warn(msg,
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", lpt.Base.MsgID),
zap.Int64("msgID", req.Base.MsgID),
zap.Int32("collectionReplicaNumber", collectionInfo.ReplicaNumber),
zap.Int32("requestReplicaNumber", lpt.ReplicaNumber))
lpt.result = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_IllegalArgument,
Reason: msg,
}
zap.Int32("requestReplicaNumber", req.ReplicaNumber))
return fmt.Errorf(msg+" [%w]", ErrLoadParametersMismatch)
}
for _, toLoadPartitionID := range lpt.PartitionIDs {
for _, toLoadPartitionID := range req.PartitionIDs {
needLoad := true
for _, loadedPartitionID := range collectionInfo.PartitionIDs {
if toLoadPartitionID == loadedPartitionID {
@ -887,35 +879,42 @@ func (lpt *loadPartitionTask) preExecute(context.Context) error {
// should release partitions first, then load partitions again
err = fmt.Errorf("some partitions %v of collection %d has been loaded into QueryNode, please release partitions firstly",
collectionInfo.PartitionIDs, collectionID)
lpt.setResultInfo(err)
return fmt.Errorf(err.Error()+" [%w]", ErrLoadParametersMismatch)
}
}
}
if lpt.result.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("loadPartitionRequest failed",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", lpt.PartitionIDs),
zap.Int64("msgID", lpt.Base.MsgID),
zap.Error(err))
return fmt.Errorf(err.Error()+" [%w]", ErrLoadParametersMismatch)
}
log.Info("loadPartitionRequest completed, all partitions to load have already been loaded into memory",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", lpt.CollectionID),
zap.Int64s("partitionIDs", lpt.PartitionIDs),
zap.Int64("msgID", lpt.Base.MsgID))
zap.Int64("collectionID", req.CollectionID),
zap.Int64s("partitionIDs", req.PartitionIDs),
zap.Int64("msgID", req.Base.MsgID))
return ErrCollectionLoaded
}
return nil
}
func (lpt *loadPartitionTask) preExecute(context.Context) error {
if lpt.ReplicaNumber < 1 {
log.Warn("replicaNumber is less than 1 for load partitions request, will set it to 1",
zap.Int32("replicaNumber", lpt.ReplicaNumber))
lpt.ReplicaNumber = 1
}
lpt.setResultInfo(nil)
err := checkLoadPartition(lpt.LoadPartitionsRequest, lpt.meta)
if err != nil {
lpt.setResultInfo(err)
return err
}
log.Info("start do loadPartitionTask",
zap.Int64("taskID", lpt.getTaskID()),
zap.Int64("msgID", lpt.GetBase().GetMsgID()),
zap.Int64("collectionID", collectionID))
zap.Int64("collectionID", lpt.GetCollectionID()))
return nil
}

View File

@ -64,6 +64,34 @@ func (queue *taskQueue) taskFull() bool {
return int64(queue.tasks.Len()) >= queue.maxTask
}
func (queue *taskQueue) willLoadOrRelease(collectionID UniqueID) commonpb.MsgType {
queue.Lock()
defer queue.Unlock()
// check the last task of this collection is load task or release task
for e := queue.tasks.Back(); e != nil; e = e.Prev() {
msgType := e.Value.(task).msgType()
switch msgType {
case commonpb.MsgType_LoadCollection:
if e.Value.(task).(*loadCollectionTask).GetCollectionID() == collectionID {
return msgType
}
case commonpb.MsgType_LoadPartitions:
if e.Value.(task).(*loadPartitionTask).GetCollectionID() == collectionID {
return msgType
}
case commonpb.MsgType_ReleaseCollection:
if e.Value.(task).(*releaseCollectionTask).GetCollectionID() == collectionID {
return msgType
}
case commonpb.MsgType_ReleasePartitions:
if e.Value.(task).(*releasePartitionTask).GetCollectionID() == collectionID {
return msgType
}
}
}
return commonpb.MsgType_Undefined
}
func (queue *taskQueue) addTask(t task) {
queue.Lock()
defer queue.Unlock()

View File

@ -614,3 +614,50 @@ func TestTaskScheduler_BindContext(t *testing.T) {
}, time.Second, time.Millisecond*10)
})
}
func TestTaskScheduler_willLoadOrRelease(t *testing.T) {
ctx := context.Background()
queryCoord := &QueryCoord{}
loadCollectionTask := genLoadCollectionTask(ctx, queryCoord)
loadPartitionTask := genLoadPartitionTask(ctx, queryCoord)
releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord)
releasePartitionTask := genReleasePartitionTask(ctx, queryCoord)
queue := newTaskQueue()
queue.tasks.PushBack(loadCollectionTask)
queue.tasks.PushBack(loadPartitionTask)
queue.tasks.PushBack(releaseCollectionTask)
queue.tasks.PushBack(releasePartitionTask)
queue.tasks.PushBack(loadCollectionTask)
loadCollectionTask.CollectionID++
queue.tasks.PushBack(loadCollectionTask) // add other collection's task
loadCollectionTask.CollectionID = defaultCollectionID
taskType := queue.willLoadOrRelease(defaultCollectionID)
assert.Equal(t, commonpb.MsgType_LoadCollection, taskType)
queue.tasks.PushBack(loadPartitionTask)
taskType = queue.willLoadOrRelease(defaultCollectionID)
assert.Equal(t, commonpb.MsgType_LoadPartitions, taskType)
queue.tasks.PushBack(releaseCollectionTask)
taskType = queue.willLoadOrRelease(defaultCollectionID)
assert.Equal(t, commonpb.MsgType_ReleaseCollection, taskType)
queue.tasks.PushBack(releasePartitionTask)
taskType = queue.willLoadOrRelease(defaultCollectionID)
assert.Equal(t, commonpb.MsgType_ReleasePartitions, taskType)
loadSegmentTask := &loadSegmentTask{
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
},
},
}
queue.tasks.PushBack(loadSegmentTask)
taskType = queue.willLoadOrRelease(defaultCollectionID)
// should be the last release or load for collection or partition
assert.Equal(t, commonpb.MsgType_ReleasePartitions, taskType)
}

View File

@ -18,6 +18,7 @@ package querycoord
import (
"context"
"errors"
"math/rand"
"testing"
"time"
@ -402,6 +403,53 @@ func Test_LoadCollectionAfterLoadPartition(t *testing.T) {
assert.Nil(t, err)
}
func TestLoadCollection_CheckLoadCollection(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.NoError(t, err)
node, err := startQueryNodeServer(ctx)
assert.NoError(t, err)
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
loadCollectionTask1 := genLoadCollectionTask(ctx, queryCoord)
err = checkLoadCollection(loadCollectionTask1.LoadCollectionRequest, loadCollectionTask1.meta)
assert.NoError(t, err) // Collection not loaded
err = queryCoord.scheduler.Enqueue(loadCollectionTask1)
assert.NoError(t, err)
err = loadCollectionTask1.waitToFinish()
assert.NoError(t, err)
loadCollectionTask2 := genLoadCollectionTask(ctx, queryCoord)
err = checkLoadCollection(loadCollectionTask2.LoadCollectionRequest, loadCollectionTask2.meta)
assert.Error(t, err) // Collection loaded
assert.True(t, errors.Is(err, ErrCollectionLoaded))
loadCollectionTask3 := genLoadCollectionTask(ctx, queryCoord)
loadCollectionTask3.ReplicaNumber++
err = checkLoadCollection(loadCollectionTask3.LoadCollectionRequest, loadCollectionTask3.meta)
assert.Error(t, err) // replica number mismatch
assert.True(t, errors.Is(err, ErrLoadParametersMismatch))
loadCollectionTask4 := genLoadCollectionTask(ctx, queryCoord)
err = loadCollectionTask4.meta.releaseCollection(loadCollectionTask4.CollectionID)
assert.NoError(t, err)
err = loadCollectionTask4.meta.addCollection(loadCollectionTask4.CollectionID, querypb.LoadType_LoadPartition, loadCollectionTask4.Schema)
assert.NoError(t, err)
err = checkLoadCollection(loadCollectionTask4.LoadCollectionRequest, loadCollectionTask4.meta)
assert.Error(t, err) // wrong load type, partition loaded before
assert.True(t, errors.Is(err, ErrLoadParametersMismatch))
node.stop()
queryCoord.Stop()
err = removeAllSession()
assert.NoError(t, err)
}
func Test_RepeatLoadCollection(t *testing.T) {
refreshParams()
ctx := context.Background()
@ -518,6 +566,43 @@ func Test_LoadPartitionAssignTaskFail(t *testing.T) {
assert.Nil(t, err)
}
func TestLoadPartition_CheckLoadPartition(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.NoError(t, err)
node, err := startQueryNodeServer(ctx)
assert.NoError(t, err)
waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID)
loadPartitionTask1 := genLoadPartitionTask(ctx, queryCoord)
err = checkLoadPartition(loadPartitionTask1.LoadPartitionsRequest, loadPartitionTask1.meta)
assert.NoError(t, err) // partition not load
err = queryCoord.scheduler.Enqueue(loadPartitionTask1)
assert.NoError(t, err)
err = loadPartitionTask1.waitToFinish()
assert.NoError(t, err)
loadPartitionTask2 := genLoadPartitionTask(ctx, queryCoord)
err = checkLoadPartition(loadPartitionTask2.LoadPartitionsRequest, loadPartitionTask2.meta)
assert.Error(t, err) // partition loaded
assert.True(t, errors.Is(err, ErrCollectionLoaded))
loadPartitionTask3 := genLoadPartitionTask(ctx, queryCoord)
loadPartitionTask3.ReplicaNumber++
err = checkLoadPartition(loadPartitionTask3.LoadPartitionsRequest, loadPartitionTask3.meta)
assert.Error(t, err) // replica number mismatch
assert.True(t, errors.Is(err, ErrLoadParametersMismatch))
node.stop()
queryCoord.Stop()
err = removeAllSession()
assert.NoError(t, err)
}
func Test_LoadPartitionExecuteFail(t *testing.T) {
refreshParams()
ctx := context.Background()