From b71a404776216ec11f551e02974a4b2dbabc3438 Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Mon, 10 Jun 2024 21:50:29 +0800 Subject: [PATCH] fix: Check if the import job exists (#33672) (#33673) issue: https://github.com/milvus-io/milvus/issues/33671 pr: https://github.com/milvus-io/milvus/pull/33672 --------- Signed-off-by: bigsheeper --- internal/datacoord/import_util.go | 3 +++ internal/datacoord/import_util_test.go | 6 ++++++ internal/datacoord/services.go | 4 ++++ internal/datacoord/services_test.go | 9 ++++++++- internal/datanode/importv2/util_test.go | 2 +- 5 files changed, 22 insertions(+), 2 deletions(-) diff --git a/internal/datacoord/import_util.go b/internal/datacoord/import_util.go index 95abf147f3..fe0c1cbde0 100644 --- a/internal/datacoord/import_util.go +++ b/internal/datacoord/import_util.go @@ -382,6 +382,9 @@ func getImportingProgress(jobID int64, imeta ImportMeta, meta *meta) (float32, i func GetJobProgress(jobID int64, imeta ImportMeta, meta *meta) (int64, internalpb.ImportJobState, int64, int64, string) { job := imeta.GetJob(jobID) + if job == nil { + return 0, internalpb.ImportJobState_Failed, 0, 0, fmt.Sprintf("import job does not exist, jobID=%d", jobID) + } switch job.GetState() { case internalpb.ImportJobState_Pending: progress := getPendingProgress(jobID, imeta) diff --git a/internal/datacoord/import_util_test.go b/internal/datacoord/import_util_test.go index 4ca55f5439..998b741683 100644 --- a/internal/datacoord/import_util_test.go +++ b/internal/datacoord/import_util_test.go @@ -538,6 +538,12 @@ func TestImportUtil_GetImportProgress(t *testing.T) { assert.Equal(t, internalpb.ImportJobState_Failed, state) assert.Equal(t, mockErr, reason) + // job does not exist + progress, state, _, _, reason = GetJobProgress(-1, imeta, meta) + assert.Equal(t, int64(0), progress) + assert.Equal(t, internalpb.ImportJobState_Failed, state) + assert.NotEqual(t, "", reason) + // pending state err = imeta.UpdateJob(job.GetJobID(), UpdateJobState(internalpb.ImportJobState_Pending)) assert.NoError(t, err) diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 84d32b7a85..50e62b6a94 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -1758,6 +1758,10 @@ func (s *Server) GetImportProgress(ctx context.Context, in *internalpb.GetImport return resp, nil } job := s.importMeta.GetJob(jobID) + if job == nil { + resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("import job does not exist, jobID=%d", jobID))) + return resp, nil + } progress, state, importedRows, totalRows, reason := GetJobProgress(jobID, s.importMeta, s.meta) resp.State = state resp.Reason = reason diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 5454c08e93..0c7728f2eb 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -1698,7 +1698,7 @@ func TestImportV2(t *testing.T) { assert.NoError(t, err) assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) - // normal case + // job does not exist catalog := mocks.NewDataCoordCatalog(t) catalog.EXPECT().ListImportJobs().Return(nil, nil) catalog.EXPECT().ListPreImportTasks().Return(nil, nil) @@ -1706,6 +1706,13 @@ func TestImportV2(t *testing.T) { catalog.EXPECT().SaveImportJob(mock.Anything).Return(nil) s.importMeta, err = NewImportMeta(catalog) assert.NoError(t, err) + resp, err = s.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{ + JobID: "-1", + }) + assert.NoError(t, err) + assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) + + // normal case var job ImportJob = &importJob{ ImportJob: &datapb.ImportJob{ JobID: 0, diff --git a/internal/datanode/importv2/util_test.go b/internal/datanode/importv2/util_test.go index 15c1c5cf81..b1cca451e3 100644 --- a/internal/datanode/importv2/util_test.go +++ b/internal/datanode/importv2/util_test.go @@ -152,7 +152,7 @@ func Test_PickSegment(t *testing.T) { importedSize := map[int64]int{} totalSize := 8 * 1024 * 1024 * 1024 - batchSize := 16 * 1024 * 1024 + batchSize := 1 * 1024 * 1024 for totalSize > 0 { picked := PickSegment(task.req.GetRequestSegments(), vchannel, partitionID)