diff --git a/internal/datanode/services.go b/internal/datanode/services.go index ee96396106..5e0458b911 100644 --- a/internal/datanode/services.go +++ b/internal/datanode/services.go @@ -495,15 +495,24 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) return returnFailFunc("failed to get collection info for collection ID", err) } - // the colInfo doesn't have a collect database name(it is empty). use the database name passed from rootcoord. - partitions, err := node.getPartitions(ctx, req.GetImportTask().GetDatabaseName(), colInfo.GetCollectionName()) - if err != nil { - return returnFailFunc("failed to get partition id list", err) - } - - partitionIDs, err := importutil.DeduceTargetPartitions(partitions, colInfo.GetSchema(), req.GetImportTask().GetPartitionId()) - if err != nil { - return returnFailFunc("failed to decude target partitions", err) + var partitionIDs []int64 + if req.GetImportTask().GetPartitionId() == 0 { + if !typeutil.HasPartitionKey(colInfo.GetSchema()) { + err = errors.New("try auto-distribute data but the collection has no partition key") + return returnFailFunc(err.Error(), err) + } + // TODO: prefer to set partitionIDs in coord instead of get here. + // the colInfo doesn't have a correct database name(it is empty). use the database name passed from rootcoord. + partitions, err := node.getPartitions(ctx, req.GetImportTask().GetDatabaseName(), colInfo.GetCollectionName()) + if err != nil { + return returnFailFunc("failed to get partition id list", err) + } + _, partitionIDs, err = typeutil.RearrangePartitionsForPartitionKey(partitions) + if err != nil { + return returnFailFunc("failed to rearrange target partitions", err) + } + } else { + partitionIDs = []int64{req.GetImportTask().GetPartitionId()} } collectionInfo, err := importutil.NewCollectionInfo(colInfo.GetSchema(), colInfo.GetShardsNum(), partitionIDs) diff --git a/internal/datanode/services_test.go b/internal/datanode/services_test.go index f6ef782b52..28cea854fb 100644 --- a/internal/datanode/services_test.go +++ b/internal/datanode/services_test.go @@ -455,6 +455,18 @@ func (s *DataNodeServicesSuite) TestImport() { s.Assert().NoError(err) s.Assert().True(merr.Ok(stat)) s.Assert().Equal("", stat.GetReason()) + + reqWithoutPartition := &datapb.ImportTaskRequest{ + ImportTask: &datapb.ImportTask{ + CollectionId: 100, + ChannelNames: []string{chName1, chName2}, + Files: []string{filePath}, + RowBased: true, + }, + } + stat2, err := s.node.Import(context.WithValue(s.ctx, ctxKey{}, ""), reqWithoutPartition) + s.Assert().NoError(err) + s.Assert().False(merr.Ok(stat2)) }) s.Run("Test Import bad flow graph", func() { diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index e770b9d750..3ffc8c7d73 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -1847,20 +1847,7 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus return nil, err } - // Backup tool call import must with a partition name, each time restore a partition isBackUp := importutil.IsBackup(req.GetOptions()) - if isBackUp { - if len(req.GetPartitionName()) == 0 { - log.Info("partition name not specified when backup recovery", - zap.String("collectionName", req.GetCollectionName())) - ret := &milvuspb.ImportResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, - "partition name not specified when backup"), - } - return ret, nil - } - } - cID := colInfo.CollectionID req.ChannelNames = c.meta.GetCollectionVirtualChannels(cID) @@ -1872,24 +1859,45 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus } } - // If has partition key and not backup/restore mode, don't allow user to specify partition name - if hasPartitionKey && !isBackUp && req.GetPartitionName() != "" { - msg := "not allow to set partition name for collection with partition key" - log.Warn(msg, zap.String("collection name", req.GetCollectionName())) - return nil, errors.New(msg) - } - // Get partition ID by partition name var pID UniqueID - if !hasPartitionKey { - if req.GetPartitionName() == "" { - req.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue() + if isBackUp { + // Currently, Backup tool call import must with a partition name, each time restore a partition + if req.GetPartitionName() != "" { + if pID, err = c.meta.GetPartitionByName(cID, req.GetPartitionName(), typeutil.MaxTimestamp); err != nil { + log.Warn("failed to get partition ID from its name", zap.String("partition name", req.GetPartitionName()), zap.Error(err)) + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapBulkInsertPartitionNotFound(req.GetCollectionName(), req.GetPartitionName())), + }, nil + } + } else { + log.Info("partition name not specified when backup recovery", + zap.String("collectionName", req.GetCollectionName())) + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapBadBulkInsertRequest("partition name not specified when backup")), + }, nil } - if pID, err = c.meta.GetPartitionByName(cID, req.GetPartitionName(), typeutil.MaxTimestamp); err != nil { - log.Warn("failed to get partition ID from its name", - zap.String("partition name", req.GetPartitionName()), - zap.Error(err)) - return nil, err + } else { + if hasPartitionKey { + if req.GetPartitionName() != "" { + msg := "not allow to set partition name for collection with partition key" + log.Warn(msg, zap.String("collection name", req.GetCollectionName())) + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapBadBulkInsertRequest(msg)), + }, nil + } + } else { + if req.GetPartitionName() == "" { + req.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue() + } + if pID, err = c.meta.GetPartitionByName(cID, req.GetPartitionName(), typeutil.MaxTimestamp); err != nil { + log.Warn("failed to get partition ID from its name", + zap.String("partition name", req.GetPartitionName()), + zap.Error(err)) + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapBulkInsertPartitionNotFound(req.GetCollectionName(), req.GetPartitionName())), + }, nil + } } } diff --git a/internal/rootcoord/root_coord_test.go b/internal/rootcoord/root_coord_test.go index 93a5412eea..3ba037c6aa 100644 --- a/internal/rootcoord/root_coord_test.go +++ b/internal/rootcoord/root_coord_test.go @@ -19,6 +19,7 @@ package rootcoord import ( "context" "fmt" + "github.com/milvus-io/milvus/pkg/common" "math/rand" "os" "sync" @@ -1052,10 +1053,11 @@ func TestCore_Import(t *testing.T) { meta.GetPartitionByNameFunc = func(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) { return 0, errors.New("mock GetPartitionByNameFunc error") } - _, err := c.Import(ctx, &milvuspb.ImportRequest{ + resp, err := c.Import(ctx, &milvuspb.ImportRequest{ CollectionName: "a-good-name", }) - assert.Error(t, err) + assert.NoError(t, err) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrBulkInsertPartitionNotFound) }) t.Run("normal case", func(t *testing.T) { @@ -1099,7 +1101,7 @@ func TestCore_Import(t *testing.T) { }, }) assert.NotNil(t, resp) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrBadBulkInsertRequest) }) // Remove the following case after bulkinsert can support partition key @@ -1152,11 +1154,69 @@ func TestCore_Import(t *testing.T) { meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { return coll.Clone(), nil } - _, err := c.Import(ctx, &milvuspb.ImportRequest{ + resp, err := c.Import(ctx, &milvuspb.ImportRequest{ CollectionName: "a-good-name", PartitionName: "p1", }) - assert.Error(t, err) + assert.NoError(t, err) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrBadBulkInsertRequest) + }) + + t.Run("backup should set partition name", func(t *testing.T) { + ctx := context.Background() + c := newTestCore(withHealthyCode(), + withMeta(meta)) + meta.GetCollectionIDByNameFunc = func(name string) (UniqueID, error) { + return 100, nil + } + meta.GetCollectionVirtualChannelsFunc = func(colID int64) []string { + return []string{"ch-1", "ch-2"} + } + meta.GetPartitionByNameFunc = func(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) { + return 101, nil + } + coll := &model.Collection{ + CollectionID: 100, + Name: "a-good-name", + Fields: []*model.Field{ + { + FieldID: 101, + Name: "test_field_name_1", + IsPrimaryKey: false, + IsPartitionKey: true, + DataType: schemapb.DataType_Int64, + }, + }, + } + meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { + return coll.Clone(), nil + } + resp1, err := c.Import(ctx, &milvuspb.ImportRequest{ + CollectionName: "a-good-name", + Options: []*commonpb.KeyValuePair{ + { + Key: importutil.BackupFlag, + Value: "true", + }, + }, + }) + assert.NoError(t, err) + assert.ErrorIs(t, merr.Error(resp1.GetStatus()), merr.ErrBadBulkInsertRequest) + + meta.GetPartitionByNameFunc = func(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) { + return common.InvalidPartitionID, fmt.Errorf("partition ID not found for partition name '%s'", partitionName) + } + resp2, err := c.Import(ctx, &milvuspb.ImportRequest{ + CollectionName: "a-good-name", + PartitionName: "a-bad-name", + Options: []*commonpb.KeyValuePair{ + { + Key: importutil.BackupFlag, + Value: "true", + }, + }, + }) + assert.ErrorIs(t, merr.Error(resp2.GetStatus()), merr.ErrBulkInsertPartitionNotFound) }) } diff --git a/internal/util/importutil/binlog_adapter.go b/internal/util/importutil/binlog_adapter.go index a257ace5d9..b06428e587 100644 --- a/internal/util/importutil/binlog_adapter.go +++ b/internal/util/importutil/binlog_adapter.go @@ -81,12 +81,6 @@ func NewBinlogAdapter(ctx context.Context, return nil, errors.New("collection schema is nil") } - // binlog import doesn't support partition key, the caller must specify one partition for importing - if len(collectionInfo.PartitionIDs) != 1 { - log.Warn("Binlog adapter: target partition must be only one", zap.Int("partitions", len(collectionInfo.PartitionIDs))) - return nil, errors.New("target partition must be only one") - } - if chunkManager == nil { log.Warn("Binlog adapter: chunk manager pointer is nil") return nil, errors.New("chunk manager pointer is nil") diff --git a/internal/util/importutil/collection_info.go b/internal/util/importutil/collection_info.go index 00a38e5190..f7fc31270e 100644 --- a/internal/util/importutil/collection_info.go +++ b/internal/util/importutil/collection_info.go @@ -23,14 +23,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) type CollectionInfo struct { Schema *schemapb.CollectionSchema ShardNum int32 - PartitionIDs []int64 // target partitions of bulkinsert, one partition for non-partition-key collection, or all partiitons for partition-key collection + PartitionIDs []int64 // target partitions of bulkinsert PrimaryKey *schemapb.FieldSchema PartitionKey *schemapb.FieldSchema @@ -39,20 +38,6 @@ type CollectionInfo struct { Name2FieldID map[string]int64 // this member is for Numpy file name validation and JSON row validation } -func DeduceTargetPartitions(partitions map[string]int64, collectionSchema *schemapb.CollectionSchema, defaultPartition int64) ([]int64, error) { - // if no partition key, rutrn the default partition ID as target partition - _, err := typeutil.GetPartitionKeyFieldSchema(collectionSchema) - if err != nil { - return []int64{defaultPartition}, nil - } - - _, partitionIDs, err := typeutil.RearrangePartitionsForPartitionKey(partitions) - if err != nil { - return nil, err - } - return partitionIDs, nil -} - func NewCollectionInfo(collectionSchema *schemapb.CollectionSchema, shardNum int32, partitionIDs []int64, diff --git a/internal/util/importutil/collection_info_test.go b/internal/util/importutil/collection_info_test.go index 3ae97699eb..71994e6b74 100644 --- a/internal/util/importutil/collection_info_test.go +++ b/internal/util/importutil/collection_info_test.go @@ -23,30 +23,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) -func Test_DeduceTargetPartitions(t *testing.T) { - schema := sampleSchema() - partitions := map[string]int64{ - "part_0": 100, - "part_1": 200, - } - partitionIDs, err := DeduceTargetPartitions(partitions, schema, int64(1)) - assert.NoError(t, err) - assert.Equal(t, 1, len(partitionIDs)) - assert.Equal(t, int64(1), partitionIDs[0]) - - schema.Fields[7].IsPartitionKey = true - partitionIDs, err = DeduceTargetPartitions(partitions, schema, int64(1)) - assert.NoError(t, err) - assert.Equal(t, len(partitions), len(partitionIDs)) - - partitions = map[string]int64{ - "part_a": 100, - } - partitionIDs, err = DeduceTargetPartitions(partitions, schema, int64(1)) - assert.Error(t, err) - assert.Nil(t, partitionIDs) -} - func Test_CollectionInfoNew(t *testing.T) { t.Run("succeed", func(t *testing.T) { info, err := NewCollectionInfo(sampleSchema(), 2, []int64{1}) diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index ca4b05b07d..7e9edcadad 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -115,6 +115,10 @@ var ( ErrInvalidSearchResult = newMilvusError("fail to parse search result", 1805, false) ErrCheckPrimaryKey = newMilvusError("please check the primary key and its' type can only in [int, string]", 1806, false) + // bulkinsert related + ErrBadBulkInsertRequest = newMilvusError("bad bulkinsert request", 1900, false) + ErrBulkInsertPartitionNotFound = newMilvusError("partition not found during bulkinsert", 1901, false) + // Segcore related ErrSegcore = newMilvusError("segcore error", 2000, false) diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index 40c74ba907..d841313364 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -133,6 +133,10 @@ func (s *ErrSuite) TestWrap() { // field related s.ErrorIs(WrapErrFieldNotFound("meta", "failed to get field"), ErrFieldNotFound) + + // bulkinsert related + s.ErrorIs(WrapBadBulkInsertRequest("fail reason"), ErrBadBulkInsertRequest) + s.ErrorIs(WrapBulkInsertPartitionNotFound("hello_milvus", "notexist"), ErrBulkInsertPartitionNotFound) } func (s *ErrSuite) TestOldCode() { diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 1c53e7cd7e..6226658533 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -593,3 +593,11 @@ func WrapErrFieldNotFound[T any](field T, msg ...string) error { func wrapWithField(err error, name string, value any) error { return errors.Wrapf(err, "%s=%v", name, value) } + +func WrapBadBulkInsertRequest(msg ...string) error { + return errors.Wrap(ErrBadBulkInsertRequest, strings.Join(msg, "; ")) +} + +func WrapBulkInsertPartitionNotFound(collection any, partition any) error { + return errors.Wrapf(ErrBulkInsertPartitionNotFound, "collection=%s, partition=%s", collection, partition) +} diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 6100ae49be..bacee9df8d 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -792,6 +792,16 @@ func GetPartitionKeyFieldSchema(schema *schemapb.CollectionSchema) (*schemapb.Fi return nil, errors.New("partition key field is not found") } +// HasPartitionKey check if a collection schema has PartitionKey field +func HasPartitionKey(schema *schemapb.CollectionSchema) bool { + for _, fieldSchema := range schema.Fields { + if fieldSchema.IsPartitionKey { + return true + } + } + return false +} + // GetPrimaryFieldData get primary field data from all field data inserted from sdk func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schemapb.FieldSchema) (*schemapb.FieldData, error) { primaryFieldID := primaryFieldSchema.FieldID diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index ce0ac41b10..7a50acadd0 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -834,11 +834,16 @@ func TestGetPrimaryFieldSchema(t *testing.T) { // no primary field error _, err := GetPrimaryFieldSchema(schema) assert.Error(t, err) - int64Field.IsPrimaryKey = true primaryField, err := GetPrimaryFieldSchema(schema) assert.NoError(t, err) assert.Equal(t, schemapb.DataType_Int64, primaryField.DataType) + + hasPartitionKey := HasPartitionKey(schema) + assert.False(t, hasPartitionKey) + int64Field.IsPartitionKey = true + hasPartitionKey2 := HasPartitionKey(schema) + assert.True(t, hasPartitionKey2) } func TestGetPK(t *testing.T) {