mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 19:39:21 +08:00
fix: use the default partition for the limit quota when the request partition name is empty (#38005)
- issue: #37685 Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
parent
49ee46ec1d
commit
302650ae0e
@ -264,6 +264,7 @@ type partitionInfo struct {
|
|||||||
partitionID typeutil.UniqueID
|
partitionID typeutil.UniqueID
|
||||||
createdTimestamp uint64
|
createdTimestamp uint64
|
||||||
createdUtcTimestamp uint64
|
createdUtcTimestamp uint64
|
||||||
|
isDefault bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (info *collectionInfo) isCollectionCached() bool {
|
func (info *collectionInfo) isCollectionCached() bool {
|
||||||
@ -427,12 +428,14 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string,
|
|||||||
return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String())
|
return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defaultPartitionName := Params.CommonCfg.DefaultPartitionName.GetValue()
|
||||||
infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo {
|
infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo {
|
||||||
return &partitionInfo{
|
return &partitionInfo{
|
||||||
name: partitions.PartitionNames[idx],
|
name: partitions.PartitionNames[idx],
|
||||||
partitionID: partitions.PartitionIDs[idx],
|
partitionID: partitions.PartitionIDs[idx],
|
||||||
createdTimestamp: partitions.CreatedTimestamps[idx],
|
createdTimestamp: partitions.CreatedTimestamps[idx],
|
||||||
createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx],
|
createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx],
|
||||||
|
isDefault: partitions.PartitionNames[idx] == defaultPartitionName,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -630,6 +633,14 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if partitionName == "" {
|
||||||
|
for _, info := range partitions.partitionInfos {
|
||||||
|
if info.isDefault {
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
info, ok := partitions.name2Info[partitionName]
|
info, ok := partitions.name2Info[partitionName]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, merr.WrapErrPartitionNotFound(partitionName)
|
return nil, merr.WrapErrPartitionNotFound(partitionName)
|
||||||
|
@ -84,7 +84,13 @@ func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map
|
|||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
if r.GetPartitionName() == "" {
|
if r.GetPartitionName() == "" {
|
||||||
return db.dbID, map[int64][]int64{collectionID: {}}, nil
|
collectionSchema, err := globalMetaCache.GetCollectionSchema(ctx, r.GetDbName(), r.GetCollectionName())
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
if collectionSchema.IsPartitionKeyCollection() {
|
||||||
|
return db.dbID, map[int64][]int64{collectionID: {}}, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName())
|
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -299,6 +299,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
|||||||
dbID: 100,
|
dbID: 100,
|
||||||
createdTimestamp: 1,
|
createdTimestamp: 1,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil)
|
||||||
globalMetaCache = mockCache
|
globalMetaCache = mockCache
|
||||||
|
|
||||||
limiter := limiterMock{rate: 100}
|
limiter := limiterMock{rate: 100}
|
||||||
@ -437,6 +438,41 @@ func TestGetInfo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("fail to get collection schema", func(t *testing.T) {
|
||||||
|
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||||
|
dbID: 100,
|
||||||
|
createdTimestamp: 1,
|
||||||
|
}, nil).Once()
|
||||||
|
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
|
||||||
|
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once()
|
||||||
|
|
||||||
|
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||||
|
DbName: "foo",
|
||||||
|
CollectionName: "coo",
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("partition key mode", func(t *testing.T) {
|
||||||
|
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||||
|
dbID: 100,
|
||||||
|
createdTimestamp: 1,
|
||||||
|
}, nil).Once()
|
||||||
|
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
|
||||||
|
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{
|
||||||
|
hasPartitionKeyField: true,
|
||||||
|
}, nil).Once()
|
||||||
|
|
||||||
|
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||||
|
DbName: "foo",
|
||||||
|
CollectionName: "coo",
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(100), db)
|
||||||
|
assert.NotNil(t, col2par[1])
|
||||||
|
assert.Equal(t, 0, len(col2par[1]))
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("fail to get partition", func(t *testing.T) {
|
t.Run("fail to get partition", func(t *testing.T) {
|
||||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||||
dbID: 100,
|
dbID: 100,
|
||||||
@ -467,11 +503,12 @@ func TestGetInfo(t *testing.T) {
|
|||||||
dbID: 100,
|
dbID: 100,
|
||||||
createdTimestamp: 1,
|
createdTimestamp: 1,
|
||||||
}, nil).Times(3)
|
}, nil).Times(3)
|
||||||
|
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil).Times(1)
|
||||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Times(3)
|
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Times(3)
|
||||||
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
|
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
|
||||||
name: "p1",
|
name: "p1",
|
||||||
partitionID: 100,
|
partitionID: 100,
|
||||||
}, nil).Twice()
|
}, nil).Times(3)
|
||||||
{
|
{
|
||||||
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||||
DbName: "foo",
|
DbName: "foo",
|
||||||
@ -491,7 +528,7 @@ func TestGetInfo(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, int64(100), db)
|
assert.Equal(t, int64(100), db)
|
||||||
assert.NotNil(t, col2par[10])
|
assert.NotNil(t, col2par[10])
|
||||||
assert.Equal(t, 0, len(col2par[10]))
|
assert.Equal(t, int64(100), col2par[10][0])
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
|
db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
|
||||||
|
@ -202,7 +202,12 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||||||
// insert to _default partition
|
// insert to _default partition
|
||||||
partitionTag := it.insertMsg.GetPartitionName()
|
partitionTag := it.insertMsg.GetPartitionName()
|
||||||
if len(partitionTag) <= 0 {
|
if len(partitionTag) <= 0 {
|
||||||
partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue()
|
pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.insertMsg.GetDbName(), collectionName, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
partitionTag = pinfo.name
|
||||||
it.insertMsg.PartitionName = partitionTag
|
it.insertMsg.PartitionName = partitionTag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3651,6 +3651,204 @@ func TestPartitionKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultPartition(t *testing.T) {
|
||||||
|
rc := NewRootCoordMock()
|
||||||
|
|
||||||
|
defer rc.Close()
|
||||||
|
qc := getQueryCoordClient()
|
||||||
|
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
mgr := newShardClientMgr()
|
||||||
|
err := InitMetaCache(ctx, rc, qc, mgr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
shardsNum := common.DefaultShardsNum
|
||||||
|
prefix := "TestInsertTaskWithPartitionKey"
|
||||||
|
collectionName := prefix + funcutil.GenRandomStr()
|
||||||
|
|
||||||
|
fieldName2Type := make(map[string]schemapb.DataType)
|
||||||
|
fieldName2Type["int64_field"] = schemapb.DataType_Int64
|
||||||
|
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
|
||||||
|
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
|
||||||
|
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
|
||||||
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("create collection", func(t *testing.T) {
|
||||||
|
createCollectionTask := &createCollectionTask{
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
|
||||||
|
Timestamp: Timestamp(time.Now().UnixNano()),
|
||||||
|
},
|
||||||
|
DbName: "",
|
||||||
|
CollectionName: collectionName,
|
||||||
|
Schema: marshaledSchema,
|
||||||
|
ShardsNum: shardsNum,
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
rootCoord: rc,
|
||||||
|
result: nil,
|
||||||
|
schema: nil,
|
||||||
|
}
|
||||||
|
err = createCollectionTask.PreExecute(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = createCollectionTask.Execute(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||||
|
factory := newSimpleMockMsgStreamFactory()
|
||||||
|
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||||
|
defer chMgr.removeAllDMLStream()
|
||||||
|
|
||||||
|
_, err = chMgr.getOrCreateDmlStream(collectionID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
pchans, err := chMgr.getChannels(collectionID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
interval := time.Millisecond * 10
|
||||||
|
tso := newMockTsoAllocator()
|
||||||
|
|
||||||
|
ticker := newChannelsTimeTicker(ctx, interval, []string{}, newGetStatisticsFunc(pchans), tso)
|
||||||
|
_ = ticker.start()
|
||||||
|
defer ticker.close()
|
||||||
|
|
||||||
|
idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_ = idAllocator.Start()
|
||||||
|
defer idAllocator.Close()
|
||||||
|
|
||||||
|
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
segAllocator.Init()
|
||||||
|
_ = segAllocator.Start()
|
||||||
|
defer segAllocator.Close()
|
||||||
|
|
||||||
|
nb := 10
|
||||||
|
fieldID := common.StartOfUserFieldID
|
||||||
|
fieldDatas := make([]*schemapb.FieldData, 0)
|
||||||
|
for fieldName, dataType := range fieldName2Type {
|
||||||
|
fieldData := generateFieldData(dataType, fieldName, nb)
|
||||||
|
fieldData.FieldId = int64(fieldID)
|
||||||
|
fieldDatas = append(fieldDatas, generateFieldData(dataType, fieldName, nb))
|
||||||
|
fieldID++
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Insert", func(t *testing.T) {
|
||||||
|
it := &insertTask{
|
||||||
|
insertMsg: &BaseInsertTask{
|
||||||
|
BaseMsg: msgstream.BaseMsg{},
|
||||||
|
InsertRequest: &msgpb.InsertRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Insert,
|
||||||
|
MsgID: 0,
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
CollectionName: collectionName,
|
||||||
|
FieldsData: fieldDatas,
|
||||||
|
NumRows: uint64(nb),
|
||||||
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
ctx: ctx,
|
||||||
|
result: &milvuspb.MutationResult{
|
||||||
|
Status: merr.Success(),
|
||||||
|
IDs: nil,
|
||||||
|
SuccIndex: nil,
|
||||||
|
ErrIndex: nil,
|
||||||
|
Acknowledged: false,
|
||||||
|
InsertCnt: 0,
|
||||||
|
DeleteCnt: 0,
|
||||||
|
UpsertCnt: 0,
|
||||||
|
Timestamp: 0,
|
||||||
|
},
|
||||||
|
idAllocator: idAllocator,
|
||||||
|
segIDAssigner: segAllocator,
|
||||||
|
chMgr: chMgr,
|
||||||
|
chTicker: ticker,
|
||||||
|
vChannels: nil,
|
||||||
|
pChannels: nil,
|
||||||
|
schema: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
it.insertMsg.PartitionName = ""
|
||||||
|
assert.NoError(t, it.OnEnqueue())
|
||||||
|
assert.NoError(t, it.PreExecute(ctx))
|
||||||
|
assert.NoError(t, it.Execute(ctx))
|
||||||
|
assert.NoError(t, it.PostExecute(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Upsert", func(t *testing.T) {
|
||||||
|
hash := testutils.GenerateHashKeys(nb)
|
||||||
|
ut := &upsertTask{
|
||||||
|
ctx: ctx,
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
baseMsg: msgstream.BaseMsg{
|
||||||
|
HashValues: hash,
|
||||||
|
},
|
||||||
|
req: &milvuspb.UpsertRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(
|
||||||
|
commonpbutil.WithMsgType(commonpb.MsgType_Upsert),
|
||||||
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||||
|
),
|
||||||
|
CollectionName: collectionName,
|
||||||
|
FieldsData: fieldDatas,
|
||||||
|
NumRows: uint32(nb),
|
||||||
|
},
|
||||||
|
|
||||||
|
result: &milvuspb.MutationResult{
|
||||||
|
Status: merr.Success(),
|
||||||
|
IDs: &schemapb.IDs{
|
||||||
|
IdField: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
idAllocator: idAllocator,
|
||||||
|
segIDAssigner: segAllocator,
|
||||||
|
chMgr: chMgr,
|
||||||
|
chTicker: ticker,
|
||||||
|
}
|
||||||
|
|
||||||
|
ut.req.PartitionName = ""
|
||||||
|
assert.NoError(t, ut.OnEnqueue())
|
||||||
|
assert.NoError(t, ut.PreExecute(ctx))
|
||||||
|
assert.NoError(t, ut.Execute(ctx))
|
||||||
|
assert.NoError(t, ut.PostExecute(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete", func(t *testing.T) {
|
||||||
|
dt := &deleteTask{
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
req: &milvuspb.DeleteRequest{
|
||||||
|
CollectionName: collectionName,
|
||||||
|
Expr: "int64_field in [0, 1]",
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
primaryKeys: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{0, 1}}},
|
||||||
|
},
|
||||||
|
idAllocator: idAllocator,
|
||||||
|
chMgr: chMgr,
|
||||||
|
chTicker: ticker,
|
||||||
|
collectionID: collectionID,
|
||||||
|
vChannels: []string{"test-channel"},
|
||||||
|
}
|
||||||
|
|
||||||
|
dt.req.PartitionName = ""
|
||||||
|
assert.NoError(t, dt.PreExecute(ctx))
|
||||||
|
assert.NoError(t, dt.Execute(ctx))
|
||||||
|
assert.NoError(t, dt.PostExecute(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestClusteringKey(t *testing.T) {
|
func TestClusteringKey(t *testing.T) {
|
||||||
rc := NewRootCoordMock()
|
rc := NewRootCoordMock()
|
||||||
|
|
||||||
|
@ -317,8 +317,12 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
|||||||
// insert to _default partition
|
// insert to _default partition
|
||||||
partitionTag := it.req.GetPartitionName()
|
partitionTag := it.req.GetPartitionName()
|
||||||
if len(partitionTag) <= 0 {
|
if len(partitionTag) <= 0 {
|
||||||
partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue()
|
pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.req.GetDbName(), collectionName, "")
|
||||||
it.req.PartitionName = partitionTag
|
if err != nil {
|
||||||
|
log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
it.req.PartitionName = pinfo.name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user