diff --git a/internal/querycoord/impl.go b/internal/querycoord/impl.go index 7fdba0da64..04b13b332f 100644 --- a/internal/querycoord/impl.go +++ b/internal/querycoord/impl.go @@ -203,6 +203,42 @@ func (qc *QueryCoord) ShowCollections(ctx context.Context, req *querypb.ShowColl }, nil } +func handleLoadError(err error, loadType querypb.LoadType, msgID, collectionID UniqueID, partitionIDs []UniqueID) (*commonpb.Status, error) { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + if errors.Is(err, ErrCollectionLoaded) { + log.Info("collection or partitions has already been loaded, return load success directly", + zap.String("loadType", loadType.String()), + zap.String("role", typeutil.QueryCoordRole), + zap.Int64("collectionID", collectionID), + zap.Int64s("partitionIDs", partitionIDs), + zap.Int64("msgID", msgID)) + + metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() + return status, nil + } else if errors.Is(err, ErrLoadParametersMismatch) { + status.ErrorCode = commonpb.ErrorCode_IllegalArgument + status.Reason = err.Error() + + metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc() + return status, nil + } else { + log.Error("load collection or partitions to query nodes failed", + zap.String("loadType", loadType.String()), + zap.String("role", typeutil.QueryCoordRole), + zap.Int64("collectionID", collectionID), + zap.Int64s("partitionIDs", partitionIDs), + zap.Int64("msgID", msgID), + zap.Error(err)) + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = err.Error() + + metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc() + return status, nil + } +} + // LoadCollection loads all the sealed segments of this collection to queryNodes, and assigns watchDmChannelRequest to queryNodes func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) { metrics.QueryCoordLoadCount.WithLabelValues(metrics.TotalLabel).Inc() @@ -235,6 +271,19 @@ func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadColle cluster: qc.cluster, meta: qc.meta, } + + LastTaskType := qc.scheduler.triggerTaskQueue.willLoadOrRelease(req.GetCollectionID()) + if LastTaskType == commonpb.MsgType_LoadCollection { + // collection will be loaded, remove idempotent loadCollection task, return success directly + return status, nil + } + if LastTaskType != commonpb.MsgType_ReleaseCollection { + err := checkLoadCollection(req, qc.meta) + if err != nil { + return handleLoadError(err, querypb.LoadType_LoadCollection, req.GetBase().GetMsgID(), req.GetCollectionID(), nil) + } + } + err := qc.scheduler.Enqueue(loadCollectionTask) if err != nil { log.Error("loadCollectionRequest failed to add execute task to scheduler", @@ -251,32 +300,7 @@ func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadColle err = loadCollectionTask.waitToFinish() if err != nil { - if errors.Is(err, ErrCollectionLoaded) { - log.Info("collection has already been loaded, return load success directly", - zap.String("role", typeutil.QueryCoordRole), - zap.Int64("collectionID", collectionID), - zap.Int64("msgID", req.Base.MsgID)) - - metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() - return status, nil - } else if errors.Is(err, ErrLoadParametersMismatch) { - status.ErrorCode = commonpb.ErrorCode_IllegalArgument - status.Reason = err.Error() - - metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc() - return status, nil - } else { - log.Error("load collection to query nodes failed", - zap.String("role", typeutil.QueryCoordRole), - zap.Int64("collectionID", collectionID), - zap.Int64("msgID", req.Base.MsgID), - zap.Error(err)) - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err.Error() - - metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc() - return status, nil - } + return handleLoadError(err, querypb.LoadType_LoadCollection, req.GetBase().GetMsgID(), req.GetCollectionID(), nil) } log.Info("loadCollectionRequest completed", @@ -508,6 +532,19 @@ func (qc *QueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadParti cluster: qc.cluster, meta: qc.meta, } + + LastTaskType := qc.scheduler.triggerTaskQueue.willLoadOrRelease(req.GetCollectionID()) + if LastTaskType == commonpb.MsgType_LoadPartitions { + // partitions will be loaded, remove idempotent loadPartition task, return success directly + return status, nil + } + if LastTaskType != commonpb.MsgType_ReleasePartitions { + err := checkLoadPartition(req, qc.meta) + if err != nil { + return handleLoadError(err, querypb.LoadType_LoadPartition, req.GetBase().GetMsgID(), req.GetCollectionID(), req.GetPartitionIDs()) + } + } + err := qc.scheduler.Enqueue(loadPartitionTask) if err != nil { log.Error("loadPartitionRequest failed to add execute task to scheduler", @@ -525,34 +562,7 @@ func (qc *QueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadParti err = loadPartitionTask.waitToFinish() if err != nil { - if errors.Is(err, ErrCollectionLoaded) { - log.Info("loadPartitionRequest completed, all partitions to load have already been loaded into memory", - zap.String("role", typeutil.QueryCoordRole), - zap.Int64("collectionID", req.CollectionID), - zap.Int64s("partitionIDs", req.PartitionIDs), - zap.Int64("msgID", req.Base.MsgID)) - - metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() - return status, nil - } else if errors.Is(err, ErrLoadParametersMismatch) { - status.ErrorCode = commonpb.ErrorCode_IllegalArgument - status.Reason = err.Error() - - metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc() - return status, nil - } else { - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err.Error() - log.Error("loadPartitionRequest failed", - zap.String("role", typeutil.QueryCoordRole), - zap.Int64("collectionID", req.CollectionID), - zap.Int64s("partitionIDs", partitionIDs), - zap.Int64("msgID", req.Base.MsgID), - zap.Error(err)) - - metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc() - return status, nil - } + return handleLoadError(err, querypb.LoadType_LoadPartition, req.GetBase().GetMsgID(), req.GetCollectionID(), req.GetPartitionIDs()) } log.Info("loadPartitionRequest completed", diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 7eb62d3b86..d7746ed2ee 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -345,37 +345,22 @@ func (lct *loadCollectionTask) updateTaskProcess() { //this function shall just calculate intermediate progress } -func (lct *loadCollectionTask) preExecute(ctx context.Context) error { - if lct.ReplicaNumber < 1 { - log.Warn("replicaNumber is less than 1 for load collection request, will set it to 1", - zap.Int32("replicaNumber", lct.ReplicaNumber)) - lct.ReplicaNumber = 1 - } - - collectionID := lct.CollectionID - schema := lct.Schema - - lct.setResultInfo(nil) - - collectionInfo, err := lct.meta.getCollectionInfoByID(collectionID) +func checkLoadCollection(req *querypb.LoadCollectionRequest, meta Meta) error { + collectionID := req.CollectionID + collectionInfo, err := meta.getCollectionInfoByID(collectionID) if err == nil { // if collection has been loaded by load collection request, return success if collectionInfo.LoadType == querypb.LoadType_LoadCollection { - if collectionInfo.ReplicaNumber != lct.ReplicaNumber { + if collectionInfo.ReplicaNumber != req.ReplicaNumber { msg := fmt.Sprintf("collection has already been loaded, and the number of replicas %v is not same as the request's %v. Should release first then reload with the new number of replicas", collectionInfo.ReplicaNumber, - lct.ReplicaNumber) + req.ReplicaNumber) log.Warn(msg, zap.String("role", typeutil.QueryCoordRole), zap.Int64("collectionID", collectionID), - zap.Int64("msgID", lct.Base.MsgID), + zap.Int64("msgID", req.Base.MsgID), zap.Int32("collectionReplicaNumber", collectionInfo.ReplicaNumber), - zap.Int32("requestReplicaNumber", lct.ReplicaNumber)) - - lct.result = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: msg, - } + zap.Int32("requestReplicaNumber", req.ReplicaNumber)) return fmt.Errorf(msg+" [%w]", ErrLoadParametersMismatch) } @@ -390,15 +375,35 @@ func (lct *loadCollectionTask) preExecute(ctx context.Context) error { zap.String("role", typeutil.QueryCoordRole), zap.Int64("collectionID", collectionID), zap.Int64s("loaded partitionIDs", collectionInfo.PartitionIDs), - zap.Int64("msgID", lct.Base.MsgID), + zap.Int64("msgID", req.Base.MsgID), zap.Error(err)) - lct.setResultInfo(err) metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc() return fmt.Errorf(err.Error()+" [%w]", ErrLoadParametersMismatch) } } + return nil +} + +func (lct *loadCollectionTask) preExecute(ctx context.Context) error { + if lct.ReplicaNumber < 1 { + log.Warn("replicaNumber is less than 1 for load collection request, will set it to 1", + zap.Int32("replicaNumber", lct.ReplicaNumber)) + lct.ReplicaNumber = 1 + } + + collectionID := lct.CollectionID + schema := lct.Schema + + lct.setResultInfo(nil) + + err := checkLoadCollection(lct.LoadCollectionRequest, lct.meta) + if err != nil { + lct.setResultInfo(err) + return err + } + log.Info("start do loadCollectionTask", zap.Int64("taskID", lct.getTaskID()), zap.Int64("msgID", lct.GetBase().GetMsgID()), @@ -837,44 +842,31 @@ func (lpt *loadPartitionTask) updateTaskProcess() { //this function shall just calculate intermediate progress } -func (lpt *loadPartitionTask) preExecute(context.Context) error { - if lpt.ReplicaNumber < 1 { - log.Warn("replicaNumber is less than 1 for load partitions request, will set it to 1", - zap.Int32("replicaNumber", lpt.ReplicaNumber)) - lpt.ReplicaNumber = 1 - } - - lpt.setResultInfo(nil) - - collectionID := lpt.CollectionID - collectionInfo, err := lpt.meta.getCollectionInfoByID(collectionID) +func checkLoadPartition(req *querypb.LoadPartitionsRequest, meta Meta) error { + collectionID := req.CollectionID + collectionInfo, err := meta.getCollectionInfoByID(collectionID) if err == nil { // if the collection has been loaded into memory by load collection request, return error // should release collection first, then load partitions again if collectionInfo.LoadType == querypb.LoadType_LoadCollection { err = fmt.Errorf("collection %d has been loaded into QueryNode, please release collection firstly", collectionID) - lpt.setResultInfo(err) + return fmt.Errorf(err.Error()+" [%w]", ErrLoadParametersMismatch) } else if collectionInfo.LoadType == querypb.LoadType_LoadPartition { - if collectionInfo.ReplicaNumber != lpt.ReplicaNumber { + if collectionInfo.ReplicaNumber != req.ReplicaNumber { msg := fmt.Sprintf("partitions has already been loaded, and the number of replicas %v is not same as the request's %v. Should release first then reload with the new number of replicas", collectionInfo.ReplicaNumber, - lpt.ReplicaNumber) + req.ReplicaNumber) log.Warn(msg, zap.String("role", typeutil.QueryCoordRole), zap.Int64("collectionID", collectionID), - zap.Int64("msgID", lpt.Base.MsgID), + zap.Int64("msgID", req.Base.MsgID), zap.Int32("collectionReplicaNumber", collectionInfo.ReplicaNumber), - zap.Int32("requestReplicaNumber", lpt.ReplicaNumber)) - - lpt.result = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: msg, - } + zap.Int32("requestReplicaNumber", req.ReplicaNumber)) return fmt.Errorf(msg+" [%w]", ErrLoadParametersMismatch) } - for _, toLoadPartitionID := range lpt.PartitionIDs { + for _, toLoadPartitionID := range req.PartitionIDs { needLoad := true for _, loadedPartitionID := range collectionInfo.PartitionIDs { if toLoadPartitionID == loadedPartitionID { @@ -887,35 +879,42 @@ func (lpt *loadPartitionTask) preExecute(context.Context) error { // should release partitions first, then load partitions again err = fmt.Errorf("some partitions %v of collection %d has been loaded into QueryNode, please release partitions firstly", collectionInfo.PartitionIDs, collectionID) - lpt.setResultInfo(err) + return fmt.Errorf(err.Error()+" [%w]", ErrLoadParametersMismatch) } } } - if lpt.result.ErrorCode != commonpb.ErrorCode_Success { - log.Warn("loadPartitionRequest failed", - zap.String("role", typeutil.QueryCoordRole), - zap.Int64("collectionID", collectionID), - zap.Int64s("partitionIDs", lpt.PartitionIDs), - zap.Int64("msgID", lpt.Base.MsgID), - zap.Error(err)) - - return fmt.Errorf(err.Error()+" [%w]", ErrLoadParametersMismatch) - } - log.Info("loadPartitionRequest completed, all partitions to load have already been loaded into memory", zap.String("role", typeutil.QueryCoordRole), - zap.Int64("collectionID", lpt.CollectionID), - zap.Int64s("partitionIDs", lpt.PartitionIDs), - zap.Int64("msgID", lpt.Base.MsgID)) + zap.Int64("collectionID", req.CollectionID), + zap.Int64s("partitionIDs", req.PartitionIDs), + zap.Int64("msgID", req.Base.MsgID)) return ErrCollectionLoaded } + return nil +} + +func (lpt *loadPartitionTask) preExecute(context.Context) error { + if lpt.ReplicaNumber < 1 { + log.Warn("replicaNumber is less than 1 for load partitions request, will set it to 1", + zap.Int32("replicaNumber", lpt.ReplicaNumber)) + lpt.ReplicaNumber = 1 + } + + lpt.setResultInfo(nil) + + err := checkLoadPartition(lpt.LoadPartitionsRequest, lpt.meta) + if err != nil { + lpt.setResultInfo(err) + return err + } + log.Info("start do loadPartitionTask", zap.Int64("taskID", lpt.getTaskID()), zap.Int64("msgID", lpt.GetBase().GetMsgID()), - zap.Int64("collectionID", collectionID)) + zap.Int64("collectionID", lpt.GetCollectionID())) return nil } diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index b685c70141..ad982ec200 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -64,6 +64,34 @@ func (queue *taskQueue) taskFull() bool { return int64(queue.tasks.Len()) >= queue.maxTask } +func (queue *taskQueue) willLoadOrRelease(collectionID UniqueID) commonpb.MsgType { + queue.Lock() + defer queue.Unlock() + // check the last task of this collection is load task or release task + for e := queue.tasks.Back(); e != nil; e = e.Prev() { + msgType := e.Value.(task).msgType() + switch msgType { + case commonpb.MsgType_LoadCollection: + if e.Value.(task).(*loadCollectionTask).GetCollectionID() == collectionID { + return msgType + } + case commonpb.MsgType_LoadPartitions: + if e.Value.(task).(*loadPartitionTask).GetCollectionID() == collectionID { + return msgType + } + case commonpb.MsgType_ReleaseCollection: + if e.Value.(task).(*releaseCollectionTask).GetCollectionID() == collectionID { + return msgType + } + case commonpb.MsgType_ReleasePartitions: + if e.Value.(task).(*releasePartitionTask).GetCollectionID() == collectionID { + return msgType + } + } + } + return commonpb.MsgType_Undefined +} + func (queue *taskQueue) addTask(t task) { queue.Lock() defer queue.Unlock() diff --git a/internal/querycoord/task_scheduler_test.go b/internal/querycoord/task_scheduler_test.go index d305192a2d..d05f49e6ae 100644 --- a/internal/querycoord/task_scheduler_test.go +++ b/internal/querycoord/task_scheduler_test.go @@ -614,3 +614,50 @@ func TestTaskScheduler_BindContext(t *testing.T) { }, time.Second, time.Millisecond*10) }) } + +func TestTaskScheduler_willLoadOrRelease(t *testing.T) { + ctx := context.Background() + queryCoord := &QueryCoord{} + + loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) + loadPartitionTask := genLoadPartitionTask(ctx, queryCoord) + releaseCollectionTask := genReleaseCollectionTask(ctx, queryCoord) + releasePartitionTask := genReleasePartitionTask(ctx, queryCoord) + + queue := newTaskQueue() + queue.tasks.PushBack(loadCollectionTask) + queue.tasks.PushBack(loadPartitionTask) + queue.tasks.PushBack(releaseCollectionTask) + queue.tasks.PushBack(releasePartitionTask) + queue.tasks.PushBack(loadCollectionTask) + loadCollectionTask.CollectionID++ + queue.tasks.PushBack(loadCollectionTask) // add other collection's task + loadCollectionTask.CollectionID = defaultCollectionID + + taskType := queue.willLoadOrRelease(defaultCollectionID) + assert.Equal(t, commonpb.MsgType_LoadCollection, taskType) + + queue.tasks.PushBack(loadPartitionTask) + taskType = queue.willLoadOrRelease(defaultCollectionID) + assert.Equal(t, commonpb.MsgType_LoadPartitions, taskType) + + queue.tasks.PushBack(releaseCollectionTask) + taskType = queue.willLoadOrRelease(defaultCollectionID) + assert.Equal(t, commonpb.MsgType_ReleaseCollection, taskType) + + queue.tasks.PushBack(releasePartitionTask) + taskType = queue.willLoadOrRelease(defaultCollectionID) + assert.Equal(t, commonpb.MsgType_ReleasePartitions, taskType) + + loadSegmentTask := &loadSegmentTask{ + LoadSegmentsRequest: &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadSegments, + }, + }, + } + queue.tasks.PushBack(loadSegmentTask) + taskType = queue.willLoadOrRelease(defaultCollectionID) + // should be the last release or load for collection or partition + assert.Equal(t, commonpb.MsgType_ReleasePartitions, taskType) +} diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 01a2ff1c2e..0489f8f71c 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -18,6 +18,7 @@ package querycoord import ( "context" + "errors" "math/rand" "testing" "time" @@ -402,6 +403,53 @@ func Test_LoadCollectionAfterLoadPartition(t *testing.T) { assert.Nil(t, err) } +func TestLoadCollection_CheckLoadCollection(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.NoError(t, err) + + node, err := startQueryNodeServer(ctx) + assert.NoError(t, err) + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) + + loadCollectionTask1 := genLoadCollectionTask(ctx, queryCoord) + + err = checkLoadCollection(loadCollectionTask1.LoadCollectionRequest, loadCollectionTask1.meta) + assert.NoError(t, err) // Collection not loaded + + err = queryCoord.scheduler.Enqueue(loadCollectionTask1) + assert.NoError(t, err) + + err = loadCollectionTask1.waitToFinish() + assert.NoError(t, err) + + loadCollectionTask2 := genLoadCollectionTask(ctx, queryCoord) + err = checkLoadCollection(loadCollectionTask2.LoadCollectionRequest, loadCollectionTask2.meta) + assert.Error(t, err) // Collection loaded + assert.True(t, errors.Is(err, ErrCollectionLoaded)) + + loadCollectionTask3 := genLoadCollectionTask(ctx, queryCoord) + loadCollectionTask3.ReplicaNumber++ + err = checkLoadCollection(loadCollectionTask3.LoadCollectionRequest, loadCollectionTask3.meta) + assert.Error(t, err) // replica number mismatch + assert.True(t, errors.Is(err, ErrLoadParametersMismatch)) + + loadCollectionTask4 := genLoadCollectionTask(ctx, queryCoord) + err = loadCollectionTask4.meta.releaseCollection(loadCollectionTask4.CollectionID) + assert.NoError(t, err) + err = loadCollectionTask4.meta.addCollection(loadCollectionTask4.CollectionID, querypb.LoadType_LoadPartition, loadCollectionTask4.Schema) + assert.NoError(t, err) + err = checkLoadCollection(loadCollectionTask4.LoadCollectionRequest, loadCollectionTask4.meta) + assert.Error(t, err) // wrong load type, partition loaded before + assert.True(t, errors.Is(err, ErrLoadParametersMismatch)) + + node.stop() + queryCoord.Stop() + err = removeAllSession() + assert.NoError(t, err) +} + func Test_RepeatLoadCollection(t *testing.T) { refreshParams() ctx := context.Background() @@ -518,6 +566,43 @@ func Test_LoadPartitionAssignTaskFail(t *testing.T) { assert.Nil(t, err) } +func TestLoadPartition_CheckLoadPartition(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.NoError(t, err) + + node, err := startQueryNodeServer(ctx) + assert.NoError(t, err) + waitQueryNodeOnline(queryCoord.cluster, node.queryNodeID) + + loadPartitionTask1 := genLoadPartitionTask(ctx, queryCoord) + err = checkLoadPartition(loadPartitionTask1.LoadPartitionsRequest, loadPartitionTask1.meta) + assert.NoError(t, err) // partition not load + + err = queryCoord.scheduler.Enqueue(loadPartitionTask1) + assert.NoError(t, err) + + err = loadPartitionTask1.waitToFinish() + assert.NoError(t, err) + + loadPartitionTask2 := genLoadPartitionTask(ctx, queryCoord) + err = checkLoadPartition(loadPartitionTask2.LoadPartitionsRequest, loadPartitionTask2.meta) + assert.Error(t, err) // partition loaded + assert.True(t, errors.Is(err, ErrCollectionLoaded)) + + loadPartitionTask3 := genLoadPartitionTask(ctx, queryCoord) + loadPartitionTask3.ReplicaNumber++ + err = checkLoadPartition(loadPartitionTask3.LoadPartitionsRequest, loadPartitionTask3.meta) + assert.Error(t, err) // replica number mismatch + assert.True(t, errors.Is(err, ErrLoadParametersMismatch)) + + node.stop() + queryCoord.Stop() + err = removeAllSession() + assert.NoError(t, err) +} + func Test_LoadPartitionExecuteFail(t *testing.T) { refreshParams() ctx := context.Background()