diff --git a/internal/datacoord/import_job.go b/internal/datacoord/import_job.go index 08f5503d68..4ae8cdecc8 100644 --- a/internal/datacoord/import_job.go +++ b/internal/datacoord/import_job.go @@ -61,6 +61,12 @@ func WithoutJobStates(states ...internalpb.ImportJobState) ImportJobFilter { } } +func WithDbID(DbID int64) ImportJobFilter { + return func(job ImportJob) bool { + return job.GetDbID() == DbID + } +} + type UpdateJobAction func(job ImportJob) func UpdateJobState(state internalpb.ImportJobState) UpdateJobAction { @@ -100,6 +106,7 @@ func UpdateJobCompleteTime(completeTime string) UpdateJobAction { type ImportJob interface { GetJobID() int64 + GetDbID() int64 GetCollectionID() int64 GetCollectionName() string GetPartitionIDs() []int64 diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index a2e546ffd9..dc80b6d3bd 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -1631,7 +1631,9 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter Status: merr.Success(), } - log := log.With(zap.Int64("collection", in.GetCollectionID()), + log := log.With( + zap.Int64("dbID", in.GetDbID()), + zap.Int64("collection", in.GetCollectionID()), zap.Int64s("partitions", in.GetPartitionIDs()), zap.Strings("channels", in.GetChannelNames())) log.Info("receive import request", zap.Any("files", in.GetFiles())) @@ -1697,6 +1699,7 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter job := &importJob{ ImportJob: &datapb.ImportJob{ JobID: idStart, + DbID: in.GetDbID(), CollectionID: in.GetCollectionID(), CollectionName: in.GetCollectionName(), PartitionIDs: in.GetPartitionIDs(), @@ -1723,7 +1726,7 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter } func (s *Server) GetImportProgress(ctx context.Context, in *internalpb.GetImportProgressRequest) (*internalpb.GetImportProgressResponse, error) { - log := log.With(zap.String("jobID", in.GetJobID())) + log := log.With(zap.String("jobID", in.GetJobID()), zap.Int64("dbID", in.GetDbID())) if err := merr.CheckHealthy(s.GetStateCode()); err != nil { return &internalpb.GetImportProgressResponse{ Status: merr.Status(err), @@ -1743,6 +1746,10 @@ func (s *Server) GetImportProgress(ctx context.Context, in *internalpb.GetImport resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("import job does not exist, jobID=%d", jobID))) return resp, nil } + if job.GetDbID() != 0 && job.GetDbID() != in.GetDbID() { + resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("import job does not exist, jobID=%d, dbID=%d", jobID, in.GetDbID()))) + return resp, nil + } progress, state, importedRows, totalRows, reason := GetJobProgress(jobID, s.importMeta, s.meta) resp.State = state resp.Reason = reason @@ -1773,11 +1780,14 @@ func (s *Server) ListImports(ctx context.Context, req *internalpb.ListImportsReq } var jobs []ImportJob - if req.GetCollectionID() != 0 { - jobs = s.importMeta.GetJobBy(WithCollectionID(req.GetCollectionID())) - } else { - jobs = s.importMeta.GetJobBy() + filters := make([]ImportJobFilter, 0) + if req.GetDbID() != 0 { + filters = append(filters, WithDbID(req.GetDbID())) } + if req.GetCollectionID() != 0 { + filters = append(filters, WithCollectionID(req.GetCollectionID())) + } + jobs = s.importMeta.GetJobBy(filters...) for _, job := range jobs { progress, state, _, _, reason := GetJobProgress(job.GetJobID(), s.importMeta, s.meta) @@ -1787,5 +1797,7 @@ func (s *Server) ListImports(ctx context.Context, req *internalpb.ListImportsReq resp.Progresses = append(resp.Progresses, progress) resp.CollectionNames = append(resp.CollectionNames, job.GetCollectionName()) } + log.Info("ListImports done", zap.Int64("collectionID", req.GetCollectionID()), + zap.Int64("dbID", req.GetDbID()), zap.Any("resp", resp)) return resp, nil } diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 90a2372b0c..205ace7dc1 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -1711,9 +1711,10 @@ func TestImportV2(t *testing.T) { assert.NoError(t, err) assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) - // normal case + // db does not exist var job ImportJob = &importJob{ ImportJob: &datapb.ImportJob{ + DbID: 1, JobID: 0, Schema: &schemapb.CollectionSchema{}, State: internalpb.ImportJobState_Failed, @@ -1722,12 +1723,31 @@ func TestImportV2(t *testing.T) { err = s.importMeta.AddJob(job) assert.NoError(t, err) resp, err = s.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{ + DbID: 2, + JobID: "0", + }) + assert.NoError(t, err) + assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) + + // normal case + job = &importJob{ + ImportJob: &datapb.ImportJob{ + DbID: 1, + JobID: 0, + Schema: &schemapb.CollectionSchema{}, + State: internalpb.ImportJobState_Pending, + }, + } + err = s.importMeta.AddJob(job) + assert.NoError(t, err) + resp, err = s.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{ + DbID: 1, JobID: "0", }) assert.NoError(t, err) assert.Equal(t, int32(0), resp.GetStatus().GetCode()) - assert.Equal(t, int64(0), resp.GetProgress()) - assert.Equal(t, internalpb.ImportJobState_Failed, resp.GetState()) + assert.Equal(t, int64(10), resp.GetProgress()) + assert.Equal(t, internalpb.ImportJobState_Pending, resp.GetState()) }) t.Run("ListImports", func(t *testing.T) { @@ -1750,6 +1770,7 @@ func TestImportV2(t *testing.T) { assert.NoError(t, err) var job ImportJob = &importJob{ ImportJob: &datapb.ImportJob{ + DbID: 2, JobID: 0, CollectionID: 1, Schema: &schemapb.CollectionSchema{}, @@ -1766,7 +1787,20 @@ func TestImportV2(t *testing.T) { } err = s.importMeta.AddTask(task) assert.NoError(t, err) + // db id not match resp, err = s.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: 3, + CollectionID: 1, + }) + assert.NoError(t, err) + assert.Equal(t, int32(0), resp.GetStatus().GetCode()) + assert.Equal(t, 0, len(resp.GetJobIDs())) + assert.Equal(t, 0, len(resp.GetStates())) + assert.Equal(t, 0, len(resp.GetReasons())) + assert.Equal(t, 0, len(resp.GetProgresses())) + // db id match + resp, err = s.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: 2, CollectionID: 1, }) assert.NoError(t, err) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 4d351f744f..40a9a9b812 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -139,8 +139,8 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listImportJob))))) router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createImportJob))))) - router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) - router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) + router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &GetImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) + router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &GetImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) } type ( diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index e7f1364a70..ad517fc21a 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -94,11 +94,14 @@ func (req *ImportReq) GetOptions() map[string]string { return req.Options } -type JobIDReq struct { - JobID string `json:"jobId" binding:"required"` +type GetImportReq struct { + DbName string `json:"dbName"` + JobID string `json:"jobId" binding:"required"` } -func (req *JobIDReq) GetJobID() string { return req.JobID } +func (req *GetImportReq) GetJobID() string { return req.JobID } + +func (req *GetImportReq) GetDbName() string { return req.DbName } type QueryReqV2 struct { DbName string `json:"dbName"` diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 86bc546523..cc6b752b1b 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -336,6 +336,7 @@ message ImportResponse { message GetImportProgressRequest { string db_name = 1; string jobID = 2; + int64 dbID = 3; } message ImportTaskProgress { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 49cdaa7f7c..527f858010 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -6158,6 +6158,7 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) return &internalpb.ImportResponse{Status: merr.Status(err)}, nil } log := log.Ctx(ctx).With( + zap.String("dbName", req.GetDbName()), zap.String("collectionName", req.GetCollectionName()), zap.String("partition name", req.GetPartitionName()), zap.Any("files", req.GetFiles()), @@ -6183,6 +6184,11 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) } }() + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { resp.Status = merr.Status(err) @@ -6293,6 +6299,7 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) } } importRequest := &internalpb.ImportRequestInternal{ + DbID: dbInfo.dbID, CollectionID: collectionID, CollectionName: req.GetCollectionName(), PartitionIDs: partitionIDs, @@ -6317,14 +6324,28 @@ func (node *Proxy) GetImportProgress(ctx context.Context, req *internalpb.GetImp }, nil } log := log.Ctx(ctx).With( + zap.String("dbName", req.GetDbName()), zap.String("jobID", req.GetJobID()), ) + + resp := &internalpb.GetImportProgressResponse{ + Status: merr.Success(), + } + method := "GetImportProgress" tr := timerecord.NewTimeRecorder(method) log.Info(rpcReceived(method)) + // Fill db id for datacoord. + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } + req.DbID = dbInfo.dbID + nodeID := fmt.Sprint(paramtable.GetNodeID()) - resp, err := node.dataCoord.GetImportProgress(ctx, req) + resp, err = node.dataCoord.GetImportProgress(ctx, req) if resp.GetStatus().GetCode() != 0 || err != nil { log.Warn("get import progress failed", zap.String("reason", resp.GetStatus().GetReason()), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(nodeID, method, metrics.FailLabel, req.GetDbName(), "").Inc() @@ -6361,6 +6382,11 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR err error collectionID UniqueID ) + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } if req.GetCollectionName() != "" { collectionID, err = globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { @@ -6369,7 +6395,9 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR return resp, nil } } + resp, err = node.dataCoord.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: dbInfo.dbID, CollectionID: collectionID, }) if resp.GetStatus().GetCode() != 0 || err != nil { diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 6104b99be5..054e9e01b0 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -1596,8 +1596,17 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) - // no such collection + // no such database mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.ImportV2(ctx, &internalpb.ImportRequest{CollectionName: "aaa"}) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + + // no such collection + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, mockErr) globalMetaCache = mc rsp, err = node.ImportV2(ctx, &internalpb.ImportRequest{CollectionName: "aaa"}) @@ -1606,6 +1615,7 @@ func TestProxy_ImportV2(t *testing.T) { // get schema failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(nil, mockErr) globalMetaCache = mc @@ -1615,6 +1625,7 @@ func TestProxy_ImportV2(t *testing.T) { // get channel failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{ @@ -1639,6 +1650,7 @@ func TestProxy_ImportV2(t *testing.T) { // get partitions failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{ @@ -1653,6 +1665,7 @@ func TestProxy_ImportV2(t *testing.T) { // get partitionID failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{}, @@ -1665,6 +1678,7 @@ func TestProxy_ImportV2(t *testing.T) { // no file mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{}, @@ -1711,7 +1725,18 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) + // no such database + mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{}) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + // normal case + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) + globalMetaCache = mc dataCoord := mocks.NewMockDataCoordClient(t) dataCoord.EXPECT().GetImportProgress(mock.Anything, mock.Anything).Return(nil, nil) node.dataCoord = dataCoord @@ -1729,8 +1754,19 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) - // normal case + // no such database mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.ListImports(ctx, &internalpb.ListImportsRequest{ + CollectionName: "col", + }) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + + // normal case + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) globalMetaCache = mc dataCoord := mocks.NewMockDataCoordClient(t) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index c500bff646..f106d6c387 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -4570,6 +4570,7 @@ func TestProxy_Import(t *testing.T) { proxy.UpdateStateCode(commonpb.StateCode_Healthy) mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{}, @@ -4610,6 +4611,10 @@ func TestProxy_Import(t *testing.T) { proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) + mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) + globalMetaCache = mc + dataCoord := mocks.NewMockDataCoordClient(t) dataCoord.EXPECT().GetImportProgress(mock.Anything, mock.Anything).Return(&internalpb.GetImportProgressResponse{ Status: merr.Success(), @@ -4635,6 +4640,10 @@ func TestProxy_Import(t *testing.T) { proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) + mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) + globalMetaCache = mc + dataCoord := mocks.NewMockDataCoordClient(t) dataCoord.EXPECT().ListImports(mock.Anything, mock.Anything).Return(&internalpb.ListImportsResponse{ Status: merr.Success(),