[Cherry-Pick] Support Database (#24769)

Support Database(#23742)
Fix db nonexists error for FlushAll (#24222)
Fix check collection limits fails (#24235)
backward compatibility with empty DB name (#24317)
Fix GetFlushAllState with DB (#24347)
Remove db from global meta cache after drop database (#24474)
Fix db name is empty for describe collection response (#24603)
Add RBAC for Database API (#24653)
Fix miss load the same name collection during recover stage (#24941)

RBAC supports Database validation (#23609)
Fix to list grant with db return empty (#23922)
Optimize PrivilegeAll permission check (#23972)
Add the default db value for the rbac request (#24307)

Signed-off-by: jaime <yun.zhang@zilliz.com>
Co-authored-by: SimFG <bang.fu@zilliz.com>
Co-authored-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
jaime 2023-06-25 17:20:43 +08:00 committed by GitHub
parent 23492fed99
commit 18df2ba6fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
148 changed files with 8232 additions and 3029 deletions

View File

@ -345,8 +345,6 @@ generate-mockery: getdeps
$(PWD)/bin/mockery --name=Store --dir=$(PWD)/internal/querycoordv2/meta --output=$(PWD)/internal/querycoordv2/meta --filename=mock_store.go --with-expecter --structname=MockStore --outpkg=meta --inpackage $(PWD)/bin/mockery --name=Store --dir=$(PWD)/internal/querycoordv2/meta --output=$(PWD)/internal/querycoordv2/meta --filename=mock_store.go --with-expecter --structname=MockStore --outpkg=meta --inpackage
$(PWD)/bin/mockery --name=Balance --dir=$(PWD)/internal/querycoordv2/balance --output=$(PWD)/internal/querycoordv2/balance --filename=mock_balancer.go --with-expecter --structname=MockBalancer --outpkg=balance --inpackage $(PWD)/bin/mockery --name=Balance --dir=$(PWD)/internal/querycoordv2/balance --output=$(PWD)/internal/querycoordv2/balance --filename=mock_balancer.go --with-expecter --structname=MockBalancer --outpkg=balance --inpackage
$(PWD)/bin/mockery --name=Controller --dir=$(PWD)/internal/querycoordv2/dist --output=$(PWD)/internal/querycoordv2/dist --filename=mock_controller.go --with-expecter --structname=MockController --outpkg=dist --inpackage $(PWD)/bin/mockery --name=Controller --dir=$(PWD)/internal/querycoordv2/dist --output=$(PWD)/internal/querycoordv2/dist --filename=mock_controller.go --with-expecter --structname=MockController --outpkg=dist --inpackage
# internal/querynode
$(PWD)/bin/mockery --name=TSafeReplicaInterface --dir=$(PWD)/internal/querynode --output=$(PWD)/internal/querynode --filename=mock_tsafe_replica_test.go --with-expecter --structname=MockTSafeReplicaInterface --outpkg=querynode --inpackage
# internal/rootcoord # internal/rootcoord
$(PWD)/bin/mockery --name=IMetaTable --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=meta_table.go --with-expecter --outpkg=mockrootcoord $(PWD)/bin/mockery --name=IMetaTable --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=meta_table.go --with-expecter --outpkg=mockrootcoord
$(PWD)/bin/mockery --name=GarbageCollector --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=garbage_collector.go --with-expecter --outpkg=mockrootcoord $(PWD)/bin/mockery --name=GarbageCollector --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=garbage_collector.go --with-expecter --outpkg=mockrootcoord

View File

@ -3,21 +3,22 @@ package meta
import ( import (
"fmt" "fmt"
"github.com/milvus-io/milvus/cmd/tools/migration/legacy"
"github.com/milvus-io/milvus/cmd/tools/migration/legacy/legacypb"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/cmd/tools/migration/legacy"
"github.com/milvus-io/milvus/cmd/tools/migration/legacy/legacypb"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/util"
) )
type FieldIndexesWithSchema struct { type FieldIndexesWithSchema struct {
indexes []*pb.FieldIndexInfo indexes []*pb.FieldIndexInfo
schema *schemapb.CollectionSchema schema *schemapb.CollectionSchema
} }
type FieldIndexes210 map[UniqueID]*FieldIndexesWithSchema // coll_id -> field indexes. type FieldIndexes210 map[UniqueID]*FieldIndexesWithSchema // coll_id -> field indexes.
type TtCollectionsMeta210 map[UniqueID]map[Timestamp]*pb.CollectionInfo // coll_id -> ts -> coll type TtCollectionsMeta210 map[UniqueID]map[Timestamp]*pb.CollectionInfo // coll_id -> ts -> coll
@ -163,7 +164,7 @@ func (meta *TtCollectionsMeta210) GenerateSaves() map[string]string {
var err error var err error
for collection := range *meta { for collection := range *meta {
for ts := range (*meta)[collection] { for ts := range (*meta)[collection] {
k := rootcoord.ComposeSnapshotKey(rootcoord.SnapshotPrefix, rootcoord.BuildCollectionKey(collection), rootcoord.SnapshotsSep, ts) k := rootcoord.ComposeSnapshotKey(rootcoord.SnapshotPrefix, rootcoord.BuildCollectionKey(util.NonDBID, collection), rootcoord.SnapshotsSep, ts)
record := (*meta)[collection][ts] record := (*meta)[collection][ts]
if record == nil { if record == nil {
v = rootcoord.ConstructTombstone() v = rootcoord.ConstructTombstone()
@ -189,7 +190,7 @@ func (meta *CollectionsMeta210) GenerateSaves() map[string]string {
var err error var err error
for collection := range *meta { for collection := range *meta {
record := (*meta)[collection] record := (*meta)[collection]
k := rootcoord.BuildCollectionKey(collection) k := rootcoord.BuildCollectionKey(util.NonDBID, collection)
if record == nil { if record == nil {
v = rootcoord.ConstructTombstone() v = rootcoord.ConstructTombstone()
} else { } else {

View File

@ -3,12 +3,14 @@ package meta
import ( import (
"github.com/blang/semver/v4" "github.com/blang/semver/v4"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/cmd/tools/migration/versions" "github.com/milvus-io/milvus/cmd/tools/migration/versions"
"github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord"
"github.com/milvus-io/milvus/internal/metastore/kv/querycoord" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord" "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/util"
) )
type TtCollectionsMeta220 map[UniqueID]map[Timestamp]*model.Collection // coll_id -> ts -> coll type TtCollectionsMeta220 map[UniqueID]map[Timestamp]*model.Collection // coll_id -> ts -> coll
@ -40,7 +42,7 @@ func (meta *TtCollectionsMeta220) GenerateSaves(sourceVersion semver.Version) (m
for collectionID := range *meta { for collectionID := range *meta {
for ts := range (*meta)[collectionID] { for ts := range (*meta)[collectionID] {
ckey := rootcoord.BuildCollectionKey(collectionID) ckey := rootcoord.BuildCollectionKey(util.NonDBID, collectionID)
key := rootcoord.ComposeSnapshotKey(rootcoord.SnapshotPrefix, ckey, rootcoord.SnapshotsSep, ts) key := rootcoord.ComposeSnapshotKey(rootcoord.SnapshotPrefix, ckey, rootcoord.SnapshotsSep, ts)
collection := (*meta)[collectionID][ts] collection := (*meta)[collectionID][ts]
var value string var value string
@ -87,7 +89,7 @@ func (meta *CollectionsMeta220) GenerateSaves(sourceVersion semver.Version) (map
} }
for collectionID := range *meta { for collectionID := range *meta {
ckey := rootcoord.BuildCollectionKey(collectionID) ckey := rootcoord.BuildCollectionKey(util.NonDBID, collectionID)
collection := (*meta)[collectionID] collection := (*meta)[collectionID]
var value string var value string
if collection == nil { if collection == nil {

View File

@ -143,6 +143,7 @@ natsmq:
# Related configuration of rootCoord, used to handle data definition language (DDL) and data control language (DCL) requests # Related configuration of rootCoord, used to handle data definition language (DDL) and data control language (DCL) requests
rootCoord: rootCoord:
dmlChannelNum: 16 # The number of dml channels created at system startup dmlChannelNum: 16 # The number of dml channels created at system startup
maxDatabaseNum: 64 # Maximum number of database
maxPartitionNum: 4096 # Maximum number of partitions in a collection maxPartitionNum: 4096 # Maximum number of partitions in a collection
minSegmentSizeToEnableIndex: 1024 # It's a threshold. When the segment size is less than this value, the segment will not be indexed minSegmentSizeToEnableIndex: 1024 # It's a threshold. When the segment size is less than this value, the segment will not be indexed
importTaskExpiration: 900 # (in seconds) Duration after which an import task will expire (be killed). Default 900 seconds (15 minutes). importTaskExpiration: 900 # (in seconds) Duration after which an import task will expire (be killed). Default 900 seconds (15 minutes).
@ -461,7 +462,7 @@ quotaAndLimits:
enabled: true # `true` to enable quota and limits, `false` to disable. enabled: true # `true` to enable quota and limits, `false` to disable.
limits: limits:
maxCollectionNum: 65536 maxCollectionNum: 65536
maxCollectionNumPerDB: 64 maxCollectionNumPerDB: 65536
# quotaCenterCollectInterval is the time interval that quotaCenter # quotaCenterCollectInterval is the time interval that quotaCenter
# collects metrics from Proxies, Query cluster and Data cluster. # collects metrics from Proxies, Query cluster and Data cluster.
# seconds, (0 ~ 65536) # seconds, (0 ~ 65536)

View File

@ -144,7 +144,7 @@ func (_c *NMockHandler_GetCollection_Call) Return(_a0 *collectionInfo, _a1 error
return _c return _c
} }
// GetDataVChanPositions provides a mock function with given fields: channel, partitionID // GetDataVChanPositions provides a mock function with given fields: ch, partitionID
func (_m *NMockHandler) GetDataVChanPositions(ch *channel, partitionID int64) *datapb.VchannelInfo { func (_m *NMockHandler) GetDataVChanPositions(ch *channel, partitionID int64) *datapb.VchannelInfo {
ret := _m.Called(ch, partitionID) ret := _m.Called(ch, partitionID)
@ -166,13 +166,13 @@ type NMockHandler_GetDataVChanPositions_Call struct {
} }
// GetDataVChanPositions is a helper method to define mock.On call // GetDataVChanPositions is a helper method to define mock.On call
// - channel *channel // - ch *channel
// - partitionID int64 // - partitionID int64
func (_e *NMockHandler_Expecter) GetDataVChanPositions(channel interface{}, partitionID interface{}) *NMockHandler_GetDataVChanPositions_Call { func (_e *NMockHandler_Expecter) GetDataVChanPositions(ch interface{}, partitionID interface{}) *NMockHandler_GetDataVChanPositions_Call {
return &NMockHandler_GetDataVChanPositions_Call{Call: _e.mock.On("GetDataVChanPositions", channel, partitionID)} return &NMockHandler_GetDataVChanPositions_Call{Call: _e.mock.On("GetDataVChanPositions", ch, partitionID)}
} }
func (_c *NMockHandler_GetDataVChanPositions_Call) Run(run func(channel *channel, partitionID int64)) *NMockHandler_GetDataVChanPositions_Call { func (_c *NMockHandler_GetDataVChanPositions_Call) Run(run func(ch *channel, partitionID int64)) *NMockHandler_GetDataVChanPositions_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(*channel), args[1].(int64)) run(args[0].(*channel), args[1].(int64))
}) })
@ -184,7 +184,7 @@ func (_c *NMockHandler_GetDataVChanPositions_Call) Return(_a0 *datapb.VchannelIn
return _c return _c
} }
// GetQueryVChanPositions provides a mock function with given fields: channel, partitionIDs // GetQueryVChanPositions provides a mock function with given fields: ch, partitionIDs
func (_m *NMockHandler) GetQueryVChanPositions(ch *channel, partitionIDs ...int64) *datapb.VchannelInfo { func (_m *NMockHandler) GetQueryVChanPositions(ch *channel, partitionIDs ...int64) *datapb.VchannelInfo {
_va := make([]interface{}, len(partitionIDs)) _va := make([]interface{}, len(partitionIDs))
for _i := range partitionIDs { for _i := range partitionIDs {
@ -213,14 +213,14 @@ type NMockHandler_GetQueryVChanPositions_Call struct {
} }
// GetQueryVChanPositions is a helper method to define mock.On call // GetQueryVChanPositions is a helper method to define mock.On call
// - channel *channel // - ch *channel
// - partitionIDs ...int64 // - partitionIDs ...int64
func (_e *NMockHandler_Expecter) GetQueryVChanPositions(channel interface{}, partitionIDs ...interface{}) *NMockHandler_GetQueryVChanPositions_Call { func (_e *NMockHandler_Expecter) GetQueryVChanPositions(ch interface{}, partitionIDs ...interface{}) *NMockHandler_GetQueryVChanPositions_Call {
return &NMockHandler_GetQueryVChanPositions_Call{Call: _e.mock.On("GetQueryVChanPositions", return &NMockHandler_GetQueryVChanPositions_Call{Call: _e.mock.On("GetQueryVChanPositions",
append([]interface{}{channel}, partitionIDs...)...)} append([]interface{}{ch}, partitionIDs...)...)}
} }
func (_c *NMockHandler_GetQueryVChanPositions_Call) Run(run func(channel *channel, partitionIDs ...int64)) *NMockHandler_GetQueryVChanPositions_Call { func (_c *NMockHandler_GetQueryVChanPositions_Call) Run(run func(ch *channel, partitionIDs ...int64)) *NMockHandler_GetQueryVChanPositions_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]int64, len(args)-1) variadicArgs := make([]int64, len(args)-1)
for i, a := range args[1:] { for i, a := range args[1:] {

View File

@ -453,6 +453,18 @@ func (m *mockRootCoordService) ShowCollections(ctx context.Context, req *milvusp
}, nil }, nil
} }
func (m *mockRootCoordService) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockRootCoordService) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockRootCoordService) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockRootCoordService) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { func (m *mockRootCoordService) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }

View File

@ -988,7 +988,7 @@ func (s *Server) loadCollectionFromRootCoord(ctx context.Context, collectionID i
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(paramtable.GetNodeID()),
), ),
DbName: "", // please do not specify the collection name alone after database feature.
CollectionID: collectionID, CollectionID: collectionID,
}) })
if err = VerifyResponse(resp, err); err != nil { if err = VerifyResponse(resp, err); err != nil {
@ -1000,9 +1000,12 @@ func (s *Server) loadCollectionFromRootCoord(ctx context.Context, collectionID i
commonpbutil.WithMsgID(0), commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(paramtable.GetNodeID()),
), ),
DbName: "", // please do not specify the collection name alone after database feature.
CollectionName: resp.Schema.Name, /*
CollectionID: resp.CollectionID, DbName: "",
CollectionName: resp.Schema.Name,
*/
CollectionID: resp.CollectionID,
}) })
if err = VerifyResponse(presp, err); err != nil { if err = VerifyResponse(presp, err); err != nil {
log.Error("show partitions error", zap.String("collectionName", resp.Schema.Name), log.Error("show partitions error", zap.String("collectionName", resp.Schema.Name),

View File

@ -3455,21 +3455,24 @@ func TestGetFlushAllState(t *testing.T) {
ChannelCPs []Timestamp ChannelCPs []Timestamp
FlushAllTs Timestamp FlushAllTs Timestamp
ServerIsHealthy bool ServerIsHealthy bool
ListDatabaseFailed bool
ShowCollectionFailed bool ShowCollectionFailed bool
DescribeCollectionFailed bool DescribeCollectionFailed bool
ExpectedSuccess bool ExpectedSuccess bool
ExpectedFlushed bool ExpectedFlushed bool
}{ }{
{"test FlushAll flushed", []Timestamp{100, 200}, 99, {"test FlushAll flushed", []Timestamp{100, 200}, 99,
true, false, false, true, true}, true, false, false, false, true, true},
{"test FlushAll not flushed", []Timestamp{100, 200}, 150, {"test FlushAll not flushed", []Timestamp{100, 200}, 150,
true, false, false, true, false}, true, false, false, false, true, false},
{"test Sever is not healthy", nil, 0, {"test Sever is not healthy", nil, 0,
false, false, false, false, false}, false, false, false, false, false, false},
{"test ListDatabase failed", nil, 0,
true, true, false, false, false, false},
{"test ShowCollections failed", nil, 0, {"test ShowCollections failed", nil, 0,
true, true, false, false, false}, true, false, true, false, false, false},
{"test DescribeCollection failed", nil, 0, {"test DescribeCollection failed", nil, 0,
true, false, true, false, false}, true, false, false, true, false, false},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.testName, func(t *testing.T) { t.Run(test.testName, func(t *testing.T) {
@ -3483,6 +3486,19 @@ func TestGetFlushAllState(t *testing.T) {
var err error var err error
svr.meta = &meta{} svr.meta = &meta{}
svr.rootCoordClient = mocks.NewRootCoord(t) svr.rootCoordClient = mocks.NewRootCoord(t)
if test.ListDatabaseFailed {
svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything).
Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
}, nil).Maybe()
} else {
svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything).
Return(&milvuspb.ListDatabasesResponse{
DbNames: []string{"db1"},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil).Maybe()
}
if test.ShowCollectionFailed { if test.ShowCollectionFailed {
svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything).
Return(&milvuspb.ShowCollectionsResponse{ Return(&milvuspb.ShowCollectionsResponse{

View File

@ -648,6 +648,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(paramtable.GetNodeID()),
), ),
// please do not specify the collection name alone after database feature.
CollectionID: collectionID, CollectionID: collectionID,
}) })
if err = VerifyResponse(dresp, err); err != nil { if err = VerifyResponse(dresp, err); err != nil {
@ -785,6 +786,7 @@ func (s *Server) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryI
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(paramtable.GetNodeID()),
), ),
// please do not specify the collection name alone after database feature.
CollectionID: collectionID, CollectionID: collectionID,
}) })
if err = VerifyResponse(dresp, err); err != nil { if err = VerifyResponse(dresp, err); err != nil {
@ -1299,35 +1301,49 @@ func (s *Server) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAll
return resp, nil return resp, nil
} }
showColRsp, err := s.rootCoordClient.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{ dbsRsp, err := s.rootCoordClient.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ListDatabases)),
commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections), })
)}) if err = VerifyResponse(dbsRsp, err); err != nil {
if err = VerifyResponse(showColRsp, err); err != nil { log.Warn("failed to ListDatabases", zap.Error(err))
log.Warn("failed to ShowCollections", zap.Error(err))
resp.Status.Reason = err.Error() resp.Status.Reason = err.Error()
return resp, nil return resp, nil
} }
for _, collection := range showColRsp.GetCollectionIds() { for _, dbName := range dbsRsp.DbNames {
describeColRsp, err := s.rootCoordClient.DescribeCollectionInternal(ctx, &milvuspb.DescribeCollectionRequest{ showColRsp, err := s.rootCoordClient.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections),
), ),
CollectionID: collection, DbName: dbName,
}) })
if err = VerifyResponse(describeColRsp, err); err != nil { if err = VerifyResponse(showColRsp, err); err != nil {
log.Warn("failed to DescribeCollectionInternal", zap.Error(err)) log.Warn("failed to ShowCollections", zap.Error(err))
resp.Status.Reason = err.Error() resp.Status.Reason = err.Error()
return resp, nil return resp, nil
} }
for _, channel := range describeColRsp.GetVirtualChannelNames() {
channelCP := s.meta.GetChannelCheckpoint(channel) for _, collection := range showColRsp.GetCollectionIds() {
if channelCP == nil || channelCP.GetTimestamp() < req.GetFlushAllTs() { describeColRsp, err := s.rootCoordClient.DescribeCollectionInternal(ctx, &milvuspb.DescribeCollectionRequest{
resp.Flushed = false Base: commonpbutil.NewMsgBase(
resp.Status.ErrorCode = commonpb.ErrorCode_Success commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
),
// please do not specify the collection name alone after database feature.
CollectionID: collection,
})
if err = VerifyResponse(describeColRsp, err); err != nil {
log.Warn("failed to DescribeCollectionInternal", zap.Error(err))
resp.Status.Reason = err.Error()
return resp, nil return resp, nil
} }
for _, channel := range describeColRsp.GetVirtualChannelNames() {
channelCP := s.meta.GetChannelCheckpoint(channel)
if channelCP == nil || channelCP.GetTimestamp() < req.GetFlushAllTs() {
resp.Flushed = false
resp.Status.ErrorCode = commonpb.ErrorCode_Success
return resp, nil
}
}
} }
} }
resp.Flushed = true resp.Flushed = true

View File

@ -54,7 +54,7 @@ type MockAllocator_Alloc_Call struct {
} }
// Alloc is a helper method to define mock.On call // Alloc is a helper method to define mock.On call
// - count uint32 // - count uint32
func (_e *MockAllocator_Expecter) Alloc(count interface{}) *MockAllocator_Alloc_Call { func (_e *MockAllocator_Expecter) Alloc(count interface{}) *MockAllocator_Alloc_Call {
return &MockAllocator_Alloc_Call{Call: _e.mock.On("Alloc", count)} return &MockAllocator_Alloc_Call{Call: _e.mock.On("Alloc", count)}
} }
@ -170,8 +170,8 @@ type MockAllocator_GetGenerator_Call struct {
} }
// GetGenerator is a helper method to define mock.On call // GetGenerator is a helper method to define mock.On call
// - count int // - count int
// - done <-chan struct{} // - done <-chan struct{}
func (_e *MockAllocator_Expecter) GetGenerator(count interface{}, done interface{}) *MockAllocator_GetGenerator_Call { func (_e *MockAllocator_Expecter) GetGenerator(count interface{}, done interface{}) *MockAllocator_GetGenerator_Call {
return &MockAllocator_GetGenerator_Call{Call: _e.mock.On("GetGenerator", count, done)} return &MockAllocator_GetGenerator_Call{Call: _e.mock.On("GetGenerator", count, done)}
} }

View File

@ -67,7 +67,7 @@ func (mService *metaService) getCollectionInfo(ctx context.Context, collID Uniqu
commonpbutil.WithMsgID(0), //GOOSE TODO commonpbutil.WithMsgID(0), //GOOSE TODO
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(paramtable.GetNodeID()),
), ),
DbName: "default", // GOOSE TODO // please do not specify the collection name alone after database feature.
CollectionID: collID, CollectionID: collID,
TimeStamp: timestamp, TimeStamp: timestamp,
} }

View File

@ -49,6 +49,18 @@ func (m *mockProxyComponent) Dummy(ctx context.Context, request *milvuspb.DummyR
var emptyBody = &gin.H{} var emptyBody = &gin.H{}
var testStatus = &commonpb.Status{Reason: "ok"} var testStatus = &commonpb.Status{Reason: "ok"}
func (m *mockProxyComponent) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return testStatus, nil
}
func (m *mockProxyComponent) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return testStatus, nil
}
func (m *mockProxyComponent) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return &milvuspb.ListDatabasesResponse{Status: testStatus}, nil
}
func (m *mockProxyComponent) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { func (m *mockProxyComponent) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
return testStatus, nil return testStatus, nil
} }

View File

@ -172,6 +172,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
otelgrpc.UnaryServerInterceptor(opts...), otelgrpc.UnaryServerInterceptor(opts...),
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor), grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
proxy.DatabaseInterceptor(),
proxy.UnaryServerHookInterceptor(), proxy.UnaryServerHookInterceptor(),
proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor), proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor),
logutil.UnaryTraceLoggerInterceptor, logutil.UnaryTraceLoggerInterceptor,
@ -946,25 +947,14 @@ func (s *Server) ListClientInfos(ctx context.Context, req *proxypb.ListClientInf
return s.proxy.ListClientInfos(ctx, req) return s.proxy.ListClientInfos(ctx, req)
} }
func (s *Server) CreateDatabase(ctx context.Context, req *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { func (s *Server) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{ return s.proxy.CreateDatabase(ctx, request)
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "TODO: implement me @jaime",
}, nil
} }
func (s *Server) DropDatabase(ctx context.Context, req *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { func (s *Server) DropDatabase(ctx context.Context, request *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{ return s.proxy.DropDatabase(ctx, request)
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "TODO: implement me @jaime",
}, nil
} }
func (s *Server) ListDatabases(ctx context.Context, req *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { func (s *Server) ListDatabases(ctx context.Context, request *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return &milvuspb.ListDatabasesResponse{ return s.proxy.ListDatabases(ctx, request)
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "TODO: implement me @jaime",
},
}, nil
} }

View File

@ -127,6 +127,18 @@ func (m *MockRootCoord) Register() error {
return m.regErr return m.regErr
} }
func (m *MockRootCoord) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockRootCoord) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockRootCoord) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return nil, nil
}
func (m *MockRootCoord) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { func (m *MockRootCoord) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
return nil, nil return nil, nil
} }
@ -526,6 +538,18 @@ func (m *MockProxy) InvalidateCollectionMetaCache(ctx context.Context, request *
return nil, nil return nil, nil
} }
func (m *MockProxy) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockProxy) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *MockProxy) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return nil, nil
}
func (m *MockProxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { func (m *MockProxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
return nil, nil return nil, nil
} }
@ -1341,6 +1365,20 @@ func Test_NewServer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("CreateDatabase", func(t *testing.T) {
_, err := server.CreateDatabase(ctx, nil)
assert.Nil(t, err)
})
t.Run("DropDatabase", func(t *testing.T) {
_, err := server.DropDatabase(ctx, nil)
assert.Nil(t, err)
})
t.Run("ListDatabase", func(t *testing.T) {
_, err := server.ListDatabases(ctx, nil)
assert.Nil(t, err)
})
err = server.Stop() err = server.Stop()
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -620,3 +620,60 @@ func (c *Client) RenameCollection(ctx context.Context, req *milvuspb.RenameColle
return client.RenameCollection(ctx, req) return client.RenameCollection(ctx, req)
}) })
} }
func (c *Client) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
in = typeutil.Clone(in)
commonpbutil.UpdateMsgBase(
in.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)),
)
ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateDatabase(ctx, in)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*commonpb.Status), err
}
func (c *Client) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
in = typeutil.Clone(in)
commonpbutil.UpdateMsgBase(
in.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)),
)
ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropDatabase(ctx, in)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*commonpb.Status), err
}
func (c *Client) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
in = typeutil.Clone(in)
commonpbutil.UpdateMsgBase(
in.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)),
)
ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ListDatabases(ctx, in)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*milvuspb.ListDatabasesResponse), err
}

View File

@ -248,6 +248,18 @@ func Test_NewClient(t *testing.T) {
r, err := client.CheckHealth(ctx, nil) r, err := client.CheckHealth(ctx, nil)
retCheck(retNotNil, r, err) retCheck(retNotNil, r, err)
} }
{
r, err := client.CreateDatabase(ctx, nil)
retCheck(retNotNil, r, err)
}
{
r, err := client.DropDatabase(ctx, nil)
retCheck(retNotNil, r, err)
}
{
r, err := client.ListDatabases(ctx, nil)
retCheck(retNotNil, r, err)
}
} }
client.grpcClient = &mock.GRPCClientBase[rootcoordpb.RootCoordClient]{ client.grpcClient = &mock.GRPCClientBase[rootcoordpb.RootCoordClient]{
@ -450,6 +462,18 @@ func Test_NewClient(t *testing.T) {
rTimeout, err := client.CheckHealth(shortCtx, nil) rTimeout, err := client.CheckHealth(shortCtx, nil)
retCheck(rTimeout, err) retCheck(rTimeout, err)
} }
{
rTimeout, err := client.CreateDatabase(shortCtx, nil)
retCheck(rTimeout, err)
}
{
rTimeout, err := client.DropDatabase(shortCtx, nil)
retCheck(rTimeout, err)
}
{
rTimeout, err := client.ListDatabases(shortCtx, nil)
retCheck(rTimeout, err)
}
// clean up // clean up
err = client.Stop() err = client.Stop()
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -69,6 +69,18 @@ type Server struct {
newQueryCoordClient func(string, *clientv3.Client) types.QueryCoord newQueryCoordClient func(string, *clientv3.Client) types.QueryCoord
} }
func (s *Server) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return s.rootCoord.CreateDatabase(ctx, request)
}
func (s *Server) DropDatabase(ctx context.Context, request *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return s.rootCoord.DropDatabase(ctx, request)
}
func (s *Server) ListDatabases(ctx context.Context, request *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return s.rootCoord.ListDatabases(ctx, request)
}
func (s *Server) CheckHealth(ctx context.Context, request *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { func (s *Server) CheckHealth(ctx context.Context, request *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
return s.rootCoord.CheckHealth(ctx, request) return s.rootCoord.CheckHealth(ctx, request)
} }

View File

@ -43,6 +43,20 @@ type mockCore struct {
types.RootCoordComponent types.RootCoordComponent
} }
func (m *mockCore) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}
func (m *mockCore) DropDatabase(ctx context.Context, request *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}
func (m *mockCore) ListDatabases(ctx context.Context, request *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return &milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil
}
func (m *mockCore) RenameCollection(ctx context.Context, request *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { func (m *mockCore) RenameCollection(ctx context.Context, request *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) {
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
} }
@ -52,6 +66,7 @@ func (m *mockCore) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthReq
IsHealthy: true, IsHealthy: true,
}, nil }, nil
} }
func (m *mockCore) UpdateStateCode(commonpb.StateCode) { func (m *mockCore) UpdateStateCode(commonpb.StateCode) {
} }
@ -194,6 +209,23 @@ func TestRun(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("CreateDatabase", func(t *testing.T) {
ret, err := svr.CreateDatabase(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, ret.ErrorCode)
})
t.Run("DropDatabase", func(t *testing.T) {
ret, err := svr.DropDatabase(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, ret.ErrorCode)
})
t.Run("ListDatabases", func(t *testing.T) {
ret, err := svr.ListDatabases(ctx, nil)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, ret.Status.ErrorCode)
})
err = svr.Stop() err = svr.Stop()
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -13,22 +13,26 @@ import (
//go:generate mockery --name=RootCoordCatalog //go:generate mockery --name=RootCoordCatalog
type RootCoordCatalog interface { type RootCoordCatalog interface {
CreateDatabase(ctx context.Context, db *model.Database, ts typeutil.Timestamp) error
DropDatabase(ctx context.Context, dbID int64, ts typeutil.Timestamp) error
ListDatabases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Database, error)
CreateCollection(ctx context.Context, collectionInfo *model.Collection, ts typeutil.Timestamp) error CreateCollection(ctx context.Context, collectionInfo *model.Collection, ts typeutil.Timestamp) error
GetCollectionByID(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*model.Collection, error) GetCollectionByID(ctx context.Context, dbID int64, ts typeutil.Timestamp, collectionID typeutil.UniqueID) (*model.Collection, error)
GetCollectionByName(ctx context.Context, collectionName string, ts typeutil.Timestamp) (*model.Collection, error) GetCollectionByName(ctx context.Context, dbID int64, collectionName string, ts typeutil.Timestamp) (*model.Collection, error)
ListCollections(ctx context.Context, ts typeutil.Timestamp) (map[string]*model.Collection, error) ListCollections(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Collection, error)
CollectionExists(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) bool CollectionExists(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, ts typeutil.Timestamp) bool
DropCollection(ctx context.Context, collectionInfo *model.Collection, ts typeutil.Timestamp) error DropCollection(ctx context.Context, collectionInfo *model.Collection, ts typeutil.Timestamp) error
AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, alterType AlterType, ts typeutil.Timestamp) error AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, alterType AlterType, ts typeutil.Timestamp) error
CreatePartition(ctx context.Context, partition *model.Partition, ts typeutil.Timestamp) error CreatePartition(ctx context.Context, dbID int64, partition *model.Partition, ts typeutil.Timestamp) error
DropPartition(ctx context.Context, collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, ts typeutil.Timestamp) error DropPartition(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, ts typeutil.Timestamp) error
AlterPartition(ctx context.Context, oldPart *model.Partition, newPart *model.Partition, alterType AlterType, ts typeutil.Timestamp) error AlterPartition(ctx context.Context, dbID int64, oldPart *model.Partition, newPart *model.Partition, alterType AlterType, ts typeutil.Timestamp) error
CreateAlias(ctx context.Context, alias *model.Alias, ts typeutil.Timestamp) error CreateAlias(ctx context.Context, alias *model.Alias, ts typeutil.Timestamp) error
DropAlias(ctx context.Context, alias string, ts typeutil.Timestamp) error DropAlias(ctx context.Context, dbID int64, alias string, ts typeutil.Timestamp) error
AlterAlias(ctx context.Context, alias *model.Alias, ts typeutil.Timestamp) error AlterAlias(ctx context.Context, alias *model.Alias, ts typeutil.Timestamp) error
ListAliases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Alias, error) ListAliases(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Alias, error)
// GetCredential gets the credential info for the username, returns error if no credential exists for this username. // GetCredential gets the credential info for the username, returns error if no credential exists for this username.
GetCredential(ctx context.Context, username string) (*model.Credential, error) GetCredential(ctx context.Context, username string) (*model.Credential, error)

View File

@ -34,6 +34,22 @@ func NewTableCatalog(txImpl dbmodel.ITransaction, metaDomain dbmodel.IMetaDomain
} }
} }
func (tc *Catalog) CreateDatabase(ctx context.Context, db *model.Database, ts typeutil.Timestamp) error {
//TODO
return nil
}
func (tc *Catalog) DropDatabase(ctx context.Context, dbID int64, ts typeutil.Timestamp) error {
//TODO
return nil
}
func (tc *Catalog) ListDatabases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Database, error) {
//TODO
return make([]*model.Database, 0), nil
}
func (tc *Catalog) CreateCollection(ctx context.Context, collection *model.Collection, ts typeutil.Timestamp) error { func (tc *Catalog) CreateCollection(ctx context.Context, collection *model.Collection, ts typeutil.Timestamp) error {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
@ -151,7 +167,7 @@ func (tc *Catalog) CreateCollection(ctx context.Context, collection *model.Colle
}) })
} }
func (tc *Catalog) GetCollectionByID(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*model.Collection, error) { func (tc *Catalog) GetCollectionByID(ctx context.Context, dbID int64, ts typeutil.Timestamp, collectionID typeutil.UniqueID) (*model.Collection, error) {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
// get latest timestamp less than or equals to param ts // get latest timestamp less than or equals to param ts
@ -215,7 +231,7 @@ func (tc *Catalog) populateCollection(ctx context.Context, collectionID typeutil
return mCollection, nil return mCollection, nil
} }
func (tc *Catalog) GetCollectionByName(ctx context.Context, collectionName string, ts typeutil.Timestamp) (*model.Collection, error) { func (tc *Catalog) GetCollectionByName(ctx context.Context, dbID int64, collectionName string, ts typeutil.Timestamp) (*model.Collection, error) {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
// Since collection name will not change for different ts // Since collection name will not change for different ts
@ -224,7 +240,7 @@ func (tc *Catalog) GetCollectionByName(ctx context.Context, collectionName strin
return nil, err return nil, err
} }
return tc.GetCollectionByID(ctx, collectionID, ts) return tc.GetCollectionByID(ctx, dbID, ts, collectionID)
} }
// ListCollections For time travel (ts > 0), find only one record respectively for each collection no matter `is_deleted` is true or false // ListCollections For time travel (ts > 0), find only one record respectively for each collection no matter `is_deleted` is true or false
@ -234,7 +250,7 @@ func (tc *Catalog) GetCollectionByName(ctx context.Context, collectionName strin
// [collection3, t3, is_deleted=false] // [collection3, t3, is_deleted=false]
// t1, t2, t3 are the largest timestamp that less than or equal to @param ts // t1, t2, t3 are the largest timestamp that less than or equal to @param ts
// the final result will only return collection2 and collection3 since collection1 is deleted // the final result will only return collection2 and collection3 since collection1 is deleted
func (tc *Catalog) ListCollections(ctx context.Context, ts typeutil.Timestamp) (map[string]*model.Collection, error) { func (tc *Catalog) ListCollections(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Collection, error) {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
// 1. find each collection_id with latest ts <= @param ts // 1. find each collection_id with latest ts <= @param ts
@ -243,7 +259,7 @@ func (tc *Catalog) ListCollections(ctx context.Context, ts typeutil.Timestamp) (
return nil, err return nil, err
} }
if len(cidTsPairs) == 0 { if len(cidTsPairs) == 0 {
return map[string]*model.Collection{}, nil return make([]*model.Collection, 0), nil
} }
// 2. populate each collection // 2. populate each collection
@ -268,16 +284,10 @@ func (tc *Catalog) ListCollections(ctx context.Context, ts typeutil.Timestamp) (
log.Error("list collections by collection_id & ts pair failed", zap.Uint64("ts", ts), zap.Error(err)) log.Error("list collections by collection_id & ts pair failed", zap.Uint64("ts", ts), zap.Error(err))
return nil, err return nil, err
} }
return collections, nil
r := map[string]*model.Collection{}
for _, c := range collections {
r[c.Name] = c
}
return r, nil
} }
func (tc *Catalog) CollectionExists(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) bool { func (tc *Catalog) CollectionExists(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, ts typeutil.Timestamp) bool {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
// get latest timestamp less than or equals to param ts // get latest timestamp less than or equals to param ts
@ -431,7 +441,7 @@ func (tc *Catalog) AlterCollection(ctx context.Context, oldColl *model.Collectio
return fmt.Errorf("altering collection doesn't support %s", alterType.String()) return fmt.Errorf("altering collection doesn't support %s", alterType.String())
} }
func (tc *Catalog) CreatePartition(ctx context.Context, partition *model.Partition, ts typeutil.Timestamp) error { func (tc *Catalog) CreatePartition(ctx context.Context, dbID int64, partition *model.Partition, ts typeutil.Timestamp) error {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
p := &dbmodel.Partition{ p := &dbmodel.Partition{
@ -452,7 +462,7 @@ func (tc *Catalog) CreatePartition(ctx context.Context, partition *model.Partiti
return nil return nil
} }
func (tc *Catalog) DropPartition(ctx context.Context, collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, ts typeutil.Timestamp) error { func (tc *Catalog) DropPartition(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, ts typeutil.Timestamp) error {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
p := &dbmodel.Partition{ p := &dbmodel.Partition{
@ -488,7 +498,7 @@ func (tc *Catalog) alterModifyPartition(ctx context.Context, oldPart *model.Part
return tc.metaDomain.PartitionDb(ctx).Update(p) return tc.metaDomain.PartitionDb(ctx).Update(p)
} }
func (tc *Catalog) AlterPartition(ctx context.Context, oldPart *model.Partition, newPart *model.Partition, alterType metastore.AlterType, ts typeutil.Timestamp) error { func (tc *Catalog) AlterPartition(ctx context.Context, dbID int64, oldPart *model.Partition, newPart *model.Partition, alterType metastore.AlterType, ts typeutil.Timestamp) error {
if alterType == metastore.MODIFY { if alterType == metastore.MODIFY {
return tc.alterModifyPartition(ctx, oldPart, newPart, ts) return tc.alterModifyPartition(ctx, oldPart, newPart, ts)
} }
@ -513,7 +523,7 @@ func (tc *Catalog) CreateAlias(ctx context.Context, alias *model.Alias, ts typeu
return nil return nil
} }
func (tc *Catalog) DropAlias(ctx context.Context, alias string, ts typeutil.Timestamp) error { func (tc *Catalog) DropAlias(ctx context.Context, dbID int64, alias string, ts typeutil.Timestamp) error {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
collectionID, err := tc.metaDomain.CollAliasDb(ctx).GetCollectionIDByAlias(tenantID, alias, ts) collectionID, err := tc.metaDomain.CollAliasDb(ctx).GetCollectionIDByAlias(tenantID, alias, ts)
@ -549,7 +559,7 @@ func (tc *Catalog) AlterAlias(ctx context.Context, alias *model.Alias, ts typeut
} }
// ListAliases query collection ID and aliases only, other information are not needed // ListAliases query collection ID and aliases only, other information are not needed
func (tc *Catalog) ListAliases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Alias, error) { func (tc *Catalog) ListAliases(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Alias, error) {
tenantID := contextutil.TenantID(ctx) tenantID := contextutil.TenantID(ctx)
// 1. find each collection with latest ts // 1. find each collection with latest ts
@ -964,7 +974,7 @@ func (tc *Catalog) ListPolicy(ctx context.Context, tenant string) ([]string, err
} }
for _, grantID := range grantIDs { for _, grantID := range grantIDs {
policies = append(policies, policies = append(policies,
funcutil.PolicyForPrivilege(grant.Role.Name, grant.Object, grant.ObjectName, grantID.Privilege)) funcutil.PolicyForPrivilege(grant.Role.Name, grant.Object, grant.ObjectName, grantID.Privilege, "default"))
} }
} }

View File

@ -20,6 +20,7 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/contextutil"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
@ -36,6 +37,8 @@ const (
segmentID1 = typeutil.UniqueID(2000) segmentID1 = typeutil.UniqueID(2000)
indexBuildID1 = typeutil.UniqueID(3000) indexBuildID1 = typeutil.UniqueID(3000)
testDb = int64(1000)
collName1 = "test_collection_name_1" collName1 = "test_collection_name_1"
collAlias1 = "test_collection_alias_1" collAlias1 = "test_collection_alias_1"
collAlias2 = "test_collection_alias_2" collAlias2 = "test_collection_alias_2"
@ -272,7 +275,7 @@ func TestTableCatalog_GetCollectionByID(t *testing.T) {
indexDbMock.On("Get", tenantID, collID1).Return(indexes, nil).Once() indexDbMock.On("Get", tenantID, collID1).Return(indexes, nil).Once()
// actual // actual
res, gotErr := mockCatalog.GetCollectionByID(ctx, collID1, ts) res, gotErr := mockCatalog.GetCollectionByID(ctx, util.NonDBID, ts, collID1)
// collection basic info // collection basic info
require.Equal(t, nil, gotErr) require.Equal(t, nil, gotErr)
require.Equal(t, coll.TenantID, res.TenantID) require.Equal(t, coll.TenantID, res.TenantID)
@ -307,7 +310,7 @@ func TestTableCatalog_GetCollectionByID_UnmarshalStartPositionsError(t *testing.
indexDbMock.On("Get", tenantID, collID1).Return(nil, nil).Once() indexDbMock.On("Get", tenantID, collID1).Return(nil, nil).Once()
// actual // actual
res, gotErr := mockCatalog.GetCollectionByID(ctx, collID1, ts) res, gotErr := mockCatalog.GetCollectionByID(ctx, util.NonDBID, ts, collID1)
require.Nil(t, res) require.Nil(t, res)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -319,7 +322,7 @@ func TestTableCatalog_GetCollectionByID_SelectCollError(t *testing.T) {
collDbMock.On("Get", tenantID, collID1, ts).Return(nil, errTest).Once() collDbMock.On("Get", tenantID, collID1, ts).Return(nil, errTest).Once()
// actual // actual
res, gotErr := mockCatalog.GetCollectionByID(ctx, collID1, ts) res, gotErr := mockCatalog.GetCollectionByID(ctx, util.NonDBID, ts, collID1)
require.Nil(t, res) require.Nil(t, res)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -341,7 +344,7 @@ func TestTableCatalog_GetCollectionByID_SelectFieldError(t *testing.T) {
fieldDbMock.On("GetByCollectionID", tenantID, collID1, ts).Return(nil, errTest).Once() fieldDbMock.On("GetByCollectionID", tenantID, collID1, ts).Return(nil, errTest).Once()
// actual // actual
res, gotErr := mockCatalog.GetCollectionByID(ctx, collID1, ts) res, gotErr := mockCatalog.GetCollectionByID(ctx, util.NonDBID, ts, collID1)
require.Nil(t, res) require.Nil(t, res)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -364,7 +367,7 @@ func TestTableCatalog_GetCollectionByID_SelectPartitionError(t *testing.T) {
partitionDbMock.On("GetByCollectionID", tenantID, collID1, ts).Return(nil, errTest).Once() partitionDbMock.On("GetByCollectionID", tenantID, collID1, ts).Return(nil, errTest).Once()
// actual // actual
res, gotErr := mockCatalog.GetCollectionByID(ctx, collID1, ts) res, gotErr := mockCatalog.GetCollectionByID(ctx, util.NonDBID, ts, collID1)
require.Nil(t, res) require.Nil(t, res)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -388,7 +391,7 @@ func TestTableCatalog_GetCollectionByID_SelectChannelError(t *testing.T) {
collChannelDbMock.On("GetByCollectionID", tenantID, collID1, ts).Return(nil, errTest).Once() collChannelDbMock.On("GetByCollectionID", tenantID, collID1, ts).Return(nil, errTest).Once()
// actual // actual
res, gotErr := mockCatalog.GetCollectionByID(ctx, collID1, ts) res, gotErr := mockCatalog.GetCollectionByID(ctx, util.NonDBID, ts, collID1)
require.Nil(t, res) require.Nil(t, res)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -449,7 +452,7 @@ func TestTableCatalog_GetCollectionByName(t *testing.T) {
indexDbMock.On("Get", tenantID, collID1).Return(indexes, nil).Once() indexDbMock.On("Get", tenantID, collID1).Return(indexes, nil).Once()
// actual // actual
res, gotErr := mockCatalog.GetCollectionByName(ctx, collName1, ts) res, gotErr := mockCatalog.GetCollectionByName(ctx, util.NonDBID, collName1, ts)
// collection basic info // collection basic info
require.Equal(t, nil, gotErr) require.Equal(t, nil, gotErr)
require.Equal(t, coll.TenantID, res.TenantID) require.Equal(t, coll.TenantID, res.TenantID)
@ -471,7 +474,7 @@ func TestTableCatalog_GetCollectionByName_SelectCollIDError(t *testing.T) {
collDbMock.On("GetCollectionIDByName", tenantID, collName1, ts).Return(typeutil.UniqueID(0), errTest).Once() collDbMock.On("GetCollectionIDByName", tenantID, collName1, ts).Return(typeutil.UniqueID(0), errTest).Once()
// actual // actual
res, gotErr := mockCatalog.GetCollectionByName(ctx, collName1, ts) res, gotErr := mockCatalog.GetCollectionByName(ctx, util.NonDBID, collName1, ts)
require.Nil(t, res) require.Nil(t, res)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -531,21 +534,21 @@ func TestTableCatalog_ListCollections(t *testing.T) {
indexDbMock.On("Get", tenantID, collID1).Return(indexes, nil).Once() indexDbMock.On("Get", tenantID, collID1).Return(indexes, nil).Once()
// actual // actual
res, gotErr := mockCatalog.ListCollections(ctx, ts) res, gotErr := mockCatalog.ListCollections(ctx, util.NonDBID, ts)
// collection basic info // collection basic info
require.Equal(t, nil, gotErr) require.Equal(t, nil, gotErr)
require.Equal(t, 1, len(res)) require.Equal(t, 1, len(res))
require.Equal(t, coll.TenantID, res[coll.CollectionName].TenantID) require.Equal(t, coll.TenantID, res[0].TenantID)
require.Equal(t, coll.CollectionID, res[coll.CollectionName].CollectionID) require.Equal(t, coll.CollectionID, res[0].CollectionID)
require.Equal(t, coll.CollectionName, res[coll.CollectionName].Name) require.Equal(t, coll.CollectionName, res[0].Name)
require.Equal(t, coll.AutoID, res[coll.CollectionName].AutoID) require.Equal(t, coll.AutoID, res[0].AutoID)
require.Equal(t, coll.Ts, res[coll.CollectionName].CreateTime) require.Equal(t, coll.Ts, res[0].CreateTime)
require.Empty(t, res[coll.CollectionName].StartPositions) require.Empty(t, res[0].StartPositions)
// partitions/fields/channels // partitions/fields/channels
require.NotEmpty(t, res[coll.CollectionName].Partitions) require.NotEmpty(t, res[0].Partitions)
require.NotEmpty(t, res[coll.CollectionName].Fields) require.NotEmpty(t, res[0].Fields)
require.NotEmpty(t, res[coll.CollectionName].VirtualChannelNames) require.NotEmpty(t, res[0].VirtualChannelNames)
require.NotEmpty(t, res[coll.CollectionName].PhysicalChannelNames) require.NotEmpty(t, res[0].PhysicalChannelNames)
} }
func TestTableCatalog_CollectionExists(t *testing.T) { func TestTableCatalog_CollectionExists(t *testing.T) {
@ -561,7 +564,7 @@ func TestTableCatalog_CollectionExists(t *testing.T) {
collDbMock.On("Get", tenantID, collID1, resultTs).Return(coll, nil).Once() collDbMock.On("Get", tenantID, collID1, resultTs).Return(coll, nil).Once()
// actual // actual
res := mockCatalog.CollectionExists(ctx, collID1, ts) res := mockCatalog.CollectionExists(ctx, util.NonDBID, collID1, ts)
require.True(t, res) require.True(t, res)
} }
@ -579,7 +582,7 @@ func TestTableCatalog_CollectionExists_IsDeletedTrue(t *testing.T) {
collDbMock.On("Get", tenantID, collID1, resultTs).Return(coll, nil).Once() collDbMock.On("Get", tenantID, collID1, resultTs).Return(coll, nil).Once()
// actual // actual
res := mockCatalog.CollectionExists(ctx, collID1, ts) res := mockCatalog.CollectionExists(ctx, util.NonDBID, collID1, ts)
require.False(t, res) require.False(t, res)
} }
@ -591,7 +594,7 @@ func TestTableCatalog_CollectionExists_CollNotExists(t *testing.T) {
collDbMock.On("Get", tenantID, collID1, resultTs).Return(nil, nil).Once() collDbMock.On("Get", tenantID, collID1, resultTs).Return(nil, nil).Once()
// actual // actual
res := mockCatalog.CollectionExists(ctx, collID1, ts) res := mockCatalog.CollectionExists(ctx, util.NonDBID, collID1, ts)
require.False(t, res) require.False(t, res)
} }
@ -601,7 +604,7 @@ func TestTableCatalog_CollectionExists_GetCidTsError(t *testing.T) {
collDbMock.On("GetCollectionIDTs", tenantID, collID1, ts).Return(nil, errTest).Once() collDbMock.On("GetCollectionIDTs", tenantID, collID1, ts).Return(nil, errTest).Once()
// actual // actual
res := mockCatalog.CollectionExists(ctx, collID1, ts) res := mockCatalog.CollectionExists(ctx, util.NonDBID, collID1, ts)
require.False(t, res) require.False(t, res)
} }
@ -836,7 +839,7 @@ func TestTableCatalog_CreatePartition(t *testing.T) {
partitionDbMock.On("Insert", mock.Anything).Return(nil).Once() partitionDbMock.On("Insert", mock.Anything).Return(nil).Once()
// actual // actual
gotErr := mockCatalog.CreatePartition(ctx, partition, ts) gotErr := mockCatalog.CreatePartition(ctx, util.NonDBID, partition, ts)
require.Equal(t, nil, gotErr) require.Equal(t, nil, gotErr)
} }
@ -853,7 +856,7 @@ func TestTableCatalog_CreatePartition_InsertPartitionError(t *testing.T) {
partitionDbMock.On("Insert", mock.Anything).Return(errTest).Once() partitionDbMock.On("Insert", mock.Anything).Return(errTest).Once()
// actual // actual
gotErr := mockCatalog.CreatePartition(ctx, partition, ts) gotErr := mockCatalog.CreatePartition(ctx, util.NonDBID, partition, ts)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -862,7 +865,7 @@ func TestTableCatalog_DropPartition_TsNot0(t *testing.T) {
partitionDbMock.On("Insert", mock.Anything).Return(nil).Once() partitionDbMock.On("Insert", mock.Anything).Return(nil).Once()
// actual // actual
gotErr := mockCatalog.DropPartition(ctx, collID1, partitionID1, ts) gotErr := mockCatalog.DropPartition(ctx, util.NonDBID, collID1, partitionID1, ts)
require.NoError(t, gotErr) require.NoError(t, gotErr)
} }
@ -872,7 +875,7 @@ func TestTableCatalog_DropPartition_TsNot0_PartitionInsertError(t *testing.T) {
partitionDbMock.On("Insert", mock.Anything).Return(errTest).Once() partitionDbMock.On("Insert", mock.Anything).Return(errTest).Once()
// actual // actual
gotErr := mockCatalog.DropPartition(ctx, collID1, partitionID1, ts) gotErr := mockCatalog.DropPartition(ctx, util.NonDBID, collID1, partitionID1, ts)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -894,7 +897,7 @@ func TestCatalog_AlterPartition(t *testing.T) {
partitionDbMock.On("Update", mock.Anything).Return(nil).Once() partitionDbMock.On("Update", mock.Anything).Return(nil).Once()
gotErr := mockCatalog.AlterPartition(ctx, partition, newPartition, metastore.MODIFY, ts) gotErr := mockCatalog.AlterPartition(ctx, util.NonDBID, partition, newPartition, metastore.MODIFY, ts)
require.NoError(t, gotErr) require.NoError(t, gotErr)
} }
@ -907,10 +910,10 @@ func TestCatalog_AlterPartition_TsNot0_AlterTypeError(t *testing.T) {
State: pb.PartitionState_PartitionCreated, State: pb.PartitionState_PartitionCreated,
} }
gotErr := mockCatalog.AlterPartition(ctx, partition, partition, metastore.ADD, ts) gotErr := mockCatalog.AlterPartition(ctx, util.NonDBID, partition, partition, metastore.ADD, ts)
require.Error(t, gotErr) require.Error(t, gotErr)
gotErr = mockCatalog.AlterPartition(ctx, partition, partition, metastore.DELETE, ts) gotErr = mockCatalog.AlterPartition(ctx, util.NonDBID, partition, partition, metastore.DELETE, ts)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -928,7 +931,7 @@ func TestCatalog_AlterPartition_TsNot0_PartitionInsertError(t *testing.T) {
partitionDbMock.On("Update", mock.Anything).Return(errTest).Once() partitionDbMock.On("Update", mock.Anything).Return(errTest).Once()
// actual // actual
gotErr := mockCatalog.AlterPartition(ctx, partition, partition, metastore.MODIFY, ts) gotErr := mockCatalog.AlterPartition(ctx, util.NonDBID, partition, partition, metastore.MODIFY, ts)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -967,7 +970,7 @@ func TestTableCatalog_DropAlias_TsNot0(t *testing.T) {
aliasDbMock.On("Insert", mock.Anything).Return(nil).Once() aliasDbMock.On("Insert", mock.Anything).Return(nil).Once()
// actual // actual
gotErr := mockCatalog.DropAlias(ctx, collAlias1, ts) gotErr := mockCatalog.DropAlias(ctx, testDb, collAlias1, ts)
require.NoError(t, gotErr) require.NoError(t, gotErr)
} }
@ -977,7 +980,7 @@ func TestTableCatalog_DropAlias_TsNot0_SelectCollectionIDByAliasError(t *testing
aliasDbMock.On("GetCollectionIDByAlias", tenantID, collAlias1, ts).Return(typeutil.UniqueID(0), errTest).Once() aliasDbMock.On("GetCollectionIDByAlias", tenantID, collAlias1, ts).Return(typeutil.UniqueID(0), errTest).Once()
// actual // actual
gotErr := mockCatalog.DropAlias(ctx, collAlias1, ts) gotErr := mockCatalog.DropAlias(ctx, testDb, collAlias1, ts)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -988,7 +991,7 @@ func TestTableCatalog_DropAlias_TsNot0_InsertIndexError(t *testing.T) {
aliasDbMock.On("Insert", mock.Anything).Return(errTest).Once() aliasDbMock.On("Insert", mock.Anything).Return(errTest).Once()
// actual // actual
gotErr := mockCatalog.DropAlias(ctx, collAlias1, ts) gotErr := mockCatalog.DropAlias(ctx, testDb, collAlias1, ts)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -1026,7 +1029,7 @@ func TestTableCatalog_ListAliases(t *testing.T) {
aliasDbMock.On("List", tenantID, cidTsPairs).Return(collAliases, nil).Once() aliasDbMock.On("List", tenantID, cidTsPairs).Return(collAliases, nil).Once()
// actual // actual
res, gotErr := mockCatalog.ListAliases(ctx, ts) res, gotErr := mockCatalog.ListAliases(ctx, testDb, ts)
require.Equal(t, nil, gotErr) require.Equal(t, nil, gotErr)
require.Equal(t, out, res) require.Equal(t, out, res)
} }
@ -1036,7 +1039,7 @@ func TestTableCatalog_ListAliases_NoResult(t *testing.T) {
aliasDbMock.On("ListCollectionIDTs", tenantID, ts).Return(nil, nil).Once() aliasDbMock.On("ListCollectionIDTs", tenantID, ts).Return(nil, nil).Once()
// actual // actual
res, gotErr := mockCatalog.ListAliases(ctx, ts) res, gotErr := mockCatalog.ListAliases(ctx, testDb, ts)
require.Equal(t, nil, gotErr) require.Equal(t, nil, gotErr)
require.Empty(t, res) require.Empty(t, res)
} }
@ -1047,7 +1050,7 @@ func TestTableCatalog_ListAliases_ListCidTsError(t *testing.T) {
aliasDbMock.On("ListCollectionIDTs", tenantID, ts).Return(nil, errTest).Once() aliasDbMock.On("ListCollectionIDTs", tenantID, ts).Return(nil, errTest).Once()
// actual // actual
res, gotErr := mockCatalog.ListAliases(ctx, ts) res, gotErr := mockCatalog.ListAliases(ctx, testDb, ts)
require.Nil(t, res) require.Nil(t, res)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -1060,7 +1063,7 @@ func TestTableCatalog_ListAliases_SelectAliasError(t *testing.T) {
aliasDbMock.On("List", tenantID, mock.Anything).Return(nil, errTest).Once() aliasDbMock.On("List", tenantID, mock.Anything).Return(nil, errTest).Once()
// actual // actual
res, gotErr := mockCatalog.ListAliases(ctx, ts) res, gotErr := mockCatalog.ListAliases(ctx, testDb, ts)
require.Nil(t, res) require.Nil(t, res)
require.Error(t, gotErr) require.Error(t, gotErr)
} }
@ -1972,7 +1975,7 @@ func TestTableCatalog_ListPolicy(t *testing.T) {
policies, err = mockCatalog.ListPolicy(ctx, tenantID) policies, err = mockCatalog.ListPolicy(ctx, tenantID)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(policies)) require.Equal(t, 3, len(policies))
require.Equal(t, funcutil.PolicyForPrivilege(roleName1, object1, objectName1, privilege1), policies[0]) require.Equal(t, funcutil.PolicyForPrivilege(roleName1, object1, objectName1, privilege1, util.DefaultDBName), policies[0])
grantDbMock.On("GetGrants", tenantID, int64(0), "", "").Return(nil, errors.New("test error")).Once() grantDbMock.On("GetGrants", tenantID, int64(0), "", "").Return(nil, errors.New("test error")).Once()
_, err = mockCatalog.ListPolicy(ctx, tenantID) _, err = mockCatalog.ListPolicy(ctx, tenantID)

View File

@ -6,8 +6,6 @@ import (
"fmt" "fmt"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv"
@ -22,6 +20,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
"go.uber.org/zap"
) )
const ( const (
@ -32,12 +31,16 @@ const (
// prefix/partitions/collection_id/partition_id -> PartitionInfo // prefix/partitions/collection_id/partition_id -> PartitionInfo
// prefix/aliases/alias_name -> AliasInfo // prefix/aliases/alias_name -> AliasInfo
// prefix/fields/collection_id/field_id -> FieldSchema // prefix/fields/collection_id/field_id -> FieldSchema
type Catalog struct { type Catalog struct {
Txn kv.TxnKV Txn kv.TxnKV
Snapshot kv.SnapShotKV Snapshot kv.SnapShotKV
} }
func BuildCollectionKey(collectionID typeutil.UniqueID) string { func BuildCollectionKey(dbID typeutil.UniqueID, collectionID typeutil.UniqueID) string {
if dbID != util.NonDBID {
return BuildCollectionKeyWithDBID(dbID, collectionID)
}
return fmt.Sprintf("%s/%d", CollectionMetaPrefix, collectionID) return fmt.Sprintf("%s/%d", CollectionMetaPrefix, collectionID)
} }
@ -65,6 +68,21 @@ func BuildAliasKey(aliasName string) string {
return fmt.Sprintf("%s/%s", AliasMetaPrefix, aliasName) return fmt.Sprintf("%s/%s", AliasMetaPrefix, aliasName)
} }
func BuildAliasKeyWithDB(dbID int64, aliasName string) string {
k := BuildAliasKey(aliasName)
if dbID == util.NonDBID {
return k
}
return fmt.Sprintf("%s/%s/%d/%s", DatabaseMetaPrefix, Aliases, dbID, aliasName)
}
func BuildAliasPrefixWithDB(dbID int64) string {
if dbID == util.NonDBID {
return AliasMetaPrefix
}
return fmt.Sprintf("%s/%s/%d", DatabaseMetaPrefix, Aliases, dbID)
}
func batchMultiSaveAndRemoveWithPrefix(snapshot kv.SnapShotKV, maxTxnNum int, saves map[string]string, removals []string, ts typeutil.Timestamp) error { func batchMultiSaveAndRemoveWithPrefix(snapshot kv.SnapShotKV, maxTxnNum int, saves map[string]string, removals []string, ts typeutil.Timestamp) error {
saveFn := func(partialKvs map[string]string) error { saveFn := func(partialKvs map[string]string) error {
return snapshot.MultiSave(partialKvs, ts) return snapshot.MultiSave(partialKvs, ts)
@ -79,12 +97,45 @@ func batchMultiSaveAndRemoveWithPrefix(snapshot kv.SnapShotKV, maxTxnNum int, sa
return etcd.RemoveByBatch(removals, removeFn) return etcd.RemoveByBatch(removals, removeFn)
} }
func (kc *Catalog) CreateDatabase(ctx context.Context, db *model.Database, ts typeutil.Timestamp) error {
key := BuildDatabaseKey(db.ID)
dbInfo := model.MarshalDatabaseModel(db)
v, err := proto.Marshal(dbInfo)
if err != nil {
return err
}
return kc.Snapshot.Save(key, string(v), ts)
}
func (kc *Catalog) DropDatabase(ctx context.Context, dbID int64, ts typeutil.Timestamp) error {
key := BuildDatabaseKey(dbID)
return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, []string{key}, ts)
}
func (kc *Catalog) ListDatabases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Database, error) {
_, vals, err := kc.Snapshot.LoadWithPrefix(DBInfoMetaPrefix, ts)
if err != nil {
return nil, err
}
dbs := make([]*model.Database, 0, len(vals))
for _, val := range vals {
dbMeta := &pb.DatabaseInfo{}
err := proto.Unmarshal([]byte(val), dbMeta)
if err != nil {
return nil, err
}
dbs = append(dbs, model.UnmarshalDatabaseModel(dbMeta))
}
return dbs, nil
}
func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection, ts typeutil.Timestamp) error { func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection, ts typeutil.Timestamp) error {
if coll.State != pb.CollectionState_CollectionCreating { if coll.State != pb.CollectionState_CollectionCreating {
return fmt.Errorf("cannot create collection with state: %s, collection: %s", coll.State.String(), coll.Name) return fmt.Errorf("cannot create collection with state: %s, collection: %s", coll.State.String(), coll.Name)
} }
k1 := BuildCollectionKey(coll.CollectionID) k1 := BuildCollectionKey(coll.DBID, coll.CollectionID)
collInfo := model.MarshalCollectionModel(coll) collInfo := model.MarshalCollectionModel(coll)
v1, err := proto.Marshal(collInfo) v1, err := proto.Marshal(collInfo)
if err != nil { if err != nil {
@ -128,17 +179,17 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection,
} }
// Though batchSave is not atomic enough, we can promise the atomicity outside. // Though batchSave is not atomic enough, we can promise the atomicity outside.
// Recovering from failure, if we found collection is creating, we should removing all these related meta. // Recovering from failure, if we found collection is creating, we should remove all these related meta.
return etcd.SaveByBatchWithLimit(kvs, maxTxnNum/2, func(partialKvs map[string]string) error { return etcd.SaveByBatchWithLimit(kvs, maxTxnNum/2, func(partialKvs map[string]string) error {
return kc.Snapshot.MultiSave(partialKvs, ts) return kc.Snapshot.MultiSave(partialKvs, ts)
}) })
} }
func (kc *Catalog) loadCollection(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*pb.CollectionInfo, error) { func (kc *Catalog) loadCollectionFromDb(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*pb.CollectionInfo, error) {
collKey := BuildCollectionKey(collectionID) collKey := BuildCollectionKey(dbID, collectionID)
collVal, err := kc.Snapshot.Load(collKey, ts) collVal, err := kc.Snapshot.Load(collKey, ts)
if err != nil { if err != nil {
return nil, common.NewCollectionNotExistError(fmt.Sprintf("collection not found: %d", collectionID)) return nil, common.NewCollectionNotExistError(fmt.Sprintf("collection not found: %d, error: %s", collectionID, err.Error()))
} }
collMeta := &pb.CollectionInfo{} collMeta := &pb.CollectionInfo{}
@ -146,6 +197,21 @@ func (kc *Catalog) loadCollection(ctx context.Context, collectionID typeutil.Uni
return collMeta, err return collMeta, err
} }
func (kc *Catalog) loadCollectionFromDefaultDb(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*pb.CollectionInfo, error) {
if info, err := kc.loadCollectionFromDb(ctx, util.DefaultDBID, collectionID, ts); err == nil {
return info, nil
}
// get collection from older version.
return kc.loadCollectionFromDb(ctx, util.NonDBID, collectionID, ts)
}
func (kc *Catalog) loadCollection(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*pb.CollectionInfo, error) {
if isDefaultDB(dbID) {
return kc.loadCollectionFromDefaultDb(ctx, collectionID, ts)
}
return kc.loadCollectionFromDb(ctx, dbID, collectionID, ts)
}
func partitionVersionAfter210(collMeta *pb.CollectionInfo) bool { func partitionVersionAfter210(collMeta *pb.CollectionInfo) bool {
return len(collMeta.GetPartitionIDs()) <= 0 && return len(collMeta.GetPartitionIDs()) <= 0 &&
len(collMeta.GetPartitionNames()) <= 0 && len(collMeta.GetPartitionNames()) <= 0 &&
@ -160,8 +226,8 @@ func partitionExistByName(collMeta *pb.CollectionInfo, partitionName string) boo
return funcutil.SliceContain(collMeta.GetPartitionNames(), partitionName) return funcutil.SliceContain(collMeta.GetPartitionNames(), partitionName)
} }
func (kc *Catalog) CreatePartition(ctx context.Context, partition *model.Partition, ts typeutil.Timestamp) error { func (kc *Catalog) CreatePartition(ctx context.Context, dbID int64, partition *model.Partition, ts typeutil.Timestamp) error {
collMeta, err := kc.loadCollection(ctx, partition.CollectionID, ts) collMeta, err := kc.loadCollection(ctx, dbID, partition.CollectionID, ts)
if err != nil { if err != nil {
return err return err
} }
@ -190,7 +256,8 @@ func (kc *Catalog) CreatePartition(ctx context.Context, partition *model.Partiti
collMeta.PartitionNames = append(collMeta.PartitionNames, partition.PartitionName) collMeta.PartitionNames = append(collMeta.PartitionNames, partition.PartitionName)
collMeta.PartitionCreatedTimestamps = append(collMeta.PartitionCreatedTimestamps, partition.PartitionCreatedTimestamp) collMeta.PartitionCreatedTimestamps = append(collMeta.PartitionCreatedTimestamps, partition.PartitionCreatedTimestamp)
k := BuildCollectionKey(partition.CollectionID) // this partition exists in older version, should be also changed in place.
k := BuildCollectionKey(util.NonDBID, partition.CollectionID)
v, err := proto.Marshal(collMeta) v, err := proto.Marshal(collMeta)
if err != nil { if err != nil {
return err return err
@ -200,14 +267,15 @@ func (kc *Catalog) CreatePartition(ctx context.Context, partition *model.Partiti
func (kc *Catalog) CreateAlias(ctx context.Context, alias *model.Alias, ts typeutil.Timestamp) error { func (kc *Catalog) CreateAlias(ctx context.Context, alias *model.Alias, ts typeutil.Timestamp) error {
oldKBefore210 := BuildAliasKey210(alias.Name) oldKBefore210 := BuildAliasKey210(alias.Name)
k := BuildAliasKey(alias.Name) oldKeyWithoutDb := BuildAliasKey(alias.Name)
k := BuildAliasKeyWithDB(alias.DbID, alias.Name)
aliasInfo := model.MarshalAliasModel(alias) aliasInfo := model.MarshalAliasModel(alias)
v, err := proto.Marshal(aliasInfo) v, err := proto.Marshal(aliasInfo)
if err != nil { if err != nil {
return err return err
} }
kvs := map[string]string{k: string(v)} kvs := map[string]string{k: string(v)}
return kc.Snapshot.MultiSaveAndRemoveWithPrefix(kvs, []string{oldKBefore210}, ts) return kc.Snapshot.MultiSaveAndRemoveWithPrefix(kvs, []string{oldKBefore210, oldKeyWithoutDb}, ts)
} }
func (kc *Catalog) CreateCredential(ctx context.Context, credential *model.Credential) error { func (kc *Catalog) CreateCredential(ctx context.Context, credential *model.Credential) error {
@ -295,9 +363,8 @@ func (kc *Catalog) appendPartitionAndFieldsInfo(ctx context.Context, collMeta *p
return collection, nil return collection, nil
} }
func (kc *Catalog) GetCollectionByID(ctx context.Context, collectionID typeutil.UniqueID, func (kc *Catalog) GetCollectionByID(ctx context.Context, dbID int64, ts typeutil.Timestamp, collectionID typeutil.UniqueID) (*model.Collection, error) {
ts typeutil.Timestamp) (*model.Collection, error) { collMeta, err := kc.loadCollection(ctx, dbID, collectionID, ts)
collMeta, err := kc.loadCollection(ctx, collectionID, ts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -305,8 +372,8 @@ func (kc *Catalog) GetCollectionByID(ctx context.Context, collectionID typeutil.
return kc.appendPartitionAndFieldsInfo(ctx, collMeta, ts) return kc.appendPartitionAndFieldsInfo(ctx, collMeta, ts)
} }
func (kc *Catalog) CollectionExists(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) bool { func (kc *Catalog) CollectionExists(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, ts typeutil.Timestamp) bool {
_, err := kc.GetCollectionByID(ctx, collectionID, ts) _, err := kc.GetCollectionByID(ctx, dbID, ts, collectionID)
return err == nil return err == nil
} }
@ -336,12 +403,14 @@ func (kc *Catalog) AlterAlias(ctx context.Context, alias *model.Alias, ts typeut
} }
func (kc *Catalog) DropCollection(ctx context.Context, collectionInfo *model.Collection, ts typeutil.Timestamp) error { func (kc *Catalog) DropCollection(ctx context.Context, collectionInfo *model.Collection, ts typeutil.Timestamp) error {
collectionKey := BuildCollectionKey(collectionInfo.CollectionID) collectionKeys := []string{BuildCollectionKey(collectionInfo.DBID, collectionInfo.CollectionID)}
var delMetakeysSnap []string var delMetakeysSnap []string
for _, alias := range collectionInfo.Aliases { for _, alias := range collectionInfo.Aliases {
delMetakeysSnap = append(delMetakeysSnap, delMetakeysSnap = append(delMetakeysSnap,
BuildAliasKey210(alias), BuildAliasKey210(alias),
BuildAliasKey(alias),
BuildAliasKeyWithDB(collectionInfo.DBID, alias),
) )
} }
// Snapshot will list all (k, v) pairs and then use Txn.MultiSave to save tombstone for these keys when it prepares // Snapshot will list all (k, v) pairs and then use Txn.MultiSave to save tombstone for these keys when it prepares
@ -364,7 +433,7 @@ func (kc *Catalog) DropCollection(ctx context.Context, collectionInfo *model.Col
} }
// if we found collection dropping, we should try removing related resources. // if we found collection dropping, we should try removing related resources.
return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, []string{collectionKey}, ts) return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, collectionKeys, ts)
} }
func (kc *Catalog) alterModifyCollection(oldColl *model.Collection, newColl *model.Collection, ts typeutil.Timestamp) error { func (kc *Catalog) alterModifyCollection(oldColl *model.Collection, newColl *model.Collection, ts typeutil.Timestamp) error {
@ -372,6 +441,7 @@ func (kc *Catalog) alterModifyCollection(oldColl *model.Collection, newColl *mod
return fmt.Errorf("altering tenant id or collection id is forbidden") return fmt.Errorf("altering tenant id or collection id is forbidden")
} }
oldCollClone := oldColl.Clone() oldCollClone := oldColl.Clone()
oldCollClone.DBID = newColl.DBID
oldCollClone.Name = newColl.Name oldCollClone.Name = newColl.Name
oldCollClone.Description = newColl.Description oldCollClone.Description = newColl.Description
oldCollClone.AutoID = newColl.AutoID oldCollClone.AutoID = newColl.AutoID
@ -382,7 +452,7 @@ func (kc *Catalog) alterModifyCollection(oldColl *model.Collection, newColl *mod
oldCollClone.CreateTime = newColl.CreateTime oldCollClone.CreateTime = newColl.CreateTime
oldCollClone.ConsistencyLevel = newColl.ConsistencyLevel oldCollClone.ConsistencyLevel = newColl.ConsistencyLevel
oldCollClone.State = newColl.State oldCollClone.State = newColl.State
key := BuildCollectionKey(oldColl.CollectionID) key := BuildCollectionKey(newColl.DBID, oldColl.CollectionID)
value, err := proto.Marshal(model.MarshalCollectionModel(oldCollClone)) value, err := proto.Marshal(model.MarshalCollectionModel(oldCollClone))
if err != nil { if err != nil {
return err return err
@ -414,7 +484,7 @@ func (kc *Catalog) alterModifyPartition(oldPart *model.Partition, newPart *model
return kc.Snapshot.Save(key, string(value), ts) return kc.Snapshot.Save(key, string(value), ts)
} }
func (kc *Catalog) AlterPartition(ctx context.Context, oldPart *model.Partition, newPart *model.Partition, alterType metastore.AlterType, ts typeutil.Timestamp) error { func (kc *Catalog) AlterPartition(ctx context.Context, dbID int64, oldPart *model.Partition, newPart *model.Partition, alterType metastore.AlterType, ts typeutil.Timestamp) error {
if alterType == metastore.MODIFY { if alterType == metastore.MODIFY {
return kc.alterModifyPartition(oldPart, newPart, ts) return kc.alterModifyPartition(oldPart, newPart, ts)
} }
@ -442,8 +512,8 @@ func dropPartition(collMeta *pb.CollectionInfo, partitionID typeutil.UniqueID) {
} }
} }
func (kc *Catalog) DropPartition(ctx context.Context, collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, ts typeutil.Timestamp) error { func (kc *Catalog) DropPartition(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, ts typeutil.Timestamp) error {
collMeta, err := kc.loadCollection(ctx, collectionID, ts) collMeta, err := kc.loadCollection(ctx, dbID, collectionID, ts)
if err != nil { if err != nil {
return err return err
} }
@ -453,7 +523,7 @@ func (kc *Catalog) DropPartition(ctx context.Context, collectionID typeutil.Uniq
return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, []string{k}, ts) return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, []string{k}, ts)
} }
k := BuildCollectionKey(collectionID) k := BuildCollectionKey(util.NonDBID, collectionID)
dropPartition(collMeta, partitionID) dropPartition(collMeta, partitionID)
v, err := proto.Marshal(collMeta) v, err := proto.Marshal(collMeta)
if err != nil { if err != nil {
@ -473,14 +543,16 @@ func (kc *Catalog) DropCredential(ctx context.Context, username string) error {
return nil return nil
} }
func (kc *Catalog) DropAlias(ctx context.Context, alias string, ts typeutil.Timestamp) error { func (kc *Catalog) DropAlias(ctx context.Context, dbID int64, alias string, ts typeutil.Timestamp) error {
oldKBefore210 := BuildAliasKey210(alias) oldKBefore210 := BuildAliasKey210(alias)
k := BuildAliasKey(alias) oldKeyWithoutDb := BuildAliasKey(alias)
return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, []string{k, oldKBefore210}, ts) k := BuildAliasKeyWithDB(dbID, alias)
return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, []string{k, oldKeyWithoutDb, oldKBefore210}, ts)
} }
func (kc *Catalog) GetCollectionByName(ctx context.Context, collectionName string, ts typeutil.Timestamp) (*model.Collection, error) { func (kc *Catalog) GetCollectionByName(ctx context.Context, dbID int64, collectionName string, ts typeutil.Timestamp) (*model.Collection, error) {
_, vals, err := kc.Snapshot.LoadWithPrefix(CollectionMetaPrefix, ts) prefix := getDatabasePrefix(dbID)
_, vals, err := kc.Snapshot.LoadWithPrefix(prefix, ts)
if err != nil { if err != nil {
log.Warn("get collection meta fail", zap.String("collectionName", collectionName), zap.Error(err)) log.Warn("get collection meta fail", zap.String("collectionName", collectionName), zap.Error(err))
return nil, err return nil, err
@ -495,24 +567,25 @@ func (kc *Catalog) GetCollectionByName(ctx context.Context, collectionName strin
} }
if colMeta.Schema.Name == collectionName { if colMeta.Schema.Name == collectionName {
// compatibility handled by kc.GetCollectionByID. // compatibility handled by kc.GetCollectionByID.
return kc.GetCollectionByID(ctx, colMeta.GetID(), ts) return kc.GetCollectionByID(ctx, dbID, ts, colMeta.GetID())
} }
} }
return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection: %s, at timestamp = %d", collectionName, ts)) return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection: %s, at timestamp = %d", collectionName, ts))
} }
func (kc *Catalog) ListCollections(ctx context.Context, ts typeutil.Timestamp) (map[string]*model.Collection, error) { func (kc *Catalog) ListCollections(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Collection, error) {
_, vals, err := kc.Snapshot.LoadWithPrefix(CollectionMetaPrefix, ts) prefix := getDatabasePrefix(dbID)
_, vals, err := kc.Snapshot.LoadWithPrefix(prefix, ts)
if err != nil { if err != nil {
log.Error("get collections meta fail", log.Error("get collections meta fail",
zap.String("prefix", CollectionMetaPrefix), zap.String("prefix", prefix),
zap.Uint64("timestamp", ts), zap.Uint64("timestamp", ts),
zap.Error(err)) zap.Error(err))
return nil, err return nil, err
} }
colls := make(map[string]*model.Collection) colls := make([]*model.Collection, 0, len(vals))
for _, val := range vals { for _, val := range vals {
collMeta := pb.CollectionInfo{} collMeta := pb.CollectionInfo{}
err := proto.Unmarshal([]byte(val), &collMeta) err := proto.Unmarshal([]byte(val), &collMeta)
@ -524,7 +597,7 @@ func (kc *Catalog) ListCollections(ctx context.Context, ts typeutil.Timestamp) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
colls[collMeta.Schema.Name] = collection colls = append(colls, collection)
} }
return colls, nil return colls, nil
@ -547,13 +620,15 @@ func (kc *Catalog) listAliasesBefore210(ctx context.Context, ts typeutil.Timesta
Name: coll.GetSchema().GetName(), Name: coll.GetSchema().GetName(),
CollectionID: coll.GetID(), CollectionID: coll.GetID(),
CreatedTime: 0, // not accurate. CreatedTime: 0, // not accurate.
DbID: coll.DbId,
}) })
} }
return aliases, nil return aliases, nil
} }
func (kc *Catalog) listAliasesAfter210(ctx context.Context, ts typeutil.Timestamp) ([]*model.Alias, error) { func (kc *Catalog) listAliasesAfter210WithDb(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Alias, error) {
_, values, err := kc.Snapshot.LoadWithPrefix(AliasMetaPrefix, ts) prefix := BuildAliasPrefixWithDB(dbID)
_, values, err := kc.Snapshot.LoadWithPrefix(prefix, ts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -569,24 +644,37 @@ func (kc *Catalog) listAliasesAfter210(ctx context.Context, ts typeutil.Timestam
Name: info.GetAliasName(), Name: info.GetAliasName(),
CollectionID: info.GetCollectionId(), CollectionID: info.GetCollectionId(),
CreatedTime: info.GetCreatedTime(), CreatedTime: info.GetCreatedTime(),
DbID: dbID,
}) })
} }
return aliases, nil return aliases, nil
} }
func (kc *Catalog) ListAliases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Alias, error) { func (kc *Catalog) listAliasesInDefaultDb(ctx context.Context, ts typeutil.Timestamp) ([]*model.Alias, error) {
aliases1, err := kc.listAliasesBefore210(ctx, ts) aliases1, err := kc.listAliasesBefore210(ctx, ts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
aliases2, err := kc.listAliasesAfter210(ctx, ts) aliases2, err := kc.listAliasesAfter210WithDb(ctx, util.DefaultDBID, ts)
if err != nil {
return nil, err
}
aliases3, err := kc.listAliasesAfter210WithDb(ctx, util.NonDBID, ts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
aliases := append(aliases1, aliases2...) aliases := append(aliases1, aliases2...)
aliases = append(aliases, aliases3...)
return aliases, nil return aliases, nil
} }
func (kc *Catalog) ListAliases(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Alias, error) {
if !isDefaultDB(dbID) {
return kc.listAliasesAfter210WithDb(ctx, dbID, ts)
}
return kc.listAliasesInDefaultDb(ctx, ts)
}
func (kc *Catalog) ListCredentials(ctx context.Context) ([]string, error) { func (kc *Catalog) ListCredentials(ctx context.Context) ([]string, error) {
keys, _, err := kc.Txn.LoadWithPrefix(CredentialPrefix) keys, _, err := kc.Txn.LoadWithPrefix(CredentialPrefix)
if err != nil { if err != nil {
@ -614,7 +702,7 @@ func (kc *Catalog) save(k string) error {
} }
if err == nil { if err == nil {
log.Debug("the key has existed", zap.String("key", k)) log.Debug("the key has existed", zap.String("key", k))
return common.NewIgnorableError(fmt.Errorf("the key[%s] is existed", k)) return common.NewIgnorableError(fmt.Errorf("the key[%s] has existed", k))
} }
return kc.Txn.Save(k, "") return kc.Txn.Save(k, "")
} }
@ -815,37 +903,41 @@ func (kc *Catalog) ListUser(ctx context.Context, tenant string, entity *milvuspb
func (kc *Catalog) AlterGrant(ctx context.Context, tenant string, entity *milvuspb.GrantEntity, operateType milvuspb.OperatePrivilegeType) error { func (kc *Catalog) AlterGrant(ctx context.Context, tenant string, entity *milvuspb.GrantEntity, operateType milvuspb.OperatePrivilegeType) error {
var ( var (
privilegeName = entity.Grantor.Privilege.Name privilegeName = entity.Grantor.Privilege.Name
k = funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", entity.Role.Name, entity.Object.Name, entity.ObjectName)) k = funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", entity.Role.Name, entity.Object.Name, funcutil.CombineObjectName(entity.DbName, entity.ObjectName)))
idStr string idStr string
v string v string
err error err error
) )
v, err = kc.Txn.Load(k) // Compatible with logic without db
if err != nil { if entity.DbName == util.DefaultDBName {
if common.IsKeyNotExistError(err) { v, err = kc.Txn.Load(funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", entity.Role.Name, entity.Object.Name, entity.ObjectName)))
log.Debug("not found the privilege entity", zap.String("key", k), zap.Any("type", operateType)) if err == nil {
idStr = v
} }
if funcutil.IsRevoke(operateType) { }
if common.IsKeyNotExistError(err) { if idStr == "" {
return common.NewIgnorableError(fmt.Errorf("the grant[%s] isn't existed", k)) if v, err = kc.Txn.Load(k); err == nil {
idStr = v
} else {
log.Warn("fail to load grant privilege entity", zap.String("key", k), zap.Any("type", operateType), zap.Error(err))
if funcutil.IsRevoke(operateType) {
if common.IsKeyNotExistError(err) {
return common.NewIgnorableError(fmt.Errorf("the grant[%s] isn't existed", k))
}
return err
}
if !common.IsKeyNotExistError(err) {
return err
} }
log.Warn("fail to load grant privilege entity", zap.String("key", k), zap.Any("type", operateType), zap.Error(err))
return err
}
if !common.IsKeyNotExistError(err) {
log.Warn("fail to load grant privilege entity", zap.String("key", k), zap.Any("type", operateType), zap.Error(err))
return err
}
idStr = crypto.MD5(k) idStr = crypto.MD5(k)
err = kc.Txn.Save(k, idStr) err = kc.Txn.Save(k, idStr)
if err != nil { if err != nil {
log.Error("fail to allocate id when altering the grant", zap.Error(err)) log.Error("fail to allocate id when altering the grant", zap.Error(err))
return err return err
}
} }
} else {
idStr = v
} }
k = funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", idStr, privilegeName)) k = funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", idStr, privilegeName))
_, err = kc.Txn.Load(k) _, err = kc.Txn.Load(k)
@ -882,6 +974,11 @@ func (kc *Catalog) ListGrant(ctx context.Context, tenant string, entity *milvusp
var granteeKey string var granteeKey string
appendGrantEntity := func(v string, object string, objectName string) error { appendGrantEntity := func(v string, object string, objectName string) error {
dbName := ""
dbName, objectName = funcutil.SplitObjectName(objectName)
if dbName != entity.DbName {
return nil
}
granteeIDKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, v) granteeIDKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, v)
keys, values, err := kc.Txn.LoadWithPrefix(granteeIDKey) keys, values, err := kc.Txn.LoadWithPrefix(granteeIDKey)
if err != nil { if err != nil {
@ -902,6 +999,7 @@ func (kc *Catalog) ListGrant(ctx context.Context, tenant string, entity *milvusp
Role: &milvuspb.RoleEntity{Name: entity.Role.Name}, Role: &milvuspb.RoleEntity{Name: entity.Role.Name},
Object: &milvuspb.ObjectEntity{Name: object}, Object: &milvuspb.ObjectEntity{Name: object},
ObjectName: objectName, ObjectName: objectName,
DbName: dbName,
Grantor: &milvuspb.GrantorEntity{ Grantor: &milvuspb.GrantorEntity{
User: &milvuspb.UserEntity{Name: values[i]}, User: &milvuspb.UserEntity{Name: values[i]},
Privilege: &milvuspb.PrivilegeEntity{Name: privilegeName}, Privilege: &milvuspb.PrivilegeEntity{Name: privilegeName},
@ -912,13 +1010,24 @@ func (kc *Catalog) ListGrant(ctx context.Context, tenant string, entity *milvusp
} }
if !funcutil.IsEmptyString(entity.ObjectName) && entity.Object != nil && !funcutil.IsEmptyString(entity.Object.Name) { if !funcutil.IsEmptyString(entity.ObjectName) && entity.Object != nil && !funcutil.IsEmptyString(entity.Object.Name) {
granteeKey = funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", entity.Role.Name, entity.Object.Name, entity.ObjectName)) if entity.DbName == util.DefaultDBName {
granteeKey = funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", entity.Role.Name, entity.Object.Name, entity.ObjectName))
v, err := kc.Txn.Load(granteeKey)
if err == nil {
err = appendGrantEntity(v, entity.Object.Name, entity.ObjectName)
if err == nil {
return entities, nil
}
}
}
granteeKey = funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", entity.Role.Name, entity.Object.Name, funcutil.CombineObjectName(entity.DbName, entity.ObjectName)))
v, err := kc.Txn.Load(granteeKey) v, err := kc.Txn.Load(granteeKey)
if err != nil { if err != nil {
log.Error("fail to load the grant privilege entity", zap.String("key", granteeKey), zap.Error(err)) log.Error("fail to load the grant privilege entity", zap.String("key", granteeKey), zap.Error(err))
return entities, err return entities, err
} }
err = appendGrantEntity(v, entity.Object.Name, entity.ObjectName) err = appendGrantEntity(v, entity.Object.Name, funcutil.CombineObjectName(entity.DbName, entity.ObjectName))
if err != nil { if err != nil {
return entities, err return entities, err
} }
@ -984,8 +1093,9 @@ func (kc *Catalog) ListPolicy(ctx context.Context, tenant string) ([]string, err
log.Warn("invalid grantee id", zap.String("string", idKey), zap.String("sub_string", granteeIDKey)) log.Warn("invalid grantee id", zap.String("string", idKey), zap.String("sub_string", granteeIDKey))
continue continue
} }
dbName, objectName := funcutil.SplitObjectName(grantInfos[2])
grantInfoStrs = append(grantInfoStrs, grantInfoStrs = append(grantInfoStrs,
funcutil.PolicyForPrivilege(grantInfos[0], grantInfos[1], grantInfos[2], granteeIDInfos[0])) funcutil.PolicyForPrivilege(grantInfos[0], grantInfos[1], objectName, granteeIDInfos[0], dbName))
} }
} }
return grantInfoStrs, nil return grantInfoStrs, nil
@ -1014,3 +1124,10 @@ func (kc *Catalog) ListUserRole(ctx context.Context, tenant string) ([]string, e
func (kc *Catalog) Close() { func (kc *Catalog) Close() {
// do nothing // do nothing
} }
func isDefaultDB(dbID int64) bool {
if dbID == util.DefaultDBID || dbID == util.NonDBID {
return true
}
return false
}

View File

@ -25,6 +25,7 @@ import (
pb "github.com/milvus-io/milvus/internal/proto/etcdpb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
@ -43,6 +44,10 @@ var (
} }
) )
const (
testDb = 1000
)
func TestCatalog_ListCollections(t *testing.T) { func TestCatalog_ListCollections(t *testing.T) {
ctx := context.Background() ctx := context.Background()
@ -70,6 +75,18 @@ func TestCatalog_ListCollections(t *testing.T) {
{}, {},
}, },
}, },
State: pb.CollectionState_CollectionCreated,
}
coll3 := &pb.CollectionInfo{
ID: 3,
Schema: &schemapb.CollectionSchema{
Name: "c1",
Fields: []*schemapb.FieldSchema{
{},
},
},
State: pb.CollectionState_CollectionDropping,
} }
targetErr := errors.New("fail") targetErr := errors.New("fail")
@ -81,7 +98,7 @@ func TestCatalog_ListCollections(t *testing.T) {
Return(nil, nil, targetErr) Return(nil, nil, targetErr)
kc := Catalog{Snapshot: kv} kc := Catalog{Snapshot: kv}
ret, err := kc.ListCollections(ctx, ts) ret, err := kc.ListCollections(ctx, util.NonDBID, ts)
assert.ErrorIs(t, err, targetErr) assert.ErrorIs(t, err, targetErr)
assert.Nil(t, ret) assert.Nil(t, ret)
}) })
@ -93,7 +110,7 @@ func TestCatalog_ListCollections(t *testing.T) {
bColl, err := proto.Marshal(coll2) bColl, err := proto.Marshal(coll2)
assert.NoError(t, err) assert.NoError(t, err)
kv.On("LoadWithPrefix", CollectionMetaPrefix, ts). kv.On("LoadWithPrefix", CollectionMetaPrefix, ts).
Return(nil, []string{string(bColl)}, nil) Return([]string{"key"}, []string{string(bColl)}, nil)
kv.On("LoadWithPrefix", mock.MatchedBy( kv.On("LoadWithPrefix", mock.MatchedBy(
func(prefix string) bool { func(prefix string) bool {
return strings.HasPrefix(prefix, PartitionMetaPrefix) return strings.HasPrefix(prefix, PartitionMetaPrefix)
@ -101,7 +118,7 @@ func TestCatalog_ListCollections(t *testing.T) {
Return(nil, nil, targetErr) Return(nil, nil, targetErr)
kc := Catalog{Snapshot: kv} kc := Catalog{Snapshot: kv}
ret, err := kc.ListCollections(ctx, ts) ret, err := kc.ListCollections(ctx, util.NonDBID, ts)
assert.ErrorIs(t, err, targetErr) assert.ErrorIs(t, err, targetErr)
assert.Nil(t, ret) assert.Nil(t, ret)
}) })
@ -113,7 +130,7 @@ func TestCatalog_ListCollections(t *testing.T) {
bColl, err := proto.Marshal(coll2) bColl, err := proto.Marshal(coll2)
assert.NoError(t, err) assert.NoError(t, err)
kv.On("LoadWithPrefix", CollectionMetaPrefix, ts). kv.On("LoadWithPrefix", CollectionMetaPrefix, ts).
Return(nil, []string{string(bColl)}, nil) Return([]string{"key"}, []string{string(bColl)}, nil)
partitionMeta := &pb.PartitionInfo{} partitionMeta := &pb.PartitionInfo{}
pm, err := proto.Marshal(partitionMeta) pm, err := proto.Marshal(partitionMeta)
@ -123,7 +140,7 @@ func TestCatalog_ListCollections(t *testing.T) {
func(prefix string) bool { func(prefix string) bool {
return strings.HasPrefix(prefix, PartitionMetaPrefix) return strings.HasPrefix(prefix, PartitionMetaPrefix)
}), ts). }), ts).
Return(nil, []string{string(pm)}, nil) Return([]string{"key"}, []string{string(pm)}, nil)
kv.On("LoadWithPrefix", mock.MatchedBy( kv.On("LoadWithPrefix", mock.MatchedBy(
func(prefix string) bool { func(prefix string) bool {
@ -132,7 +149,7 @@ func TestCatalog_ListCollections(t *testing.T) {
Return(nil, nil, targetErr) Return(nil, nil, targetErr)
kc := Catalog{Snapshot: kv} kc := Catalog{Snapshot: kv}
ret, err := kc.ListCollections(ctx, ts) ret, err := kc.ListCollections(ctx, util.NonDBID, ts)
assert.ErrorIs(t, err, targetErr) assert.ErrorIs(t, err, targetErr)
assert.Nil(t, ret) assert.Nil(t, ret)
}) })
@ -144,24 +161,23 @@ func TestCatalog_ListCollections(t *testing.T) {
bColl, err := proto.Marshal(coll1) bColl, err := proto.Marshal(coll1)
assert.NoError(t, err) assert.NoError(t, err)
kv.On("LoadWithPrefix", CollectionMetaPrefix, ts). kv.On("LoadWithPrefix", CollectionMetaPrefix, ts).
Return(nil, []string{string(bColl)}, nil) Return([]string{"key"}, []string{string(bColl)}, nil)
kc := Catalog{Snapshot: kv} kc := Catalog{Snapshot: kv}
ret, err := kc.ListCollections(ctx, ts) ret, err := kc.ListCollections(ctx, util.NonDBID, ts)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(ret)) assert.Equal(t, 1, len(ret))
modCol := ret["c1"] assert.Equal(t, coll1.ID, ret[0].CollectionID)
assert.Equal(t, coll1.ID, modCol.CollectionID)
}) })
t.Run("list collection ok for the newest version", func(t *testing.T) { t.Run("list collection with db", func(t *testing.T) {
kv := mocks.NewSnapShotKV(t) kv := mocks.NewSnapShotKV(t)
ts := uint64(1) ts := uint64(1)
bColl, err := proto.Marshal(coll2) bColl, err := proto.Marshal(coll2)
assert.NoError(t, err) assert.NoError(t, err)
kv.On("LoadWithPrefix", CollectionMetaPrefix, ts). kv.On("LoadWithPrefix", BuildDatabasePrefixWithDBID(testDb), ts).
Return(nil, []string{string(bColl)}, nil) Return([]string{"key"}, []string{string(bColl)}, nil)
partitionMeta := &pb.PartitionInfo{} partitionMeta := &pb.PartitionInfo{}
pm, err := proto.Marshal(partitionMeta) pm, err := proto.Marshal(partitionMeta)
@ -171,7 +187,7 @@ func TestCatalog_ListCollections(t *testing.T) {
func(prefix string) bool { func(prefix string) bool {
return strings.HasPrefix(prefix, PartitionMetaPrefix) return strings.HasPrefix(prefix, PartitionMetaPrefix)
}), ts). }), ts).
Return(nil, []string{string(pm)}, nil) Return([]string{"key"}, []string{string(pm)}, nil)
fieldMeta := &schemapb.FieldSchema{} fieldMeta := &schemapb.FieldSchema{}
fm, err := proto.Marshal(fieldMeta) fm, err := proto.Marshal(fieldMeta)
@ -181,14 +197,56 @@ func TestCatalog_ListCollections(t *testing.T) {
func(prefix string) bool { func(prefix string) bool {
return strings.HasPrefix(prefix, FieldMetaPrefix) return strings.HasPrefix(prefix, FieldMetaPrefix)
}), ts). }), ts).
Return(nil, []string{string(fm)}, nil) Return([]string{"key"}, []string{string(fm)}, nil)
kc := Catalog{Snapshot: kv} kc := Catalog{Snapshot: kv}
ret, err := kc.ListCollections(ctx, ts) ret, err := kc.ListCollections(ctx, testDb, ts)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, ret) assert.NotNil(t, ret)
assert.Equal(t, 1, len(ret)) assert.Equal(t, 1, len(ret))
}) })
t.Run("list collection ok for the newest version", func(t *testing.T) {
kv := mocks.NewSnapShotKV(t)
ts := uint64(1)
bColl, err := proto.Marshal(coll2)
assert.NoError(t, err)
aColl, err := proto.Marshal(coll3)
assert.NoError(t, err)
kv.On("LoadWithPrefix", CollectionMetaPrefix, ts).
Return([]string{"key", "key2"}, []string{string(bColl), string(aColl)}, nil)
partitionMeta := &pb.PartitionInfo{}
pm, err := proto.Marshal(partitionMeta)
assert.NoError(t, err)
kv.On("LoadWithPrefix", mock.MatchedBy(
func(prefix string) bool {
return strings.HasPrefix(prefix, PartitionMetaPrefix)
}), ts).
Return([]string{"key"}, []string{string(pm)}, nil)
fieldMeta := &schemapb.FieldSchema{}
fm, err := proto.Marshal(fieldMeta)
assert.NoError(t, err)
kv.On("LoadWithPrefix", mock.MatchedBy(
func(prefix string) bool {
return strings.HasPrefix(prefix, FieldMetaPrefix)
}), ts).
Return([]string{"key"}, []string{string(fm)}, nil)
kc := Catalog{Snapshot: kv}
ret, err := kc.ListCollections(ctx, util.NonDBID, ts)
assert.NoError(t, err)
assert.NotNil(t, ret)
assert.Equal(t, 2, len(ret))
assert.Equal(t, int64(2), ret[0].CollectionID)
assert.Equal(t, int64(3), ret[1].CollectionID)
})
} }
func TestCatalog_loadCollection(t *testing.T) { func TestCatalog_loadCollection(t *testing.T) {
@ -199,7 +257,7 @@ func TestCatalog_loadCollection(t *testing.T) {
return "", errors.New("mock") return "", errors.New("mock")
} }
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
_, err := kc.loadCollection(ctx, 1, 0) _, err := kc.loadCollection(ctx, testDb, 1, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -210,7 +268,7 @@ func TestCatalog_loadCollection(t *testing.T) {
return "not in pb format", nil return "not in pb format", nil
} }
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
_, err := kc.loadCollection(ctx, 1, 0) _, err := kc.loadCollection(ctx, testDb, 1, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -224,7 +282,7 @@ func TestCatalog_loadCollection(t *testing.T) {
return string(value), nil return string(value), nil
} }
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
got, err := kc.loadCollection(ctx, 1, 0) got, err := kc.loadCollection(ctx, 0, 1, 0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, got.GetID(), coll.GetID()) assert.Equal(t, got.GetID(), coll.GetID())
}) })
@ -307,11 +365,11 @@ func TestCatalog_GetCollectionByID(t *testing.T) {
}, },
) )
coll, err := c.GetCollectionByID(ctx, 1, 1) coll, err := c.GetCollectionByID(ctx, 0, 1, 1)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, coll) assert.Nil(t, coll)
coll, err = c.GetCollectionByID(ctx, 1, 10000) coll, err = c.GetCollectionByID(ctx, 0, 10000, 1)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, coll) assert.NotNil(t, coll)
} }
@ -324,7 +382,7 @@ func TestCatalog_CreatePartitionV2(t *testing.T) {
return "", errors.New("mock") return "", errors.New("mock")
} }
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err := kc.CreatePartition(ctx, &model.Partition{}, 0) err := kc.CreatePartition(ctx, 0, &model.Partition{}, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -345,13 +403,13 @@ func TestCatalog_CreatePartitionV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err = kc.CreatePartition(ctx, &model.Partition{}, 0) err = kc.CreatePartition(ctx, 0, &model.Partition{}, 0)
assert.Error(t, err) assert.Error(t, err)
snapshot.SaveFunc = func(key string, value string, ts typeutil.Timestamp) error { snapshot.SaveFunc = func(key string, value string, ts typeutil.Timestamp) error {
return nil return nil
} }
err = kc.CreatePartition(ctx, &model.Partition{}, 0) err = kc.CreatePartition(ctx, 0, &model.Partition{}, 0)
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -370,7 +428,7 @@ func TestCatalog_CreatePartitionV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err = kc.CreatePartition(ctx, &model.Partition{PartitionID: partID}, 0) err = kc.CreatePartition(ctx, 0, &model.Partition{PartitionID: partID}, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -389,7 +447,7 @@ func TestCatalog_CreatePartitionV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err = kc.CreatePartition(ctx, &model.Partition{PartitionName: partition}, 0) err = kc.CreatePartition(ctx, 0, &model.Partition{PartitionName: partition}, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -414,13 +472,13 @@ func TestCatalog_CreatePartitionV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err = kc.CreatePartition(ctx, &model.Partition{}, 0) err = kc.CreatePartition(ctx, 0, &model.Partition{}, 0)
assert.Error(t, err) assert.Error(t, err)
snapshot.SaveFunc = func(key string, value string, ts typeutil.Timestamp) error { snapshot.SaveFunc = func(key string, value string, ts typeutil.Timestamp) error {
return nil return nil
} }
err = kc.CreatePartition(ctx, &model.Partition{}, 0) err = kc.CreatePartition(ctx, 0, &model.Partition{}, 0)
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
@ -613,7 +671,7 @@ func TestCatalog_DropPartitionV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err := kc.DropPartition(ctx, 100, 101, 0) err := kc.DropPartition(ctx, 0, 100, 101, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -634,13 +692,13 @@ func TestCatalog_DropPartitionV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err = kc.DropPartition(ctx, 100, 101, 0) err = kc.DropPartition(ctx, 0, 100, 101, 0)
assert.Error(t, err) assert.Error(t, err)
snapshot.MultiSaveAndRemoveWithPrefixFunc = func(saves map[string]string, removals []string, ts typeutil.Timestamp) error { snapshot.MultiSaveAndRemoveWithPrefixFunc = func(saves map[string]string, removals []string, ts typeutil.Timestamp) error {
return nil return nil
} }
err = kc.DropPartition(ctx, 100, 101, 0) err = kc.DropPartition(ctx, 0, 100, 101, 0)
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -665,13 +723,13 @@ func TestCatalog_DropPartitionV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err = kc.DropPartition(ctx, 100, 101, 0) err = kc.DropPartition(ctx, 0, 100, 101, 0)
assert.Error(t, err) assert.Error(t, err)
snapshot.SaveFunc = func(key string, value string, ts typeutil.Timestamp) error { snapshot.SaveFunc = func(key string, value string, ts typeutil.Timestamp) error {
return nil return nil
} }
err = kc.DropPartition(ctx, 100, 102, 0) err = kc.DropPartition(ctx, 0, 100, 102, 0)
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
@ -686,13 +744,13 @@ func TestCatalog_DropAliasV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
err := kc.DropAlias(ctx, "alias", 0) err := kc.DropAlias(ctx, testDb, "alias", 0)
assert.Error(t, err) assert.Error(t, err)
snapshot.MultiSaveAndRemoveWithPrefixFunc = func(saves map[string]string, removals []string, ts typeutil.Timestamp) error { snapshot.MultiSaveAndRemoveWithPrefixFunc = func(saves map[string]string, removals []string, ts typeutil.Timestamp) error {
return nil return nil
} }
err = kc.DropAlias(ctx, "alias", 0) err = kc.DropAlias(ctx, testDb, "alias", 0)
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -757,7 +815,7 @@ func TestCatalog_listAliasesAfter210(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
_, err := kc.listAliasesAfter210(ctx, 0) _, err := kc.listAliasesAfter210WithDb(ctx, testDb, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -771,7 +829,7 @@ func TestCatalog_listAliasesAfter210(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
_, err := kc.listAliasesAfter210(ctx, 0) _, err := kc.listAliasesAfter210WithDb(ctx, testDb, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -789,7 +847,7 @@ func TestCatalog_listAliasesAfter210(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
got, err := kc.listAliasesAfter210(ctx, 0) got, err := kc.listAliasesAfter210WithDb(ctx, testDb, 0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(got)) assert.Equal(t, 1, len(got))
assert.Equal(t, int64(100), got[0].CollectionID) assert.Equal(t, int64(100), got[0].CollectionID)
@ -807,7 +865,7 @@ func TestCatalog_ListAliasesV2(t *testing.T) {
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
_, err := kc.ListAliases(ctx, 0) _, err := kc.ListAliases(ctx, testDb, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -823,41 +881,44 @@ func TestCatalog_ListAliasesV2(t *testing.T) {
if key == AliasMetaPrefix { if key == AliasMetaPrefix {
return nil, nil, errors.New("mock") return nil, nil, errors.New("mock")
} }
if strings.Contains(key, DatabaseMetaPrefix) {
return nil, nil, errors.New("mock")
}
return []string{"key"}, []string{string(value)}, nil return []string{"key"}, []string{string(value)}, nil
} }
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
_, err = kc.ListAliases(ctx, 0) _, err = kc.ListAliases(ctx, util.NonDBID, 0)
assert.Error(t, err)
_, err = kc.ListAliases(ctx, testDb, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("normal case", func(t *testing.T) { t.Run("normal case", func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
coll := &pb.CollectionInfo{Schema: &schemapb.CollectionSchema{Name: "alias1"}, ID: 100, ShardsNum: 50}
value, err := proto.Marshal(coll)
assert.NoError(t, err)
alias := &pb.AliasInfo{CollectionId: 101, AliasName: "alias2"} alias := &pb.AliasInfo{CollectionId: 101, AliasName: "alias2"}
value2, err := proto.Marshal(alias) value2, err := proto.Marshal(alias)
assert.NoError(t, err) assert.NoError(t, err)
snapshot := kv.NewMockSnapshotKV() snapshot := kv.NewMockSnapshotKV()
snapshot.LoadWithPrefixFunc = func(key string, ts typeutil.Timestamp) ([]string, []string, error) { snapshot.LoadWithPrefixFunc = func(key string, ts typeutil.Timestamp) ([]string, []string, error) {
if key == AliasMetaPrefix { dbStr := fmt.Sprintf("%d", testDb)
if strings.Contains(key, dbStr) && strings.Contains(key, Aliases) {
return []string{"key1"}, []string{string(value2)}, nil return []string{"key1"}, []string{string(value2)}, nil
} }
return []string{"key"}, []string{string(value)}, nil return []string{}, []string{}, nil
} }
kc := Catalog{Snapshot: snapshot} kc := Catalog{Snapshot: snapshot}
got, err := kc.ListAliases(ctx, 0) got, err := kc.ListAliases(ctx, testDb, 0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, len(got)) assert.Equal(t, 1, len(got))
assert.Equal(t, "alias1", got[0].Name) assert.Equal(t, "alias2", got[0].Name)
assert.Equal(t, "alias2", got[1].Name)
}) })
} }
@ -928,7 +989,7 @@ func TestCatalog_AlterCollection(t *testing.T) {
newC := &model.Collection{CollectionID: collectionID, State: pb.CollectionState_CollectionCreated} newC := &model.Collection{CollectionID: collectionID, State: pb.CollectionState_CollectionCreated}
err := kc.AlterCollection(ctx, oldC, newC, metastore.MODIFY, 0) err := kc.AlterCollection(ctx, oldC, newC, metastore.MODIFY, 0)
assert.NoError(t, err) assert.NoError(t, err)
key := BuildCollectionKey(collectionID) key := BuildCollectionKey(0, collectionID)
value, ok := kvs[key] value, ok := kvs[key]
assert.True(t, ok) assert.True(t, ok)
var collPb pb.CollectionInfo var collPb pb.CollectionInfo
@ -953,14 +1014,14 @@ func TestCatalog_AlterPartition(t *testing.T) {
t.Run("add", func(t *testing.T) { t.Run("add", func(t *testing.T) {
kc := &Catalog{} kc := &Catalog{}
ctx := context.Background() ctx := context.Background()
err := kc.AlterPartition(ctx, nil, nil, metastore.ADD, 0) err := kc.AlterPartition(ctx, testDb, nil, nil, metastore.ADD, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("delete", func(t *testing.T) { t.Run("delete", func(t *testing.T) {
kc := &Catalog{} kc := &Catalog{}
ctx := context.Background() ctx := context.Background()
err := kc.AlterPartition(ctx, nil, nil, metastore.DELETE, 0) err := kc.AlterPartition(ctx, testDb, nil, nil, metastore.DELETE, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -977,7 +1038,7 @@ func TestCatalog_AlterPartition(t *testing.T) {
var partitionID int64 = 2 var partitionID int64 = 2
oldP := &model.Partition{PartitionID: partitionID, CollectionID: collectionID, State: pb.PartitionState_PartitionCreating} oldP := &model.Partition{PartitionID: partitionID, CollectionID: collectionID, State: pb.PartitionState_PartitionCreating}
newP := &model.Partition{PartitionID: partitionID, CollectionID: collectionID, State: pb.PartitionState_PartitionCreated} newP := &model.Partition{PartitionID: partitionID, CollectionID: collectionID, State: pb.PartitionState_PartitionCreated}
err := kc.AlterPartition(ctx, oldP, newP, metastore.MODIFY, 0) err := kc.AlterPartition(ctx, testDb, oldP, newP, metastore.MODIFY, 0)
assert.NoError(t, err) assert.NoError(t, err)
key := BuildPartitionKey(collectionID, partitionID) key := BuildPartitionKey(collectionID, partitionID)
value, ok := kvs[key] value, ok := kvs[key]
@ -995,7 +1056,7 @@ func TestCatalog_AlterPartition(t *testing.T) {
var collectionID int64 = 1 var collectionID int64 = 1
oldP := &model.Partition{PartitionID: 1, CollectionID: collectionID, State: pb.PartitionState_PartitionCreating} oldP := &model.Partition{PartitionID: 1, CollectionID: collectionID, State: pb.PartitionState_PartitionCreating}
newP := &model.Partition{PartitionID: 2, CollectionID: collectionID, State: pb.PartitionState_PartitionCreated} newP := &model.Partition{PartitionID: 2, CollectionID: collectionID, State: pb.PartitionState_PartitionCreated}
err := kc.AlterPartition(ctx, oldP, newP, metastore.MODIFY, 0) err := kc.AlterPartition(ctx, testDb, oldP, newP, metastore.MODIFY, 0)
assert.Error(t, err) assert.Error(t, err)
}) })
} }
@ -1929,6 +1990,7 @@ func TestRBAC_Grant(t *testing.T) {
validRole = "role1" validRole = "role1"
invalidRole = "role2" invalidRole = "role2"
keyNotExistRole = "role3" keyNotExistRole = "role3"
errorSaveRole = "role100"
validUser = "user1" validUser = "user1"
invalidUser = "user2" invalidUser = "user2"
@ -1948,9 +2010,14 @@ func TestRBAC_Grant(t *testing.T) {
validRoleValue := crypto.MD5(validRoleKey) validRoleValue := crypto.MD5(validRoleKey)
invalidRoleKey := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", invalidRole, object, objName)) invalidRoleKey := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", invalidRole, object, objName))
invalidRoleKeyWithDb := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", invalidRole, object, funcutil.CombineObjectName(util.DefaultDBName, objName)))
keyNotExistRoleKey := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", keyNotExistRole, object, objName)) keyNotExistRoleKey := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", keyNotExistRole, object, objName))
keyNotExistRoleValue := crypto.MD5(keyNotExistRoleKey) keyNotExistRoleKeyWithDb := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", keyNotExistRole, object, funcutil.CombineObjectName(util.DefaultDBName, objName)))
keyNotExistRoleValueWithDb := crypto.MD5(keyNotExistRoleKeyWithDb)
errorSaveRoleKey := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", errorSaveRole, object, objName))
errorSaveRoleKeyWithDb := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant, fmt.Sprintf("%s/%s/%s", errorSaveRole, object, funcutil.CombineObjectName(util.DefaultDBName, objName)))
// Mock return in kv_catalog.go:AlterGrant:L815 // Mock return in kv_catalog.go:AlterGrant:L815
kvmock.EXPECT().Load(validRoleKey).Call. kvmock.EXPECT().Load(validRoleKey).Call.
@ -1960,16 +2027,33 @@ func TestRBAC_Grant(t *testing.T) {
Return("", func(key string) error { Return("", func(key string) error {
return fmt.Errorf("mock load error, key=%s", key) return fmt.Errorf("mock load error, key=%s", key)
}) })
kvmock.EXPECT().Load(invalidRoleKeyWithDb).Call.
Return("", func(key string) error {
return fmt.Errorf("mock load error, key=%s", key)
})
kvmock.EXPECT().Load(keyNotExistRoleKey).Call. kvmock.EXPECT().Load(keyNotExistRoleKey).Call.
Return("", func(key string) error { Return("", func(key string) error {
return common.NewKeyNotExistError(key) return common.NewKeyNotExistError(key)
}) })
kvmock.EXPECT().Save(keyNotExistRoleKey, mock.Anything).Return(nil) kvmock.EXPECT().Load(keyNotExistRoleKeyWithDb).Call.
Return("", func(key string) error {
return common.NewKeyNotExistError(key)
})
kvmock.EXPECT().Load(errorSaveRoleKey).Call.
Return("", func(key string) error {
return common.NewKeyNotExistError(key)
})
kvmock.EXPECT().Load(errorSaveRoleKeyWithDb).Call.
Return("", func(key string) error {
return common.NewKeyNotExistError(key)
})
kvmock.EXPECT().Save(keyNotExistRoleKeyWithDb, mock.Anything).Return(nil)
kvmock.EXPECT().Save(errorSaveRoleKeyWithDb, mock.Anything).Return(errors.New("mock save error role"))
validPrivilegeKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", validRoleValue, validPrivilege)) validPrivilegeKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", validRoleValue, validPrivilege))
invalidPrivilegeKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", validRoleValue, invalidPrivilege)) invalidPrivilegeKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", validRoleValue, invalidPrivilege))
keyNotExistPrivilegeKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", validRoleValue, keyNotExistPrivilege)) keyNotExistPrivilegeKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", validRoleValue, keyNotExistPrivilege))
keyNotExistPrivilegeKey2 := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", keyNotExistRoleValue, keyNotExistPrivilege2)) keyNotExistPrivilegeKey2WithDb := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", keyNotExistRoleValueWithDb, keyNotExistPrivilege2))
// Mock return in kv_catalog.go:AlterGrant:L838 // Mock return in kv_catalog.go:AlterGrant:L838
kvmock.EXPECT().Load(validPrivilegeKey).Call.Return("", nil) kvmock.EXPECT().Load(validPrivilegeKey).Call.Return("", nil)
@ -1981,7 +2065,7 @@ func TestRBAC_Grant(t *testing.T) {
Return("", func(key string) error { Return("", func(key string) error {
return common.NewKeyNotExistError(key) return common.NewKeyNotExistError(key)
}) })
kvmock.EXPECT().Load(keyNotExistPrivilegeKey2).Call. kvmock.EXPECT().Load(keyNotExistPrivilegeKey2WithDb).Call.
Return("", func(key string) error { Return("", func(key string) error {
return common.NewKeyNotExistError(key) return common.NewKeyNotExistError(key)
}) })
@ -2010,9 +2094,10 @@ func TestRBAC_Grant(t *testing.T) {
{false, validUser, invalidRole, invalidPrivilege, false, "grant invalid role with invalid privilege"}, {false, validUser, invalidRole, invalidPrivilege, false, "grant invalid role with invalid privilege"},
{false, validUser, invalidRole, validPrivilege, false, "grant invalid role with valid privilege"}, {false, validUser, invalidRole, validPrivilege, false, "grant invalid role with valid privilege"},
{false, validUser, invalidRole, keyNotExistPrivilege, false, "grant invalid role with not exist privilege"}, {false, validUser, invalidRole, keyNotExistPrivilege, false, "grant invalid role with not exist privilege"},
{false, validUser, errorSaveRole, keyNotExistPrivilege, false, "grant error role with not exist privilege"},
// not exist role // not exist role
{false, validUser, keyNotExistRole, validPrivilege, true, "grant not exist role with exist privilege"}, {false, validUser, keyNotExistRole, validPrivilege, true, "grant not exist role with exist privilege"},
{true, validUser, keyNotExistRole, keyNotExistPrivilege2, false, "grant not exist role with exist privilege"}, {true, validUser, keyNotExistRole, keyNotExistPrivilege2, false, "grant not exist role with not exist privilege"},
} }
for _, test := range tests { for _, test := range tests {
@ -2021,6 +2106,7 @@ func TestRBAC_Grant(t *testing.T) {
Role: &milvuspb.RoleEntity{Name: test.roleName}, Role: &milvuspb.RoleEntity{Name: test.roleName},
Object: &milvuspb.ObjectEntity{Name: object}, Object: &milvuspb.ObjectEntity{Name: object},
ObjectName: objName, ObjectName: objName,
DbName: util.DefaultDBName,
Grantor: &milvuspb.GrantorEntity{ Grantor: &milvuspb.GrantorEntity{
User: &milvuspb.UserEntity{Name: test.userName}, User: &milvuspb.UserEntity{Name: test.userName},
Privilege: &milvuspb.PrivilegeEntity{Name: test.privilegeName}}, Privilege: &milvuspb.PrivilegeEntity{Name: test.privilegeName}},
@ -2077,6 +2163,7 @@ func TestRBAC_Grant(t *testing.T) {
Role: &milvuspb.RoleEntity{Name: test.roleName}, Role: &milvuspb.RoleEntity{Name: test.roleName},
Object: &milvuspb.ObjectEntity{Name: object}, Object: &milvuspb.ObjectEntity{Name: object},
ObjectName: objName, ObjectName: objName,
DbName: util.DefaultDBName,
Grantor: &milvuspb.GrantorEntity{ Grantor: &milvuspb.GrantorEntity{
User: &milvuspb.UserEntity{Name: test.userName}, User: &milvuspb.UserEntity{Name: test.userName},
Privilege: &milvuspb.PrivilegeEntity{Name: test.privilegeName}}, Privilege: &milvuspb.PrivilegeEntity{Name: test.privilegeName}},
@ -2141,6 +2228,10 @@ func TestRBAC_Grant(t *testing.T) {
fmt.Sprintf("%s/%s/%s", "role1", "obj1", "obj_name1")) fmt.Sprintf("%s/%s/%s", "role1", "obj1", "obj_name1"))
kvmock.EXPECT().Load(validGranteeKey).Call. kvmock.EXPECT().Load(validGranteeKey).Call.
Return(func(key string) string { return crypto.MD5(key) }, nil) Return(func(key string) string { return crypto.MD5(key) }, nil)
validGranteeKey2 := funcutil.HandleTenantForEtcdKey(GranteePrefix, tenant,
fmt.Sprintf("%s/%s/%s", "role1", "obj1", "foo.obj_name2"))
kvmock.EXPECT().Load(validGranteeKey2).Call.
Return(func(key string) string { return crypto.MD5(key) }, nil)
kvmock.EXPECT().Load(mock.Anything).Call. kvmock.EXPECT().Load(mock.Anything).Call.
Return("", errors.New("mock Load error")) Return("", errors.New("mock Load error"))
@ -2191,10 +2282,23 @@ func TestRBAC_Grant(t *testing.T) {
Object: &milvuspb.ObjectEntity{Name: "obj1"}, Object: &milvuspb.ObjectEntity{Name: "obj1"},
ObjectName: "obj_name1", ObjectName: "obj_name1",
Role: &milvuspb.RoleEntity{Name: "role1"}}, "valid role with valid entity"}, Role: &milvuspb.RoleEntity{Name: "role1"}}, "valid role with valid entity"},
{true, &milvuspb.GrantEntity{
Object: &milvuspb.ObjectEntity{Name: "obj1"},
ObjectName: "obj_name2",
DbName: "foo",
Role: &milvuspb.RoleEntity{Name: "role1"}}, "valid role and dbName with valid entity"},
{false, &milvuspb.GrantEntity{
Object: &milvuspb.ObjectEntity{Name: "obj1"},
ObjectName: "obj_name2",
DbName: "foo2",
Role: &milvuspb.RoleEntity{Name: "role1"}}, "valid role and invalid dbName with valid entity"},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
if test.entity.DbName == "" {
test.entity.DbName = util.DefaultDBName
}
grants, err := c.ListGrant(ctx, tenant, test.entity) grants, err := c.ListGrant(ctx, tenant, test.entity)
if test.isValid { if test.isValid {
assert.NoError(t, err) assert.NoError(t, err)
@ -2290,10 +2394,10 @@ func TestRBAC_Grant(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 4, len(policy)) assert.Equal(t, 4, len(policy))
ps := []string{ ps := []string{
funcutil.PolicyForPrivilege("role1", "obj1", "obj_name1", "PrivilegeLoad"), funcutil.PolicyForPrivilege("role1", "obj1", "obj_name1", "PrivilegeLoad", "default"),
funcutil.PolicyForPrivilege("role1", "obj1", "obj_name1", "PrivilegeRelease"), funcutil.PolicyForPrivilege("role1", "obj1", "obj_name1", "PrivilegeRelease", "default"),
funcutil.PolicyForPrivilege("role2", "obj2", "obj_name2", "PrivilegeLoad"), funcutil.PolicyForPrivilege("role2", "obj2", "obj_name2", "PrivilegeLoad", "default"),
funcutil.PolicyForPrivilege("role2", "obj2", "obj_name2", "PrivilegeRelease"), funcutil.PolicyForPrivilege("role2", "obj2", "obj_name2", "PrivilegeRelease", "default"),
} }
assert.ElementsMatch(t, ps, policy) assert.ElementsMatch(t, ps, policy)
} else { } else {

View File

@ -1,9 +1,19 @@
package rootcoord package rootcoord
import (
"fmt"
"github.com/milvus-io/milvus/pkg/util"
)
const ( const (
// ComponentPrefix prefix for rootcoord component // ComponentPrefix prefix for rootcoord component
ComponentPrefix = "root-coord" ComponentPrefix = "root-coord"
DatabaseMetaPrefix = ComponentPrefix + "/database"
DBInfoMetaPrefix = DatabaseMetaPrefix + "/db-info"
CollectionInfoMetaPrefix = DatabaseMetaPrefix + "/collection-info"
// CollectionMetaPrefix prefix for collection meta // CollectionMetaPrefix prefix for collection meta
CollectionMetaPrefix = ComponentPrefix + "/collection" CollectionMetaPrefix = ComponentPrefix + "/collection"
@ -16,6 +26,7 @@ const (
SnapshotsSep = "_ts" SnapshotsSep = "_ts"
SnapshotPrefix = "snapshots" SnapshotPrefix = "snapshots"
Aliases = "aliases"
// CommonCredentialPrefix subpath for common credential // CommonCredentialPrefix subpath for common credential
/* #nosec G101 */ /* #nosec G101 */
@ -39,3 +50,22 @@ const (
// GranteeIDPrefix prefix for mapping among privilege and grantor // GranteeIDPrefix prefix for mapping among privilege and grantor
GranteeIDPrefix = ComponentPrefix + CommonCredentialPrefix + "/grantee-id" GranteeIDPrefix = ComponentPrefix + CommonCredentialPrefix + "/grantee-id"
) )
func BuildDatabasePrefixWithDBID(dbID int64) string {
return fmt.Sprintf("%s/%d", CollectionInfoMetaPrefix, dbID)
}
func BuildCollectionKeyWithDBID(dbID int64, collectionID int64) string {
return fmt.Sprintf("%s/%d/%d", CollectionInfoMetaPrefix, dbID, collectionID)
}
func BuildDatabaseKey(dbID int64) string {
return fmt.Sprintf("%s/%d", DBInfoMetaPrefix, dbID)
}
func getDatabasePrefix(dbID int64) string {
if dbID != util.NonDBID {
return BuildDatabasePrefixWithDBID(dbID)
}
return CollectionMetaPrefix
}

View File

@ -74,13 +74,13 @@ func (_m *RootCoordCatalog) AlterGrant(ctx context.Context, tenant string, entit
return r0 return r0
} }
// AlterPartition provides a mock function with given fields: ctx, oldPart, newPart, alterType, ts // AlterPartition provides a mock function with given fields: ctx, dbID, oldPart, newPart, alterType, ts
func (_m *RootCoordCatalog) AlterPartition(ctx context.Context, oldPart *model.Partition, newPart *model.Partition, alterType metastore.AlterType, ts uint64) error { func (_m *RootCoordCatalog) AlterPartition(ctx context.Context, dbID int64, oldPart *model.Partition, newPart *model.Partition, alterType metastore.AlterType, ts uint64) error {
ret := _m.Called(ctx, oldPart, newPart, alterType, ts) ret := _m.Called(ctx, dbID, oldPart, newPart, alterType, ts)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *model.Partition, *model.Partition, metastore.AlterType, uint64) error); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, *model.Partition, *model.Partition, metastore.AlterType, uint64) error); ok {
r0 = rf(ctx, oldPart, newPart, alterType, ts) r0 = rf(ctx, dbID, oldPart, newPart, alterType, ts)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -107,13 +107,13 @@ func (_m *RootCoordCatalog) Close() {
_m.Called() _m.Called()
} }
// CollectionExists provides a mock function with given fields: ctx, collectionID, ts // CollectionExists provides a mock function with given fields: ctx, dbID, collectionID, ts
func (_m *RootCoordCatalog) CollectionExists(ctx context.Context, collectionID int64, ts uint64) bool { func (_m *RootCoordCatalog) CollectionExists(ctx context.Context, dbID int64, collectionID int64, ts uint64) bool {
ret := _m.Called(ctx, collectionID, ts) ret := _m.Called(ctx, dbID, collectionID, ts)
var r0 bool var r0 bool
if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) bool); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, int64, uint64) bool); ok {
r0 = rf(ctx, collectionID, ts) r0 = rf(ctx, dbID, collectionID, ts)
} else { } else {
r0 = ret.Get(0).(bool) r0 = ret.Get(0).(bool)
} }
@ -163,13 +163,27 @@ func (_m *RootCoordCatalog) CreateCredential(ctx context.Context, credential *mo
return r0 return r0
} }
// CreatePartition provides a mock function with given fields: ctx, partition, ts // CreateDatabase provides a mock function with given fields: ctx, db, ts
func (_m *RootCoordCatalog) CreatePartition(ctx context.Context, partition *model.Partition, ts uint64) error { func (_m *RootCoordCatalog) CreateDatabase(ctx context.Context, db *model.Database, ts uint64) error {
ret := _m.Called(ctx, partition, ts) ret := _m.Called(ctx, db, ts)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *model.Partition, uint64) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *model.Database, uint64) error); ok {
r0 = rf(ctx, partition, ts) r0 = rf(ctx, db, ts)
} else {
r0 = ret.Error(0)
}
return r0
}
// CreatePartition provides a mock function with given fields: ctx, dbID, partition, ts
func (_m *RootCoordCatalog) CreatePartition(ctx context.Context, dbID int64, partition *model.Partition, ts uint64) error {
ret := _m.Called(ctx, dbID, partition, ts)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, *model.Partition, uint64) error); ok {
r0 = rf(ctx, dbID, partition, ts)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -205,13 +219,13 @@ func (_m *RootCoordCatalog) DeleteGrant(ctx context.Context, tenant string, role
return r0 return r0
} }
// DropAlias provides a mock function with given fields: ctx, alias, ts // DropAlias provides a mock function with given fields: ctx, dbID, alias, ts
func (_m *RootCoordCatalog) DropAlias(ctx context.Context, alias string, ts uint64) error { func (_m *RootCoordCatalog) DropAlias(ctx context.Context, dbID int64, alias string, ts uint64) error {
ret := _m.Called(ctx, alias, ts) ret := _m.Called(ctx, dbID, alias, ts)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, uint64) error); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, string, uint64) error); ok {
r0 = rf(ctx, alias, ts) r0 = rf(ctx, dbID, alias, ts)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -247,13 +261,27 @@ func (_m *RootCoordCatalog) DropCredential(ctx context.Context, username string)
return r0 return r0
} }
// DropPartition provides a mock function with given fields: ctx, collectionID, partitionID, ts // DropDatabase provides a mock function with given fields: ctx, dbID, ts
func (_m *RootCoordCatalog) DropPartition(ctx context.Context, collectionID int64, partitionID int64, ts uint64) error { func (_m *RootCoordCatalog) DropDatabase(ctx context.Context, dbID int64, ts uint64) error {
ret := _m.Called(ctx, collectionID, partitionID, ts) ret := _m.Called(ctx, dbID, ts)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, int64, uint64) error); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) error); ok {
r0 = rf(ctx, collectionID, partitionID, ts) r0 = rf(ctx, dbID, ts)
} else {
r0 = ret.Error(0)
}
return r0
}
// DropPartition provides a mock function with given fields: ctx, dbID, collectionID, partitionID, ts
func (_m *RootCoordCatalog) DropPartition(ctx context.Context, dbID int64, collectionID int64, partitionID int64, ts uint64) error {
ret := _m.Called(ctx, dbID, collectionID, partitionID, ts)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, int64, int64, uint64) error); ok {
r0 = rf(ctx, dbID, collectionID, partitionID, ts)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -275,13 +303,13 @@ func (_m *RootCoordCatalog) DropRole(ctx context.Context, tenant string, roleNam
return r0 return r0
} }
// GetCollectionByID provides a mock function with given fields: ctx, collectionID, ts // GetCollectionByID provides a mock function with given fields: ctx, dbID, ts, collectionID
func (_m *RootCoordCatalog) GetCollectionByID(ctx context.Context, collectionID int64, ts uint64) (*model.Collection, error) { func (_m *RootCoordCatalog) GetCollectionByID(ctx context.Context, dbID int64, ts uint64, collectionID int64) (*model.Collection, error) {
ret := _m.Called(ctx, collectionID, ts) ret := _m.Called(ctx, dbID, ts, collectionID)
var r0 *model.Collection var r0 *model.Collection
if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) *model.Collection); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, uint64, int64) *model.Collection); ok {
r0 = rf(ctx, collectionID, ts) r0 = rf(ctx, dbID, ts, collectionID)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.Collection) r0 = ret.Get(0).(*model.Collection)
@ -289,8 +317,8 @@ func (_m *RootCoordCatalog) GetCollectionByID(ctx context.Context, collectionID
} }
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, uint64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int64, uint64, int64) error); ok {
r1 = rf(ctx, collectionID, ts) r1 = rf(ctx, dbID, ts, collectionID)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -298,13 +326,13 @@ func (_m *RootCoordCatalog) GetCollectionByID(ctx context.Context, collectionID
return r0, r1 return r0, r1
} }
// GetCollectionByName provides a mock function with given fields: ctx, collectionName, ts // GetCollectionByName provides a mock function with given fields: ctx, dbID, collectionName, ts
func (_m *RootCoordCatalog) GetCollectionByName(ctx context.Context, collectionName string, ts uint64) (*model.Collection, error) { func (_m *RootCoordCatalog) GetCollectionByName(ctx context.Context, dbID int64, collectionName string, ts uint64) (*model.Collection, error) {
ret := _m.Called(ctx, collectionName, ts) ret := _m.Called(ctx, dbID, collectionName, ts)
var r0 *model.Collection var r0 *model.Collection
if rf, ok := ret.Get(0).(func(context.Context, string, uint64) *model.Collection); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, string, uint64) *model.Collection); ok {
r0 = rf(ctx, collectionName, ts) r0 = rf(ctx, dbID, collectionName, ts)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.Collection) r0 = ret.Get(0).(*model.Collection)
@ -312,8 +340,8 @@ func (_m *RootCoordCatalog) GetCollectionByName(ctx context.Context, collectionN
} }
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, uint64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int64, string, uint64) error); ok {
r1 = rf(ctx, collectionName, ts) r1 = rf(ctx, dbID, collectionName, ts)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -344,13 +372,13 @@ func (_m *RootCoordCatalog) GetCredential(ctx context.Context, username string)
return r0, r1 return r0, r1
} }
// ListAliases provides a mock function with given fields: ctx, ts // ListAliases provides a mock function with given fields: ctx, dbID, ts
func (_m *RootCoordCatalog) ListAliases(ctx context.Context, ts uint64) ([]*model.Alias, error) { func (_m *RootCoordCatalog) ListAliases(ctx context.Context, dbID int64, ts uint64) ([]*model.Alias, error) {
ret := _m.Called(ctx, ts) ret := _m.Called(ctx, dbID, ts)
var r0 []*model.Alias var r0 []*model.Alias
if rf, ok := ret.Get(0).(func(context.Context, uint64) []*model.Alias); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) []*model.Alias); ok {
r0 = rf(ctx, ts) r0 = rf(ctx, dbID, ts)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).([]*model.Alias) r0 = ret.Get(0).([]*model.Alias)
@ -358,8 +386,8 @@ func (_m *RootCoordCatalog) ListAliases(ctx context.Context, ts uint64) ([]*mode
} }
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(context.Context, uint64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int64, uint64) error); ok {
r1 = rf(ctx, ts) r1 = rf(ctx, dbID, ts)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -367,22 +395,22 @@ func (_m *RootCoordCatalog) ListAliases(ctx context.Context, ts uint64) ([]*mode
return r0, r1 return r0, r1
} }
// ListCollections provides a mock function with given fields: ctx, ts // ListCollections provides a mock function with given fields: ctx, dbID, ts
func (_m *RootCoordCatalog) ListCollections(ctx context.Context, ts uint64) (map[string]*model.Collection, error) { func (_m *RootCoordCatalog) ListCollections(ctx context.Context, dbID int64, ts uint64) ([]*model.Collection, error) {
ret := _m.Called(ctx, ts) ret := _m.Called(ctx, dbID, ts)
var r0 map[string]*model.Collection var r0 []*model.Collection
if rf, ok := ret.Get(0).(func(context.Context, uint64) map[string]*model.Collection); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) []*model.Collection); ok {
r0 = rf(ctx, ts) r0 = rf(ctx, dbID, ts)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]*model.Collection) r0 = ret.Get(0).([]*model.Collection)
} }
} }
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(context.Context, uint64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int64, uint64) error); ok {
r1 = rf(ctx, ts) r1 = rf(ctx, dbID, ts)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -413,6 +441,29 @@ func (_m *RootCoordCatalog) ListCredentials(ctx context.Context) ([]string, erro
return r0, r1 return r0, r1
} }
// ListDatabases provides a mock function with given fields: ctx, ts
func (_m *RootCoordCatalog) ListDatabases(ctx context.Context, ts uint64) ([]*model.Database, error) {
ret := _m.Called(ctx, ts)
var r0 []*model.Database
if rf, ok := ret.Get(0).(func(context.Context, uint64) []*model.Database); ok {
r0 = rf(ctx, ts)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*model.Database)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, uint64) error); ok {
r1 = rf(ctx, ts)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListGrant provides a mock function with given fields: ctx, tenant, entity // ListGrant provides a mock function with given fields: ctx, tenant, entity
func (_m *RootCoordCatalog) ListGrant(ctx context.Context, tenant string, entity *milvuspb.GrantEntity) ([]*milvuspb.GrantEntity, error) { func (_m *RootCoordCatalog) ListGrant(ctx context.Context, tenant string, entity *milvuspb.GrantEntity) ([]*milvuspb.GrantEntity, error) {
ret := _m.Called(ctx, tenant, entity) ret := _m.Called(ctx, tenant, entity)

View File

@ -7,6 +7,7 @@ type Alias struct {
CollectionID int64 CollectionID int64
CreatedTime uint64 CreatedTime uint64
State pb.AliasState State pb.AliasState
DbID int64
} }
func (a Alias) Available() bool { func (a Alias) Available() bool {
@ -19,12 +20,14 @@ func (a Alias) Clone() *Alias {
CollectionID: a.CollectionID, CollectionID: a.CollectionID,
CreatedTime: a.CreatedTime, CreatedTime: a.CreatedTime,
State: a.State, State: a.State,
DbID: a.DbID,
} }
} }
func (a Alias) Equal(other Alias) bool { func (a Alias) Equal(other Alias) bool {
return a.Name == other.Name && return a.Name == other.Name &&
a.CollectionID == other.CollectionID a.CollectionID == other.CollectionID &&
a.DbID == other.DbID
} }
func MarshalAliasModel(alias *Alias) *pb.AliasInfo { func MarshalAliasModel(alias *Alias) *pb.AliasInfo {
@ -33,6 +36,7 @@ func MarshalAliasModel(alias *Alias) *pb.AliasInfo {
CollectionId: alias.CollectionID, CollectionId: alias.CollectionID,
CreatedTime: alias.CreatedTime, CreatedTime: alias.CreatedTime,
State: alias.State, State: alias.State,
DbId: alias.DbID,
} }
} }
@ -42,5 +46,6 @@ func UnmarshalAliasModel(info *pb.AliasInfo) *Alias {
CollectionID: info.GetCollectionId(), CollectionID: info.GetCollectionId(),
CreatedTime: info.GetCreatedTime(), CreatedTime: info.GetCreatedTime(),
State: info.GetState(), State: info.GetState(),
DbID: info.GetDbId(),
} }
} }

View File

@ -10,7 +10,7 @@ import (
type Collection struct { type Collection struct {
TenantID string TenantID string
DBName string // TODO: @jiquan.long please help to assign, persistent and check DBID int64
CollectionID int64 CollectionID int64
Partitions []*Partition Partitions []*Partition
Name string Name string
@ -36,6 +36,7 @@ func (c Collection) Available() bool {
func (c Collection) Clone() *Collection { func (c Collection) Clone() *Collection {
return &Collection{ return &Collection{
TenantID: c.TenantID, TenantID: c.TenantID,
DBID: c.DBID,
CollectionID: c.CollectionID, CollectionID: c.CollectionID,
Name: c.Name, Name: c.Name,
Description: c.Description, Description: c.Description,
@ -64,6 +65,7 @@ func (c Collection) GetPartitionNum(filterUnavailable bool) int {
func (c Collection) Equal(other Collection) bool { func (c Collection) Equal(other Collection) bool {
return c.TenantID == other.TenantID && return c.TenantID == other.TenantID &&
c.DBID == other.DBID &&
CheckPartitionsEqual(c.Partitions, other.Partitions) && CheckPartitionsEqual(c.Partitions, other.Partitions) &&
c.Name == other.Name && c.Name == other.Name &&
c.Description == other.Description && c.Description == other.Description &&
@ -92,6 +94,7 @@ func UnmarshalCollectionModel(coll *pb.CollectionInfo) *Collection {
return &Collection{ return &Collection{
CollectionID: coll.ID, CollectionID: coll.ID,
DBID: coll.DbId,
Name: coll.Schema.Name, Name: coll.Schema.Name,
Description: coll.Schema.Description, Description: coll.Schema.Description,
AutoID: coll.Schema.AutoID, AutoID: coll.Schema.AutoID,
@ -157,6 +160,7 @@ func marshalCollectionModelWithConfig(coll *Collection, c *config) *pb.Collectio
collectionPb := &pb.CollectionInfo{ collectionPb := &pb.CollectionInfo{
ID: coll.CollectionID, ID: coll.CollectionID,
DbId: coll.DBID,
Schema: collSchema, Schema: collSchema,
CreateTime: coll.CreateTime, CreateTime: coll.CreateTime,
VirtualChannelNames: coll.VirtualChannelNames, VirtualChannelNames: coll.VirtualChannelNames,

View File

@ -0,0 +1,79 @@
package model
import (
"time"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/util"
)
type Database struct {
TenantID string
ID int64
Name string
State pb.DatabaseState
CreatedTime uint64
}
func NewDatabase(id int64, name string, sate pb.DatabaseState) *Database {
return &Database{
ID: id,
Name: name,
State: sate,
CreatedTime: uint64(time.Now().UnixNano()),
}
}
func NewDefaultDatabase() *Database {
return NewDatabase(util.DefaultDBID, util.DefaultDBName, pb.DatabaseState_DatabaseCreated)
}
func (c Database) Available() bool {
return c.State == pb.DatabaseState_DatabaseCreated
}
func (c Database) Clone() *Database {
return &Database{
TenantID: c.TenantID,
ID: c.ID,
Name: c.Name,
State: c.State,
CreatedTime: c.CreatedTime,
}
}
func (c Database) Equal(other Database) bool {
return c.TenantID == other.TenantID &&
c.Name == other.Name &&
c.ID == other.ID &&
c.State == other.State &&
c.CreatedTime == other.CreatedTime
}
func MarshalDatabaseModel(db *Database) *pb.DatabaseInfo {
if db == nil {
return nil
}
return &pb.DatabaseInfo{
TenantId: db.TenantID,
Id: db.ID,
Name: db.Name,
State: db.State,
CreatedTime: db.CreatedTime,
}
}
func UnmarshalDatabaseModel(info *pb.DatabaseInfo) *Database {
if info == nil {
return nil
}
return &Database{
Name: info.GetName(),
ID: info.GetId(),
CreatedTime: info.GetCreatedTime(),
State: info.GetState(),
TenantID: info.GetTenantId(),
}
}

View File

@ -0,0 +1,49 @@
package model
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
)
var (
dbPB = &etcdpb.DatabaseInfo{
TenantId: "1",
Name: "test",
Id: 1,
CreatedTime: 1,
State: etcdpb.DatabaseState_DatabaseCreated,
}
dbModel = &Database{
TenantID: "1",
Name: "test",
ID: 1,
CreatedTime: 1,
State: etcdpb.DatabaseState_DatabaseCreated,
}
)
func TestMarshalDatabaseModel(t *testing.T) {
ret := MarshalDatabaseModel(dbModel)
assert.Equal(t, dbPB, ret)
assert.Nil(t, MarshalDatabaseModel(nil))
}
func TestUnmarshalDatabaseModel(t *testing.T) {
ret := UnmarshalDatabaseModel(dbPB)
assert.Equal(t, dbModel, ret)
assert.Nil(t, UnmarshalDatabaseModel(nil))
}
func TestDatabaseCloneAndEqual(t *testing.T) {
clone := dbModel.Clone()
assert.Equal(t, dbModel, clone)
}
func TestDatabaseAvailable(t *testing.T) {
assert.True(t, dbModel.Available())
assert.True(t, NewDefaultDatabase().Available())
}

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.14.0. DO NOT EDIT. // Code generated by mockery v2.16.0. DO NOT EDIT.
package mocks package mocks
@ -407,6 +407,53 @@ func (_c *RootCoord_CreateCredential_Call) Return(_a0 *commonpb.Status, _a1 erro
return _c return _c
} }
// CreateDatabase provides a mock function with given fields: ctx, req
func (_m *RootCoord) CreateDatabase(ctx context.Context, req *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, req)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateDatabaseRequest) *commonpb.Status); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateDatabaseRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RootCoord_CreateDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateDatabase'
type RootCoord_CreateDatabase_Call struct {
*mock.Call
}
// CreateDatabase is a helper method to define mock.On call
// - ctx context.Context
// - req *milvuspb.CreateDatabaseRequest
func (_e *RootCoord_Expecter) CreateDatabase(ctx interface{}, req interface{}) *RootCoord_CreateDatabase_Call {
return &RootCoord_CreateDatabase_Call{Call: _e.mock.On("CreateDatabase", ctx, req)}
}
func (_c *RootCoord_CreateDatabase_Call) Run(run func(ctx context.Context, req *milvuspb.CreateDatabaseRequest)) *RootCoord_CreateDatabase_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*milvuspb.CreateDatabaseRequest))
})
return _c
}
func (_c *RootCoord_CreateDatabase_Call) Return(_a0 *commonpb.Status, _a1 error) *RootCoord_CreateDatabase_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// CreatePartition provides a mock function with given fields: ctx, req // CreatePartition provides a mock function with given fields: ctx, req
func (_m *RootCoord) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { func (_m *RootCoord) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
@ -736,6 +783,53 @@ func (_c *RootCoord_DropCollection_Call) Return(_a0 *commonpb.Status, _a1 error)
return _c return _c
} }
// DropDatabase provides a mock function with given fields: ctx, req
func (_m *RootCoord) DropDatabase(ctx context.Context, req *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, req)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropDatabaseRequest) *commonpb.Status); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropDatabaseRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RootCoord_DropDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropDatabase'
type RootCoord_DropDatabase_Call struct {
*mock.Call
}
// DropDatabase is a helper method to define mock.On call
// - ctx context.Context
// - req *milvuspb.DropDatabaseRequest
func (_e *RootCoord_Expecter) DropDatabase(ctx interface{}, req interface{}) *RootCoord_DropDatabase_Call {
return &RootCoord_DropDatabase_Call{Call: _e.mock.On("DropDatabase", ctx, req)}
}
func (_c *RootCoord_DropDatabase_Call) Run(run func(ctx context.Context, req *milvuspb.DropDatabaseRequest)) *RootCoord_DropDatabase_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*milvuspb.DropDatabaseRequest))
})
return _c
}
func (_c *RootCoord_DropDatabase_Call) Return(_a0 *commonpb.Status, _a1 error) *RootCoord_DropDatabase_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// DropPartition provides a mock function with given fields: ctx, req // DropPartition provides a mock function with given fields: ctx, req
func (_m *RootCoord) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { func (_m *RootCoord) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
@ -1380,6 +1474,53 @@ func (_c *RootCoord_ListCredUsers_Call) Return(_a0 *milvuspb.ListCredUsersRespon
return _c return _c
} }
// ListDatabases provides a mock function with given fields: ctx, req
func (_m *RootCoord) ListDatabases(ctx context.Context, req *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
ret := _m.Called(ctx, req)
var r0 *milvuspb.ListDatabasesResponse
if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListDatabasesRequest) *milvuspb.ListDatabasesResponse); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*milvuspb.ListDatabasesResponse)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListDatabasesRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RootCoord_ListDatabases_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListDatabases'
type RootCoord_ListDatabases_Call struct {
*mock.Call
}
// ListDatabases is a helper method to define mock.On call
// - ctx context.Context
// - req *milvuspb.ListDatabasesRequest
func (_e *RootCoord_Expecter) ListDatabases(ctx interface{}, req interface{}) *RootCoord_ListDatabases_Call {
return &RootCoord_ListDatabases_Call{Call: _e.mock.On("ListDatabases", ctx, req)}
}
func (_c *RootCoord_ListDatabases_Call) Run(run func(ctx context.Context, req *milvuspb.ListDatabasesRequest)) *RootCoord_ListDatabases_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*milvuspb.ListDatabasesRequest))
})
return _c
}
func (_c *RootCoord_ListDatabases_Call) Return(_a0 *milvuspb.ListDatabasesResponse, _a1 error) *RootCoord_ListDatabases_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// ListImportTasks provides a mock function with given fields: ctx, req // ListImportTasks provides a mock function with given fields: ctx, req
func (_m *RootCoord) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { func (_m *RootCoord) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)

View File

@ -20,6 +20,14 @@ message FieldIndexInfo{
int64 indexID = 2; int64 indexID = 2;
} }
enum DatabaseState {
DatabaseUnknown = 0;
DatabaseCreated = 1;
DatabaseCreating = 2;
DatabaseDropping = 3;
DatabaseDropped = 4;
}
enum CollectionState { enum CollectionState {
CollectionCreated = 0; CollectionCreated = 0;
CollectionCreating = 1; CollectionCreating = 1;
@ -60,6 +68,7 @@ message CollectionInfo {
common.ConsistencyLevel consistency_level = 12; common.ConsistencyLevel consistency_level = 12;
CollectionState state = 13; // To keep compatible with older version, default state is `Created`. CollectionState state = 13; // To keep compatible with older version, default state is `Created`.
repeated common.KeyValuePair properties = 14; repeated common.KeyValuePair properties = 14;
int64 db_id = 15;
} }
message PartitionInfo { message PartitionInfo {
@ -75,6 +84,15 @@ message AliasInfo {
int64 collection_id = 2; int64 collection_id = 2;
uint64 created_time = 3; uint64 created_time = 3;
AliasState state = 4; // To keep compatible with older version, default state is `Created`. AliasState state = 4; // To keep compatible with older version, default state is `Created`.
int64 db_id = 5;
}
message DatabaseInfo {
string tenant_id = 1;
string name = 2;
int64 id = 3;
DatabaseState state = 4;
uint64 created_time = 5;
} }
message SegmentIndexInfo { message SegmentIndexInfo {

View File

@ -22,6 +22,40 @@ var _ = math.Inf
// proto package needs to be updated. // proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type DatabaseState int32
const (
DatabaseState_DatabaseUnknown DatabaseState = 0
DatabaseState_DatabaseCreated DatabaseState = 1
DatabaseState_DatabaseCreating DatabaseState = 2
DatabaseState_DatabaseDropping DatabaseState = 3
DatabaseState_DatabaseDropped DatabaseState = 4
)
var DatabaseState_name = map[int32]string{
0: "DatabaseUnknown",
1: "DatabaseCreated",
2: "DatabaseCreating",
3: "DatabaseDropping",
4: "DatabaseDropped",
}
var DatabaseState_value = map[string]int32{
"DatabaseUnknown": 0,
"DatabaseCreated": 1,
"DatabaseCreating": 2,
"DatabaseDropping": 3,
"DatabaseDropped": 4,
}
func (x DatabaseState) String() string {
return proto.EnumName(DatabaseState_name, int32(x))
}
func (DatabaseState) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_975d306d62b73e88, []int{0}
}
type CollectionState int32 type CollectionState int32
const ( const (
@ -50,7 +84,7 @@ func (x CollectionState) String() string {
} }
func (CollectionState) EnumDescriptor() ([]byte, []int) { func (CollectionState) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_975d306d62b73e88, []int{0} return fileDescriptor_975d306d62b73e88, []int{1}
} }
type PartitionState int32 type PartitionState int32
@ -81,7 +115,7 @@ func (x PartitionState) String() string {
} }
func (PartitionState) EnumDescriptor() ([]byte, []int) { func (PartitionState) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_975d306d62b73e88, []int{1} return fileDescriptor_975d306d62b73e88, []int{2}
} }
type AliasState int32 type AliasState int32
@ -112,7 +146,7 @@ func (x AliasState) String() string {
} }
func (AliasState) EnumDescriptor() ([]byte, []int) { func (AliasState) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_975d306d62b73e88, []int{2} return fileDescriptor_975d306d62b73e88, []int{3}
} }
type IndexInfo struct { type IndexInfo struct {
@ -252,6 +286,7 @@ type CollectionInfo struct {
ConsistencyLevel commonpb.ConsistencyLevel `protobuf:"varint,12,opt,name=consistency_level,json=consistencyLevel,proto3,enum=milvus.proto.common.ConsistencyLevel" json:"consistency_level,omitempty"` ConsistencyLevel commonpb.ConsistencyLevel `protobuf:"varint,12,opt,name=consistency_level,json=consistencyLevel,proto3,enum=milvus.proto.common.ConsistencyLevel" json:"consistency_level,omitempty"`
State CollectionState `protobuf:"varint,13,opt,name=state,proto3,enum=milvus.proto.etcd.CollectionState" json:"state,omitempty"` State CollectionState `protobuf:"varint,13,opt,name=state,proto3,enum=milvus.proto.etcd.CollectionState" json:"state,omitempty"`
Properties []*commonpb.KeyValuePair `protobuf:"bytes,14,rep,name=properties,proto3" json:"properties,omitempty"` Properties []*commonpb.KeyValuePair `protobuf:"bytes,14,rep,name=properties,proto3" json:"properties,omitempty"`
DbId int64 `protobuf:"varint,15,opt,name=db_id,json=dbId,proto3" json:"db_id,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"` XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"` XXX_sizecache int32 `json:"-"`
@ -380,6 +415,13 @@ func (m *CollectionInfo) GetProperties() []*commonpb.KeyValuePair {
return nil return nil
} }
func (m *CollectionInfo) GetDbId() int64 {
if m != nil {
return m.DbId
}
return 0
}
type PartitionInfo struct { type PartitionInfo struct {
PartitionID int64 `protobuf:"varint,1,opt,name=partitionID,proto3" json:"partitionID,omitempty"` PartitionID int64 `protobuf:"varint,1,opt,name=partitionID,proto3" json:"partitionID,omitempty"`
PartitionName string `protobuf:"bytes,2,opt,name=partitionName,proto3" json:"partitionName,omitempty"` PartitionName string `protobuf:"bytes,2,opt,name=partitionName,proto3" json:"partitionName,omitempty"`
@ -456,6 +498,7 @@ type AliasInfo struct {
CollectionId int64 `protobuf:"varint,2,opt,name=collection_id,json=collectionId,proto3" json:"collection_id,omitempty"` CollectionId int64 `protobuf:"varint,2,opt,name=collection_id,json=collectionId,proto3" json:"collection_id,omitempty"`
CreatedTime uint64 `protobuf:"varint,3,opt,name=created_time,json=createdTime,proto3" json:"created_time,omitempty"` CreatedTime uint64 `protobuf:"varint,3,opt,name=created_time,json=createdTime,proto3" json:"created_time,omitempty"`
State AliasState `protobuf:"varint,4,opt,name=state,proto3,enum=milvus.proto.etcd.AliasState" json:"state,omitempty"` State AliasState `protobuf:"varint,4,opt,name=state,proto3,enum=milvus.proto.etcd.AliasState" json:"state,omitempty"`
DbId int64 `protobuf:"varint,5,opt,name=db_id,json=dbId,proto3" json:"db_id,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"` XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"` XXX_sizecache int32 `json:"-"`
@ -514,6 +557,84 @@ func (m *AliasInfo) GetState() AliasState {
return AliasState_AliasCreated return AliasState_AliasCreated
} }
func (m *AliasInfo) GetDbId() int64 {
if m != nil {
return m.DbId
}
return 0
}
type DatabaseInfo struct {
TenantId string `protobuf:"bytes,1,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"`
Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
Id int64 `protobuf:"varint,3,opt,name=id,proto3" json:"id,omitempty"`
State DatabaseState `protobuf:"varint,4,opt,name=state,proto3,enum=milvus.proto.etcd.DatabaseState" json:"state,omitempty"`
CreatedTime uint64 `protobuf:"varint,5,opt,name=created_time,json=createdTime,proto3" json:"created_time,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *DatabaseInfo) Reset() { *m = DatabaseInfo{} }
func (m *DatabaseInfo) String() string { return proto.CompactTextString(m) }
func (*DatabaseInfo) ProtoMessage() {}
func (*DatabaseInfo) Descriptor() ([]byte, []int) {
return fileDescriptor_975d306d62b73e88, []int{5}
}
func (m *DatabaseInfo) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_DatabaseInfo.Unmarshal(m, b)
}
func (m *DatabaseInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_DatabaseInfo.Marshal(b, m, deterministic)
}
func (m *DatabaseInfo) XXX_Merge(src proto.Message) {
xxx_messageInfo_DatabaseInfo.Merge(m, src)
}
func (m *DatabaseInfo) XXX_Size() int {
return xxx_messageInfo_DatabaseInfo.Size(m)
}
func (m *DatabaseInfo) XXX_DiscardUnknown() {
xxx_messageInfo_DatabaseInfo.DiscardUnknown(m)
}
var xxx_messageInfo_DatabaseInfo proto.InternalMessageInfo
func (m *DatabaseInfo) GetTenantId() string {
if m != nil {
return m.TenantId
}
return ""
}
func (m *DatabaseInfo) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func (m *DatabaseInfo) GetId() int64 {
if m != nil {
return m.Id
}
return 0
}
func (m *DatabaseInfo) GetState() DatabaseState {
if m != nil {
return m.State
}
return DatabaseState_DatabaseUnknown
}
func (m *DatabaseInfo) GetCreatedTime() uint64 {
if m != nil {
return m.CreatedTime
}
return 0
}
type SegmentIndexInfo struct { type SegmentIndexInfo struct {
CollectionID int64 `protobuf:"varint,1,opt,name=collectionID,proto3" json:"collectionID,omitempty"` CollectionID int64 `protobuf:"varint,1,opt,name=collectionID,proto3" json:"collectionID,omitempty"`
PartitionID int64 `protobuf:"varint,2,opt,name=partitionID,proto3" json:"partitionID,omitempty"` PartitionID int64 `protobuf:"varint,2,opt,name=partitionID,proto3" json:"partitionID,omitempty"`
@ -532,7 +653,7 @@ func (m *SegmentIndexInfo) Reset() { *m = SegmentIndexInfo{} }
func (m *SegmentIndexInfo) String() string { return proto.CompactTextString(m) } func (m *SegmentIndexInfo) String() string { return proto.CompactTextString(m) }
func (*SegmentIndexInfo) ProtoMessage() {} func (*SegmentIndexInfo) ProtoMessage() {}
func (*SegmentIndexInfo) Descriptor() ([]byte, []int) { func (*SegmentIndexInfo) Descriptor() ([]byte, []int) {
return fileDescriptor_975d306d62b73e88, []int{5} return fileDescriptor_975d306d62b73e88, []int{6}
} }
func (m *SegmentIndexInfo) XXX_Unmarshal(b []byte) error { func (m *SegmentIndexInfo) XXX_Unmarshal(b []byte) error {
@ -626,7 +747,7 @@ func (m *CollectionMeta) Reset() { *m = CollectionMeta{} }
func (m *CollectionMeta) String() string { return proto.CompactTextString(m) } func (m *CollectionMeta) String() string { return proto.CompactTextString(m) }
func (*CollectionMeta) ProtoMessage() {} func (*CollectionMeta) ProtoMessage() {}
func (*CollectionMeta) Descriptor() ([]byte, []int) { func (*CollectionMeta) Descriptor() ([]byte, []int) {
return fileDescriptor_975d306d62b73e88, []int{6} return fileDescriptor_975d306d62b73e88, []int{7}
} }
func (m *CollectionMeta) XXX_Unmarshal(b []byte) error { func (m *CollectionMeta) XXX_Unmarshal(b []byte) error {
@ -706,7 +827,7 @@ func (m *CredentialInfo) Reset() { *m = CredentialInfo{} }
func (m *CredentialInfo) String() string { return proto.CompactTextString(m) } func (m *CredentialInfo) String() string { return proto.CompactTextString(m) }
func (*CredentialInfo) ProtoMessage() {} func (*CredentialInfo) ProtoMessage() {}
func (*CredentialInfo) Descriptor() ([]byte, []int) { func (*CredentialInfo) Descriptor() ([]byte, []int) {
return fileDescriptor_975d306d62b73e88, []int{7} return fileDescriptor_975d306d62b73e88, []int{8}
} }
func (m *CredentialInfo) XXX_Unmarshal(b []byte) error { func (m *CredentialInfo) XXX_Unmarshal(b []byte) error {
@ -763,6 +884,7 @@ func (m *CredentialInfo) GetSha256Password() string {
} }
func init() { func init() {
proto.RegisterEnum("milvus.proto.etcd.DatabaseState", DatabaseState_name, DatabaseState_value)
proto.RegisterEnum("milvus.proto.etcd.CollectionState", CollectionState_name, CollectionState_value) proto.RegisterEnum("milvus.proto.etcd.CollectionState", CollectionState_name, CollectionState_value)
proto.RegisterEnum("milvus.proto.etcd.PartitionState", PartitionState_name, PartitionState_value) proto.RegisterEnum("milvus.proto.etcd.PartitionState", PartitionState_name, PartitionState_value)
proto.RegisterEnum("milvus.proto.etcd.AliasState", AliasState_name, AliasState_value) proto.RegisterEnum("milvus.proto.etcd.AliasState", AliasState_name, AliasState_value)
@ -771,6 +893,7 @@ func init() {
proto.RegisterType((*CollectionInfo)(nil), "milvus.proto.etcd.CollectionInfo") proto.RegisterType((*CollectionInfo)(nil), "milvus.proto.etcd.CollectionInfo")
proto.RegisterType((*PartitionInfo)(nil), "milvus.proto.etcd.PartitionInfo") proto.RegisterType((*PartitionInfo)(nil), "milvus.proto.etcd.PartitionInfo")
proto.RegisterType((*AliasInfo)(nil), "milvus.proto.etcd.AliasInfo") proto.RegisterType((*AliasInfo)(nil), "milvus.proto.etcd.AliasInfo")
proto.RegisterType((*DatabaseInfo)(nil), "milvus.proto.etcd.DatabaseInfo")
proto.RegisterType((*SegmentIndexInfo)(nil), "milvus.proto.etcd.SegmentIndexInfo") proto.RegisterType((*SegmentIndexInfo)(nil), "milvus.proto.etcd.SegmentIndexInfo")
proto.RegisterType((*CollectionMeta)(nil), "milvus.proto.etcd.CollectionMeta") proto.RegisterType((*CollectionMeta)(nil), "milvus.proto.etcd.CollectionMeta")
proto.RegisterType((*CredentialInfo)(nil), "milvus.proto.etcd.CredentialInfo") proto.RegisterType((*CredentialInfo)(nil), "milvus.proto.etcd.CredentialInfo")
@ -779,69 +902,77 @@ func init() {
func init() { proto.RegisterFile("etcd_meta.proto", fileDescriptor_975d306d62b73e88) } func init() { proto.RegisterFile("etcd_meta.proto", fileDescriptor_975d306d62b73e88) }
var fileDescriptor_975d306d62b73e88 = []byte{ var fileDescriptor_975d306d62b73e88 = []byte{
// 1020 bytes of a gzipped FileDescriptorProto // 1139 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xbc, 0x55, 0xcb, 0x8e, 0xdc, 0x44, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xbc, 0x56, 0xcd, 0x72, 0xe3, 0x44,
0x14, 0x8d, 0xdb, 0xfd, 0xf2, 0xed, 0xc7, 0x74, 0x17, 0xc9, 0xc8, 0x19, 0x12, 0x70, 0x1a, 0x02, 0x10, 0x5e, 0x59, 0xb6, 0x63, 0xb5, 0x7f, 0x33, 0xfb, 0x53, 0xda, 0xec, 0x2e, 0x68, 0x0d, 0x0b,
0x56, 0xa4, 0xcc, 0x88, 0x19, 0x5e, 0x1b, 0x10, 0x61, 0xac, 0x48, 0x2d, 0x20, 0x6a, 0x79, 0x46, 0xae, 0x54, 0x6d, 0x52, 0x24, 0xb0, 0x70, 0x81, 0x62, 0x89, 0x6b, 0xab, 0x5c, 0xc0, 0x96, 0x4b,
0x59, 0xb0, 0xb1, 0xaa, 0xed, 0x9a, 0xee, 0x42, 0x7e, 0xc9, 0x55, 0x3d, 0x30, 0x7f, 0xc0, 0x9f, 0x09, 0x7b, 0xe0, 0xa2, 0x1a, 0x4b, 0x93, 0x78, 0x40, 0x1a, 0xa9, 0x34, 0xe3, 0x2c, 0xe1, 0x09,
0xf0, 0x09, 0x7c, 0x01, 0x5f, 0xc3, 0x9a, 0x15, 0x1b, 0x54, 0x55, 0x7e, 0x77, 0x0f, 0x62, 0xc5, 0x38, 0xf2, 0x1c, 0xbc, 0x00, 0x17, 0xae, 0x3c, 0x0d, 0x67, 0xee, 0xd4, 0xcc, 0xe8, 0xdf, 0x0e,
0xce, 0xf7, 0x54, 0xdd, 0x5b, 0xf7, 0xdc, 0xc7, 0x31, 0x1c, 0x11, 0xee, 0x07, 0x5e, 0x44, 0x38, 0xc5, 0x89, 0x9b, 0xfb, 0xd3, 0x74, 0x4f, 0x7f, 0xdd, 0xdf, 0x74, 0x1b, 0xc6, 0x44, 0xf8, 0x81,
0x3e, 0x4d, 0xb3, 0x84, 0x27, 0x68, 0x1e, 0xd1, 0xf0, 0x76, 0xc7, 0x94, 0x75, 0x2a, 0x4e, 0x4f, 0x17, 0x11, 0x81, 0x8f, 0x92, 0x34, 0x16, 0x31, 0xda, 0x8f, 0x68, 0x78, 0xbd, 0xe1, 0xda, 0x3a,
0xc6, 0x7e, 0x12, 0x45, 0x49, 0xac, 0xa0, 0x93, 0x31, 0xf3, 0xb7, 0x24, 0xca, 0xaf, 0x2f, 0xfe, 0x92, 0x5f, 0x0f, 0x06, 0x7e, 0x1c, 0x45, 0x31, 0xd3, 0xd0, 0xc1, 0x80, 0xfb, 0x6b, 0x12, 0x65,
0xd0, 0xc0, 0x58, 0xc6, 0x01, 0xf9, 0x65, 0x19, 0xdf, 0x24, 0xe8, 0x29, 0x00, 0x15, 0x86, 0x17, 0xc7, 0xa7, 0x7f, 0x1a, 0x60, 0x2d, 0x58, 0x40, 0x7e, 0x5a, 0xb0, 0xcb, 0x18, 0x3d, 0x01, 0xa0,
0xe3, 0x88, 0x98, 0x9a, 0xa5, 0xd9, 0x86, 0x6b, 0x48, 0xe4, 0x0d, 0x8e, 0x08, 0x32, 0x61, 0x20, 0xd2, 0xf0, 0x18, 0x8e, 0x88, 0x6d, 0x38, 0xc6, 0xcc, 0x72, 0x2d, 0x85, 0xbc, 0xc6, 0x11, 0x41,
0x8d, 0xa5, 0x63, 0x76, 0x2c, 0xcd, 0xd6, 0xdd, 0xc2, 0x44, 0x0e, 0x8c, 0x95, 0x63, 0x8a, 0x33, 0x36, 0xec, 0x29, 0x63, 0x31, 0xb7, 0x5b, 0x8e, 0x31, 0x33, 0xdd, 0xdc, 0x44, 0x73, 0x18, 0x68,
0x1c, 0x31, 0x53, 0xb7, 0x74, 0x7b, 0x74, 0xfe, 0xec, 0xb4, 0x91, 0x4c, 0x9e, 0xc6, 0x77, 0xe4, 0xc7, 0x04, 0xa7, 0x38, 0xe2, 0xb6, 0xe9, 0x98, 0xb3, 0xfe, 0xc9, 0xd3, 0xa3, 0x5a, 0x32, 0x59,
0xee, 0x2d, 0x0e, 0x77, 0x64, 0x85, 0x69, 0xe6, 0x8e, 0xa4, 0xdb, 0x4a, 0x7a, 0x89, 0xf8, 0x01, 0x1a, 0x5f, 0x93, 0x9b, 0x37, 0x38, 0xdc, 0x90, 0x25, 0xa6, 0xa9, 0xdb, 0x57, 0x6e, 0x4b, 0xe5,
0x09, 0x09, 0x27, 0x81, 0xd9, 0xb5, 0x34, 0x7b, 0xe8, 0x16, 0x26, 0x7a, 0x1f, 0x46, 0x7e, 0x46, 0x25, 0xe3, 0x07, 0x24, 0x24, 0x82, 0x04, 0x76, 0xdb, 0x31, 0x66, 0x3d, 0x37, 0x37, 0xd1, 0xbb,
0x30, 0x27, 0x1e, 0xa7, 0x11, 0x31, 0x7b, 0x96, 0x66, 0x77, 0x5d, 0x50, 0xd0, 0x35, 0x8d, 0xc8, 0xd0, 0xf7, 0x53, 0x82, 0x05, 0xf1, 0x04, 0x8d, 0x88, 0xdd, 0x71, 0x8c, 0x59, 0xdb, 0x05, 0x0d,
0xc2, 0x81, 0xe9, 0x6b, 0x4a, 0xc2, 0xa0, 0xe2, 0x62, 0xc2, 0xe0, 0x86, 0x86, 0x24, 0x58, 0x3a, 0x5d, 0xd0, 0x88, 0x4c, 0xe7, 0x30, 0x7a, 0x45, 0x49, 0x18, 0x94, 0x5c, 0x6c, 0xd8, 0xbb, 0xa4,
0x92, 0x88, 0xee, 0x16, 0xe6, 0xfd, 0x34, 0x16, 0x7f, 0xf7, 0x60, 0x7a, 0x99, 0x84, 0x21, 0xf1, 0x21, 0x09, 0x16, 0x73, 0x45, 0xc4, 0x74, 0x73, 0xf3, 0x76, 0x1a, 0xd3, 0x5f, 0xbb, 0x30, 0x3a,
0x39, 0x4d, 0x62, 0x19, 0x66, 0x0a, 0x9d, 0x32, 0x42, 0x67, 0xe9, 0xa0, 0xaf, 0xa0, 0xaf, 0x0a, 0x8b, 0xc3, 0x90, 0xf8, 0x82, 0xc6, 0x4c, 0x85, 0x19, 0x41, 0xab, 0x88, 0xd0, 0x5a, 0xcc, 0xd1,
0x28, 0x7d, 0x47, 0xe7, 0xcf, 0x9b, 0x1c, 0xf3, 0xe2, 0x56, 0x41, 0xae, 0x24, 0xe0, 0xe6, 0x4e, 0xe7, 0xd0, 0xd5, 0x05, 0x54, 0xbe, 0xfd, 0x93, 0x67, 0x75, 0x8e, 0x59, 0x71, 0xcb, 0x20, 0xe7,
0x6d, 0x22, 0x7a, 0x9b, 0x08, 0x5a, 0xc0, 0x38, 0xc5, 0x19, 0xa7, 0x32, 0x01, 0x87, 0x99, 0x5d, 0x0a, 0x70, 0x33, 0xa7, 0x26, 0x11, 0xb3, 0x49, 0x04, 0x4d, 0x61, 0x90, 0xe0, 0x54, 0x50, 0x95,
0x4b, 0xb7, 0x75, 0xb7, 0x81, 0xa1, 0x8f, 0x60, 0x5a, 0xda, 0xa2, 0x31, 0xcc, 0xec, 0x59, 0xba, 0xc0, 0x9c, 0xdb, 0x6d, 0xc7, 0x9c, 0x99, 0x6e, 0x0d, 0x43, 0x1f, 0xc0, 0xa8, 0xb0, 0x65, 0x63,
0x6d, 0xb8, 0x2d, 0x14, 0xbd, 0x86, 0xc9, 0x8d, 0x28, 0x8a, 0x27, 0xf9, 0x11, 0x66, 0xf6, 0x0f, 0xb8, 0xdd, 0x71, 0xcc, 0x99, 0xe5, 0x36, 0x50, 0xf4, 0x0a, 0x86, 0x97, 0xb2, 0x28, 0x9e, 0xe2,
0xb5, 0x45, 0xcc, 0xc8, 0x69, 0xb3, 0x78, 0xee, 0xf8, 0xa6, 0xb4, 0x09, 0x43, 0xe7, 0xf0, 0xe8, 0x47, 0xb8, 0xdd, 0xdd, 0xd5, 0x16, 0xa9, 0x91, 0xa3, 0x7a, 0xf1, 0xdc, 0xc1, 0x65, 0x61, 0x13,
0x96, 0x66, 0x7c, 0x87, 0x43, 0xcf, 0xdf, 0xe2, 0x38, 0x26, 0xa1, 0x1c, 0x10, 0x66, 0x0e, 0xe4, 0x8e, 0x4e, 0xe0, 0xfe, 0x35, 0x4d, 0xc5, 0x06, 0x87, 0x9e, 0xbf, 0xc6, 0x8c, 0x91, 0x50, 0x09,
0xb3, 0xef, 0xe4, 0x87, 0x97, 0xea, 0x4c, 0xbd, 0xfd, 0x29, 0x1c, 0xa7, 0xdb, 0x3b, 0x46, 0xfd, 0x84, 0xdb, 0x7b, 0xea, 0xda, 0xbb, 0xd9, 0xc7, 0x33, 0xfd, 0x4d, 0xdf, 0xfd, 0x31, 0x3c, 0x48,
0x3d, 0xa7, 0xa1, 0x74, 0x7a, 0x58, 0x9c, 0x36, 0xbc, 0xbe, 0x81, 0x27, 0x25, 0x07, 0x4f, 0x55, 0xd6, 0x37, 0x9c, 0xfa, 0x5b, 0x4e, 0x3d, 0xe5, 0x74, 0x2f, 0xff, 0x5a, 0xf3, 0xfa, 0x12, 0x1e,
0x25, 0x90, 0x95, 0x62, 0x1c, 0x47, 0x29, 0x33, 0x0d, 0x4b, 0xb7, 0xbb, 0xee, 0x49, 0x79, 0xe7, 0x17, 0x1c, 0x3c, 0x5d, 0x95, 0x40, 0x55, 0x8a, 0x0b, 0x1c, 0x25, 0xdc, 0xb6, 0x1c, 0x73, 0xd6,
0x52, 0x5d, 0xb9, 0x2e, 0x6f, 0x88, 0x11, 0x66, 0x5b, 0x9c, 0x05, 0xcc, 0x8b, 0x77, 0x91, 0x09, 0x76, 0x0f, 0x8a, 0x33, 0x67, 0xfa, 0xc8, 0x45, 0x71, 0x42, 0x4a, 0x98, 0xaf, 0x71, 0x1a, 0x70,
0x96, 0x66, 0xf7, 0x5c, 0x43, 0x21, 0x6f, 0x76, 0x11, 0x5a, 0xc2, 0x11, 0xe3, 0x38, 0xe3, 0x5e, 0x8f, 0x6d, 0x22, 0x1b, 0x1c, 0x63, 0xd6, 0x71, 0x2d, 0x8d, 0xbc, 0xde, 0x44, 0x68, 0x01, 0x63,
0x9a, 0x30, 0x19, 0x81, 0x99, 0x23, 0x59, 0x14, 0xeb, 0xbe, 0x59, 0x75, 0x30, 0xc7, 0x72, 0x54, 0x2e, 0x70, 0x2a, 0xbc, 0x24, 0xe6, 0x2a, 0x02, 0xb7, 0xfb, 0xaa, 0x28, 0xce, 0x6d, 0x5a, 0x9d,
0xa7, 0xd2, 0x71, 0x55, 0xf8, 0x21, 0x17, 0xe6, 0x7e, 0x12, 0x33, 0xca, 0x38, 0x89, 0xfd, 0x3b, 0x63, 0x81, 0x95, 0x54, 0x47, 0xca, 0x71, 0x99, 0xfb, 0x21, 0x17, 0xf6, 0xfd, 0x98, 0x71, 0xca,
0x2f, 0x24, 0xb7, 0x24, 0x34, 0xc7, 0x96, 0x66, 0x4f, 0xdb, 0x43, 0x91, 0x07, 0xbb, 0xac, 0x6e, 0x05, 0x61, 0xfe, 0x8d, 0x17, 0x92, 0x6b, 0x12, 0xda, 0x03, 0xc7, 0x98, 0x8d, 0x9a, 0xa2, 0xc8,
0x7f, 0x2f, 0x2e, 0xbb, 0x33, 0xbf, 0x85, 0xa0, 0x2f, 0xa1, 0xc7, 0x38, 0xe6, 0xc4, 0x9c, 0xc8, 0x82, 0x9d, 0x95, 0xa7, 0xbf, 0x91, 0x87, 0xdd, 0x89, 0xdf, 0x40, 0xd0, 0x67, 0xd0, 0xe1, 0x02,
0x38, 0x8b, 0x03, 0x9d, 0xaa, 0x8d, 0x96, 0xb8, 0xe9, 0x2a, 0x07, 0xf4, 0x0a, 0x20, 0xcd, 0x92, 0x0b, 0x62, 0x0f, 0x55, 0x9c, 0xe9, 0x8e, 0x4e, 0x55, 0xa4, 0x25, 0x4f, 0xba, 0xda, 0x01, 0xbd,
0x94, 0x64, 0x9c, 0x12, 0x66, 0x4e, 0xff, 0xeb, 0xfe, 0xd5, 0x9c, 0x16, 0x7f, 0x69, 0x30, 0x59, 0x04, 0x48, 0xd2, 0x38, 0x21, 0xa9, 0xa0, 0x84, 0xdb, 0xa3, 0xff, 0xfa, 0xfe, 0x2a, 0x4e, 0xe8,
0x95, 0x73, 0x26, 0x86, 0xdf, 0x82, 0x51, 0x6d, 0xf0, 0xf2, 0x2d, 0xa8, 0x43, 0xe8, 0x43, 0x98, 0x2e, 0x74, 0x82, 0x95, 0x47, 0x03, 0x7b, 0xac, 0xd4, 0xde, 0x0e, 0x56, 0x8b, 0x60, 0xfa, 0xb7,
0x34, 0x86, 0x4e, 0x6e, 0x85, 0xe1, 0x36, 0x41, 0xf4, 0x35, 0xbc, 0xfb, 0x2f, 0x6d, 0xcd, 0xb7, 0x01, 0xc3, 0x65, 0x21, 0x3e, 0xf9, 0x22, 0x1c, 0xe8, 0x57, 0xd4, 0x98, 0x3d, 0x8d, 0x2a, 0x84,
0xe0, 0xf1, 0xbd, 0x5d, 0x45, 0x1f, 0xc0, 0xc4, 0x2f, 0x69, 0x7b, 0x54, 0xc9, 0x83, 0xee, 0x8e, 0xde, 0x87, 0x61, 0x4d, 0x89, 0xea, 0xa9, 0x58, 0x6e, 0x1d, 0x44, 0x5f, 0xc0, 0xa3, 0x7f, 0xe9,
0x2b, 0x70, 0x19, 0xa0, 0x2f, 0x8a, 0xda, 0xf5, 0x64, 0xed, 0x0e, 0x4d, 0x79, 0xc9, 0xae, 0x5e, 0x75, 0xf6, 0x34, 0x1e, 0xde, 0xda, 0x6a, 0xf4, 0x1e, 0x0c, 0xfd, 0xa2, 0x16, 0x32, 0xed, 0xb6,
0xba, 0xc5, 0x6f, 0x1a, 0x18, 0xaf, 0x42, 0x8a, 0x59, 0xa1, 0x81, 0x58, 0x18, 0x0d, 0x0d, 0x94, 0xca, 0x64, 0x50, 0x82, 0x8b, 0x00, 0x7d, 0x9a, 0x17, 0xb4, 0xa3, 0x0a, 0xba, 0x4b, 0xfa, 0x05,
0x88, 0xa4, 0xb2, 0x97, 0x4a, 0xe7, 0x40, 0x2a, 0xcf, 0x60, 0x5c, 0x67, 0x99, 0x13, 0xcc, 0x37, 0xbb, 0x6a, 0x3d, 0xa7, 0x7f, 0x18, 0x60, 0xbd, 0x0c, 0x29, 0xe6, 0xf9, 0x60, 0xc4, 0xd2, 0xa8,
0x5f, 0xf2, 0x42, 0x17, 0x45, 0xb6, 0x5d, 0x99, 0xed, 0xd3, 0x03, 0xd9, 0xca, 0x9c, 0x1a, 0x99, 0x0d, 0x46, 0x85, 0x28, 0x2a, 0x5b, 0xa9, 0xb4, 0x76, 0xa4, 0xf2, 0x14, 0x06, 0x55, 0x96, 0x19,
0xfe, 0xda, 0x81, 0xd9, 0x15, 0xd9, 0x44, 0x24, 0xe6, 0x95, 0xd0, 0x2d, 0xa0, 0xfe, 0x78, 0xd1, 0xc1, 0x6c, 0x1c, 0x28, 0x5e, 0xe8, 0x34, 0xcf, 0xb6, 0xad, 0xb2, 0x7d, 0xb2, 0x23, 0x5b, 0x95,
0xa5, 0x06, 0xd6, 0x6e, 0x64, 0x67, 0xbf, 0x91, 0x4f, 0xc0, 0x60, 0x79, 0x64, 0x47, 0xe6, 0xab, 0x53, 0xad, 0xf3, 0x45, 0xdb, 0x3a, 0x95, 0xb6, 0xfd, 0x66, 0xc0, 0x40, 0x2a, 0x77, 0x85, 0x39,
0xbb, 0x15, 0xa0, 0xc4, 0x54, 0x28, 0x82, 0x93, 0x97, 0xbe, 0x30, 0xeb, 0x62, 0xda, 0x6b, 0xfe, 0x51, 0x0c, 0x1e, 0x81, 0x25, 0x08, 0xc3, 0x4c, 0xc8, 0x93, 0x9a, 0x40, 0x4f, 0x03, 0x8b, 0x00,
0x13, 0x4c, 0x18, 0xac, 0x77, 0x54, 0xfa, 0xf4, 0xd5, 0x49, 0x6e, 0x8a, 0xf2, 0x90, 0x18, 0xaf, 0x21, 0x68, 0xb3, 0xb2, 0x4f, 0xea, 0xb7, 0x1c, 0x7c, 0x34, 0x50, 0x49, 0x9a, 0x6e, 0x8b, 0x06,
0x43, 0xa2, 0x84, 0xc9, 0x1c, 0x48, 0xb1, 0x1f, 0x29, 0x4c, 0x12, 0x6b, 0xeb, 0xe4, 0x70, 0x4f, 0xe8, 0x45, 0x3d, 0x37, 0x67, 0x47, 0x6e, 0xf9, 0x85, 0xb5, 0xf4, 0x9a, 0xb4, 0x3b, 0x5b, 0xb4,
0xf0, 0xff, 0xd4, 0xea, 0x52, 0xfd, 0x03, 0xe1, 0xf8, 0x7f, 0x97, 0xea, 0xf7, 0x00, 0xca, 0x0a, 0xa7, 0xbf, 0xb4, 0x60, 0x72, 0x4e, 0xae, 0x22, 0xc2, 0x44, 0x39, 0xbf, 0xa7, 0x50, 0x2d, 0x5f,
0x15, 0x42, 0x5d, 0x43, 0xd0, 0xf3, 0x9a, 0x4c, 0x7b, 0x1c, 0x6f, 0x0a, 0x99, 0xae, 0x96, 0xe3, 0xae, 0xb3, 0x1a, 0xd6, 0x94, 0x62, 0x6b, 0x5b, 0x8a, 0x8f, 0xc1, 0xe2, 0x59, 0xe4, 0x79, 0x46,
0x1a, 0x6f, 0xd8, 0x9e, 0xe2, 0xf7, 0xf7, 0x15, 0x7f, 0xf1, 0xbb, 0x60, 0x9b, 0x91, 0x80, 0xc4, 0xa6, 0x04, 0xf4, 0x8e, 0x90, 0x83, 0x6e, 0x9e, 0x89, 0x27, 0x37, 0xab, 0x3b, 0xa2, 0x53, 0x5f,
0x9c, 0xe2, 0x50, 0xb6, 0xfd, 0x04, 0x86, 0x3b, 0x46, 0xb2, 0xda, 0x94, 0x96, 0x36, 0x7a, 0x09, 0x75, 0x36, 0xec, 0xad, 0x36, 0x54, 0xf9, 0x74, 0xf5, 0x97, 0xcc, 0x94, 0x4c, 0x09, 0xc3, 0xab,
0x88, 0xc4, 0x7e, 0x76, 0x97, 0x8a, 0x09, 0x4c, 0x31, 0x63, 0x3f, 0x27, 0x59, 0x90, 0xaf, 0xe6, 0x90, 0xe8, 0x79, 0x6b, 0xef, 0xa9, 0x1d, 0xd6, 0xd7, 0x98, 0x22, 0xd6, 0x1c, 0xff, 0xbd, 0xad,
0xbc, 0x3c, 0x59, 0xe5, 0x07, 0xe8, 0x18, 0xfa, 0x9c, 0xc4, 0x38, 0xe6, 0x92, 0xa4, 0xe1, 0xe6, 0x3d, 0xf6, 0x97, 0x51, 0xdd, 0x40, 0xdf, 0x12, 0x81, 0xff, 0xf7, 0x0d, 0xf4, 0x0e, 0x40, 0x51,
0x16, 0x7a, 0x0c, 0x43, 0xca, 0x3c, 0xb6, 0x4b, 0x49, 0x56, 0xfc, 0x90, 0x29, 0xbb, 0x12, 0x26, 0xa1, 0x7c, 0xff, 0x54, 0x10, 0xf4, 0xac, 0xb2, 0x7d, 0x3c, 0x81, 0xaf, 0xf2, 0xed, 0x53, 0x3e,
0xfa, 0x18, 0x8e, 0xd8, 0x16, 0x9f, 0x7f, 0xf6, 0x79, 0x15, 0xbe, 0x27, 0x7d, 0xa7, 0x0a, 0x2e, 0xef, 0x0b, 0x7c, 0xc5, 0xb7, 0x16, 0x59, 0x77, 0x7b, 0x91, 0x4d, 0x7f, 0x97, 0x6c, 0x53, 0x12,
0x62, 0xbf, 0x48, 0xe0, 0xa8, 0xa5, 0x58, 0xe8, 0x11, 0xcc, 0x2b, 0x28, 0xdf, 0xf5, 0xd9, 0x03, 0x10, 0x26, 0x28, 0x0e, 0x55, 0xdb, 0x0f, 0xa0, 0xb7, 0xe1, 0x24, 0xad, 0xbc, 0xb3, 0xc2, 0x46,
0x74, 0x0c, 0xa8, 0x05, 0xd3, 0x78, 0x33, 0xd3, 0x9a, 0xb8, 0x93, 0x25, 0x69, 0x2a, 0xf0, 0x4e, 0xcf, 0x01, 0x11, 0xe6, 0xa7, 0x37, 0x89, 0x14, 0x53, 0x82, 0x39, 0x7f, 0x1b, 0xa7, 0x41, 0x26,
0x33, 0x8c, 0xc4, 0x49, 0x30, 0xd3, 0x5f, 0xfc, 0x04, 0xd3, 0xe6, 0x9a, 0xa3, 0x87, 0x30, 0x5b, 0xda, 0xfd, 0xe2, 0xcb, 0x32, 0xfb, 0x80, 0x1e, 0x40, 0x57, 0x2b, 0x5c, 0x91, 0xb4, 0xdc, 0xcc,
0xb5, 0xa4, 0x65, 0xf6, 0x40, 0xb8, 0x37, 0x51, 0xf5, 0x5a, 0x1d, 0xae, 0x3d, 0x56, 0x8f, 0x51, 0x42, 0x0f, 0xa1, 0x47, 0xb9, 0xc7, 0x37, 0x09, 0x49, 0xf3, 0xff, 0x19, 0x94, 0x9f, 0x4b, 0x13,
0xbd, 0xf5, 0x16, 0xa0, 0x5a, 0x52, 0x34, 0x83, 0xb1, 0xb4, 0xaa, 0x37, 0xe6, 0x30, 0xa9, 0x10, 0x7d, 0x08, 0x63, 0xbe, 0xc6, 0x27, 0x9f, 0xbc, 0x28, 0xc3, 0x77, 0x94, 0xef, 0x48, 0xc3, 0x79,
0x15, 0xbf, 0x80, 0x6a, 0xb1, 0x0b, 0xbf, 0x32, 0xee, 0xb7, 0x17, 0x3f, 0x7e, 0xb2, 0xa1, 0x7c, 0xec, 0xc3, 0x9f, 0x61, 0x58, 0x53, 0x3b, 0xba, 0x0b, 0xe3, 0x1c, 0xf8, 0x8e, 0xfd, 0xc8, 0xe2,
0xbb, 0x5b, 0x0b, 0xcd, 0x3e, 0x53, 0x53, 0xfb, 0x92, 0x26, 0xf9, 0xd7, 0x19, 0x8d, 0xb9, 0x68, 0xb7, 0x6c, 0x72, 0xa7, 0x0a, 0x66, 0xe3, 0x6b, 0x62, 0xa0, 0x7b, 0x30, 0xa9, 0x81, 0x94, 0x5d,
0x74, 0x78, 0x26, 0x07, 0xf9, 0x4c, 0x88, 0x45, 0xba, 0x5e, 0xf7, 0xa5, 0x75, 0xf1, 0x4f, 0x00, 0x4d, 0x5a, 0x55, 0x74, 0x9e, 0xc6, 0x49, 0x22, 0x51, 0xb3, 0x1a, 0x40, 0xa1, 0x24, 0x98, 0xb4,
0x00, 0x00, 0xff, 0xff, 0xf9, 0x76, 0x1c, 0x4f, 0x13, 0x0a, 0x00, 0x00, 0x0f, 0x63, 0x18, 0x37, 0x96, 0x00, 0xba, 0x0f, 0xfb, 0x25, 0x94, 0x5f, 0x75, 0x07, 0x3d, 0x00,
0xd4, 0x80, 0x65, 0x58, 0xa3, 0x8e, 0x17, 0xd7, 0xb5, 0xea, 0x61, 0xf2, 0x0b, 0xcd, 0xc3, 0x1f,
0x60, 0x54, 0x1f, 0x92, 0x32, 0xdb, 0x65, 0x63, 0x30, 0x4f, 0xee, 0x48, 0xf7, 0x3a, 0xaa, 0x6f,
0xab, 0xc2, 0x95, 0xcb, 0xaa, 0x31, 0xca, 0xbb, 0xde, 0x00, 0x94, 0x23, 0x0e, 0x4d, 0x60, 0xa0,
0xac, 0xf2, 0x8e, 0x7d, 0x18, 0x96, 0x88, 0x8e, 0x9f, 0x43, 0x95, 0xd8, 0xb9, 0x5f, 0x11, 0xf7,
0xab, 0xd3, 0xef, 0x3f, 0xba, 0xa2, 0x62, 0xbd, 0x59, 0xc9, 0x35, 0x78, 0xac, 0x5f, 0xcc, 0x73,
0x1a, 0x67, 0xbf, 0x8e, 0x29, 0x13, 0x52, 0x64, 0xe1, 0xb1, 0x7a, 0x44, 0xc7, 0x72, 0x9c, 0x25,
0xab, 0x55, 0x57, 0x59, 0xa7, 0xff, 0x04, 0x00, 0x00, 0xff, 0xff, 0xee, 0x8d, 0xf3, 0xf4, 0x66,
0x0b, 0x00, 0x00,
} }

View File

@ -141,6 +141,10 @@ service RootCoord {
rpc CheckHealth(milvus.CheckHealthRequest) returns (milvus.CheckHealthResponse) {} rpc CheckHealth(milvus.CheckHealthRequest) returns (milvus.CheckHealthResponse) {}
rpc RenameCollection(milvus.RenameCollectionRequest) returns (common.Status) {} rpc RenameCollection(milvus.RenameCollectionRequest) returns (common.Status) {}
rpc CreateDatabase(milvus.CreateDatabaseRequest) returns (common.Status) {}
rpc DropDatabase(milvus.DropDatabaseRequest) returns (common.Status) {}
rpc ListDatabases(milvus.ListDatabasesRequest) returns (milvus.ListDatabasesResponse) {}
} }
message AllocTimestampRequest { message AllocTimestampRequest {

View File

@ -674,106 +674,108 @@ func init() {
func init() { proto.RegisterFile("root_coord.proto", fileDescriptor_4513485a144f6b06) } func init() { proto.RegisterFile("root_coord.proto", fileDescriptor_4513485a144f6b06) }
var fileDescriptor_4513485a144f6b06 = []byte{ var fileDescriptor_4513485a144f6b06 = []byte{
// 1570 bytes of a gzipped FileDescriptorProto // 1616 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xbc, 0x58, 0xeb, 0x92, 0xd3, 0x36, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xbc, 0x58, 0x6d, 0x93, 0xda, 0xb6,
0x14, 0x26, 0x09, 0x7b, 0x3b, 0xc9, 0x26, 0x8b, 0x86, 0x4b, 0x1a, 0x68, 0x1b, 0x02, 0x2d, 0xe1, 0x16, 0x0e, 0x90, 0xdd, 0x85, 0x03, 0x0b, 0x1b, 0x4d, 0x5e, 0xb8, 0x24, 0xf7, 0x5e, 0x42, 0xf6,
0x96, 0xa5, 0xcb, 0x0c, 0xa5, 0xfc, 0x63, 0x13, 0x66, 0xc9, 0xb4, 0x3b, 0x6c, 0x1d, 0xe8, 0xd0, 0xde, 0x90, 0x37, 0x36, 0xdd, 0xcc, 0xa4, 0x69, 0xbe, 0x65, 0x21, 0xb3, 0x61, 0xda, 0x9d, 0x6c,
0xcb, 0x4e, 0xaa, 0xd8, 0x87, 0xc4, 0xb3, 0x8e, 0x15, 0x2c, 0x65, 0x2f, 0xd3, 0x1f, 0x9d, 0xce, 0x4d, 0xd2, 0xa6, 0x2f, 0x3b, 0x54, 0xd8, 0x0a, 0x78, 0xd6, 0x58, 0xc4, 0x12, 0xfb, 0x32, 0xfd,
0xf4, 0x7f, 0xdf, 0xa9, 0x7d, 0x87, 0xbe, 0x40, 0x5f, 0xa4, 0x23, 0x5f, 0x14, 0xdb, 0xb1, 0xb3, 0xd0, 0xe9, 0x4c, 0xbf, 0xf7, 0x3f, 0xb5, 0x3f, 0xa5, 0xff, 0xa1, 0x9f, 0x3b, 0xb2, 0x6c, 0x61,
0x5e, 0xa0, 0xfd, 0x67, 0x49, 0x9f, 0xbe, 0xef, 0xe8, 0x1c, 0x1d, 0x49, 0xc7, 0xb0, 0xe1, 0x30, 0x1b, 0x1b, 0x4c, 0x92, 0xf6, 0x9b, 0x25, 0x3d, 0x7a, 0x9e, 0xa3, 0x73, 0x74, 0x8e, 0x24, 0xc3,
0x26, 0xfa, 0x3a, 0x63, 0x8e, 0xd1, 0x9a, 0x38, 0x4c, 0x30, 0x72, 0x79, 0x6c, 0x5a, 0x87, 0x53, 0x96, 0x43, 0x29, 0xef, 0xeb, 0x94, 0x3a, 0x46, 0x6b, 0xe2, 0x50, 0x4e, 0xd1, 0xd5, 0xb1, 0x69,
0xee, 0xb5, 0x5a, 0x72, 0xd8, 0x1d, 0xad, 0x95, 0x74, 0x36, 0x1e, 0x33, 0xdb, 0xeb, 0xaf, 0x95, 0x9d, 0x4c, 0x99, 0x6c, 0xb5, 0xc4, 0xb0, 0x3b, 0x5a, 0x2b, 0xe9, 0x74, 0x3c, 0xa6, 0xb6, 0xec,
0xc2, 0xa8, 0x5a, 0xd9, 0xb4, 0x05, 0x3a, 0x36, 0xb5, 0xfc, 0x76, 0x71, 0xe2, 0xb0, 0xe3, 0x13, 0xaf, 0x95, 0x82, 0xa8, 0x5a, 0xd9, 0xb4, 0x39, 0x71, 0x6c, 0x6c, 0x79, 0xed, 0xe2, 0xc4, 0xa1,
0xbf, 0x51, 0x41, 0xa1, 0x1b, 0xfd, 0x31, 0x0a, 0xea, 0x75, 0x34, 0xfa, 0x70, 0xe9, 0xa9, 0x65, 0x67, 0xe7, 0x5e, 0xa3, 0x42, 0xb8, 0x6e, 0xf4, 0xc7, 0x84, 0x63, 0xd9, 0xd1, 0xe8, 0xc3, 0x95,
0x31, 0xfd, 0xa5, 0x39, 0x46, 0x2e, 0xe8, 0x78, 0xa2, 0xe1, 0xdb, 0x29, 0x72, 0x41, 0x1e, 0xc0, 0x67, 0x96, 0x45, 0xf5, 0x57, 0xe6, 0x98, 0x30, 0x8e, 0xc7, 0x13, 0x8d, 0xbc, 0x9b, 0x12, 0xc6,
0xf9, 0x01, 0xe5, 0x58, 0xcd, 0xd5, 0x73, 0xcd, 0xe2, 0xd6, 0xb5, 0x56, 0xc4, 0x12, 0x5f, 0x7e, 0xd1, 0x43, 0xb8, 0x38, 0xc0, 0x8c, 0x54, 0x33, 0xf5, 0x4c, 0xb3, 0xb8, 0x7b, 0xa3, 0x15, 0xb2,
0x97, 0x0f, 0xb7, 0x29, 0x47, 0xcd, 0x45, 0x92, 0x8b, 0xb0, 0xa4, 0xb3, 0xa9, 0x2d, 0xaa, 0x85, 0xc4, 0x93, 0x3f, 0x60, 0xc3, 0x3d, 0xcc, 0x88, 0xe6, 0x22, 0xd1, 0x65, 0x58, 0xd3, 0xe9, 0xd4,
0x7a, 0xae, 0xb9, 0xae, 0x79, 0x8d, 0xc6, 0x6f, 0x39, 0xb8, 0x1c, 0x57, 0xe0, 0x13, 0x66, 0x73, 0xe6, 0xd5, 0x5c, 0x3d, 0xd3, 0xdc, 0xd4, 0x64, 0xa3, 0xf1, 0x73, 0x06, 0xae, 0x46, 0x15, 0xd8,
0x24, 0x0f, 0x61, 0x99, 0x0b, 0x2a, 0xa6, 0xdc, 0x17, 0xb9, 0x9a, 0x28, 0xd2, 0x73, 0x21, 0x9a, 0x84, 0xda, 0x8c, 0xa0, 0x47, 0xb0, 0xce, 0x38, 0xe6, 0x53, 0xe6, 0x89, 0x5c, 0x8f, 0x15, 0xe9,
0x0f, 0x25, 0xd7, 0x60, 0x4d, 0x04, 0x4c, 0xd5, 0x7c, 0x3d, 0xd7, 0x3c, 0xaf, 0xcd, 0x3a, 0x52, 0xb9, 0x10, 0xcd, 0x83, 0xa2, 0x1b, 0x50, 0xe0, 0x3e, 0x53, 0x35, 0x5b, 0xcf, 0x34, 0x2f, 0x6a,
0x6c, 0x78, 0x0d, 0x65, 0xd7, 0x84, 0x6e, 0xe7, 0x03, 0xac, 0x2e, 0x1f, 0x66, 0xb6, 0xa0, 0xa2, 0xb3, 0x8e, 0x04, 0x1b, 0xde, 0x40, 0xd9, 0x35, 0xa1, 0xdb, 0xf9, 0x08, 0xab, 0xcb, 0x06, 0x99,
0x98, 0xdf, 0x67, 0x55, 0x65, 0xc8, 0x77, 0x3b, 0x2e, 0x75, 0x41, 0xcb, 0x77, 0x3b, 0x29, 0xeb, 0x2d, 0xa8, 0x28, 0xe6, 0x0f, 0x59, 0x55, 0x19, 0xb2, 0xdd, 0x8e, 0x4b, 0x9d, 0xd3, 0xb2, 0xdd,
0xf8, 0x33, 0x0f, 0xa5, 0xee, 0x78, 0xc2, 0x1c, 0xa1, 0x21, 0x9f, 0x5a, 0xe2, 0xdd, 0xb4, 0xae, 0x4e, 0xc2, 0x3a, 0x7e, 0xcb, 0x42, 0xa9, 0x3b, 0x9e, 0x50, 0x87, 0x6b, 0x84, 0x4d, 0x2d, 0xfe,
0xc0, 0x8a, 0xa0, 0xfc, 0xa0, 0x6f, 0x1a, 0xbe, 0xe0, 0xb2, 0x6c, 0x76, 0x0d, 0xf2, 0x29, 0x14, 0x7e, 0x5a, 0xd7, 0x60, 0x83, 0x63, 0x76, 0xdc, 0x37, 0x0d, 0x4f, 0x70, 0x5d, 0x34, 0xbb, 0x06,
0x0d, 0x2a, 0xa8, 0xcd, 0x0c, 0x94, 0x83, 0x05, 0x77, 0x10, 0x82, 0xae, 0xae, 0x41, 0x1e, 0xc1, 0xfa, 0x2f, 0x14, 0x0d, 0xcc, 0xb1, 0x4d, 0x0d, 0x22, 0x06, 0x73, 0xee, 0x20, 0xf8, 0x5d, 0x5d,
0x92, 0xe4, 0xc0, 0xea, 0xf9, 0x7a, 0xae, 0x59, 0xde, 0xaa, 0x27, 0xaa, 0x79, 0x06, 0x4a, 0x4d, 0x03, 0x3d, 0x86, 0x35, 0xc1, 0x41, 0xaa, 0x17, 0xeb, 0x99, 0x66, 0x79, 0xb7, 0x1e, 0xab, 0x26,
0xd4, 0x3c, 0x38, 0xa9, 0xc1, 0x2a, 0xc7, 0xe1, 0x18, 0x6d, 0xc1, 0xab, 0x4b, 0xf5, 0x42, 0xb3, 0x0d, 0x14, 0x9a, 0x44, 0x93, 0x70, 0x54, 0x83, 0x3c, 0x23, 0xc3, 0x31, 0xb1, 0x39, 0xab, 0xae,
0xa0, 0xa9, 0x36, 0xf9, 0x08, 0x56, 0xe9, 0x54, 0xb0, 0xbe, 0x69, 0xf0, 0xea, 0xb2, 0x3b, 0xb6, 0xd5, 0x73, 0xcd, 0x9c, 0xa6, 0xda, 0xe8, 0x5f, 0x90, 0xc7, 0x53, 0x4e, 0xfb, 0xa6, 0xc1, 0xaa,
0x22, 0xdb, 0x5d, 0x83, 0x93, 0xab, 0xb0, 0xe6, 0xb0, 0xa3, 0xbe, 0xe7, 0x88, 0x15, 0xd7, 0x9a, 0xeb, 0xee, 0xd8, 0x86, 0x68, 0x77, 0x0d, 0x86, 0xae, 0x43, 0xc1, 0xa1, 0xa7, 0x7d, 0xe9, 0x88,
0x55, 0x87, 0x1d, 0xb5, 0x65, 0x9b, 0x7c, 0x09, 0x4b, 0xa6, 0xfd, 0x86, 0xf1, 0xea, 0x6a, 0xbd, 0x0d, 0xd7, 0x9a, 0xbc, 0x43, 0x4f, 0xdb, 0xa2, 0x8d, 0x3e, 0x85, 0x35, 0xd3, 0x7e, 0x4b, 0x59,
0xd0, 0x2c, 0x6e, 0x5d, 0x4f, 0xb4, 0xe5, 0x6b, 0x3c, 0xf9, 0x8e, 0x5a, 0x53, 0xdc, 0xa3, 0xa6, 0x35, 0x5f, 0xcf, 0x35, 0x8b, 0xbb, 0x37, 0x63, 0x6d, 0xf9, 0x9c, 0x9c, 0x7f, 0x85, 0xad, 0x29,
0xa3, 0x79, 0xf8, 0xc6, 0x1f, 0x39, 0xb8, 0xd2, 0x41, 0xae, 0x3b, 0xe6, 0x00, 0x7b, 0xbe, 0x15, 0x39, 0xc4, 0xa6, 0xa3, 0x49, 0x7c, 0xe3, 0xd7, 0x0c, 0x5c, 0xeb, 0x10, 0xa6, 0x3b, 0xe6, 0x80,
0xef, 0xbe, 0x2d, 0x1a, 0x50, 0xd2, 0x99, 0x65, 0xa1, 0x2e, 0x4c, 0x66, 0xab, 0x10, 0x46, 0xfa, 0xf4, 0x3c, 0x2b, 0xde, 0x7f, 0x5b, 0x34, 0xa0, 0xa4, 0x53, 0xcb, 0x22, 0x3a, 0x37, 0xa9, 0xad,
0xc8, 0x27, 0x00, 0xfe, 0x72, 0xbb, 0x1d, 0x5e, 0x2d, 0xb8, 0x8b, 0x0c, 0xf5, 0x34, 0xa6, 0x50, 0x42, 0x18, 0xea, 0x43, 0xff, 0x01, 0xf0, 0x96, 0xdb, 0xed, 0xb0, 0x6a, 0xce, 0x5d, 0x64, 0xa0,
0xf1, 0x0d, 0x91, 0xc4, 0x5d, 0xfb, 0x0d, 0x9b, 0xa3, 0xcd, 0x25, 0xd0, 0xd6, 0xa1, 0x38, 0xa1, 0xa7, 0x31, 0x85, 0x8a, 0x67, 0x88, 0x20, 0xee, 0xda, 0x6f, 0xe9, 0x1c, 0x6d, 0x26, 0x86, 0xb6,
0x8e, 0x30, 0x23, 0xca, 0xe1, 0x2e, 0x99, 0x2b, 0x4a, 0xc6, 0x0f, 0xe7, 0xac, 0xa3, 0xf1, 0x4f, 0x0e, 0xc5, 0x09, 0x76, 0xb8, 0x19, 0x52, 0x0e, 0x76, 0x89, 0x5c, 0x51, 0x32, 0x5e, 0x38, 0x67,
0x1e, 0x4a, 0xbe, 0xae, 0xd4, 0xe4, 0xa4, 0x03, 0x6b, 0x72, 0x4d, 0x7d, 0xe9, 0x27, 0xdf, 0x05, 0x1d, 0x8d, 0x3f, 0xb2, 0x50, 0xf2, 0x74, 0x85, 0x26, 0x43, 0x1d, 0x28, 0x88, 0x35, 0xf5, 0x85,
0xb7, 0x5a, 0xc9, 0x27, 0x50, 0x2b, 0x66, 0xb0, 0xb6, 0x3a, 0x08, 0x4c, 0xef, 0x40, 0xd1, 0xb4, 0x9f, 0x3c, 0x17, 0xdc, 0x6e, 0xc5, 0x57, 0xa0, 0x56, 0xc4, 0x60, 0x2d, 0x3f, 0xf0, 0x4d, 0xef,
0x0d, 0x3c, 0xee, 0x7b, 0xe1, 0xc9, 0xbb, 0xe1, 0xb9, 0x11, 0xe5, 0x91, 0xa7, 0x50, 0x4b, 0x69, 0x40, 0xd1, 0xb4, 0x0d, 0x72, 0xd6, 0x97, 0xe1, 0xc9, 0xba, 0xe1, 0xb9, 0x15, 0xe6, 0x11, 0x55,
0x1b, 0x78, 0xec, 0x72, 0x80, 0x19, 0x7c, 0x72, 0x82, 0x70, 0x01, 0x8f, 0x85, 0x43, 0xfb, 0x61, 0xa8, 0xa5, 0xb4, 0x0d, 0x72, 0xe6, 0x72, 0x80, 0xe9, 0x7f, 0x32, 0x44, 0xe0, 0x12, 0x39, 0xe3,
0xae, 0x82, 0xcb, 0xf5, 0xd5, 0x29, 0x36, 0xb9, 0x04, 0xad, 0x67, 0x72, 0xb6, 0xe2, 0xe6, 0xcf, 0x0e, 0xee, 0x07, 0xb9, 0x72, 0x2e, 0xd7, 0x67, 0x4b, 0x6c, 0x72, 0x09, 0x5a, 0xcf, 0xc5, 0x6c,
0x6c, 0xe1, 0x9c, 0x68, 0x15, 0x8c, 0xf6, 0xd6, 0x7e, 0x86, 0x8b, 0x49, 0x40, 0xb2, 0x01, 0x85, 0xc5, 0xcd, 0x9e, 0xdb, 0xdc, 0x39, 0xd7, 0x2a, 0x24, 0xdc, 0x5b, 0xfb, 0x01, 0x2e, 0xc7, 0x01,
0x03, 0x3c, 0xf1, 0xdd, 0x2e, 0x3f, 0xc9, 0x16, 0x2c, 0x1d, 0xca, 0xad, 0xe4, 0xfa, 0x79, 0x6e, 0xd1, 0x16, 0xe4, 0x8e, 0xc9, 0xb9, 0xe7, 0x76, 0xf1, 0x89, 0x76, 0x61, 0xed, 0x44, 0x6c, 0x25,
0x6f, 0xb8, 0x0b, 0x9a, 0xad, 0xc4, 0x83, 0x3e, 0xc9, 0x3f, 0xce, 0x35, 0xfe, 0xca, 0x43, 0x75, 0xd7, 0xcf, 0x73, 0x7b, 0xc3, 0x5d, 0xd0, 0x6c, 0x25, 0x12, 0xfa, 0x34, 0xfb, 0x24, 0xd3, 0xf8,
0x7e, 0xbb, 0xbd, 0xcf, 0x59, 0x91, 0x65, 0xcb, 0x0d, 0x61, 0xdd, 0x0f, 0x74, 0xc4, 0x75, 0xdb, 0x3d, 0x0b, 0xd5, 0xf9, 0xed, 0xf6, 0x21, 0xb5, 0x22, 0xcd, 0x96, 0x1b, 0xc2, 0xa6, 0x17, 0xe8,
0x69, 0xae, 0x4b, 0xb3, 0x30, 0xe2, 0x53, 0xcf, 0x87, 0x25, 0x1e, 0xea, 0xaa, 0x21, 0x5c, 0x98, 0x90, 0xeb, 0xf6, 0x92, 0x5c, 0x97, 0x64, 0x61, 0xc8, 0xa7, 0xd2, 0x87, 0x25, 0x16, 0xe8, 0xaa,
0x83, 0x24, 0x78, 0xef, 0x49, 0xd4, 0x7b, 0x37, 0xb3, 0x84, 0x30, 0xec, 0x45, 0x03, 0x2e, 0xee, 0x11, 0xb8, 0x34, 0x07, 0x89, 0xf1, 0xde, 0xd3, 0xb0, 0xf7, 0xb6, 0xd3, 0x84, 0x30, 0xe8, 0x45,
0xa0, 0x68, 0x3b, 0x68, 0xa0, 0x2d, 0x4c, 0x6a, 0xbd, 0x7b, 0xc2, 0xd6, 0x60, 0x75, 0xca, 0xe5, 0x03, 0x2e, 0xef, 0x13, 0xde, 0x76, 0x88, 0x41, 0x6c, 0x6e, 0x62, 0xeb, 0xfd, 0x13, 0xb6, 0x06,
0xfd, 0x38, 0xf6, 0x8c, 0x59, 0xd3, 0x54, 0xbb, 0xf1, 0x7b, 0x0e, 0x2e, 0xc5, 0x64, 0xde, 0x27, 0xf9, 0x29, 0x13, 0xe7, 0xe3, 0x58, 0x1a, 0x53, 0xd0, 0x54, 0xbb, 0xf1, 0x4b, 0x06, 0xae, 0x44,
0x50, 0x0b, 0xa4, 0xe4, 0xd8, 0x84, 0x72, 0x7e, 0xc4, 0x1c, 0xef, 0xa0, 0x5d, 0xd3, 0x54, 0x7b, 0x64, 0x3e, 0x24, 0x50, 0x0b, 0xa4, 0xc4, 0xd8, 0x04, 0x33, 0x76, 0x4a, 0x1d, 0x59, 0x68, 0x0b,
0xeb, 0xef, 0x06, 0xac, 0x69, 0x8c, 0x89, 0xb6, 0x74, 0x09, 0xb1, 0x80, 0x48, 0x9b, 0xd8, 0x78, 0x9a, 0x6a, 0xef, 0xfe, 0xb9, 0x0d, 0x05, 0x8d, 0x52, 0xde, 0x16, 0x2e, 0x41, 0x16, 0x20, 0x61,
0xc2, 0x6c, 0xb4, 0xbd, 0x83, 0x95, 0x93, 0x56, 0xd4, 0x00, 0xbf, 0x31, 0x0f, 0xf4, 0x1d, 0x55, 0x13, 0x1d, 0x4f, 0xa8, 0x4d, 0x6c, 0x59, 0x58, 0x19, 0x6a, 0x85, 0x0d, 0xf0, 0x1a, 0xf3, 0x40,
0xbb, 0x99, 0x88, 0x8f, 0x81, 0x1b, 0xe7, 0xc8, 0xd8, 0x55, 0x93, 0x77, 0xf5, 0x4b, 0x53, 0x3f, 0xcf, 0x51, 0xb5, 0xed, 0x58, 0x7c, 0x04, 0xdc, 0xb8, 0x80, 0xc6, 0xae, 0x9a, 0x38, 0xab, 0x5f,
0x68, 0x8f, 0xa8, 0x6d, 0xa3, 0x45, 0x1e, 0x44, 0x67, 0xab, 0x17, 0xc6, 0x3c, 0x34, 0xd0, 0xbb, 0x99, 0xfa, 0x71, 0x7b, 0x84, 0x6d, 0x9b, 0x58, 0xe8, 0x61, 0x78, 0xb6, 0xba, 0x61, 0xcc, 0x43,
0x91, 0xa8, 0xd7, 0x13, 0x8e, 0x69, 0x0f, 0x03, 0xaf, 0x36, 0xce, 0x91, 0xb7, 0x6e, 0x5c, 0xa5, 0x7d, 0xbd, 0x5b, 0xb1, 0x7a, 0x3d, 0xee, 0x98, 0xf6, 0xd0, 0xf7, 0x6a, 0xe3, 0x02, 0x7a, 0xe7,
0xba, 0xc9, 0x85, 0xa9, 0xf3, 0x40, 0x70, 0x2b, 0x5d, 0x70, 0x0e, 0x7c, 0x46, 0xc9, 0x3e, 0x6c, 0xc6, 0x55, 0xa8, 0x9b, 0x8c, 0x9b, 0x3a, 0xf3, 0x05, 0x77, 0x93, 0x05, 0xe7, 0xc0, 0x2b, 0x4a,
0xb4, 0x1d, 0xa4, 0x02, 0xdb, 0x2a, 0x61, 0xc8, 0xbd, 0x64, 0xef, 0xc4, 0x60, 0x81, 0xd0, 0xa2, 0xf6, 0x61, 0xab, 0xed, 0x10, 0xcc, 0x49, 0x5b, 0x25, 0x0c, 0xba, 0x1f, 0xef, 0x9d, 0x08, 0xcc,
0xe0, 0x37, 0xce, 0x91, 0x1f, 0xa1, 0xdc, 0x71, 0xd8, 0x24, 0x44, 0x7f, 0x27, 0x91, 0x3e, 0x0a, 0x17, 0x5a, 0x14, 0xfc, 0xc6, 0x05, 0xf4, 0x1d, 0x94, 0x3b, 0x0e, 0x9d, 0x04, 0xe8, 0xef, 0xc6,
0xca, 0x48, 0xde, 0x87, 0xf5, 0xe7, 0x94, 0x87, 0xb8, 0x6f, 0x27, 0x72, 0x47, 0x30, 0x01, 0xf5, 0xd2, 0x87, 0x41, 0x29, 0xc9, 0xfb, 0xb0, 0xf9, 0x02, 0xb3, 0x00, 0xf7, 0x9d, 0x58, 0xee, 0x10,
0xf5, 0x44, 0xe8, 0x36, 0x63, 0x56, 0xc8, 0x3d, 0x47, 0x40, 0x82, 0xc3, 0x20, 0xa4, 0x92, 0xbc, 0xc6, 0xa7, 0xbe, 0x19, 0x0b, 0xdd, 0xa3, 0xd4, 0x0a, 0xb8, 0xe7, 0x14, 0x90, 0x5f, 0x0c, 0x02,
0xdd, 0xe6, 0x81, 0x81, 0xd4, 0x66, 0x66, 0xbc, 0x12, 0xfe, 0x15, 0x6a, 0xf3, 0xe3, 0x5d, 0x3f, 0x2a, 0xf1, 0xdb, 0x6d, 0x1e, 0xe8, 0x4b, 0xed, 0xa4, 0xc6, 0x2b, 0xe1, 0x9f, 0xa0, 0x36, 0x3f,
0xf0, 0xff, 0x87, 0x01, 0xaf, 0xa0, 0xe8, 0x45, 0xfc, 0xa9, 0x65, 0x52, 0x4e, 0x6e, 0x2d, 0xd8, 0xde, 0xf5, 0x02, 0xff, 0x4f, 0x18, 0xf0, 0x1a, 0x8a, 0x32, 0xe2, 0xcf, 0x2c, 0x13, 0x33, 0x74,
0x13, 0x2e, 0x22, 0x63, 0xc4, 0xbe, 0x85, 0x35, 0x19, 0x69, 0x8f, 0xf4, 0xb3, 0xd4, 0x9d, 0x70, 0x7b, 0xc1, 0x9e, 0x70, 0x11, 0x29, 0x23, 0xf6, 0x25, 0x14, 0x44, 0xa4, 0x25, 0xe9, 0xff, 0x12,
0x16, 0xca, 0x1e, 0xc0, 0x53, 0x4b, 0xa0, 0xe3, 0x71, 0x7e, 0x9e, 0xc8, 0x39, 0x03, 0x64, 0x24, 0x77, 0xc2, 0x2a, 0x94, 0x3d, 0x80, 0x67, 0x16, 0x27, 0x8e, 0xe4, 0xfc, 0x7f, 0x2c, 0xe7, 0x0c,
0xb5, 0xa1, 0xd2, 0x1b, 0xc9, 0xd7, 0x55, 0xe0, 0x1a, 0x4e, 0xee, 0x26, 0x67, 0x54, 0x14, 0x15, 0x90, 0x92, 0xd4, 0x86, 0x4a, 0x6f, 0x24, 0x6e, 0x57, 0xbe, 0x6b, 0x18, 0xba, 0x17, 0x9f, 0x51,
0xd0, 0xdf, 0xcb, 0x06, 0x56, 0xee, 0xde, 0x97, 0x4f, 0x67, 0x81, 0x4e, 0x68, 0x97, 0xdd, 0x4d, 0x61, 0x94, 0x4f, 0x7f, 0x3f, 0x1d, 0x58, 0xb9, 0xfb, 0x48, 0x5c, 0x9d, 0x39, 0x71, 0x02, 0xbb,
0x5f, 0xc9, 0x99, 0x13, 0x65, 0x1f, 0x2a, 0x5e, 0xac, 0xf6, 0x82, 0x07, 0x51, 0x0a, 0x7d, 0x0c, 0xec, 0x5e, 0xf2, 0x4a, 0x56, 0x4e, 0x94, 0x23, 0xa8, 0xc8, 0x58, 0x1d, 0xfa, 0x17, 0xa2, 0x04,
0x95, 0x91, 0xfe, 0x7b, 0x58, 0x97, 0x51, 0x9b, 0x91, 0xdf, 0x4e, 0x8d, 0xec, 0x59, 0xa9, 0xf7, 0xfa, 0x08, 0x2a, 0x25, 0xfd, 0x37, 0xb0, 0x29, 0xa2, 0x36, 0x23, 0xbf, 0x93, 0x18, 0xd9, 0x55,
0xa1, 0xf4, 0x9c, 0xf2, 0x19, 0x73, 0x33, 0x2d, 0xc3, 0xe7, 0x88, 0x33, 0x25, 0xf8, 0x01, 0x94, 0xa9, 0x8f, 0xa0, 0xf4, 0x02, 0xb3, 0x19, 0x73, 0x33, 0x29, 0xc3, 0xe7, 0x88, 0x53, 0x25, 0xf8,
0x65, 0x50, 0xd4, 0x64, 0x9e, 0x72, 0x3c, 0x45, 0x41, 0x81, 0xc4, 0xdd, 0x4c, 0x58, 0x25, 0xc6, 0x31, 0x94, 0x45, 0x50, 0xd4, 0x64, 0x96, 0x50, 0x9e, 0xc2, 0x20, 0x5f, 0xe2, 0x5e, 0x2a, 0xac,
0xe1, 0x72, 0x74, 0x4c, 0x25, 0xf4, 0x7f, 0x28, 0x8a, 0x50, 0x92, 0x63, 0xc1, 0x5b, 0x26, 0xc5, 0x12, 0x63, 0x70, 0x35, 0x3c, 0xa6, 0x12, 0xfa, 0x6f, 0x14, 0x25, 0x50, 0x12, 0x63, 0xfe, 0x5d,
0x81, 0x61, 0x48, 0x20, 0x74, 0x3b, 0x03, 0x32, 0x74, 0x77, 0x95, 0xa3, 0x85, 0x2d, 0xb9, 0x9f, 0x26, 0xc1, 0x81, 0x41, 0x88, 0x2f, 0x74, 0x27, 0x05, 0x32, 0x70, 0x76, 0x95, 0xc3, 0x0f, 0x5b,
0xf6, 0xac, 0x49, 0x2c, 0xb1, 0x6b, 0xad, 0xac, 0x70, 0x25, 0xf9, 0x13, 0xac, 0xf8, 0xe5, 0x66, 0xf4, 0x20, 0xe9, 0x5a, 0x13, 0xfb, 0xc4, 0xae, 0xb5, 0xd2, 0xc2, 0x95, 0xe4, 0xf7, 0xb0, 0xe1,
0x3c, 0xeb, 0x63, 0x93, 0x55, 0xa5, 0x5b, 0xbb, 0x75, 0x2a, 0x4e, 0xb1, 0x53, 0xb8, 0xf4, 0x6a, 0x3d, 0x37, 0xa3, 0x59, 0x1f, 0x99, 0xac, 0x5e, 0xba, 0xb5, 0xdb, 0x4b, 0x71, 0x8a, 0x1d, 0xc3,
0x62, 0xc8, 0x2b, 0xcf, 0xbb, 0x58, 0x83, 0xab, 0x3d, 0xbe, 0xb7, 0xd5, 0x6d, 0x1c, 0xc3, 0xed, 0x95, 0xd7, 0x13, 0x43, 0x1c, 0x79, 0xf2, 0x60, 0xf5, 0x8f, 0xf6, 0xe8, 0xde, 0x56, 0xa7, 0x71,
0xf2, 0xe1, 0x69, 0x7b, 0xdb, 0x81, 0x8f, 0xbb, 0xf6, 0x21, 0xb5, 0x4c, 0x23, 0x72, 0xb3, 0xee, 0x04, 0x77, 0xc0, 0x86, 0xcb, 0xf6, 0xb6, 0x03, 0xff, 0xee, 0xda, 0x27, 0xd8, 0x32, 0x8d, 0xd0,
0xa2, 0xa0, 0x6d, 0xaa, 0x8f, 0x30, 0x7e, 0xf1, 0x7b, 0xff, 0x2e, 0xa2, 0x53, 0x14, 0x38, 0x63, 0xc9, 0x7a, 0x40, 0x38, 0x6e, 0x63, 0x7d, 0x44, 0xa2, 0x07, 0xbf, 0xfc, 0x77, 0x11, 0x9e, 0xa2,
0x3e, 0xfd, 0x02, 0xc4, 0x3b, 0x85, 0xec, 0x37, 0xe6, 0x70, 0xea, 0x50, 0x6f, 0xd3, 0xa7, 0x3d, 0xc0, 0x29, 0xf3, 0xe9, 0x47, 0x40, 0xb2, 0x0a, 0xd9, 0x6f, 0xcd, 0xe1, 0xd4, 0xc1, 0x72, 0xd3,
0x69, 0xe6, 0xa1, 0x81, 0xcc, 0x17, 0x67, 0x98, 0x11, 0x7a, 0x6d, 0xc0, 0x0e, 0x8a, 0x5d, 0x14, 0x27, 0x5d, 0x69, 0xe6, 0xa1, 0xbe, 0xcc, 0x27, 0x2b, 0xcc, 0x08, 0xdc, 0x36, 0x60, 0x9f, 0xf0,
0x8e, 0xa9, 0xa7, 0x1d, 0xd5, 0x33, 0x40, 0x4a, 0xd0, 0x12, 0x70, 0x4a, 0xa0, 0x07, 0xcb, 0x5e, 0x03, 0xc2, 0x1d, 0x53, 0x4f, 0x2a, 0xd5, 0x33, 0x40, 0x42, 0xd0, 0x62, 0x70, 0x4a, 0xa0, 0x07,
0xc5, 0x4d, 0x1a, 0x89, 0x93, 0x82, 0xff, 0x05, 0x8b, 0xde, 0x48, 0xea, 0x9f, 0x42, 0xe8, 0x8c, 0xeb, 0xf2, 0xc5, 0x8d, 0x1a, 0xb1, 0x93, 0xfc, 0xff, 0x05, 0x8b, 0xee, 0x48, 0xea, 0x9f, 0x42,
0xd8, 0x41, 0x11, 0xaa, 0xe4, 0x53, 0xd2, 0x35, 0x0a, 0x5a, 0x9c, 0xae, 0x71, 0xac, 0x12, 0xb3, 0xa0, 0x46, 0xec, 0x13, 0x1e, 0x78, 0xc9, 0x27, 0xa4, 0x6b, 0x18, 0xb4, 0x38, 0x5d, 0xa3, 0x58,
0xa1, 0xf2, 0x8d, 0xc9, 0xfd, 0xc1, 0x97, 0x94, 0x1f, 0xa4, 0x5d, 0x3c, 0x31, 0xd4, 0xe2, 0x8b, 0x25, 0x66, 0x43, 0xe5, 0x0b, 0x93, 0x79, 0x83, 0xaf, 0x30, 0x3b, 0x4e, 0x3a, 0x78, 0x22, 0xa8,
0x67, 0x0e, 0x1c, 0xf2, 0x58, 0x49, 0x43, 0x39, 0xe0, 0xfb, 0x2d, 0xb5, 0x18, 0x09, 0xff, 0x6a, 0xc5, 0x07, 0xcf, 0x1c, 0x38, 0xe0, 0xb1, 0x92, 0x46, 0xc4, 0x80, 0xe7, 0xb7, 0xc4, 0xc7, 0x48,
0x39, 0x6d, 0x93, 0xbd, 0x56, 0xaf, 0x4a, 0x55, 0x3c, 0xc4, 0x2f, 0xfb, 0x59, 0xda, 0x28, 0x88, 0xf0, 0x57, 0xcb, 0xb2, 0x4d, 0xf6, 0x46, 0xdd, 0x2a, 0xd5, 0xe3, 0x21, 0x7a, 0xd8, 0xcf, 0xd2,
0xac, 0x73, 0x32, 0x30, 0xfb, 0x59, 0xf9, 0xa1, 0x99, 0xfb, 0xb0, 0xd1, 0x41, 0x0b, 0x23, 0xcc, 0x46, 0x41, 0xc4, 0x3b, 0x27, 0x05, 0xb3, 0x97, 0x95, 0x1f, 0x9b, 0xb9, 0x0f, 0x5b, 0x1d, 0x62,
0xf7, 0x52, 0xde, 0x4d, 0x51, 0x58, 0xc6, 0xcc, 0x1b, 0xc1, 0xba, 0x0c, 0x83, 0x9c, 0xf7, 0x8a, 0x91, 0x10, 0xf3, 0xfd, 0x84, 0x7b, 0x53, 0x18, 0x96, 0x32, 0xf3, 0x46, 0xb0, 0x29, 0xc2, 0x20,
0xa3, 0xc3, 0x53, 0x2e, 0xc9, 0x08, 0x26, 0xa0, 0xbe, 0x93, 0x05, 0x1a, 0xda, 0x43, 0xeb, 0x91, 0xe6, 0xbd, 0x66, 0xc4, 0x61, 0x09, 0x87, 0x64, 0x08, 0xe3, 0x53, 0xdf, 0x4d, 0x03, 0x0d, 0xec,
0xc2, 0x2d, 0xbe, 0x8e, 0x59, 0x50, 0x93, 0xca, 0xc8, 0xda, 0xfd, 0x8c, 0xe8, 0xd0, 0x1e, 0x02, 0xa1, 0xcd, 0xd0, 0xc3, 0x2d, 0xba, 0x8e, 0x59, 0x50, 0xe3, 0x9e, 0x91, 0xb5, 0x07, 0x29, 0xd1,
0x2f, 0xdc, 0x1a, 0xb3, 0x30, 0x25, 0xad, 0x67, 0x80, 0x8c, 0xee, 0x7a, 0x01, 0xab, 0xf2, 0xbd, 0x81, 0x3d, 0x04, 0x32, 0xdc, 0x1a, 0xb5, 0x48, 0x42, 0x5a, 0xcf, 0x00, 0x29, 0xdd, 0xf5, 0x12,
0xe0, 0x52, 0xde, 0x4c, 0x7d, 0x4e, 0x9c, 0x81, 0x70, 0x1f, 0x2a, 0x2f, 0x26, 0xe8, 0x50, 0x81, 0xf2, 0xe2, 0xbe, 0xe0, 0x52, 0x6e, 0x27, 0x5e, 0x27, 0x56, 0x20, 0x3c, 0x82, 0xca, 0xcb, 0x09,
0xd2, 0x5f, 0x2e, 0x6f, 0x72, 0x66, 0xc5, 0x50, 0x99, 0x6b, 0x11, 0xe8, 0xa1, 0x3c, 0xc1, 0x17, 0x71, 0x30, 0x27, 0xc2, 0x5f, 0x2e, 0x6f, 0x7c, 0x66, 0x45, 0x50, 0xa9, 0xdf, 0x22, 0xd0, 0x23,
0x38, 0x61, 0x06, 0x58, 0x7c, 0xb6, 0x85, 0x71, 0xe1, 0xc3, 0xd3, 0xeb, 0x97, 0x86, 0x2d, 0x14, 0xa2, 0x82, 0x2f, 0x70, 0xc2, 0x0c, 0xb0, 0xb8, 0xb6, 0x05, 0x71, 0xc1, 0xe2, 0x29, 0xfb, 0x85,
0x70, 0x2d, 0xcf, 0x20, 0xe0, 0xe1, 0xc2, 0xb5, 0xa0, 0xbf, 0xf4, 0x3d, 0xc7, 0x3c, 0x34, 0x2d, 0x61, 0x0b, 0x05, 0x5c, 0xcb, 0x53, 0x08, 0x48, 0x5c, 0xf0, 0x2d, 0xe8, 0x2d, 0xfd, 0xd0, 0x31,
0x1c, 0x62, 0x4a, 0x06, 0xc4, 0x61, 0x19, 0x5d, 0x34, 0x80, 0xa2, 0x27, 0xbc, 0xe3, 0x50, 0x5b, 0x4f, 0x4c, 0x8b, 0x0c, 0x49, 0x42, 0x06, 0x44, 0x61, 0x29, 0x5d, 0x34, 0x80, 0xa2, 0x14, 0xde,
0x90, 0x45, 0xa6, 0xb9, 0x88, 0x80, 0xb6, 0x79, 0x3a, 0x50, 0x2d, 0x42, 0x07, 0x90, 0x69, 0xb1, 0x77, 0xb0, 0xcd, 0xd1, 0x22, 0xd3, 0x5c, 0x84, 0x4f, 0xdb, 0x5c, 0x0e, 0x54, 0x8b, 0xd0, 0x01,
0xc7, 0x2c, 0x53, 0x3f, 0x89, 0x3f, 0x76, 0xd4, 0xd1, 0x30, 0x83, 0xa4, 0x3c, 0x76, 0x12, 0x91, 0x44, 0x5a, 0x1c, 0x52, 0xcb, 0xd4, 0xcf, 0xa3, 0x97, 0x1d, 0x55, 0x1a, 0x66, 0x90, 0x84, 0xcb,
0x4a, 0x64, 0x00, 0xc5, 0xf6, 0x08, 0xf5, 0x83, 0xe7, 0x48, 0x2d, 0x31, 0x4a, 0x2b, 0x8e, 0x66, 0x4e, 0x2c, 0x52, 0x89, 0x0c, 0xa0, 0xd8, 0x1e, 0x11, 0xfd, 0xf8, 0x05, 0xc1, 0x16, 0x1f, 0x25,
0x88, 0xc5, 0x0b, 0x89, 0x00, 0xc3, 0xd1, 0xd0, 0xd0, 0xa6, 0xe3, 0xd3, 0x2b, 0xf3, 0x38, 0x2c, 0x3d, 0x8e, 0x66, 0x88, 0xc5, 0x0b, 0x09, 0x01, 0x83, 0xd1, 0xd0, 0x88, 0x8d, 0xc7, 0xcb, 0x5f,
0x5b, 0x34, 0xb6, 0x1f, 0xff, 0xf0, 0x68, 0x68, 0x8a, 0xd1, 0x74, 0x20, 0x47, 0x36, 0x3d, 0xe8, 0xe6, 0x51, 0x58, 0xfa, 0x97, 0xb9, 0x4c, 0xca, 0x0e, 0xe6, 0xd8, 0xfd, 0x1b, 0x74, 0x77, 0x41,
0x7d, 0x93, 0xf9, 0x5f, 0x9b, 0x81, 0x07, 0x36, 0xdd, 0xd9, 0x9b, 0xea, 0x14, 0x98, 0x0c, 0x06, 0xe6, 0xfa, 0xa0, 0x94, 0xe4, 0x5f, 0x43, 0x49, 0xa4, 0xa7, 0xa2, 0x6e, 0x26, 0x66, 0xf0, 0x8a,
0xcb, 0x6e, 0xd7, 0xc3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xd8, 0x8d, 0x70, 0xb9, 0xa3, 0x19, 0xc4, 0x5e, 0x15, 0xf5, 0x67, 0x2d, 0xaa, 0xa2, 0x0a, 0xb3, 0xbc, 0x8a, 0x06, 0xa0, 0x7e, 0x00,
0x00, 0x00, 0xf6, 0x9e, 0x7c, 0xfb, 0x78, 0x68, 0xf2, 0xd1, 0x74, 0x20, 0x6c, 0xd8, 0x91, 0xe0, 0x07, 0x26,
0xf5, 0xbe, 0x76, 0xfc, 0x1d, 0xb2, 0xe3, 0x92, 0xed, 0xa8, 0x2a, 0x39, 0x19, 0x0c, 0xd6, 0xdd,
0xae, 0x47, 0x7f, 0x05, 0x00, 0x00, 0xff, 0xff, 0xcc, 0x48, 0x0a, 0xdf, 0xc3, 0x1a, 0x00, 0x00,
} }
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.
@ -884,6 +886,9 @@ type RootCoordClient interface {
ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error)
CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)
RenameCollection(ctx context.Context, in *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) RenameCollection(ctx context.Context, in *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error)
CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error)
DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error)
ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error)
} }
type rootCoordClient struct { type rootCoordClient struct {
@ -1290,6 +1295,33 @@ func (c *rootCoordClient) RenameCollection(ctx context.Context, in *milvuspb.Ren
return out, nil return out, nil
} }
func (c *rootCoordClient) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
out := new(commonpb.Status)
err := c.cc.Invoke(ctx, "/milvus.proto.rootcoord.RootCoord/CreateDatabase", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *rootCoordClient) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
out := new(commonpb.Status)
err := c.cc.Invoke(ctx, "/milvus.proto.rootcoord.RootCoord/DropDatabase", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *rootCoordClient) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) {
out := new(milvuspb.ListDatabasesResponse)
err := c.cc.Invoke(ctx, "/milvus.proto.rootcoord.RootCoord/ListDatabases", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// RootCoordServer is the server API for RootCoord service. // RootCoordServer is the server API for RootCoord service.
type RootCoordServer interface { type RootCoordServer interface {
GetComponentStates(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) GetComponentStates(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)
@ -1388,6 +1420,9 @@ type RootCoordServer interface {
ListPolicy(context.Context, *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) ListPolicy(context.Context, *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error)
CheckHealth(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) CheckHealth(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)
RenameCollection(context.Context, *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) RenameCollection(context.Context, *milvuspb.RenameCollectionRequest) (*commonpb.Status, error)
CreateDatabase(context.Context, *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error)
DropDatabase(context.Context, *milvuspb.DropDatabaseRequest) (*commonpb.Status, error)
ListDatabases(context.Context, *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error)
} }
// UnimplementedRootCoordServer can be embedded to have forward compatible implementations. // UnimplementedRootCoordServer can be embedded to have forward compatible implementations.
@ -1526,6 +1561,15 @@ func (*UnimplementedRootCoordServer) CheckHealth(ctx context.Context, req *milvu
func (*UnimplementedRootCoordServer) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { func (*UnimplementedRootCoordServer) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) {
return nil, status.Errorf(codes.Unimplemented, "method RenameCollection not implemented") return nil, status.Errorf(codes.Unimplemented, "method RenameCollection not implemented")
} }
func (*UnimplementedRootCoordServer) CreateDatabase(ctx context.Context, req *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return nil, status.Errorf(codes.Unimplemented, "method CreateDatabase not implemented")
}
func (*UnimplementedRootCoordServer) DropDatabase(ctx context.Context, req *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return nil, status.Errorf(codes.Unimplemented, "method DropDatabase not implemented")
}
func (*UnimplementedRootCoordServer) ListDatabases(ctx context.Context, req *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListDatabases not implemented")
}
func RegisterRootCoordServer(s *grpc.Server, srv RootCoordServer) { func RegisterRootCoordServer(s *grpc.Server, srv RootCoordServer) {
s.RegisterService(&_RootCoord_serviceDesc, srv) s.RegisterService(&_RootCoord_serviceDesc, srv)
@ -2323,6 +2367,60 @@ func _RootCoord_RenameCollection_Handler(srv interface{}, ctx context.Context, d
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _RootCoord_CreateDatabase_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(milvuspb.CreateDatabaseRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RootCoordServer).CreateDatabase(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/milvus.proto.rootcoord.RootCoord/CreateDatabase",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RootCoordServer).CreateDatabase(ctx, req.(*milvuspb.CreateDatabaseRequest))
}
return interceptor(ctx, in, info, handler)
}
func _RootCoord_DropDatabase_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(milvuspb.DropDatabaseRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RootCoordServer).DropDatabase(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/milvus.proto.rootcoord.RootCoord/DropDatabase",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RootCoordServer).DropDatabase(ctx, req.(*milvuspb.DropDatabaseRequest))
}
return interceptor(ctx, in, info, handler)
}
func _RootCoord_ListDatabases_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(milvuspb.ListDatabasesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RootCoordServer).ListDatabases(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/milvus.proto.rootcoord.RootCoord/ListDatabases",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RootCoordServer).ListDatabases(ctx, req.(*milvuspb.ListDatabasesRequest))
}
return interceptor(ctx, in, info, handler)
}
var _RootCoord_serviceDesc = grpc.ServiceDesc{ var _RootCoord_serviceDesc = grpc.ServiceDesc{
ServiceName: "milvus.proto.rootcoord.RootCoord", ServiceName: "milvus.proto.rootcoord.RootCoord",
HandlerType: (*RootCoordServer)(nil), HandlerType: (*RootCoordServer)(nil),
@ -2503,6 +2601,18 @@ var _RootCoord_serviceDesc = grpc.ServiceDesc{
MethodName: "RenameCollection", MethodName: "RenameCollection",
Handler: _RootCoord_RenameCollection_Handler, Handler: _RootCoord_RenameCollection_Handler,
}, },
{
MethodName: "CreateDatabase",
Handler: _RootCoord_CreateDatabase_Handler,
},
{
MethodName: "DropDatabase",
Handler: _RootCoord_DropDatabase_Handler,
},
{
MethodName: "ListDatabases",
Handler: _RootCoord_ListDatabases_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "root_coord.proto", Metadata: "root_coord.proto",

View File

@ -0,0 +1,244 @@
package proxy
import (
"context"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
// DatabaseInterceptor fill dbname into request based on kv pair <"dbname": "xx"> in header
func DatabaseInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
filledCtx, filledReq := fillDatabase(ctx, req)
return handler(filledCtx, filledReq)
}
}
func fillDatabase(ctx context.Context, req interface{}) (context.Context, interface{}) {
switch r := req.(type) {
case *milvuspb.CreateCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DropCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.HasCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.LoadCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ReleaseCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DescribeCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetStatisticsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetCollectionStatisticsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ShowCollectionsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.AlterCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.CreatePartitionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DropPartitionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.HasPartitionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.LoadPartitionsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ReleasePartitionsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetPartitionStatisticsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ShowPartitionsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetLoadingProgressRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetLoadStateRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.CreateIndexRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DescribeIndexRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DropIndexRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetIndexBuildProgressRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetIndexStateRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.InsertRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DeleteRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.SearchRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.FlushRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.QueryRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.CreateAliasRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DropAliasRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.AlterAliasRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ImportRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ListImportTasksRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.RenameCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.TransferReplicaRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetPersistentSegmentInfoRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetQuerySegmentInfoRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.LoadBalanceRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetReplicasRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.OperatePrivilegeRequest:
if r.Entity != nil && r.Entity.DbName == "" {
r.Entity.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.SelectGrantRequest:
if r.Entity != nil && r.Entity.DbName == "" {
r.Entity.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetIndexStatisticsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.UpsertRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
default:
return ctx, req
}
}

View File

@ -0,0 +1,137 @@
package proxy
import (
"context"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util"
)
func TestDatabaseInterceptor(t *testing.T) {
ctx := context.Background()
interceptor := DatabaseInterceptor()
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return "", nil
}
t.Run("empty md", func(t *testing.T) {
req := &milvuspb.CreateCollectionRequest{}
_, err := interceptor(ctx, req, &grpc.UnaryServerInfo{}, handler)
assert.NoError(t, err)
assert.Equal(t, util.DefaultDBName, req.GetDbName())
})
t.Run("with invalid metadata", func(t *testing.T) {
md := metadata.Pairs("xxx", "yyy")
ctx = metadata.NewIncomingContext(ctx, md)
req := &milvuspb.CreateCollectionRequest{}
_, err := interceptor(ctx, req, &grpc.UnaryServerInfo{}, handler)
assert.NoError(t, err)
assert.Equal(t, util.DefaultDBName, req.GetDbName())
})
t.Run("empty req", func(t *testing.T) {
md := metadata.Pairs("xxx", "yyy")
ctx = metadata.NewIncomingContext(ctx, md)
_, err := interceptor(ctx, "", &grpc.UnaryServerInfo{}, handler)
assert.NoError(t, err)
})
t.Run("test ok for all request", func(t *testing.T) {
availableReqs := []proto.Message{
&milvuspb.CreateCollectionRequest{},
&milvuspb.DropCollectionRequest{},
&milvuspb.HasCollectionRequest{},
&milvuspb.LoadCollectionRequest{},
&milvuspb.ReleaseCollectionRequest{},
&milvuspb.DescribeCollectionRequest{},
&milvuspb.GetStatisticsRequest{},
&milvuspb.GetCollectionStatisticsRequest{},
&milvuspb.ShowCollectionsRequest{},
&milvuspb.AlterCollectionRequest{},
&milvuspb.CreatePartitionRequest{},
&milvuspb.DropPartitionRequest{},
&milvuspb.HasPartitionRequest{},
&milvuspb.LoadPartitionsRequest{},
&milvuspb.ReleasePartitionsRequest{},
&milvuspb.GetPartitionStatisticsRequest{},
&milvuspb.ShowPartitionsRequest{},
&milvuspb.GetLoadingProgressRequest{},
&milvuspb.GetLoadStateRequest{},
&milvuspb.CreateIndexRequest{},
&milvuspb.DescribeIndexRequest{},
&milvuspb.DropIndexRequest{},
&milvuspb.GetIndexBuildProgressRequest{},
&milvuspb.GetIndexStateRequest{},
&milvuspb.InsertRequest{},
&milvuspb.DeleteRequest{},
&milvuspb.SearchRequest{},
&milvuspb.FlushRequest{},
&milvuspb.QueryRequest{},
&milvuspb.CreateAliasRequest{},
&milvuspb.DropAliasRequest{},
&milvuspb.AlterAliasRequest{},
&milvuspb.GetPersistentSegmentInfoRequest{},
&milvuspb.GetQuerySegmentInfoRequest{},
&milvuspb.LoadBalanceRequest{},
&milvuspb.GetReplicasRequest{},
&milvuspb.ImportRequest{},
&milvuspb.RenameCollectionRequest{},
&milvuspb.TransferReplicaRequest{},
&milvuspb.ListImportTasksRequest{},
&milvuspb.OperatePrivilegeRequest{Entity: &milvuspb.GrantEntity{}},
&milvuspb.SelectGrantRequest{Entity: &milvuspb.GrantEntity{}},
}
md := metadata.Pairs(util.HeaderDBName, "db")
ctx = metadata.NewIncomingContext(ctx, md)
for _, req := range availableReqs {
before, err := proto.Marshal(req)
assert.NoError(t, err)
_, err = interceptor(ctx, req, &grpc.UnaryServerInfo{}, handler)
assert.NoError(t, err)
after, err := proto.Marshal(req)
assert.NoError(t, err)
assert.True(t, len(after) > len(before))
}
unavailableReqs := []proto.Message{
&milvuspb.GetMetricsRequest{},
&milvuspb.DummyRequest{},
&milvuspb.CalcDistanceRequest{},
&milvuspb.FlushAllRequest{},
&milvuspb.GetCompactionStateRequest{},
&milvuspb.ManualCompactionRequest{},
&milvuspb.GetCompactionPlansRequest{},
&milvuspb.GetFlushStateRequest{},
&milvuspb.GetFlushAllStateRequest{},
&milvuspb.GetImportStateRequest{},
}
for _, req := range unavailableReqs {
before, err := proto.Marshal(req)
assert.NoError(t, err)
_, err = interceptor(ctx, req, &grpc.UnaryServerInfo{}, handler)
assert.NoError(t, err)
after, err := proto.Marshal(req)
assert.NoError(t, err)
if len(after) != len(before) {
t.Errorf("req has been modified:%s", req.String())
}
}
})
}

View File

@ -4,6 +4,10 @@ import (
"context" "context"
"fmt" "fmt"
"plugin" "plugin"
"strconv"
"strings"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -86,6 +90,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
log.Info("hook mock", zap.String("user", getCurrentUser(ctx)), log.Info("hook mock", zap.String("user", getCurrentUser(ctx)),
zap.String("full method", fullMethod), zap.Error(err)) zap.String("full method", fullMethod), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookMock, fullMethod).Inc() metrics.ProxyHookFunc.WithLabelValues(metrics.HookMock, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return mockResp, err return mockResp, err
} }
@ -93,6 +98,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
log.Warn("hook before error", zap.String("user", getCurrentUser(ctx)), zap.String("full method", fullMethod), log.Warn("hook before error", zap.String("user", getCurrentUser(ctx)), zap.String("full method", fullMethod),
zap.Any("request", req), zap.Error(err)) zap.Any("request", req), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookBefore, fullMethod).Inc() metrics.ProxyHookFunc.WithLabelValues(metrics.HookBefore, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return nil, err return nil, err
} }
realResp, realErr = handler(newCtx, req) realResp, realErr = handler(newCtx, req)
@ -100,12 +106,22 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
log.Warn("hook after error", zap.String("user", getCurrentUser(ctx)), zap.String("full method", fullMethod), log.Warn("hook after error", zap.String("user", getCurrentUser(ctx)), zap.String("full method", fullMethod),
zap.Any("request", req), zap.Error(err)) zap.Any("request", req), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookAfter, fullMethod).Inc() metrics.ProxyHookFunc.WithLabelValues(metrics.HookAfter, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return nil, err return nil, err
} }
return realResp, realErr return realResp, realErr
} }
} }
func updateProxyFunctionCallMetric(fullMethod string) {
if fullMethod == "" {
return
}
method := strings.Split(fullMethod, "/")[0]
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
}
func getCurrentUser(ctx context.Context) string { func getCurrentUser(ctx context.Context) string {
username, err := GetCurUserFromContext(ctx) username, err := GetCurUserFromContext(ctx)
if err != nil { if err != nil {

View File

@ -81,6 +81,9 @@ func TestHookInterceptor(t *testing.T) {
info = &grpc.UnaryServerInfo{ info = &grpc.UnaryServerInfo{
FullMethod: "test", FullMethod: "test",
} }
emptyFullMethod = &grpc.UnaryServerInfo{
FullMethod: "",
}
interceptor = UnaryServerHookInterceptor() interceptor = UnaryServerHookInterceptor()
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")} mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
r = &req{method: "req"} r = &req{method: "req"}
@ -98,6 +101,11 @@ func TestHookInterceptor(t *testing.T) {
}) })
assert.Equal(t, res, mockHoo.mockRes) assert.Equal(t, res, mockHoo.mockRes)
assert.Equal(t, err, mockHoo.mockErr) assert.Equal(t, err, mockHoo.mockErr)
res, err = interceptor(ctx, "request", emptyFullMethod, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, res, mockHoo.mockRes)
assert.Equal(t, err, mockHoo.mockErr)
hoo = beforeHoo hoo = beforeHoo
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) { _, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {

View File

@ -58,6 +58,7 @@ import (
) )
const moduleName = "Proxy" const moduleName = "Proxy"
const SlowReadSpan = time.Second * 5 const SlowReadSpan = time.Second * 5
// UpdateStateCode updates the state code of Proxy. // UpdateStateCode updates the state code of Proxy.
@ -129,7 +130,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
var aliasName []string var aliasName []string
if globalMetaCache != nil { if globalMetaCache != nil {
if collectionName != "" { if collectionName != "" {
globalMetaCache.RemoveCollection(ctx, collectionName) // no need to return error, though collection may be not cached globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName) // no need to return error, though collection may be not cached
} }
if request.CollectionID != UniqueID(0) { if request.CollectionID != UniqueID(0) {
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID) aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID)
@ -152,6 +153,160 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
}, nil }, nil
} }
func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
if !node.checkHealthy() {
return unhealthyStatus(), nil
}
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-CreateDatabase")
defer sp.End()
method := "CreateDatabase"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
cct := &createDatabaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
CreateDatabaseRequest: request,
rootCoord: node.rootCoord,
}
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
zap.String("role", typeutil.ProxyRole),
zap.String("dbName", request.DbName))
log.Info(rpcReceived(method))
if err := node.sched.ddQueue.Enqueue(cct); err != nil {
log.Warn(rpcFailedToEnqueue(method), zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc()
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, nil
}
log.Info(rpcEnqueued(method))
if err := cct.WaitToFinish(); err != nil {
log.Warn(rpcFailedToWaitToFinish(method), zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, nil
}
log.Info(rpcDone(method))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return cct.result, nil
}
func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
if !node.checkHealthy() {
return unhealthyStatus(), nil
}
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-DropDatabase")
defer sp.End()
method := "DropDatabase"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
dct := &dropDatabaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
DropDatabaseRequest: request,
rootCoord: node.rootCoord,
}
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
zap.String("role", typeutil.ProxyRole),
zap.String("dbName", request.DbName))
log.Info(rpcReceived(method))
if err := node.sched.ddQueue.Enqueue(dct); err != nil {
log.Warn(rpcFailedToEnqueue(method), zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc()
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, nil
}
log.Info(rpcEnqueued(method))
if err := dct.WaitToFinish(); err != nil {
log.Warn(rpcFailedToWaitToFinish(method), zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, nil
}
log.Info(rpcDone(method))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return dct.result, nil
}
func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
resp := &milvuspb.ListDatabasesResponse{}
if !node.checkHealthy() {
resp.Status = unhealthyStatus()
return resp, nil
}
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-ListDatabases")
defer sp.End()
method := "ListDatabases"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
dct := &listDatabaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
ListDatabasesRequest: request,
rootCoord: node.rootCoord,
}
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
zap.String("role", typeutil.ProxyRole))
log.Info(rpcReceived(method))
if err := node.sched.ddQueue.Enqueue(dct); err != nil {
log.Warn(rpcFailedToEnqueue(method), zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc()
resp.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
return resp, nil
}
log.Info(rpcEnqueued(method))
if err := dct.WaitToFinish(); err != nil {
log.Warn(rpcFailedToWaitToFinish(method), zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
resp.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
return resp, nil
}
log.Info(rpcDone(method), zap.Int("num of db", len(dct.result.DbNames)))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return dct.result, nil
}
// CreateCollection create a collection by the schema. // CreateCollection create a collection by the schema.
// TODO(dragondriver): add more detailed ut for ConsistencyLevel, should we support multiple consistency level in Proxy? // TODO(dragondriver): add more detailed ut for ConsistencyLevel, should we support multiple consistency level in Proxy?
func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
@ -570,7 +725,9 @@ func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.Des
log.Debug("DescribeCollection done", log.Debug("DescribeCollection done",
zap.Uint64("BeginTS", dct.BeginTs()), zap.Uint64("BeginTS", dct.BeginTs()),
zap.Uint64("EndTS", dct.EndTs())) zap.Uint64("EndTS", dct.EndTs()),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc() metrics.SuccessLabel).Inc()
@ -1467,7 +1624,7 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
if err := validateCollectionName(request.CollectionName); err != nil { if err := validateCollectionName(request.CollectionName); err != nil {
return getErrResponse(err), nil return getErrResponse(err), nil
} }
collectionID, err := globalMetaCache.GetCollectionID(ctx, request.CollectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, request.GetDbName(), request.CollectionName)
if err != nil { if err != nil {
return getErrResponse(err), nil return getErrResponse(err), nil
} }
@ -1565,7 +1722,7 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
}() }()
collectionID, err := globalMetaCache.GetCollectionID(ctx, request.CollectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, request.GetDbName(), request.CollectionName)
if err != nil { if err != nil {
successResponse.State = commonpb.LoadState_LoadStateNotExist successResponse.State = commonpb.LoadState_LoadStateNotExist
return successResponse, nil return successResponse, nil
@ -2132,6 +2289,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
commonpbutil.WithMsgID(0), commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(paramtable.GetNodeID()),
), ),
DbName: request.GetDbName(),
CollectionName: request.CollectionName, CollectionName: request.CollectionName,
PartitionName: request.PartitionName, PartitionName: request.PartitionName,
FieldsData: request.FieldsData, FieldsData: request.FieldsData,
@ -2346,24 +2504,18 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
metrics.UpsertLabel, request.GetCollectionName()).Add(float64(proto.Size(request))) metrics.UpsertLabel, request.GetCollectionName()).Add(float64(proto.Size(request)))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
request.Base = commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Upsert),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
)
it := &upsertTask{ it := &upsertTask{
baseMsg: msgstream.BaseMsg{ baseMsg: msgstream.BaseMsg{
HashValues: request.HashKeys, HashValues: request.HashKeys,
}, },
ctx: ctx, ctx: ctx,
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
req: request,
req: &milvuspb.UpsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Upsert),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionName: request.CollectionName,
PartitionName: request.PartitionName,
FieldsData: request.FieldsData,
NumRows: request.NumRows,
},
result: &milvuspb.MutationResult{ result: &milvuspb.MutationResult{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
@ -3106,39 +3258,48 @@ func (node *Proxy) FlushAll(ctx context.Context, _ *milvuspb.FlushAllRequest) (*
} }
log.Info(rpcReceived("FlushAll")) log.Info(rpcReceived("FlushAll"))
// Flush all collections to accelerate the flushAll progress hasError := func(status *commonpb.Status, err error) bool {
showColRsp, err := node.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{ if err != nil {
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections)), resp.Status = &commonpb.Status{
}) ErrorCode: commonpb.ErrorCode_UnexpectedError,
if err != nil { Reason: err.Error(),
resp.Status = &commonpb.Status{ }
ErrorCode: commonpb.ErrorCode_UnexpectedError, log.Warn("FlushAll failed", zap.String("err", err.Error()))
Reason: err.Error(), return true
} }
log.Warn("FlushAll failed", zap.String("err", err.Error())) if status.ErrorCode != commonpb.ErrorCode_Success {
return resp, nil log.Warn("FlushAll failed", zap.String("err", status.GetReason()))
} resp.Status = status
if showColRsp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { return true
log.Warn("FlushAll failed", zap.String("err", showColRsp.GetStatus().GetReason()))
resp.Status = showColRsp.GetStatus()
return resp, nil
}
flushRsp, err := node.Flush(ctx, &milvuspb.FlushRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_Flush)),
CollectionNames: showColRsp.GetCollectionNames(),
})
if err != nil {
resp.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
} }
log.Warn("FlushAll failed", zap.String("err", err.Error())) return false
}
dbsRsp, err := node.rootCoord.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ListDatabases)),
})
if hasError(dbsRsp.GetStatus(), err) {
return resp, nil return resp, nil
} }
if flushRsp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("FlushAll failed", zap.String("err", flushRsp.GetStatus().GetReason())) for _, dbName := range dbsRsp.DbNames {
resp.Status = flushRsp.GetStatus() // Flush all collections to accelerate the flushAll progress
return resp, nil showColRsp, err := node.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections)),
DbName: dbName,
})
if hasError(showColRsp.GetStatus(), err) {
return resp, nil
}
flushRsp, err := node.Flush(ctx, &milvuspb.FlushRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_Flush)),
DbName: dbName,
CollectionNames: showColRsp.GetCollectionNames(),
})
if hasError(flushRsp.GetStatus(), err) {
return resp, nil
}
} }
// allocate current ts as FlushAllTs // allocate current ts as FlushAllTs
@ -3194,7 +3355,7 @@ func (node *Proxy) GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.G
metrics.TotalLabel).Inc() metrics.TotalLabel).Inc()
// list segments // list segments
collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetCollectionName()) collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil { if err != nil {
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
resp.Status.Reason = fmt.Errorf("getCollectionID failed, err:%w", err).Error() resp.Status.Reason = fmt.Errorf("getCollectionID failed, err:%w", err).Error()
@ -3284,7 +3445,7 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.TotalLabel).Inc() metrics.TotalLabel).Inc()
collID, err := globalMetaCache.GetCollectionID(ctx, req.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.CollectionName)
if err != nil { if err != nil {
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
resp.Status.Reason = err.Error() resp.Status.Reason = err.Error()
@ -3587,7 +3748,7 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
} }
collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetCollectionName()) collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil { if err != nil {
log.Warn("failed to get collection id", log.Warn("failed to get collection id",
zap.String("collection name", req.GetCollectionName()), zap.String("collection name", req.GetCollectionName()),
@ -3649,7 +3810,7 @@ func (node *Proxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReq
) )
if req.GetCollectionName() != "" { if req.GetCollectionName() != "" {
req.CollectionID, _ = globalMetaCache.GetCollectionID(ctx, req.GetCollectionName()) req.CollectionID, _ = globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName())
} }
r, err := node.queryCoord.GetReplicas(ctx, req) r, err := node.queryCoord.GetReplicas(ctx, req)
@ -5126,26 +5287,3 @@ func (node *Proxy) ListClientInfos(ctx context.Context, req *proxypb.ListClientI
ClientInfos: clients, ClientInfos: clients,
}, nil }, nil
} }
func (node *Proxy) CreateDatabase(ctx context.Context, req *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "TODO: implement me @jaime",
}, nil
}
func (node *Proxy) DropDatabase(ctx context.Context, req *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "TODO: implement me @jaime",
}, nil
}
func (node *Proxy) ListDatabases(ctx context.Context, req *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return &milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "TODO: implement me @jaime",
},
}, nil
}

View File

@ -20,8 +20,6 @@ import (
"context" "context"
"testing" "testing"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@ -32,12 +30,12 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) { func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) {
@ -368,6 +366,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
func TestProxy_FlushAll(t *testing.T) { func TestProxy_FlushAll(t *testing.T) {
factory := dependency.NewDefaultFactory(true) factory := dependency.NewDefaultFactory(true)
ctx := context.Background() ctx := context.Background()
paramtable.Init()
node, err := NewProxy(ctx, factory) node, err := NewProxy(ctx, factory)
assert.NoError(t, err) assert.NoError(t, err)
@ -375,6 +374,8 @@ func TestProxy_FlushAll(t *testing.T) {
node.tsoAllocator = &timestampAllocator{ node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(), tso: newMockTimestampAllocatorInterface(),
} }
Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
assert.NoError(t, err) assert.NoError(t, err)
err = node.sched.Start() err = node.sched.Start()
@ -384,17 +385,26 @@ func TestProxy_FlushAll(t *testing.T) {
node.rootCoord = mocks.NewRootCoord(t) node.rootCoord = mocks.NewRootCoord(t)
// set expectations // set expectations
cache := newMockCache() cache := NewMockCache(t)
getIDFunc := func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { cache.On("GetCollectionID",
return UniqueID(0), nil mock.Anything, // context.Context
} mock.AnythingOfType("string"),
cache.getIDFunc = getIDFunc mock.AnythingOfType("string"),
).Return(UniqueID(0), nil).Once()
cache.On("RemoveDatabase",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
).Maybe()
globalMetaCache = cache globalMetaCache = cache
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
node.dataCoord.(*mocks.DataCoord).EXPECT().Flush(mock.Anything, mock.Anything). node.dataCoord.(*mocks.DataCoord).EXPECT().Flush(mock.Anything, mock.Anything).
Return(&datapb.FlushResponse{Status: successStatus}, nil).Maybe() Return(&datapb.FlushResponse{Status: successStatus}, nil).Maybe()
node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything).
Return(&milvuspb.ShowCollectionsResponse{Status: successStatus, CollectionNames: []string{"col-0"}}, nil).Maybe() Return(&milvuspb.ShowCollectionsResponse{Status: successStatus, CollectionNames: []string{"col-0"}}, nil).Maybe()
node.rootCoord.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything).
Return(&milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil).Maybe()
t.Run("FlushAll", func(t *testing.T) { t.Run("FlushAll", func(t *testing.T) {
resp, err := node.FlushAll(ctx, &milvuspb.FlushAllRequest{}) resp, err := node.FlushAll(ctx, &milvuspb.FlushAllRequest{})
@ -411,13 +421,19 @@ func TestProxy_FlushAll(t *testing.T) {
}) })
t.Run("FlushAll failed, get id failed", func(t *testing.T) { t.Run("FlushAll failed, get id failed", func(t *testing.T) {
globalMetaCache.(*mockCache).getIDFunc = func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { globalMetaCache.(*MockCache).On("GetCollectionID",
return 0, errors.New("mock error") mock.Anything, // context.Context
} mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), errors.New("mock error")).Once()
resp, err := node.FlushAll(ctx, &milvuspb.FlushAllRequest{}) resp, err := node.FlushAll(ctx, &milvuspb.FlushAllRequest{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError)
globalMetaCache.(*mockCache).getIDFunc = getIDFunc globalMetaCache.(*MockCache).On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), nil).Once()
}) })
t.Run("FlushAll failed, DataCoord flush failed", func(t *testing.T) { t.Run("FlushAll failed, DataCoord flush failed", func(t *testing.T) {
@ -436,6 +452,8 @@ func TestProxy_FlushAll(t *testing.T) {
t.Run("FlushAll failed, RootCoord showCollections failed", func(t *testing.T) { t.Run("FlushAll failed, RootCoord showCollections failed", func(t *testing.T) {
node.rootCoord.(*mocks.RootCoord).ExpectedCalls = nil node.rootCoord.(*mocks.RootCoord).ExpectedCalls = nil
node.rootCoord.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything).
Return(&milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil).Maybe()
node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything).
Return(&milvuspb.ShowCollectionsResponse{ Return(&milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
@ -447,6 +465,20 @@ func TestProxy_FlushAll(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError)
}) })
t.Run("FlushAll failed, RootCoord showCollections failed", func(t *testing.T) {
node.rootCoord.(*mocks.RootCoord).ExpectedCalls = nil
node.rootCoord.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything).
Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock err",
},
}, nil).Maybe()
resp, err := node.FlushAll(ctx, &milvuspb.FlushAllRequest{})
assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError)
})
} }
func TestProxy_GetFlushAllState(t *testing.T) { func TestProxy_GetFlushAllState(t *testing.T) {
@ -615,3 +647,172 @@ func TestProxy_ListClientInfos(t *testing.T) {
}) })
} }
func TestProxyCreateDatabase(t *testing.T) {
paramtable.Init()
t.Run("not healthy", func(t *testing.T) {
node := &Proxy{session: &sessionutil.Session{ServerID: 1}}
node.stateCode.Store(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := node.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode())
})
factory := dependency.NewDefaultFactory(true)
ctx := context.Background()
node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
node.multiRateLimiter = NewMultiRateLimiter()
node.stateCode.Store(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err)
err = node.sched.Start()
assert.NoError(t, err)
defer node.sched.Close()
t.Run("create database fail", func(t *testing.T) {
rc := mocks.NewRootCoord(t)
rc.On("CreateDatabase", mock.Anything, mock.Anything).
Return(nil, errors.New("fail"))
node.rootCoord = rc
ctx := context.Background()
resp, err := node.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{DbName: "db"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode())
})
t.Run("create database ok", func(t *testing.T) {
rc := mocks.NewRootCoord(t)
rc.On("CreateDatabase", mock.Anything, mock.Anything).
Return(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil)
node.rootCoord = rc
node.stateCode.Store(commonpb.StateCode_Healthy)
ctx := context.Background()
resp, err := node.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{DbName: "db"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
}
func TestProxyDropDatabase(t *testing.T) {
paramtable.Init()
t.Run("not healthy", func(t *testing.T) {
node := &Proxy{session: &sessionutil.Session{ServerID: 1}}
node.stateCode.Store(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := node.DropDatabase(ctx, &milvuspb.DropDatabaseRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode())
})
factory := dependency.NewDefaultFactory(true)
ctx := context.Background()
node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
node.multiRateLimiter = NewMultiRateLimiter()
node.stateCode.Store(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err)
err = node.sched.Start()
assert.NoError(t, err)
defer node.sched.Close()
t.Run("drop database fail", func(t *testing.T) {
rc := mocks.NewRootCoord(t)
rc.On("DropDatabase", mock.Anything, mock.Anything).
Return(nil, errors.New("fail"))
node.rootCoord = rc
ctx := context.Background()
resp, err := node.DropDatabase(ctx, &milvuspb.DropDatabaseRequest{DbName: "db"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode())
})
t.Run("drop database ok", func(t *testing.T) {
rc := mocks.NewRootCoord(t)
rc.On("DropDatabase", mock.Anything, mock.Anything).
Return(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil)
node.rootCoord = rc
node.stateCode.Store(commonpb.StateCode_Healthy)
ctx := context.Background()
resp, err := node.DropDatabase(ctx, &milvuspb.DropDatabaseRequest{DbName: "db"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
}
func TestProxyListDatabase(t *testing.T) {
paramtable.Init()
t.Run("not healthy", func(t *testing.T) {
node := &Proxy{session: &sessionutil.Session{ServerID: 1}}
node.stateCode.Store(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := node.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode())
})
factory := dependency.NewDefaultFactory(true)
ctx := context.Background()
node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
node.multiRateLimiter = NewMultiRateLimiter()
node.stateCode.Store(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err)
err = node.sched.Start()
assert.NoError(t, err)
defer node.sched.Close()
t.Run("list database fail", func(t *testing.T) {
rc := mocks.NewRootCoord(t)
rc.On("ListDatabases", mock.Anything, mock.Anything).
Return(nil, errors.New("fail"))
node.rootCoord = rc
ctx := context.Background()
resp, err := node.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode())
})
t.Run("list database ok", func(t *testing.T) {
rc := mocks.NewRootCoord(t)
rc.On("ListDatabases", mock.Anything, mock.Anything).
Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}}, nil)
node.rootCoord = rc
node.stateCode.Store(commonpb.StateCode_Healthy)
ctx := context.Background()
resp, err := node.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
}

View File

@ -18,20 +18,22 @@ package proxy
import ( import (
"context" "context"
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/retry"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
) )
type executeFunc func(context.Context, UniqueID, types.QueryNode, ...string) error type executeFunc func(context.Context, UniqueID, types.QueryNode, ...string) error
type ChannelWorkload struct { type ChannelWorkload struct {
db string
collection string collection string
channel string channel string
shardLeaders []int64 shardLeaders []int64
@ -41,6 +43,7 @@ type ChannelWorkload struct {
} }
type CollectionWorkLoad struct { type CollectionWorkLoad struct {
db string
collection string collection string
nq int64 nq int64
exec executeFunc exec executeFunc
@ -77,20 +80,20 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
return !excludeNodes.Contain(node) return !excludeNodes.Contain(node)
} }
getShardLeaders := func(collection string, channel string) ([]int64, error) { getShardLeaders := func() ([]int64, error) {
shardLeaders, err := globalMetaCache.GetShards(ctx, false, collection) shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collection)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return lo.Map(shardLeaders[channel], func(node nodeInfo, _ int) int64 { return node.nodeID }), nil return lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID }), nil
} }
availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes) availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes)
targetNode, err := lb.balancer.SelectNode(availableNodes, workload.nq) targetNode, err := lb.balancer.SelectNode(availableNodes, workload.nq)
if err != nil { if err != nil {
globalMetaCache.DeprecateShardCache(workload.collection) globalMetaCache.DeprecateShardCache(workload.db, workload.collection)
nodes, err := getShardLeaders(workload.collection, workload.channel) nodes, err := getShardLeaders()
if err != nil || len(nodes) == 0 { if err != nil || len(nodes) == 0 {
log.Warn("failed to get shard delegator", log.Warn("failed to get shard delegator",
zap.Error(err)) zap.Error(err))
@ -164,7 +167,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
// Execute will execute collection workload in parallel // Execute will execute collection workload in parallel
func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error { func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error {
dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.collection) dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.db, workload.collection)
if err != nil { if err != nil {
return err return err
} }
@ -175,6 +178,7 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID }) nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID })
wg.Go(func() error { wg.Go(func() error {
err := lb.ExecuteWithRetry(ctx, ChannelWorkload{ err := lb.ExecuteWithRetry(ctx, ChannelWorkload{
db: workload.db,
collection: workload.collection, collection: workload.collection,
channel: channel, channel: channel,
shardLeaders: nodes, shardLeaders: nodes,

View File

@ -21,6 +21,10 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -31,9 +35,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
) )
type LBPolicySuite struct { type LBPolicySuite struct {
@ -129,6 +130,7 @@ func (s *LBPolicySuite) loadCollection() {
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: s.collection, CollectionName: s.collection,
DbName: dbName,
Schema: marshaledSchema, Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum, ShardsNum: common.DefaultShardsNum,
}, },
@ -141,7 +143,7 @@ func (s *LBPolicySuite) loadCollection() {
s.NoError(createColT.Execute(ctx)) s.NoError(createColT.Execute(ctx))
s.NoError(createColT.PostExecute(ctx)) s.NoError(createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, s.collection) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collection)
s.NoError(err) s.NoError(err)
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
@ -159,6 +161,7 @@ func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background() ctx := context.Background()
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(5, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: s.nodes, shardLeaders: s.nodes,
@ -172,6 +175,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(3, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: []int64{}, shardLeaders: []int64{},
@ -184,6 +188,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: []int64{}, shardLeaders: []int64{},
@ -196,6 +201,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: s.nodes, shardLeaders: s.nodes,
@ -210,6 +216,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.qc.ExpectedCalls = nil s.qc.ExpectedCalls = nil
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrNoAvailableNodeInReplica) s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrNoAvailableNodeInReplica)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: s.nodes, shardLeaders: s.nodes,
@ -228,6 +235,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: s.nodes, shardLeaders: s.nodes,
@ -243,6 +251,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(-1, merr.ErrNoAvailableNode)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: s.nodes, shardLeaders: s.nodes,
@ -261,6 +270,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: s.nodes, shardLeaders: s.nodes,
@ -277,6 +287,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: s.nodes, shardLeaders: s.nodes,
@ -296,6 +307,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0 counter := 0
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: s.nodes, shardLeaders: s.nodes,
@ -319,6 +331,7 @@ func (s *LBPolicySuite) TestExecute() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{ err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collection: s.collection, collection: s.collection,
nq: 1, nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
@ -330,6 +343,7 @@ func (s *LBPolicySuite) TestExecute() {
// test some channel failed // test some channel failed
counter := atomic.NewInt64(0) counter := atomic.NewInt64(0)
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{ err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collection: s.collection, collection: s.collection,
nq: 1, nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
@ -344,9 +358,10 @@ func (s *LBPolicySuite) TestExecute() {
// test get shard leader failed // test get shard leader failed
s.qc.ExpectedCalls = nil s.qc.ExpectedCalls = nil
globalMetaCache.DeprecateShardCache(s.collection) globalMetaCache.DeprecateShardCache(dbName, s.collection)
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")) s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, errors.New("fake error"))
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{ err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collection: s.collection, collection: s.collection,
nq: 1, nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {

View File

@ -26,13 +26,6 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
@ -42,38 +35,45 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/retry"
"github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
// Cache is the interface for system meta data cache // Cache is the interface for system meta data cache
//
//go:generate mockery --name=Cache --filename=mock_cache_test.go --outpkg=proxy --output=. --inpackage --structname=MockCache --with-expecter
type Cache interface { type Cache interface {
// GetCollectionID get collection's id by name. // GetCollectionID get collection's id by name.
GetCollectionID(ctx context.Context, collectionName string) (typeutil.UniqueID, error) GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error)
// GetCollectionName get collection's name by id // GetDatabaseAndCollectionName get collection's name and database by id
GetCollectionName(ctx context.Context, collectionID int64) (string, error) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error)
// GetCollectionInfo get collection's information by name, such as collection id, schema, and etc. // GetCollectionInfo get collection's information by name, such as collection id, schema, and etc.
GetCollectionInfo(ctx context.Context, collectionName string) (*collectionInfo, error) GetCollectionInfo(ctx context.Context, database, collectionName string) (*collectionInfo, error)
// GetPartitionID get partition's identifier of specific collection. // GetPartitionID get partition's identifier of specific collection.
GetPartitionID(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error)
// GetPartitions get all partitions' id of specific collection. // GetPartitions get all partitions' id of specific collection.
GetPartitions(ctx context.Context, collectionName string) (map[string]typeutil.UniqueID, error) GetPartitions(ctx context.Context, database, collectionName string) (map[string]typeutil.UniqueID, error)
// GetPartitionInfo get partition's info. // GetPartitionInfo get partition's info.
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error) GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema. // GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, collectionName string) (map[string][]nodeInfo, error) GetShards(ctx context.Context, withCache bool, database, collectionName string) (map[string][]nodeInfo, error)
DeprecateShardCache(collectionName string) DeprecateShardCache(database, collectionName string)
expireShardLeaderCache(ctx context.Context) expireShardLeaderCache(ctx context.Context)
RemoveCollection(ctx context.Context, collectionName string) RemoveCollection(ctx context.Context, database, collectionName string)
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string
RemovePartition(ctx context.Context, collectionName string, partitionName string) RemovePartition(ctx context.Context, database, collectionName string, partitionName string)
// GetCredentialInfo operate credential cache // GetCredentialInfo operate credential cache
GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
@ -84,6 +84,8 @@ type Cache interface {
GetUserRole(username string) []string GetUserRole(username string) []string
RefreshPolicyInfo(op typeutil.CacheOp) error RefreshPolicyInfo(op typeutil.CacheOp) error
InitPolicyInfo(info []string, userRoles []string) InitPolicyInfo(info []string, userRoles []string)
RemoveDatabase(ctx context.Context, database string)
} }
type collectionInfo struct { type collectionInfo struct {
@ -95,6 +97,7 @@ type collectionInfo struct {
createdTimestamp uint64 createdTimestamp uint64
createdUtcTimestamp uint64 createdUtcTimestamp uint64
consistencyLevel commonpb.ConsistencyLevel consistencyLevel commonpb.ConsistencyLevel
database string
} }
func (info *collectionInfo) isCollectionCached() bool { func (info *collectionInfo) isCollectionCached() bool {
@ -169,7 +172,7 @@ type MetaCache struct {
rootCoord types.RootCoord rootCoord types.RootCoord
queryCoord types.QueryCoord queryCoord types.QueryCoord
collInfo map[string]*collectionInfo collInfo map[string]map[string]*collectionInfo // database -> collection -> collection_info
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
privilegeInfos map[string]struct{} // privileges cache privilegeInfos map[string]struct{} // privileges cache
userToRoles map[string]map[string]struct{} // user to role cache userToRoles map[string]map[string]struct{} // user to role cache
@ -207,7 +210,7 @@ func NewMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord, shardM
return &MetaCache{ return &MetaCache{
rootCoord: rootCoord, rootCoord: rootCoord,
queryCoord: queryCoord, queryCoord: queryCoord,
collInfo: map[string]*collectionInfo{}, collInfo: map[string]map[string]*collectionInfo{},
credMap: map[string]*internalpb.CredentialInfo{}, credMap: map[string]*internalpb.CredentialInfo{},
shardMgr: shardMgr, shardMgr: shardMgr,
privilegeInfos: map[string]struct{}{}, privilegeInfos: map[string]struct{}{},
@ -216,23 +219,30 @@ func NewMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord, shardM
} }
// GetCollectionID returns the corresponding collection id for provided collection name // GetCollectionID returns the corresponding collection id for provided collection name
func (m *MetaCache) GetCollectionID(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error) {
m.mu.RLock() m.mu.RLock()
collInfo, ok := m.collInfo[collectionName]
var ok bool
var collInfo *collectionInfo
db, dbOk := m.collInfo[database]
if dbOk && db != nil {
collInfo, ok = db[collectionName]
}
if !ok || !collInfo.isCollectionCached() { if !ok || !collInfo.isCollectionCached() {
metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GeCollectionID", metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GeCollectionID", metrics.CacheMissLabel).Inc()
tr := timerecord.NewTimeRecorder("UpdateCache") tr := timerecord.NewTimeRecorder("UpdateCache")
m.mu.RUnlock() m.mu.RUnlock()
coll, err := m.describeCollection(ctx, collectionName, 0) coll, err := m.describeCollection(ctx, database, collectionName, 0)
if err != nil { if err != nil {
return 0, err return 0, err
} }
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.updateCollection(coll, collectionName) m.updateCollection(coll, database, collectionName)
metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
collInfo = m.collInfo[collectionName] collInfo = m.collInfo[database][collectionName]
return collInfo.collID, nil return collInfo.collID, nil
} }
defer m.mu.RUnlock() defer m.mu.RUnlock()
@ -241,14 +251,16 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, collectionName string)
return collInfo.collID, nil return collInfo.collID, nil
} }
// GetCollectionName returns the corresponding collection name for provided collection id // GetDatabaseAndCollectionName returns the corresponding collection name for provided collection id
func (m *MetaCache) GetCollectionName(ctx context.Context, collectionID int64) (string, error) { func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) {
m.mu.RLock() m.mu.RLock()
var collInfo *collectionInfo var collInfo *collectionInfo
for _, coll := range m.collInfo { for _, db := range m.collInfo {
if coll.collID == collectionID { for _, coll := range db {
collInfo = coll if coll.collID == collectionID {
break collInfo = coll
break
}
} }
} }
@ -256,40 +268,46 @@ func (m *MetaCache) GetCollectionName(ctx context.Context, collectionID int64) (
metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GeCollectionName", metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GeCollectionName", metrics.CacheMissLabel).Inc()
tr := timerecord.NewTimeRecorder("UpdateCache") tr := timerecord.NewTimeRecorder("UpdateCache")
m.mu.RUnlock() m.mu.RUnlock()
coll, err := m.describeCollection(ctx, "", collectionID) coll, err := m.describeCollection(ctx, "", "", collectionID)
if err != nil { if err != nil {
return "", err return "", "", err
} }
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.updateCollection(coll, coll.Schema.Name)
m.updateCollection(coll, coll.GetDbName(), coll.Schema.Name)
metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
return coll.Schema.Name, nil return coll.GetDbName(), coll.Schema.Name, nil
} }
defer m.mu.RUnlock() defer m.mu.RUnlock()
metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GeCollectionName", metrics.CacheHitLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GeCollectionName", metrics.CacheHitLabel).Inc()
return collInfo.schema.Name, nil return collInfo.database, collInfo.schema.Name, nil
} }
// GetCollectionInfo returns the collection information related to provided collection name // GetCollectionInfo returns the collection information related to provided collection name
// If the information is not found, proxy will try to fetch information for other source (RootCoord for now) // If the information is not found, proxy will try to fetch information for other source (RootCoord for now)
func (m *MetaCache) GetCollectionInfo(ctx context.Context, collectionName string) (*collectionInfo, error) { func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string) (*collectionInfo, error) {
m.mu.RLock() m.mu.RLock()
var collInfo *collectionInfo var collInfo *collectionInfo
collInfo, ok := m.collInfo[collectionName] var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
m.mu.RUnlock() m.mu.RUnlock()
if !ok || !collInfo.isCollectionCached() { if !ok || !collInfo.isCollectionCached() {
tr := timerecord.NewTimeRecorder("UpdateCache") tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetCollectionInfo", metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetCollectionInfo", metrics.CacheMissLabel).Inc()
coll, err := m.describeCollection(ctx, collectionName, 0) coll, err := m.describeCollection(ctx, database, collectionName, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.mu.Lock() m.mu.Lock()
m.updateCollection(coll, collectionName) m.updateCollection(coll, database, collectionName)
collInfo = m.collInfo[collectionName] collInfo = m.collInfo[database][collectionName]
m.mu.Unlock() m.mu.Unlock()
metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
} }
@ -298,15 +316,21 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, collectionName string
return collInfo, nil return collInfo, nil
} }
func (m *MetaCache) GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error) {
m.mu.RLock() m.mu.RLock()
collInfo, ok := m.collInfo[collectionName] var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
if !ok || !collInfo.isCollectionCached() { if !ok || !collInfo.isCollectionCached() {
metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetCollectionSchema", metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetCollectionSchema", metrics.CacheMissLabel).Inc()
tr := timerecord.NewTimeRecorder("UpdateCache") tr := timerecord.NewTimeRecorder("UpdateCache")
m.mu.RUnlock() m.mu.RUnlock()
coll, err := m.describeCollection(ctx, collectionName, 0) coll, err := m.describeCollection(ctx, database, collectionName, 0)
if err != nil { if err != nil {
log.Warn("Failed to load collection from rootcoord ", log.Warn("Failed to load collection from rootcoord ",
zap.String("collection name ", collectionName), zap.String("collection name ", collectionName),
@ -315,8 +339,8 @@ func (m *MetaCache) GetCollectionSchema(ctx context.Context, collectionName stri
} }
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.updateCollection(coll, collectionName) m.updateCollection(coll, database, collectionName)
collInfo = m.collInfo[collectionName] collInfo = m.collInfo[database][collectionName]
metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("Reload collection from root coordinator ", log.Debug("Reload collection from root coordinator ",
zap.String("collection name", collectionName), zap.String("collection name", collectionName),
@ -329,37 +353,56 @@ func (m *MetaCache) GetCollectionSchema(ctx context.Context, collectionName stri
return collInfo.schema, nil return collInfo.schema, nil
} }
func (m *MetaCache) updateCollection(coll *milvuspb.DescribeCollectionResponse, collectionName string) { func (m *MetaCache) updateCollection(coll *milvuspb.DescribeCollectionResponse, database, collectionName string) {
_, ok := m.collInfo[collectionName] _, dbOk := m.collInfo[database]
if !ok { if !dbOk {
m.collInfo[collectionName] = &collectionInfo{} m.collInfo[database] = make(map[string]*collectionInfo)
} }
m.collInfo[collectionName].schema = coll.Schema
m.collInfo[collectionName].collID = coll.CollectionID _, ok := m.collInfo[database][collectionName]
m.collInfo[collectionName].createdTimestamp = coll.CreatedTimestamp if !ok {
m.collInfo[collectionName].createdUtcTimestamp = coll.CreatedUtcTimestamp m.collInfo[database][collectionName] = &collectionInfo{}
m.collInfo[collectionName].consistencyLevel = coll.ConsistencyLevel }
m.collInfo[database][collectionName].schema = coll.Schema
m.collInfo[database][collectionName].collID = coll.CollectionID
m.collInfo[database][collectionName].createdTimestamp = coll.CreatedTimestamp
m.collInfo[database][collectionName].createdUtcTimestamp = coll.CreatedUtcTimestamp
} }
func (m *MetaCache) GetPartitionID(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) { func (m *MetaCache) GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error) {
partInfo, err := m.GetPartitionInfo(ctx, collectionName, partitionName) partInfo, err := m.GetPartitionInfo(ctx, database, collectionName, partitionName)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return partInfo.partitionID, nil return partInfo.partitionID, nil
} }
func (m *MetaCache) GetPartitions(ctx context.Context, collectionName string) (map[string]typeutil.UniqueID, error) { func (m *MetaCache) GetPartitions(ctx context.Context, database, collectionName string) (map[string]typeutil.UniqueID, error) {
collInfo, err := m.GetCollectionInfo(ctx, collectionName) _, err := m.GetCollectionID(ctx, database, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
if !ok {
m.mu.RUnlock()
return nil, fmt.Errorf("can't find collection name:%s", collectionName)
}
if collInfo.partInfo == nil || len(collInfo.partInfo) == 0 { if collInfo.partInfo == nil || len(collInfo.partInfo) == 0 {
tr := timerecord.NewTimeRecorder("UpdateCache") tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetPartitions", metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetPartitions", metrics.CacheMissLabel).Inc()
m.mu.RUnlock()
partitions, err := m.showPartitions(ctx, collectionName) partitions, err := m.showPartitions(ctx, database, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -367,20 +410,21 @@ func (m *MetaCache) GetPartitions(ctx context.Context, collectionName string) (m
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
err = m.updatePartitions(partitions, collectionName) err = m.updatePartitions(partitions, database, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("proxy", zap.Any("GetPartitions:partitions after update", partitions), zap.String("collectionName", collectionName)) log.Debug("proxy", zap.Any("GetPartitions:partitions after update", partitions), zap.String("collectionName", collectionName))
ret := make(map[string]typeutil.UniqueID) ret := make(map[string]typeutil.UniqueID)
partInfo := m.collInfo[collectionName].partInfo partInfo := m.collInfo[database][collectionName].partInfo
for k, v := range partInfo { for k, v := range partInfo {
ret[k] = v.partitionID ret[k] = v.partitionID
} }
return ret, nil return ret, nil
} }
defer m.mu.RUnlock()
metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetPartitions", metrics.CacheHitLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetPartitions", metrics.CacheHitLabel).Inc()
ret := make(map[string]typeutil.UniqueID) ret := make(map[string]typeutil.UniqueID)
@ -392,30 +436,46 @@ func (m *MetaCache) GetPartitions(ctx context.Context, collectionName string) (m
return ret, nil return ret, nil
} }
func (m *MetaCache) GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error) { func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error) {
collInfo, err := m.GetCollectionInfo(ctx, collectionName) _, err := m.GetCollectionID(ctx, database, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
if !ok {
m.mu.RUnlock()
return nil, fmt.Errorf("can't find collection name:%s", collectionName)
}
var partInfo *partitionInfo
partInfo, ok = collInfo.partInfo[partitionName]
m.mu.RUnlock()
partInfo, ok := collInfo.partInfo[partitionName]
if !ok { if !ok {
tr := timerecord.NewTimeRecorder("UpdateCache") tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetPartitionInfo", metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "GetPartitionInfo", metrics.CacheMissLabel).Inc()
partitions, err := m.showPartitions(ctx, collectionName) partitions, err := m.showPartitions(ctx, database, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
err = m.updatePartitions(partitions, collectionName) err = m.updatePartitions(partitions, database, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("proxy", zap.Any("GetPartitionID:partitions after update", partitions), zap.String("collectionName", collectionName)) log.Debug("proxy", zap.Any("GetPartitionID:partitions after update", partitions), zap.String("collectionName", collectionName))
partInfo, ok = m.collInfo[collectionName].partInfo[partitionName] partInfo, ok = m.collInfo[database][collectionName].partInfo[partitionName]
if !ok { if !ok {
return nil, merr.WrapErrPartitionNotFound(partitionName) return nil, merr.WrapErrPartitionNotFound(partitionName)
} }
@ -429,11 +489,12 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, collectionName string,
} }
// Get the collection information from rootcoord. // Get the collection information from rootcoord.
func (m *MetaCache) describeCollection(ctx context.Context, collectionName string, collectionID int64) (*milvuspb.DescribeCollectionResponse, error) { func (m *MetaCache) describeCollection(ctx context.Context, database, collectionName string, collectionID int64) (*milvuspb.DescribeCollectionResponse, error) {
req := &milvuspb.DescribeCollectionRequest{ req := &milvuspb.DescribeCollectionRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
), ),
DbName: database,
CollectionName: collectionName, CollectionName: collectionName,
CollectionID: collectionID, CollectionID: collectionID,
} }
@ -459,6 +520,7 @@ func (m *MetaCache) describeCollection(ctx context.Context, collectionName strin
CreatedTimestamp: coll.CreatedTimestamp, CreatedTimestamp: coll.CreatedTimestamp,
CreatedUtcTimestamp: coll.CreatedUtcTimestamp, CreatedUtcTimestamp: coll.CreatedUtcTimestamp,
ConsistencyLevel: coll.ConsistencyLevel, ConsistencyLevel: coll.ConsistencyLevel,
DbName: coll.GetDbName(),
} }
for _, field := range coll.Schema.Fields { for _, field := range coll.Schema.Fields {
if field.FieldID >= common.StartOfUserFieldID { if field.FieldID >= common.StartOfUserFieldID {
@ -468,11 +530,12 @@ func (m *MetaCache) describeCollection(ctx context.Context, collectionName strin
return resp, nil return resp, nil
} }
func (m *MetaCache) showPartitions(ctx context.Context, collectionName string) (*milvuspb.ShowPartitionsResponse, error) { func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectionName string) (*milvuspb.ShowPartitionsResponse, error) {
req := &milvuspb.ShowPartitionsRequest{ req := &milvuspb.ShowPartitionsRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
), ),
DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
} }
@ -492,14 +555,19 @@ func (m *MetaCache) showPartitions(ctx context.Context, collectionName string) (
return partitions, nil return partitions, nil
} }
func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse, collectionName string) error { func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse, database, collectionName string) error {
_, ok := m.collInfo[collectionName] _, dbOk := m.collInfo[database]
if !dbOk {
m.collInfo[database] = make(map[string]*collectionInfo)
}
_, ok := m.collInfo[database][collectionName]
if !ok { if !ok {
m.collInfo[collectionName] = &collectionInfo{ m.collInfo[database][collectionName] = &collectionInfo{
partInfo: map[string]*partitionInfo{}, partInfo: map[string]*partitionInfo{},
} }
} }
partInfo := m.collInfo[collectionName].partInfo partInfo := m.collInfo[database][collectionName].partInfo
if partInfo == nil { if partInfo == nil {
partInfo = map[string]*partitionInfo{} partInfo = map[string]*partitionInfo{}
} }
@ -518,37 +586,50 @@ func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse
} }
} }
} }
m.collInfo[collectionName].partInfo = partInfo m.collInfo[database][collectionName].partInfo = partInfo
return nil return nil
} }
func (m *MetaCache) RemoveCollection(ctx context.Context, collectionName string) { func (m *MetaCache) RemoveCollection(ctx context.Context, database, collectionName string) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
delete(m.collInfo, collectionName) _, dbOk := m.collInfo[database]
if dbOk {
delete(m.collInfo[database], collectionName)
}
} }
func (m *MetaCache) RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string { func (m *MetaCache) RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
var collNames []string var collNames []string
for k, v := range m.collInfo { for database, db := range m.collInfo {
if v.collID == collectionID { for k, v := range db {
delete(m.collInfo, k) if v.collID == collectionID {
collNames = append(collNames, k) delete(m.collInfo[database], k)
collNames = append(collNames, k)
}
} }
} }
return collNames return collNames
} }
func (m *MetaCache) RemovePartition(ctx context.Context, collectionName, partitionName string) { func (m *MetaCache) RemovePartition(ctx context.Context, database, collectionName, partitionName string) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
_, ok := m.collInfo[collectionName]
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
_, ok = db[collectionName]
}
if !ok { if !ok {
return return
} }
partInfo := m.collInfo[collectionName].partInfo
partInfo := m.collInfo[database][collectionName].partInfo
if partInfo == nil { if partInfo == nil {
return return
} }
@ -605,8 +686,8 @@ func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
} }
// GetShards update cache if withCache == false // GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string) (map[string][]nodeInfo, error) { func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string) (map[string][]nodeInfo, error) {
info, err := m.GetCollectionInfo(ctx, collectionName) info, err := m.GetCollectionInfo(ctx, database, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -660,7 +741,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
shards := parseShardLeaderList2QueryNode(resp.GetShards()) shards := parseShardLeaderList2QueryNode(resp.GetShards())
info, err = m.GetCollectionInfo(ctx, collectionName) info, err = m.GetCollectionInfo(ctx, database, collectionName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get shards, collection %s not found", collectionName) return nil, fmt.Errorf("failed to get shards, collection %s not found", collectionName)
} }
@ -704,10 +785,18 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m
} }
// DeprecateShardCache clear the shard leader cache of a collection // DeprecateShardCache clear the shard leader cache of a collection
func (m *MetaCache) DeprecateShardCache(collectionName string) { func (m *MetaCache) DeprecateShardCache(database, collectionName string) {
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName)) log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
m.mu.RLock() m.mu.RLock()
info, ok := m.collInfo[collectionName] var info *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if !dbOk {
m.mu.RUnlock()
log.Warn("not found database", zap.String("dbName", database))
return
}
info, ok = db[collectionName]
m.mu.RUnlock() m.mu.RUnlock()
if ok { if ok {
info.deprecateLeaderCache() info.deprecateLeaderCache()
@ -727,10 +816,13 @@ func (m *MetaCache) expireShardLeaderCache(ctx context.Context) {
return return
case <-ticker.C: case <-ticker.C:
m.mu.RLock() m.mu.RLock()
log.Info("expire all shard leader cache", for database, db := range m.collInfo {
zap.Strings("collections", lo.Keys(m.collInfo))) log.Info("expire all shard leader cache",
for _, info := range m.collInfo { zap.String("database", database),
info.deprecateLeaderCache() zap.Strings("collections", lo.Keys(db)))
for _, info := range db {
info.deprecateLeaderCache()
}
} }
m.mu.RUnlock() m.mu.RUnlock()
} }
@ -804,3 +896,9 @@ func (m *MetaCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
} }
return nil return nil
} }
func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.collInfo, database)
}

View File

@ -25,14 +25,10 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
uatomic "go.uber.org/atomic"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
uatomic "go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -43,9 +39,13 @@ import (
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
var dbName = GetCurDBNameFromContextOrDefault(context.Background())
type MockRootCoordClientInterface struct { type MockRootCoordClientInterface struct {
types.RootCoord types.RootCoord
Error bool Error bool
@ -126,6 +126,7 @@ func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, i
AutoID: true, AutoID: true,
Name: "collection1", Name: "collection1",
}, },
DbName: dbName,
}, nil }, nil
} }
if in.CollectionName == "collection2" || in.CollectionID == 2 { if in.CollectionName == "collection2" || in.CollectionID == 2 {
@ -138,6 +139,7 @@ func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, i
AutoID: true, AutoID: true,
Name: "collection2", Name: "collection2",
}, },
DbName: dbName,
}, nil }, nil
} }
if in.CollectionName == "errorCollection" { if in.CollectionName == "errorCollection" {
@ -149,6 +151,7 @@ func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, i
Schema: &schemapb.CollectionSchema{ Schema: &schemapb.CollectionSchema{
AutoID: true, AutoID: true,
}, },
DbName: dbName,
}, nil }, nil
} }
@ -215,13 +218,13 @@ func TestMetaCache_GetCollection(t *testing.T) {
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err) assert.NoError(t, err)
id, err := globalMetaCache.GetCollectionID(ctx, "collection1") id, err := globalMetaCache.GetCollectionID(ctx, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, id, typeutil.UniqueID(1)) assert.Equal(t, id, typeutil.UniqueID(1))
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
// should'nt be accessed to remote root coord. // should'nt be accessed to remote root coord.
schema, err := globalMetaCache.GetCollectionSchema(ctx, "collection1") schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{ assert.Equal(t, schema, &schemapb.CollectionSchema{
@ -229,11 +232,11 @@ func TestMetaCache_GetCollection(t *testing.T) {
Fields: []*schemapb.FieldSchema{}, Fields: []*schemapb.FieldSchema{},
Name: "collection1", Name: "collection1",
}) })
id, err = globalMetaCache.GetCollectionID(ctx, "collection2") id, err = globalMetaCache.GetCollectionID(ctx, dbName, "collection2")
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, id, typeutil.UniqueID(2)) assert.Equal(t, id, typeutil.UniqueID(2))
schema, err = globalMetaCache.GetCollectionSchema(ctx, "collection2") schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection2")
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{ assert.Equal(t, schema, &schemapb.CollectionSchema{
@ -243,11 +246,11 @@ func TestMetaCache_GetCollection(t *testing.T) {
}) })
// test to get from cache, this should trigger root request // test to get from cache, this should trigger root request
id, err = globalMetaCache.GetCollectionID(ctx, "collection1") id, err = globalMetaCache.GetCollectionID(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, id, typeutil.UniqueID(1)) assert.Equal(t, id, typeutil.UniqueID(1))
schema, err = globalMetaCache.GetCollectionSchema(ctx, "collection1") schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{ assert.Equal(t, schema, &schemapb.CollectionSchema{
@ -266,13 +269,14 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err) assert.NoError(t, err)
collection, err := globalMetaCache.GetCollectionName(ctx, 1) db, collection, err := globalMetaCache.GetDatabaseAndCollectionName(ctx, 1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, db, dbName)
assert.Equal(t, collection, "collection1") assert.Equal(t, collection, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
// should'nt be accessed to remote root coord. // should'nt be accessed to remote root coord.
schema, err := globalMetaCache.GetCollectionSchema(ctx, "collection1") schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{ assert.Equal(t, schema, &schemapb.CollectionSchema{
@ -280,11 +284,11 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
Fields: []*schemapb.FieldSchema{}, Fields: []*schemapb.FieldSchema{},
Name: "collection1", Name: "collection1",
}) })
collection, err = globalMetaCache.GetCollectionName(ctx, 1) _, collection, err = globalMetaCache.GetDatabaseAndCollectionName(ctx, 1)
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, collection, "collection1") assert.Equal(t, collection, "collection1")
schema, err = globalMetaCache.GetCollectionSchema(ctx, "collection2") schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection2")
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{ assert.Equal(t, schema, &schemapb.CollectionSchema{
@ -294,11 +298,11 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
}) })
// test to get from cache, this should trigger root request // test to get from cache, this should trigger root request
collection, err = globalMetaCache.GetCollectionName(ctx, 1) _, collection, err = globalMetaCache.GetDatabaseAndCollectionName(ctx, 1)
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, collection, "collection1") assert.Equal(t, collection, "collection1")
schema, err = globalMetaCache.GetCollectionSchema(ctx, "collection1") schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{ assert.Equal(t, schema, &schemapb.CollectionSchema{
@ -317,13 +321,13 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
rootCoord.Error = true rootCoord.Error = true
schema, err := globalMetaCache.GetCollectionSchema(ctx, "collection1") schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, schema) assert.Nil(t, schema)
rootCoord.Error = false rootCoord.Error = false
schema, err = globalMetaCache.GetCollectionSchema(ctx, "collection1") schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{ assert.Equal(t, schema, &schemapb.CollectionSchema{
AutoID: true, AutoID: true,
@ -349,10 +353,10 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) {
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err) assert.NoError(t, err)
id, err := globalMetaCache.GetCollectionID(ctx, "collection3") id, err := globalMetaCache.GetCollectionID(ctx, dbName, "collection3")
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, id, int64(0)) assert.Equal(t, id, int64(0))
schema, err := globalMetaCache.GetCollectionSchema(ctx, "collection3") schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection3")
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, schema) assert.Nil(t, schema)
} }
@ -365,16 +369,16 @@ func TestMetaCache_GetPartitionID(t *testing.T) {
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err) assert.NoError(t, err)
id, err := globalMetaCache.GetPartitionID(ctx, "collection1", "par1") id, err := globalMetaCache.GetPartitionID(ctx, dbName, "collection1", "par1")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, id, typeutil.UniqueID(1)) assert.Equal(t, id, typeutil.UniqueID(1))
id, err = globalMetaCache.GetPartitionID(ctx, "collection1", "par2") id, err = globalMetaCache.GetPartitionID(ctx, dbName, "collection1", "par2")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, id, typeutil.UniqueID(2)) assert.Equal(t, id, typeutil.UniqueID(2))
id, err = globalMetaCache.GetPartitionID(ctx, "collection2", "par1") id, err = globalMetaCache.GetPartitionID(ctx, dbName, "collection2", "par1")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, id, typeutil.UniqueID(3)) assert.Equal(t, id, typeutil.UniqueID(3))
id, err = globalMetaCache.GetPartitionID(ctx, "collection2", "par2") id, err = globalMetaCache.GetPartitionID(ctx, dbName, "collection2", "par2")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, id, typeutil.UniqueID(4)) assert.Equal(t, id, typeutil.UniqueID(4))
} }
@ -393,7 +397,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
defer wg.Done() defer wg.Done()
for i := 0; i < cnt; i++ { for i := 0; i < cnt; i++ {
//GetCollectionSchema will never fail //GetCollectionSchema will never fail
schema, err := globalMetaCache.GetCollectionSchema(ctx, "collection1") schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{ assert.Equal(t, schema, &schemapb.CollectionSchema{
AutoID: true, AutoID: true,
@ -408,7 +412,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
defer wg.Done() defer wg.Done()
for i := 0; i < cnt; i++ { for i := 0; i < cnt; i++ {
//GetPartitions may fail //GetPartitions may fail
globalMetaCache.GetPartitions(ctx, "collection1") globalMetaCache.GetPartitions(ctx, dbName, "collection1")
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
} }
@ -417,7 +421,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
defer wg.Done() defer wg.Done()
for i := 0; i < cnt; i++ { for i := 0; i < cnt; i++ {
//periodically invalid collection cache //periodically invalid collection cache
globalMetaCache.RemoveCollection(ctx, "collection1") globalMetaCache.RemoveCollection(ctx, dbName, "collection1")
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
} }
@ -442,24 +446,24 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Test the case where ShowPartitionsResponse is not aligned // Test the case where ShowPartitionsResponse is not aligned
id, err := globalMetaCache.GetPartitionID(ctx, "errorCollection", "par1") id, err := globalMetaCache.GetPartitionID(ctx, dbName, "errorCollection", "par1")
assert.Error(t, err) assert.Error(t, err)
log.Debug(err.Error()) log.Debug(err.Error())
assert.Equal(t, id, typeutil.UniqueID(0)) assert.Equal(t, id, typeutil.UniqueID(0))
partitions, err2 := globalMetaCache.GetPartitions(ctx, "errorCollection") partitions, err2 := globalMetaCache.GetPartitions(ctx, dbName, "errorCollection")
assert.NotNil(t, err2) assert.NotNil(t, err2)
log.Debug(err.Error()) log.Debug(err.Error())
assert.Equal(t, len(partitions), 0) assert.Equal(t, len(partitions), 0)
// Test non existed tables // Test non existed tables
id, err = globalMetaCache.GetPartitionID(ctx, "nonExisted", "par1") id, err = globalMetaCache.GetPartitionID(ctx, dbName, "nonExisted", "par1")
assert.Error(t, err) assert.Error(t, err)
log.Debug(err.Error()) log.Debug(err.Error())
assert.Equal(t, id, typeutil.UniqueID(0)) assert.Equal(t, id, typeutil.UniqueID(0))
// Test non existed partition // Test non existed partition
id, err = globalMetaCache.GetPartitionID(ctx, "collection1", "par3") id, err = globalMetaCache.GetPartitionID(ctx, dbName, "collection1", "par3")
assert.Error(t, err) assert.Error(t, err)
log.Debug(err.Error()) log.Debug(err.Error())
assert.Equal(t, id, typeutil.UniqueID(0)) assert.Equal(t, id, typeutil.UniqueID(0))
@ -483,7 +487,7 @@ func TestMetaCache_GetShards(t *testing.T) {
defer qc.Stop() defer qc.Stop()
t.Run("No collection in meta cache", func(t *testing.T) { t.Run("No collection in meta cache", func(t *testing.T) {
shards, err := globalMetaCache.GetShards(ctx, true, "non-exists") shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists")
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, shards) assert.Empty(t, shards)
}) })
@ -498,7 +502,7 @@ func TestMetaCache_GetShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil) }, nil)
shards, err := globalMetaCache.GetShards(ctx, false, collectionName) shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName)
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, shards) assert.Empty(t, shards)
}) })
@ -519,7 +523,7 @@ func TestMetaCache_GetShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil) }, nil)
shards, err := globalMetaCache.GetShards(ctx, true, collectionName) shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, shards) assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards)) assert.Equal(t, 1, len(shards))
@ -532,7 +536,7 @@ func TestMetaCache_GetShards(t *testing.T) {
Reason: "not implemented", Reason: "not implemented",
}, },
}, nil) }, nil)
shards, err = globalMetaCache.GetShards(ctx, true, collectionName) shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, shards) assert.NotEmpty(t, shards)
@ -559,11 +563,11 @@ func TestMetaCache_ClearShards(t *testing.T) {
defer qc.Stop() defer qc.Stop()
t.Run("Clear with no collection info", func(t *testing.T) { t.Run("Clear with no collection info", func(t *testing.T) {
globalMetaCache.DeprecateShardCache("collection_not_exist") globalMetaCache.DeprecateShardCache(dbName, "collection_not_exist")
}) })
t.Run("Clear valid collection empty cache", func(t *testing.T) { t.Run("Clear valid collection empty cache", func(t *testing.T) {
globalMetaCache.DeprecateShardCache(collectionName) globalMetaCache.DeprecateShardCache(dbName, collectionName)
}) })
t.Run("Clear valid collection valid cache", func(t *testing.T) { t.Run("Clear valid collection valid cache", func(t *testing.T) {
@ -583,13 +587,13 @@ func TestMetaCache_ClearShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil) }, nil)
shards, err := globalMetaCache.GetShards(ctx, true, collectionName) shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, shards) require.NotEmpty(t, shards)
require.Equal(t, 1, len(shards)) require.Equal(t, 1, len(shards))
require.Equal(t, 3, len(shards["channel-1"])) require.Equal(t, 3, len(shards["channel-1"]))
globalMetaCache.DeprecateShardCache(collectionName) globalMetaCache.DeprecateShardCache(dbName, collectionName)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
@ -597,7 +601,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
Reason: "not implemented", Reason: "not implemented",
}, },
}, nil) }, nil)
shards, err = globalMetaCache.GetShards(ctx, true, collectionName) shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName)
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, shards) assert.Empty(t, shards)
}) })
@ -701,26 +705,26 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
InMemoryPercentages: []int64{100, 50}, InMemoryPercentages: []int64{100, 50},
}, nil) }, nil)
_, err = globalMetaCache.GetCollectionInfo(ctx, "collection1") _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
// no collectionInfo of collection1, should access RootCoord // no collectionInfo of collection1, should access RootCoord
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
_, err = globalMetaCache.GetCollectionInfo(ctx, "collection1") _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
// shouldn't access RootCoord again // shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1)
globalMetaCache.RemoveCollection(ctx, "collection1") globalMetaCache.RemoveCollection(ctx, dbName, "collection1")
// no collectionInfo of collection2, should access RootCoord // no collectionInfo of collection2, should access RootCoord
_, err = globalMetaCache.GetCollectionInfo(ctx, "collection1") _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
// shouldn't access RootCoord again // shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Equal(t, rootCoord.GetAccessCount(), 2)
globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1)) globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1))
// no collectionInfo of collection2, should access RootCoord // no collectionInfo of collection2, should access RootCoord
_, err = globalMetaCache.GetCollectionInfo(ctx, "collection1") _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
// shouldn't access RootCoord again // shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 3) assert.Equal(t, rootCoord.GetAccessCount(), 3)
@ -756,7 +760,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, },
}, },
}, nil) }, nil)
nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1") nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, nodeInfos["channel-1"], 3) assert.Len(t, nodeInfos["channel-1"], 3)
@ -775,7 +779,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil) }, nil)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1") nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 2 return len(nodeInfos["channel-1"]) == 2
}, 3*time.Second, 1*time.Second) }, 3*time.Second, 1*time.Second)
@ -795,7 +799,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil) }, nil)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1") nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 3 return len(nodeInfos["channel-1"]) == 3
}, 3*time.Second, 1*time.Second) }, 3*time.Second, 1*time.Second)
@ -820,7 +824,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil) }, nil)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1") nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
assert.NoError(t, err) assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3 return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3
}, 3*time.Second, 1*time.Second) }, 3*time.Second, 1*time.Second)

View File

@ -1,90 +1,863 @@
// Code generated by mockery v2.16.0. DO NOT EDIT.
package proxy package proxy
import ( import (
"context" "context"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
type getCollectionIDFunc func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) // MockCache is an autogenerated mock type for the Cache type
type getCollectionNameFunc func(ctx context.Context, collectionID int64) (string, error) type MockCache struct {
type getCollectionSchemaFunc func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) mock.Mock
type getCollectionInfoFunc func(ctx context.Context, collectionName string) (*collectionInfo, error)
type getUserRoleFunc func(username string) []string
type getPartitionIDFunc func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error)
type mockCache struct {
Cache
getIDFunc getCollectionIDFunc
getNameFunc getCollectionNameFunc
getSchemaFunc getCollectionSchemaFunc
getInfoFunc getCollectionInfoFunc
getUserRoleFunc getUserRoleFunc
getPartitionIDFunc getPartitionIDFunc
} }
func (m *mockCache) GetCollectionID(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { type MockCache_Expecter struct {
if m.getIDFunc != nil { mock *mock.Mock
return m.getIDFunc(ctx, collectionName) }
func (_m *MockCache) EXPECT() *MockCache_Expecter {
return &MockCache_Expecter{mock: &_m.Mock}
}
// DeprecateShardCache provides a mock function with given fields: database, collectionName
func (_m *MockCache) DeprecateShardCache(database string, collectionName string) {
_m.Called(database, collectionName)
}
// MockCache_DeprecateShardCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeprecateShardCache'
type MockCache_DeprecateShardCache_Call struct {
*mock.Call
}
// DeprecateShardCache is a helper method to define mock.On call
// - database string
// - collectionName string
func (_e *MockCache_Expecter) DeprecateShardCache(database interface{}, collectionName interface{}) *MockCache_DeprecateShardCache_Call {
return &MockCache_DeprecateShardCache_Call{Call: _e.mock.On("DeprecateShardCache", database, collectionName)}
}
func (_c *MockCache_DeprecateShardCache_Call) Run(run func(database string, collectionName string)) *MockCache_DeprecateShardCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(string))
})
return _c
}
func (_c *MockCache_DeprecateShardCache_Call) Return() *MockCache_DeprecateShardCache_Call {
_c.Call.Return()
return _c
}
// GetCollectionID provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionID(ctx context.Context, database string, collectionName string) (int64, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, string, string) int64); ok {
r0 = rf(ctx, database, collectionName)
} else {
r0 = ret.Get(0).(int64)
} }
return 0, nil
}
func (m *mockCache) GetCollectionName(ctx context.Context, collectionID int64) (string, error) { var r1 error
if m.getIDFunc != nil { if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
return m.getNameFunc(ctx, collectionID) r1 = rf(ctx, database, collectionName)
} else {
r1 = ret.Error(1)
} }
return "", nil
return r0, r1
} }
func (m *mockCache) GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { // MockCache_GetCollectionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionID'
if m.getSchemaFunc != nil { type MockCache_GetCollectionID_Call struct {
return m.getSchemaFunc(ctx, collectionName) *mock.Call
}
// GetCollectionID is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
func (_e *MockCache_Expecter) GetCollectionID(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_GetCollectionID_Call {
return &MockCache_GetCollectionID_Call{Call: _e.mock.On("GetCollectionID", ctx, database, collectionName)}
}
func (_c *MockCache_GetCollectionID_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_GetCollectionID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string))
})
return _c
}
func (_c *MockCache_GetCollectionID_Call) Return(_a0 int64, _a1 error) *MockCache_GetCollectionID_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetCollectionInfo provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string) (*collectionInfo, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 *collectionInfo
if rf, ok := ret.Get(0).(func(context.Context, string, string) *collectionInfo); ok {
r0 = rf(ctx, database, collectionName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*collectionInfo)
}
} }
return nil, nil
}
func (m *mockCache) GetCollectionInfo(ctx context.Context, collectionName string) (*collectionInfo, error) { var r1 error
if m.getInfoFunc != nil { if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
return m.getInfoFunc(ctx, collectionName) r1 = rf(ctx, database, collectionName)
} else {
r1 = ret.Error(1)
} }
return nil, nil
return r0, r1
} }
func (m *mockCache) RemoveCollection(ctx context.Context, collectionName string) { // MockCache_GetCollectionInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionInfo'
type MockCache_GetCollectionInfo_Call struct {
*mock.Call
} }
func (m *mockCache) GetPartitionID(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) { // GetCollectionInfo is a helper method to define mock.On call
if m.getPartitionIDFunc != nil { // - ctx context.Context
return m.getPartitionIDFunc(ctx, collectionName, partitionName) // - database string
// - collectionName string
func (_e *MockCache_Expecter) GetCollectionInfo(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_GetCollectionInfo_Call {
return &MockCache_GetCollectionInfo_Call{Call: _e.mock.On("GetCollectionInfo", ctx, database, collectionName)}
}
func (_c *MockCache_GetCollectionInfo_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_GetCollectionInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string))
})
return _c
}
func (_c *MockCache_GetCollectionInfo_Call) Return(_a0 *collectionInfo, _a1 error) *MockCache_GetCollectionInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetCollectionSchema provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemapb.CollectionSchema, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 *schemapb.CollectionSchema
if rf, ok := ret.Get(0).(func(context.Context, string, string) *schemapb.CollectionSchema); ok {
r0 = rf(ctx, database, collectionName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*schemapb.CollectionSchema)
}
} }
return 0, nil
}
func (m *mockCache) GetUserRole(username string) []string { var r1 error
if m.getUserRoleFunc != nil { if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
return m.getUserRoleFunc(username) r1 = rf(ctx, database, collectionName)
} else {
r1 = ret.Error(1)
} }
return []string{}
return r0, r1
} }
func (m *mockCache) setGetIDFunc(f getCollectionIDFunc) { // MockCache_GetCollectionSchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionSchema'
m.getIDFunc = f type MockCache_GetCollectionSchema_Call struct {
*mock.Call
} }
func (m *mockCache) setGetSchemaFunc(f getCollectionSchemaFunc) { // GetCollectionSchema is a helper method to define mock.On call
m.getSchemaFunc = f // - ctx context.Context
// - database string
// - collectionName string
func (_e *MockCache_Expecter) GetCollectionSchema(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_GetCollectionSchema_Call {
return &MockCache_GetCollectionSchema_Call{Call: _e.mock.On("GetCollectionSchema", ctx, database, collectionName)}
} }
func (m *mockCache) setGetInfoFunc(f getCollectionInfoFunc) { func (_c *MockCache_GetCollectionSchema_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_GetCollectionSchema_Call {
m.getInfoFunc = f _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string))
})
return _c
} }
func (m *mockCache) setGetPartitionIDFunc(f getPartitionIDFunc) { func (_c *MockCache_GetCollectionSchema_Call) Return(_a0 *schemapb.CollectionSchema, _a1 error) *MockCache_GetCollectionSchema_Call {
m.getPartitionIDFunc = f _c.Call.Return(_a0, _a1)
return _c
} }
func newMockCache() *mockCache { // GetCredentialInfo provides a mock function with given fields: ctx, username
return &mockCache{} func (_m *MockCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
ret := _m.Called(ctx, username)
var r0 *internalpb.CredentialInfo
if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok {
r0 = rf(ctx, username)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*internalpb.CredentialInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, username)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetCredentialInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCredentialInfo'
type MockCache_GetCredentialInfo_Call struct {
*mock.Call
}
// GetCredentialInfo is a helper method to define mock.On call
// - ctx context.Context
// - username string
func (_e *MockCache_Expecter) GetCredentialInfo(ctx interface{}, username interface{}) *MockCache_GetCredentialInfo_Call {
return &MockCache_GetCredentialInfo_Call{Call: _e.mock.On("GetCredentialInfo", ctx, username)}
}
func (_c *MockCache_GetCredentialInfo_Call) Run(run func(ctx context.Context, username string)) *MockCache_GetCredentialInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInfo, _a1 error) *MockCache_GetCredentialInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetDatabaseAndCollectionName provides a mock function with given fields: ctx, collectionID
func (_m *MockCache) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) {
ret := _m.Called(ctx, collectionID)
var r0 string
if rf, ok := ret.Get(0).(func(context.Context, int64) string); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Get(0).(string)
}
var r1 string
if rf, ok := ret.Get(1).(func(context.Context, int64) string); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Get(1).(string)
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, int64) error); ok {
r2 = rf(ctx, collectionID)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// MockCache_GetDatabaseAndCollectionName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseAndCollectionName'
type MockCache_GetDatabaseAndCollectionName_Call struct {
*mock.Call
}
// GetDatabaseAndCollectionName is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockCache_Expecter) GetDatabaseAndCollectionName(ctx interface{}, collectionID interface{}) *MockCache_GetDatabaseAndCollectionName_Call {
return &MockCache_GetDatabaseAndCollectionName_Call{Call: _e.mock.On("GetDatabaseAndCollectionName", ctx, collectionID)}
}
func (_c *MockCache_GetDatabaseAndCollectionName_Call) Run(run func(ctx context.Context, collectionID int64)) *MockCache_GetDatabaseAndCollectionName_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockCache_GetDatabaseAndCollectionName_Call) Return(_a0 string, _a1 string, _a2 error) *MockCache_GetDatabaseAndCollectionName_Call {
_c.Call.Return(_a0, _a1, _a2)
return _c
}
// GetPartitionID provides a mock function with given fields: ctx, database, collectionName, partitionName
func (_m *MockCache) GetPartitionID(ctx context.Context, database string, collectionName string, partitionName string) (int64, error) {
ret := _m.Called(ctx, database, collectionName, partitionName)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) int64); ok {
r0 = rf(ctx, database, collectionName, partitionName)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, database, collectionName, partitionName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetPartitionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPartitionID'
type MockCache_GetPartitionID_Call struct {
*mock.Call
}
// GetPartitionID is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
// - partitionName string
func (_e *MockCache_Expecter) GetPartitionID(ctx interface{}, database interface{}, collectionName interface{}, partitionName interface{}) *MockCache_GetPartitionID_Call {
return &MockCache_GetPartitionID_Call{Call: _e.mock.On("GetPartitionID", ctx, database, collectionName, partitionName)}
}
func (_c *MockCache_GetPartitionID_Call) Run(run func(ctx context.Context, database string, collectionName string, partitionName string)) *MockCache_GetPartitionID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string))
})
return _c
}
func (_c *MockCache_GetPartitionID_Call) Return(_a0 int64, _a1 error) *MockCache_GetPartitionID_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetPartitionInfo provides a mock function with given fields: ctx, database, collectionName, partitionName
func (_m *MockCache) GetPartitionInfo(ctx context.Context, database string, collectionName string, partitionName string) (*partitionInfo, error) {
ret := _m.Called(ctx, database, collectionName, partitionName)
var r0 *partitionInfo
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *partitionInfo); ok {
r0 = rf(ctx, database, collectionName, partitionName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*partitionInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, database, collectionName, partitionName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetPartitionInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPartitionInfo'
type MockCache_GetPartitionInfo_Call struct {
*mock.Call
}
// GetPartitionInfo is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
// - partitionName string
func (_e *MockCache_Expecter) GetPartitionInfo(ctx interface{}, database interface{}, collectionName interface{}, partitionName interface{}) *MockCache_GetPartitionInfo_Call {
return &MockCache_GetPartitionInfo_Call{Call: _e.mock.On("GetPartitionInfo", ctx, database, collectionName, partitionName)}
}
func (_c *MockCache_GetPartitionInfo_Call) Run(run func(ctx context.Context, database string, collectionName string, partitionName string)) *MockCache_GetPartitionInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string))
})
return _c
}
func (_c *MockCache_GetPartitionInfo_Call) Return(_a0 *partitionInfo, _a1 error) *MockCache_GetPartitionInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetPartitions provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetPartitions(ctx context.Context, database string, collectionName string) (map[string]int64, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 map[string]int64
if rf, ok := ret.Get(0).(func(context.Context, string, string) map[string]int64); ok {
r0 = rf(ctx, database, collectionName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]int64)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, database, collectionName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPartitions'
type MockCache_GetPartitions_Call struct {
*mock.Call
}
// GetPartitions is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
func (_e *MockCache_Expecter) GetPartitions(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_GetPartitions_Call {
return &MockCache_GetPartitions_Call{Call: _e.mock.On("GetPartitions", ctx, database, collectionName)}
}
func (_c *MockCache_GetPartitions_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_GetPartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string))
})
return _c
}
func (_c *MockCache_GetPartitions_Call) Return(_a0 map[string]int64, _a1 error) *MockCache_GetPartitions_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetPrivilegeInfo provides a mock function with given fields: ctx
func (_m *MockCache) GetPrivilegeInfo(ctx context.Context) []string {
ret := _m.Called(ctx)
var r0 []string
if rf, ok := ret.Get(0).(func(context.Context) []string); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockCache_GetPrivilegeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrivilegeInfo'
type MockCache_GetPrivilegeInfo_Call struct {
*mock.Call
}
// GetPrivilegeInfo is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockCache_Expecter) GetPrivilegeInfo(ctx interface{}) *MockCache_GetPrivilegeInfo_Call {
return &MockCache_GetPrivilegeInfo_Call{Call: _e.mock.On("GetPrivilegeInfo", ctx)}
}
func (_c *MockCache_GetPrivilegeInfo_Call) Run(run func(ctx context.Context)) *MockCache_GetPrivilegeInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockCache_GetPrivilegeInfo_Call {
_c.Call.Return(_a0)
return _c
}
// GetShards provides a mock function with given fields: ctx, withCache, database, collectionName
func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database string, collectionName string) (map[string][]nodeInfo, error) {
ret := _m.Called(ctx, withCache, database, collectionName)
var r0 map[string][]nodeInfo
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string) map[string][]nodeInfo); ok {
r0 = rf(ctx, withCache, database, collectionName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string][]nodeInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string) error); ok {
r1 = rf(ctx, withCache, database, collectionName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetShards_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShards'
type MockCache_GetShards_Call struct {
*mock.Call
}
// GetShards is a helper method to define mock.On call
// - ctx context.Context
// - withCache bool
// - database string
// - collectionName string
func (_e *MockCache_Expecter) GetShards(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}) *MockCache_GetShards_Call {
return &MockCache_GetShards_Call{Call: _e.mock.On("GetShards", ctx, withCache, database, collectionName)}
}
func (_c *MockCache_GetShards_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string)) *MockCache_GetShards_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string))
})
return _c
}
func (_c *MockCache_GetShards_Call) Return(_a0 map[string][]nodeInfo, _a1 error) *MockCache_GetShards_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetUserRole provides a mock function with given fields: username
func (_m *MockCache) GetUserRole(username string) []string {
ret := _m.Called(username)
var r0 []string
if rf, ok := ret.Get(0).(func(string) []string); ok {
r0 = rf(username)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockCache_GetUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserRole'
type MockCache_GetUserRole_Call struct {
*mock.Call
}
// GetUserRole is a helper method to define mock.On call
// - username string
func (_e *MockCache_Expecter) GetUserRole(username interface{}) *MockCache_GetUserRole_Call {
return &MockCache_GetUserRole_Call{Call: _e.mock.On("GetUserRole", username)}
}
func (_c *MockCache_GetUserRole_Call) Run(run func(username string)) *MockCache_GetUserRole_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockCache_GetUserRole_Call) Return(_a0 []string) *MockCache_GetUserRole_Call {
_c.Call.Return(_a0)
return _c
}
// InitPolicyInfo provides a mock function with given fields: info, userRoles
func (_m *MockCache) InitPolicyInfo(info []string, userRoles []string) {
_m.Called(info, userRoles)
}
// MockCache_InitPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitPolicyInfo'
type MockCache_InitPolicyInfo_Call struct {
*mock.Call
}
// InitPolicyInfo is a helper method to define mock.On call
// - info []string
// - userRoles []string
func (_e *MockCache_Expecter) InitPolicyInfo(info interface{}, userRoles interface{}) *MockCache_InitPolicyInfo_Call {
return &MockCache_InitPolicyInfo_Call{Call: _e.mock.On("InitPolicyInfo", info, userRoles)}
}
func (_c *MockCache_InitPolicyInfo_Call) Run(run func(info []string, userRoles []string)) *MockCache_InitPolicyInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string), args[1].([]string))
})
return _c
}
func (_c *MockCache_InitPolicyInfo_Call) Return() *MockCache_InitPolicyInfo_Call {
_c.Call.Return()
return _c
}
// RefreshPolicyInfo provides a mock function with given fields: op
func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
ret := _m.Called(op)
var r0 error
if rf, ok := ret.Get(0).(func(typeutil.CacheOp) error); ok {
r0 = rf(op)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockCache_RefreshPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfo'
type MockCache_RefreshPolicyInfo_Call struct {
*mock.Call
}
// RefreshPolicyInfo is a helper method to define mock.On call
// - op typeutil.CacheOp
func (_e *MockCache_Expecter) RefreshPolicyInfo(op interface{}) *MockCache_RefreshPolicyInfo_Call {
return &MockCache_RefreshPolicyInfo_Call{Call: _e.mock.On("RefreshPolicyInfo", op)}
}
func (_c *MockCache_RefreshPolicyInfo_Call) Run(run func(op typeutil.CacheOp)) *MockCache_RefreshPolicyInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(typeutil.CacheOp))
})
return _c
}
func (_c *MockCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockCache_RefreshPolicyInfo_Call {
_c.Call.Return(_a0)
return _c
}
// RemoveCollection provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) RemoveCollection(ctx context.Context, database string, collectionName string) {
_m.Called(ctx, database, collectionName)
}
// MockCache_RemoveCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCollection'
type MockCache_RemoveCollection_Call struct {
*mock.Call
}
// RemoveCollection is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
func (_e *MockCache_Expecter) RemoveCollection(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_RemoveCollection_Call {
return &MockCache_RemoveCollection_Call{Call: _e.mock.On("RemoveCollection", ctx, database, collectionName)}
}
func (_c *MockCache_RemoveCollection_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_RemoveCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string))
})
return _c
}
func (_c *MockCache_RemoveCollection_Call) Return() *MockCache_RemoveCollection_Call {
_c.Call.Return()
return _c
}
// RemoveCollectionsByID provides a mock function with given fields: ctx, collectionID
func (_m *MockCache) RemoveCollectionsByID(ctx context.Context, collectionID int64) []string {
ret := _m.Called(ctx, collectionID)
var r0 []string
if rf, ok := ret.Get(0).(func(context.Context, int64) []string); ok {
r0 = rf(ctx, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockCache_RemoveCollectionsByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCollectionsByID'
type MockCache_RemoveCollectionsByID_Call struct {
*mock.Call
}
// RemoveCollectionsByID is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockCache_Expecter) RemoveCollectionsByID(ctx interface{}, collectionID interface{}) *MockCache_RemoveCollectionsByID_Call {
return &MockCache_RemoveCollectionsByID_Call{Call: _e.mock.On("RemoveCollectionsByID", ctx, collectionID)}
}
func (_c *MockCache_RemoveCollectionsByID_Call) Run(run func(ctx context.Context, collectionID int64)) *MockCache_RemoveCollectionsByID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockCache_RemoveCollectionsByID_Call) Return(_a0 []string) *MockCache_RemoveCollectionsByID_Call {
_c.Call.Return(_a0)
return _c
}
// RemoveCredential provides a mock function with given fields: username
func (_m *MockCache) RemoveCredential(username string) {
_m.Called(username)
}
// MockCache_RemoveCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCredential'
type MockCache_RemoveCredential_Call struct {
*mock.Call
}
// RemoveCredential is a helper method to define mock.On call
// - username string
func (_e *MockCache_Expecter) RemoveCredential(username interface{}) *MockCache_RemoveCredential_Call {
return &MockCache_RemoveCredential_Call{Call: _e.mock.On("RemoveCredential", username)}
}
func (_c *MockCache_RemoveCredential_Call) Run(run func(username string)) *MockCache_RemoveCredential_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockCache_RemoveCredential_Call) Return() *MockCache_RemoveCredential_Call {
_c.Call.Return()
return _c
}
// RemoveDatabase provides a mock function with given fields: ctx, database
func (_m *MockCache) RemoveDatabase(ctx context.Context, database string) {
_m.Called(ctx, database)
}
// MockCache_RemoveDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveDatabase'
type MockCache_RemoveDatabase_Call struct {
*mock.Call
}
// RemoveDatabase is a helper method to define mock.On call
// - ctx context.Context
// - database string
func (_e *MockCache_Expecter) RemoveDatabase(ctx interface{}, database interface{}) *MockCache_RemoveDatabase_Call {
return &MockCache_RemoveDatabase_Call{Call: _e.mock.On("RemoveDatabase", ctx, database)}
}
func (_c *MockCache_RemoveDatabase_Call) Run(run func(ctx context.Context, database string)) *MockCache_RemoveDatabase_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockCache_RemoveDatabase_Call) Return() *MockCache_RemoveDatabase_Call {
_c.Call.Return()
return _c
}
// RemovePartition provides a mock function with given fields: ctx, database, collectionName, partitionName
func (_m *MockCache) RemovePartition(ctx context.Context, database string, collectionName string, partitionName string) {
_m.Called(ctx, database, collectionName, partitionName)
}
// MockCache_RemovePartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemovePartition'
type MockCache_RemovePartition_Call struct {
*mock.Call
}
// RemovePartition is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
// - partitionName string
func (_e *MockCache_Expecter) RemovePartition(ctx interface{}, database interface{}, collectionName interface{}, partitionName interface{}) *MockCache_RemovePartition_Call {
return &MockCache_RemovePartition_Call{Call: _e.mock.On("RemovePartition", ctx, database, collectionName, partitionName)}
}
func (_c *MockCache_RemovePartition_Call) Run(run func(ctx context.Context, database string, collectionName string, partitionName string)) *MockCache_RemovePartition_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string))
})
return _c
}
func (_c *MockCache_RemovePartition_Call) Return() *MockCache_RemovePartition_Call {
_c.Call.Return()
return _c
}
// UpdateCredential provides a mock function with given fields: credInfo
func (_m *MockCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
_m.Called(credInfo)
}
// MockCache_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential'
type MockCache_UpdateCredential_Call struct {
*mock.Call
}
// UpdateCredential is a helper method to define mock.On call
// - credInfo *internalpb.CredentialInfo
func (_e *MockCache_Expecter) UpdateCredential(credInfo interface{}) *MockCache_UpdateCredential_Call {
return &MockCache_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", credInfo)}
}
func (_c *MockCache_UpdateCredential_Call) Run(run func(credInfo *internalpb.CredentialInfo)) *MockCache_UpdateCredential_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*internalpb.CredentialInfo))
})
return _c
}
func (_c *MockCache_UpdateCredential_Call) Return() *MockCache_UpdateCredential_Call {
_c.Call.Return()
return _c
}
// expireShardLeaderCache provides a mock function with given fields: ctx
func (_m *MockCache) expireShardLeaderCache(ctx context.Context) {
_m.Called(ctx)
}
// MockCache_expireShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'expireShardLeaderCache'
type MockCache_expireShardLeaderCache_Call struct {
*mock.Call
}
// expireShardLeaderCache is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockCache_Expecter) expireShardLeaderCache(ctx interface{}) *MockCache_expireShardLeaderCache_Call {
return &MockCache_expireShardLeaderCache_Call{Call: _e.mock.On("expireShardLeaderCache", ctx)}
}
func (_c *MockCache_expireShardLeaderCache_Call) Run(run func(ctx context.Context)) *MockCache_expireShardLeaderCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockCache_expireShardLeaderCache_Call) Return() *MockCache_expireShardLeaderCache_Call {
_c.Call.Return()
return _c
}
type mockConstructorTestingTNewMockCache interface {
mock.TestingT
Cleanup(func())
}
// NewMockCache creates a new instance of MockCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockCache(t mockConstructorTestingTNewMockCache) *MockCache {
mock := &MockCache{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
} }

View File

@ -116,7 +116,7 @@ func repackInsertDataByPartition(ctx context.Context,
} }
} }
partitionID, err := globalMetaCache.GetPartitionID(ctx, insertMsg.CollectionName, partitionName) partitionID, err := globalMetaCache.GetPartitionID(ctx, insertMsg.GetDbName(), insertMsg.CollectionName, partitionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -221,7 +221,7 @@ func repackInsertDataWithPartitionKey(ctx context.Context,
} }
channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg) channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg)
partitionNames, err := getDefaultPartitionNames(ctx, insertMsg.CollectionName) partitionNames, err := getDefaultPartitionNames(ctx, insertMsg.GetDbName(), insertMsg.CollectionName)
if err != nil { if err != nil {
log.Warn("get default partition names failed in partition key mode", log.Warn("get default partition names failed in partition key mode",
zap.String("collection name", insertMsg.CollectionName), zap.String("collection name", insertMsg.CollectionName),

View File

@ -22,6 +22,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -47,8 +48,14 @@ func TestRepackInsertData(t *testing.T) {
rc.Start() rc.Start()
defer rc.Stop() defer rc.Stop()
err := InitMetaCache(ctx, rc, nil, nil) cache := NewMockCache(t)
assert.NoError(t, err) cache.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(int64(1), nil)
globalMetaCache = cache
idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID()) idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID())
assert.NoError(t, err) assert.NoError(t, err)
@ -139,10 +146,10 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
nb := 10 nb := 10
hash := generateHashKeys(nb) hash := generateHashKeys(nb)
prefix := "TestRepackInsertData" prefix := "TestRepackInsertData"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr() collectionName := prefix + funcutil.GenRandomStr()
ctx := context.Background() ctx := context.Background()
dbName := GetCurDBNameFromContextOrDefault(ctx)
rc := NewRootCoordMock() rc := NewRootCoordMock()
rc.Start() rc.Start()

View File

@ -9,14 +9,15 @@ import (
"github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model" "github.com/casbin/casbin/v2/model"
jsonadapter "github.com/casbin/json-adapter/v2" jsonadapter "github.com/casbin/json-adapter/v2"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/funcutil"
) )
type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error) type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error)
@ -36,7 +37,7 @@ p = sub, obj, act
e = some(where (p.eft == allow)) e = some(where (p.eft == allow))
[matchers] [matchers]
m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && p.act == "PrivilegeAll") m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && p.act == "PrivilegeAll")
` `
) )
@ -94,10 +95,12 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
objectNameIndexs := privilegeExt.ObjectNameIndexs objectNameIndexs := privilegeExt.ObjectNameIndexs
objectNames := funcutil.GetObjectNames(req, objectNameIndexs) objectNames := funcutil.GetObjectNames(req, objectNameIndexs)
objectPrivilege := privilegeExt.ObjectPrivilege.String() objectPrivilege := privilegeExt.ObjectPrivilege.String()
dbName := GetCurDBNameFromContextOrDefault(ctx)
policyInfo := strings.Join(globalMetaCache.GetPrivilegeInfo(ctx), ",") policyInfo := strings.Join(globalMetaCache.GetPrivilegeInfo(ctx), ",")
log := log.With(zap.String("username", username), zap.Strings("role_names", roleNames), log := log.With(zap.String("username", username), zap.Strings("role_names", roleNames),
zap.String("object_type", objectType), zap.String("object_privilege", objectPrivilege), zap.String("object_type", objectType), zap.String("object_privilege", objectPrivilege),
zap.String("db_name", dbName),
zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName), zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName),
zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames), zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames),
zap.String("policy_info", policyInfo)) zap.String("policy_info", policyInfo))
@ -112,9 +115,10 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
log.Warn("NewEnforcer fail", zap.String("policy", policy), zap.Error(err)) log.Warn("NewEnforcer fail", zap.String("policy", policy), zap.Error(err))
return ctx, err return ctx, err
} }
e.AddFunction("dbMatch", DBMatchFunc)
for _, roleName := range roleNames { for _, roleName := range roleNames {
permitFunc := func(resName string) (bool, error) { permitFunc := func(resName string) (bool, error) {
object := funcutil.PolicyForResource(objectType, resName) object := funcutil.PolicyForResource(dbName, objectType, resName)
isPermit, err := e.Enforce(roleName, object, objectPrivilege) isPermit, err := e.Enforce(roleName, object, objectPrivilege)
if err != nil { if err != nil {
return false, err return false, err
@ -167,3 +171,13 @@ func isCurUserObject(objectType string, curUser string, object string) bool {
} }
return curUser == object return curUser == object
} }
func DBMatchFunc(args ...interface{}) (interface{}, error) {
name1 := args[0].(string)
name2 := args[1].(string)
db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:])
db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:])
return db1 == db2, nil
}

View File

@ -54,11 +54,11 @@ func TestPrivilegeInterceptor(t *testing.T) {
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
}, },
PolicyInfos: []string{ PolicyInfos: []string{
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeLoad.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeLoad.String(), "default"),
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeFlush.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeGetLoadState.String(), "default"),
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeGetLoadState.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeGetLoadingProgress.String(), "default"),
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeGetLoadingProgress.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeFlush.String(), "default"),
funcutil.PolicyForPrivilege("role2", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeAll.String()), funcutil.PolicyForPrivilege("role2", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeAll.String(), "default"),
}, },
UserRoles: []string{ UserRoles: []string{
funcutil.EncodeUserRoleCache("alice", "role1"), funcutil.EncodeUserRoleCache("alice", "role1"),
@ -137,6 +137,12 @@ func TestPrivilegeInterceptor(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
_, err = PrivilegeInterceptor(GetContextWithDB(context.Background(), "fooo:123456", "foo"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
assert.NotNil(t, err)
g := sync.WaitGroup{} g := sync.WaitGroup{}
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
g.Add(1) g.Add(1)
@ -179,12 +185,12 @@ func TestResourceGroupPrivilege(t *testing.T) {
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
}, },
PolicyInfos: []string{ PolicyInfos: []string{
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeCreateResourceGroup.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeCreateResourceGroup.String(), "default"),
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeDropResourceGroup.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeDropResourceGroup.String(), "default"),
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeDescribeResourceGroup.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeDescribeResourceGroup.String(), "default"),
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeListResourceGroups.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeListResourceGroups.String(), "default"),
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeTransferNode.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeTransferNode.String(), "default"),
funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeTransferReplica.String()), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeTransferReplica.String(), "default"),
}, },
UserRoles: []string{ UserRoles: []string{
funcutil.EncodeUserRoleCache("fooo", "role1"), funcutil.EncodeUserRoleCache("fooo", "role1"),

View File

@ -27,15 +27,15 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/importutil"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/tracer"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -53,9 +53,14 @@ import (
"github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/importutil"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/tracer"
"github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/distance" "github.com/milvus-io/milvus/pkg/util/distance"
@ -64,12 +69,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
) )
const ( const (
@ -512,7 +511,7 @@ func TestProxy(t *testing.T) {
prefix := "test_proxy_" prefix := "test_proxy_"
partitionPrefix := "test_proxy_partition_" partitionPrefix := "test_proxy_partition_"
dbName := "" dbName := GetCurDBNameFromContextOrDefault(ctx)
collectionName := prefix + funcutil.GenRandomStr() collectionName := prefix + funcutil.GenRandomStr()
otherCollectionName := collectionName + "_other_" + funcutil.GenRandomStr() otherCollectionName := collectionName + "_other_" + funcutil.GenRandomStr()
partitionName := partitionPrefix + funcutil.GenRandomStr() partitionName := partitionPrefix + funcutil.GenRandomStr()
@ -687,6 +686,7 @@ func TestProxy(t *testing.T) {
Base: nil, Base: nil,
CollectionName: collectionName, CollectionName: collectionName,
Alias: "alias", Alias: "alias",
DbName: dbName,
} }
resp, err := proxy.CreateAlias(ctx, aliasReq) resp, err := proxy.CreateAlias(ctx, aliasReq)
assert.NoError(t, err) assert.NoError(t, err)
@ -712,6 +712,7 @@ func TestProxy(t *testing.T) {
Base: nil, Base: nil,
CollectionName: collectionName, CollectionName: collectionName,
Alias: "alias", Alias: "alias",
DbName: dbName,
} }
resp, err := proxy.AlterAlias(ctx, alterReq) resp, err := proxy.AlterAlias(ctx, alterReq)
assert.NoError(t, err) assert.NoError(t, err)
@ -744,8 +745,9 @@ func TestProxy(t *testing.T) {
defer wg.Done() defer wg.Done()
// drop alias // drop alias
resp, err := proxy.DropAlias(ctx, &milvuspb.DropAliasRequest{ resp, err := proxy.DropAlias(ctx, &milvuspb.DropAliasRequest{
Base: nil, Base: nil,
Alias: "alias", Alias: "alias",
DbName: dbName,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
@ -791,7 +793,7 @@ func TestProxy(t *testing.T) {
wg.Add(1) wg.Add(1)
t.Run("describe collection", func(t *testing.T) { t.Run("describe collection", func(t *testing.T) {
defer wg.Done() defer wg.Done()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ resp, err := proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
@ -863,7 +865,7 @@ func TestProxy(t *testing.T) {
resp, err := proxy.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{ resp, err := proxy.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
Base: nil, Base: nil,
DbName: dbName, DbName: dbName,
CollectionName: "cn", CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{ Properties: []*commonpb.KeyValuePair{
{ {
Key: common.CollectionTTLConfigKey, Key: common.CollectionTTLConfigKey,
@ -872,7 +874,7 @@ func TestProxy(t *testing.T) {
}, },
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
}) })
wg.Add(1) wg.Add(1)
@ -968,7 +970,7 @@ func TestProxy(t *testing.T) {
wg.Add(1) wg.Add(1)
t.Run("show partitions", func(t *testing.T) { t.Run("show partitions", func(t *testing.T) {
defer wg.Done() defer wg.Done()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{
@ -986,6 +988,7 @@ func TestProxy(t *testing.T) {
{ {
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
PartitionNames: resp.PartitionNames, PartitionNames: resp.PartitionNames,
}) })
@ -1161,6 +1164,7 @@ func TestProxy(t *testing.T) {
defer wg.Done() defer wg.Done()
{ {
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
}) })
assert.NoError(t, err) assert.NoError(t, err)
@ -1216,8 +1220,8 @@ func TestProxy(t *testing.T) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
counter++ counter++
} }
assert.True(t, loaded)
}) })
assert.True(t, loaded)
wg.Add(1) wg.Add(1)
t.Run("show in-memory collections", func(t *testing.T) { t.Run("show in-memory collections", func(t *testing.T) {
@ -1259,6 +1263,7 @@ func TestProxy(t *testing.T) {
{ {
progressResp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ progressResp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
}) })
assert.NoError(t, err) assert.NoError(t, err)
@ -1268,6 +1273,7 @@ func TestProxy(t *testing.T) {
{ {
progressResp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ progressResp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
DbName: dbName,
CollectionName: otherCollectionName, CollectionName: otherCollectionName,
}) })
assert.NoError(t, err) assert.NoError(t, err)
@ -1277,6 +1283,7 @@ func TestProxy(t *testing.T) {
{ {
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
DbName: dbName,
CollectionName: otherCollectionName, CollectionName: otherCollectionName,
}) })
assert.NoError(t, err) assert.NoError(t, err)
@ -1289,7 +1296,7 @@ func TestProxy(t *testing.T) {
t.Run("get replicas", func(t *testing.T) { t.Run("get replicas", func(t *testing.T) {
defer wg.Done() defer wg.Done()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{ resp, err := proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{
@ -1669,6 +1676,7 @@ func TestProxy(t *testing.T) {
t.Run("test import", func(t *testing.T) { t.Run("test import", func(t *testing.T) {
defer wg.Done() defer wg.Done()
req := &milvuspb.ImportRequest{ req := &milvuspb.ImportRequest{
DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
Files: []string{"f1.json"}, Files: []string{"f1.json"},
} }
@ -1709,7 +1717,7 @@ func TestProxy(t *testing.T) {
wg.Add(1) wg.Add(1)
t.Run("release collection", func(t *testing.T) { t.Run("release collection", func(t *testing.T) {
defer wg.Done() defer wg.Done()
_, err := globalMetaCache.GetCollectionID(ctx, collectionName) _, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{ resp, err := proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{
@ -1750,7 +1758,7 @@ func TestProxy(t *testing.T) {
wg.Add(1) wg.Add(1)
t.Run("load partitions", func(t *testing.T) { t.Run("load partitions", func(t *testing.T) {
defer wg.Done() defer wg.Done()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.LoadPartitions(ctx, &milvuspb.LoadPartitionsRequest{ resp, err := proxy.LoadPartitions(ctx, &milvuspb.LoadPartitionsRequest{
@ -1823,7 +1831,7 @@ func TestProxy(t *testing.T) {
wg.Add(1) wg.Add(1)
t.Run("show in-memory partitions", func(t *testing.T) { t.Run("show in-memory partitions", func(t *testing.T) {
defer wg.Done() defer wg.Done()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{
@ -1865,6 +1873,7 @@ func TestProxy(t *testing.T) {
{ {
resp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ resp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
PartitionNames: []string{partitionName}, PartitionNames: []string{partitionName},
}) })
@ -1875,6 +1884,7 @@ func TestProxy(t *testing.T) {
{ {
resp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ resp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
PartitionNames: []string{otherPartitionName}, PartitionNames: []string{otherPartitionName},
}) })
@ -1980,7 +1990,7 @@ func TestProxy(t *testing.T) {
wg.Add(1) wg.Add(1)
t.Run("show in-memory partitions after release partition", func(t *testing.T) { t.Run("show in-memory partitions after release partition", func(t *testing.T) {
defer wg.Done() defer wg.Done()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{
@ -2058,7 +2068,7 @@ func TestProxy(t *testing.T) {
wg.Add(1) wg.Add(1)
t.Run("show partitions after drop partition", func(t *testing.T) { t.Run("show partitions after drop partition", func(t *testing.T) {
defer wg.Done() defer wg.Done()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{
@ -2118,7 +2128,7 @@ func TestProxy(t *testing.T) {
wg.Add(1) wg.Add(1)
t.Run("drop collection", func(t *testing.T) { t.Run("drop collection", func(t *testing.T) {
defer wg.Done() defer wg.Done()
_, err := globalMetaCache.GetCollectionID(ctx, collectionName) _, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
resp, err := proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{ resp, err := proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{
@ -4090,8 +4100,6 @@ func TestProxy_Import(t *testing.T) {
defer wg.Done() defer wg.Done()
proxy := &Proxy{} proxy := &Proxy{}
proxy.UpdateStateCode(commonpb.StateCode_Healthy) proxy.UpdateStateCode(commonpb.StateCode_Healthy)
cache := newMockCache()
globalMetaCache = cache
chMgr := newMockChannelsMgr() chMgr := newMockChannelsMgr()
proxy.chMgr = chMgr proxy.chMgr = chMgr
rc := newMockRootCoord() rc := newMockRootCoord()
@ -4112,8 +4120,6 @@ func TestProxy_Import(t *testing.T) {
defer wg.Done() defer wg.Done()
proxy := &Proxy{} proxy := &Proxy{}
proxy.UpdateStateCode(commonpb.StateCode_Healthy) proxy.UpdateStateCode(commonpb.StateCode_Healthy)
cache := newMockCache()
globalMetaCache = cache
chMgr := newMockChannelsMgr() chMgr := newMockChannelsMgr()
proxy.chMgr = chMgr proxy.chMgr = chMgr
rc := newMockRootCoord() rc := newMockRootCoord()
@ -4134,8 +4140,6 @@ func TestProxy_Import(t *testing.T) {
defer wg.Done() defer wg.Done()
proxy := &Proxy{} proxy := &Proxy{}
proxy.UpdateStateCode(commonpb.StateCode_Healthy) proxy.UpdateStateCode(commonpb.StateCode_Healthy)
cache := newMockCache()
globalMetaCache = cache
chMgr := newMockChannelsMgr() chMgr := newMockChannelsMgr()
proxy.chMgr = chMgr proxy.chMgr = chMgr
rc := newMockRootCoord() rc := newMockRootCoord()
@ -4213,13 +4217,18 @@ func TestProxy_GetStatistics(t *testing.T) {
func TestProxy_GetLoadState(t *testing.T) { func TestProxy_GetLoadState(t *testing.T) {
originCache := globalMetaCache originCache := globalMetaCache
m := newMockCache() m := NewMockCache(t)
m.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { m.On("GetCollectionID",
return 1, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
m.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) { mock.AnythingOfType("string"),
return 2, nil ).Return(UniqueID(1), nil)
}) m.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(2), nil)
globalMetaCache = m globalMetaCache = m
defer func() { defer func() {
globalMetaCache = originCache globalMetaCache = originCache

View File

@ -53,40 +53,50 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
func getRequestInfo(req interface{}) (int64, internalpb.RateType, int, error) { func getRequestInfo(req interface{}) (int64, internalpb.RateType, int, error) {
switch r := req.(type) { switch r := req.(type) {
case *milvuspb.InsertRequest: case *milvuspb.InsertRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetCollectionName()) collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DMLInsert, proto.Size(r), nil return collectionID, internalpb.RateType_DMLInsert, proto.Size(r), nil
case *milvuspb.DeleteRequest: case *milvuspb.DeleteRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetCollectionName()) collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DMLDelete, proto.Size(r), nil return collectionID, internalpb.RateType_DMLDelete, proto.Size(r), nil
case *milvuspb.ImportRequest: case *milvuspb.ImportRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetCollectionName()) collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DMLBulkLoad, proto.Size(r), nil return collectionID, internalpb.RateType_DMLBulkLoad, proto.Size(r), nil
case *milvuspb.SearchRequest: case *milvuspb.SearchRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetCollectionName()) collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DQLSearch, int(r.GetNq()), nil return collectionID, internalpb.RateType_DQLSearch, int(r.GetNq()), nil
case *milvuspb.QueryRequest: case *milvuspb.QueryRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetCollectionName()) collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DQLQuery, 1, nil // we regard the nq of query as equivalent to 1. return collectionID, internalpb.RateType_DQLQuery, 1, nil // think of the query request's nq as 1
case *milvuspb.CreateCollectionRequest: case *milvuspb.CreateCollectionRequest:
return 0, internalpb.RateType_DDLCollection, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.DropCollectionRequest: case *milvuspb.DropCollectionRequest:
return 0, internalpb.RateType_DDLCollection, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.LoadCollectionRequest: case *milvuspb.LoadCollectionRequest:
return 0, internalpb.RateType_DDLCollection, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.ReleaseCollectionRequest: case *milvuspb.ReleaseCollectionRequest:
return 0, internalpb.RateType_DDLCollection, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.CreatePartitionRequest: case *milvuspb.CreatePartitionRequest:
return 0, internalpb.RateType_DDLPartition, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.DropPartitionRequest: case *milvuspb.DropPartitionRequest:
return 0, internalpb.RateType_DDLPartition, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.LoadPartitionsRequest: case *milvuspb.LoadPartitionsRequest:
return 0, internalpb.RateType_DDLPartition, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.ReleasePartitionsRequest: case *milvuspb.ReleasePartitionsRequest:
return 0, internalpb.RateType_DDLPartition, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.CreateIndexRequest: case *milvuspb.CreateIndexRequest:
return 0, internalpb.RateType_DDLIndex, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.DropIndexRequest: case *milvuspb.DropIndexRequest:
return 0, internalpb.RateType_DDLIndex, 1, nil collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return collectionID, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.FlushRequest: case *milvuspb.FlushRequest:
return 0, internalpb.RateType_DDLFlush, 1, nil return 0, internalpb.RateType_DDLFlush, 1, nil
case *milvuspb.ManualCompactionRequest: case *milvuspb.ManualCompactionRequest:

View File

@ -22,6 +22,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -48,7 +49,13 @@ func (l *limiterMock) Check(collection int64, rt internalpb.RateType, n int) com
func TestRateLimitInterceptor(t *testing.T) { func TestRateLimitInterceptor(t *testing.T) {
t.Run("test getRequestInfo", func(t *testing.T) { t.Run("test getRequestInfo", func(t *testing.T) {
globalMetaCache = newMockCache() mockCache := NewMockCache(t)
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(int64(0), nil)
globalMetaCache = mockCache
collection, rt, size, err := getRequestInfo(&milvuspb.InsertRequest{}) collection, rt, size, err := getRequestInfo(&milvuspb.InsertRequest{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size) assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
@ -173,6 +180,14 @@ func TestRateLimitInterceptor(t *testing.T) {
}) })
t.Run("test RateLimitInterceptor", func(t *testing.T) { t.Run("test RateLimitInterceptor", func(t *testing.T) {
mockCache := NewMockCache(t)
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(int64(0), nil)
globalMetaCache = mockCache
limiter := limiterMock{rate: 100} limiter := limiterMock{rate: 100}
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return &milvuspb.MutationResult{ return &milvuspb.MutationResult{

View File

@ -24,28 +24,21 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/uniquegenerator"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/pkg/util/milvuserrors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/milvuserrors"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/milvus-io/milvus/pkg/util/uniquegenerator"
) )
type collectionMeta struct { type collectionMeta struct {
@ -385,7 +378,6 @@ func (coord *RootCoordMock) CreateCollection(ctx context.Context, req *milvuspb.
coord.collID2Partitions[collID].partitionID2Meta[id] = partitionMeta{} coord.collID2Partitions[collID].partitionID2Meta[id] = partitionMeta{}
} }
} else { } else {
id := UniqueID(idGenerator.GetInt()) id := UniqueID(idGenerator.GetInt())
coord.collID2Partitions[collID].partitionName2ID[defaultPartitionName] = id coord.collID2Partitions[collID].partitionName2ID[defaultPartitionName] = id
coord.collID2Partitions[collID].partitionID2Name[id] = defaultPartitionName coord.collID2Partitions[collID].partitionID2Name[id] = defaultPartitionName
@ -1163,6 +1155,18 @@ func (coord *RootCoordMock) AlterCollection(ctx context.Context, request *milvus
return &commonpb.Status{}, nil return &commonpb.Status{}, nil
} }
func (coord *RootCoordMock) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{}, nil
}
func (coord *RootCoordMock) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}
func (coord *RootCoordMock) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return &milvuspb.ListDatabasesResponse{}, nil
}
func (coord *RootCoordMock) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { func (coord *RootCoordMock) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
if coord.checkHealthFunc != nil { if coord.checkHealthFunc != nil {
return coord.checkHealthFunc(ctx, req) return coord.checkHealthFunc(ctx, req)
@ -1175,10 +1179,15 @@ func (coord *RootCoordMock) RenameCollection(ctx context.Context, req *milvuspb.
} }
type DescribeCollectionFunc func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) type DescribeCollectionFunc func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error)
type ShowSegmentsFunc func(ctx context.Context, request *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) type ShowSegmentsFunc func(ctx context.Context, request *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error)
type DescribeSegmentsFunc func(ctx context.Context, request *rootcoordpb.DescribeSegmentsRequest) (*rootcoordpb.DescribeSegmentsResponse, error) type DescribeSegmentsFunc func(ctx context.Context, request *rootcoordpb.DescribeSegmentsRequest) (*rootcoordpb.DescribeSegmentsResponse, error)
type ImportFunc func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) type ImportFunc func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error)
type DropCollectionFunc func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) type DropCollectionFunc func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error)
type GetGetCredentialFunc func(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) type GetGetCredentialFunc func(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error)
@ -1255,6 +1264,18 @@ func (m *mockRootCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHeal
}, nil }, nil
} }
func (m *mockRootCoord) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{}, nil
}
func (m *mockRootCoord) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
return &commonpb.Status{}, nil
}
func (m *mockRootCoord) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) {
return &milvuspb.ListDatabasesResponse{}, nil
}
func newMockRootCoord() *mockRootCoord { func newMockRootCoord() *mockRootCoord {
return &mockRootCoord{} return &mockRootCoord{}
} }

View File

@ -82,6 +82,10 @@ const (
ListResourceGroupsTaskName = "ListResourceGroupsTask" ListResourceGroupsTaskName = "ListResourceGroupsTask"
DescribeResourceGroupTaskName = "DescribeResourceGroupTask" DescribeResourceGroupTaskName = "DescribeResourceGroupTask"
CreateDatabaseTaskName = "CreateCollectionTask"
DropDatabaseTaskName = "DropDatabaseTaskName"
ListDatabaseTaskName = "ListDatabaseTaskName"
// minFloat32 minimum float. // minFloat32 minimum float.
minFloat32 = -1 * float32(math.MaxFloat32) minFloat32 = -1 * float32(math.MaxFloat32)
) )
@ -513,6 +517,7 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error {
VirtualChannelNames: nil, VirtualChannelNames: nil,
PhysicalChannelNames: nil, PhysicalChannelNames: nil,
CollectionName: dct.GetCollectionName(), CollectionName: dct.GetCollectionName(),
DbName: dct.GetDbName(),
} }
result, err := dct.rootCoord.DescribeCollection(ctx, dct.DescribeCollectionRequest) result, err := dct.rootCoord.DescribeCollection(ctx, dct.DescribeCollectionRequest)
@ -537,6 +542,7 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error {
dct.result.ConsistencyLevel = result.ConsistencyLevel dct.result.ConsistencyLevel = result.ConsistencyLevel
dct.result.Aliases = result.Aliases dct.result.Aliases = result.Aliases
dct.result.Properties = result.Properties dct.result.Properties = result.Properties
dct.result.DbName = result.GetDbName()
dct.result.NumPartitions = result.NumPartitions dct.result.NumPartitions = result.NumPartitions
for _, field := range result.Schema.Fields { for _, field := range result.Schema.Fields {
if field.IsDynamic { if field.IsDynamic {
@ -649,7 +655,7 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error {
} }
collectionIDs := make([]UniqueID, 0) collectionIDs := make([]UniqueID, 0)
for _, collectionName := range sct.CollectionNames { for _, collectionName := range sct.CollectionNames {
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, sct.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))
@ -703,7 +709,7 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error {
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))
continue continue
} }
collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, collectionName) collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Debug("Failed to get collection info.", zap.Any("collectionName", collectionName), log.Debug("Failed to get collection info.", zap.Any("collectionName", collectionName),
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))
@ -844,7 +850,7 @@ func (cpt *createPartitionTask) PreExecute(ctx context.Context) error {
return err return err
} }
partitionKeyMode, err := isPartitionKeyMode(ctx, collName) partitionKeyMode, err := isPartitionKeyMode(ctx, cpt.GetDbName(), collName)
if err != nil { if err != nil {
return err return err
} }
@ -930,7 +936,7 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error {
return err return err
} }
partitionKeyMode, err := isPartitionKeyMode(ctx, collName) partitionKeyMode, err := isPartitionKeyMode(ctx, dpt.GetDbName(), collName)
if err != nil { if err != nil {
return err return err
} }
@ -942,11 +948,11 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error {
return err return err
} }
collID, err := globalMetaCache.GetCollectionID(ctx, dpt.GetCollectionName()) collID, err := globalMetaCache.GetCollectionID(ctx, dpt.GetDbName(), dpt.GetCollectionName())
if err != nil { if err != nil {
return err return err
} }
partID, err := globalMetaCache.GetPartitionID(ctx, dpt.GetCollectionName(), dpt.GetPartitionName()) partID, err := globalMetaCache.GetPartitionID(ctx, dpt.GetDbName(), dpt.GetCollectionName(), dpt.GetPartitionName())
if err != nil { if err != nil {
if errors.Is(merr.ErrPartitionNotFound, err) { if errors.Is(merr.ErrPartitionNotFound, err) {
return nil return nil
@ -1143,7 +1149,7 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error {
if spt.GetType() == milvuspb.ShowType_InMemory { if spt.GetType() == milvuspb.ShowType_InMemory {
collectionName := spt.CollectionName collectionName := spt.CollectionName
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, spt.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions"))
@ -1156,7 +1162,7 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error {
} }
partitionIDs := make([]UniqueID, 0) partitionIDs := make([]UniqueID, 0)
for _, partitionName := range spt.PartitionNames { for _, partitionName := range spt.PartitionNames {
partitionID, err := globalMetaCache.GetPartitionID(ctx, collectionName, partitionName) partitionID, err := globalMetaCache.GetPartitionID(ctx, spt.GetDbName(), collectionName, partitionName)
if err != nil { if err != nil {
log.Debug("Failed to get partition id.", zap.Any("partitionName", partitionName), log.Debug("Failed to get partition id.", zap.Any("partitionName", partitionName),
zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions"))
@ -1202,7 +1208,7 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error {
zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions"))
return errors.New("failed to show partitions") return errors.New("failed to show partitions")
} }
partitionInfo, err := globalMetaCache.GetPartitionInfo(ctx, collectionName, partitionName) partitionInfo, err := globalMetaCache.GetPartitionInfo(ctx, spt.GetDbName(), collectionName, partitionName)
if err != nil { if err != nil {
log.Debug("Failed to get partition id.", zap.Any("partitionName", partitionName), log.Debug("Failed to get partition id.", zap.Any("partitionName", partitionName),
zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions"))
@ -1281,7 +1287,7 @@ func (ft *flushTask) Execute(ctx context.Context) error {
flushColl2Segments := make(map[string]*schemapb.LongArray) flushColl2Segments := make(map[string]*schemapb.LongArray)
coll2SealTimes := make(map[string]int64) coll2SealTimes := make(map[string]int64)
for _, collName := range ft.CollectionNames { for _, collName := range ft.CollectionNames {
collID, err := globalMetaCache.GetCollectionID(ctx, collName) collID, err := globalMetaCache.GetCollectionID(ctx, ft.GetDbName(), collName)
if err != nil { if err != nil {
return err return err
} }
@ -1290,7 +1296,6 @@ func (ft *flushTask) Execute(ctx context.Context) error {
ft.Base, ft.Base,
commonpbutil.WithMsgType(commonpb.MsgType_Flush), commonpbutil.WithMsgType(commonpb.MsgType_Flush),
), ),
DbID: 0,
CollectionID: collID, CollectionID: collID,
IsImport: false, IsImport: false,
} }
@ -1310,7 +1315,7 @@ func (ft *flushTask) Execute(ctx context.Context) error {
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
Reason: "", Reason: "",
}, },
DbName: "", DbName: ft.GetDbName(),
CollSegIDs: coll2Segments, CollSegIDs: coll2Segments,
FlushCollSegIDs: flushColl2Segments, FlushCollSegIDs: flushColl2Segments,
CollSealTimes: coll2SealTimes, CollSealTimes: coll2SealTimes,
@ -1391,7 +1396,7 @@ func (lct *loadCollectionTask) PreExecute(ctx context.Context) error {
} }
func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) { func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) {
collID, err := globalMetaCache.GetCollectionID(ctx, lct.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, lct.GetDbName(), lct.CollectionName)
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
@ -1403,7 +1408,7 @@ func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) {
} }
lct.collectionID = collID lct.collectionID = collID
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lct.CollectionName) collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lct.GetDbName(), lct.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -1457,7 +1462,7 @@ func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) {
} }
func (lct *loadCollectionTask) PostExecute(ctx context.Context) error { func (lct *loadCollectionTask) PostExecute(ctx context.Context) error {
collID, err := globalMetaCache.GetCollectionID(ctx, lct.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, lct.GetDbName(), lct.CollectionName)
log.Ctx(ctx).Debug("loadCollectionTask PostExecute", log.Ctx(ctx).Debug("loadCollectionTask PostExecute",
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.Int64("collectionID", collID)) zap.Int64("collectionID", collID))
@ -1529,7 +1534,7 @@ func (rct *releaseCollectionTask) PreExecute(ctx context.Context) error {
} }
func (rct *releaseCollectionTask) Execute(ctx context.Context) (err error) { func (rct *releaseCollectionTask) Execute(ctx context.Context) (err error) {
collID, err := globalMetaCache.GetCollectionID(ctx, rct.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, rct.GetDbName(), rct.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -1545,13 +1550,13 @@ func (rct *releaseCollectionTask) Execute(ctx context.Context) (err error) {
rct.result, err = rct.queryCoord.ReleaseCollection(ctx, request) rct.result, err = rct.queryCoord.ReleaseCollection(ctx, request)
globalMetaCache.RemoveCollection(ctx, rct.CollectionName) globalMetaCache.RemoveCollection(ctx, rct.GetDbName(), rct.CollectionName)
return err return err
} }
func (rct *releaseCollectionTask) PostExecute(ctx context.Context) error { func (rct *releaseCollectionTask) PostExecute(ctx context.Context) error {
globalMetaCache.DeprecateShardCache(rct.CollectionName) globalMetaCache.DeprecateShardCache(rct.GetDbName(), rct.CollectionName)
return nil return nil
} }
@ -1613,7 +1618,7 @@ func (lpt *loadPartitionsTask) PreExecute(ctx context.Context) error {
return err return err
} }
partitionKeyMode, err := isPartitionKeyMode(ctx, collName) partitionKeyMode, err := isPartitionKeyMode(ctx, lpt.GetDbName(), collName)
if err != nil { if err != nil {
return err return err
} }
@ -1626,12 +1631,12 @@ func (lpt *loadPartitionsTask) PreExecute(ctx context.Context) error {
func (lpt *loadPartitionsTask) Execute(ctx context.Context) error { func (lpt *loadPartitionsTask) Execute(ctx context.Context) error {
var partitionIDs []int64 var partitionIDs []int64
collID, err := globalMetaCache.GetCollectionID(ctx, lpt.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, lpt.GetDbName(), lpt.CollectionName)
if err != nil { if err != nil {
return err return err
} }
lpt.collectionID = collID lpt.collectionID = collID
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lpt.CollectionName) collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lpt.GetDbName(), lpt.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -1663,7 +1668,7 @@ func (lpt *loadPartitionsTask) Execute(ctx context.Context) error {
return errors.New(errMsg) return errors.New(errMsg)
} }
for _, partitionName := range lpt.PartitionNames { for _, partitionName := range lpt.PartitionNames {
partitionID, err := globalMetaCache.GetPartitionID(ctx, lpt.CollectionName, partitionName) partitionID, err := globalMetaCache.GetPartitionID(ctx, lpt.GetDbName(), lpt.CollectionName, partitionName)
if err != nil { if err != nil {
return err return err
} }
@ -1751,7 +1756,7 @@ func (rpt *releasePartitionsTask) PreExecute(ctx context.Context) error {
return err return err
} }
partitionKeyMode, err := isPartitionKeyMode(ctx, collName) partitionKeyMode, err := isPartitionKeyMode(ctx, rpt.GetDbName(), collName)
if err != nil { if err != nil {
return err return err
} }
@ -1764,13 +1769,13 @@ func (rpt *releasePartitionsTask) PreExecute(ctx context.Context) error {
func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) { func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) {
var partitionIDs []int64 var partitionIDs []int64
collID, err := globalMetaCache.GetCollectionID(ctx, rpt.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, rpt.GetDbName(), rpt.CollectionName)
if err != nil { if err != nil {
return err return err
} }
rpt.collectionID = collID rpt.collectionID = collID
for _, partitionName := range rpt.PartitionNames { for _, partitionName := range rpt.PartitionNames {
partitionID, err := globalMetaCache.GetPartitionID(ctx, rpt.CollectionName, partitionName) partitionID, err := globalMetaCache.GetPartitionID(ctx, rpt.GetDbName(), rpt.CollectionName, partitionName)
if err != nil { if err != nil {
return err return err
} }
@ -1790,7 +1795,7 @@ func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) {
} }
func (rpt *releasePartitionsTask) PostExecute(ctx context.Context) error { func (rpt *releasePartitionsTask) PostExecute(ctx context.Context) error {
globalMetaCache.DeprecateShardCache(rpt.CollectionName) globalMetaCache.DeprecateShardCache(rpt.GetDbName(), rpt.CollectionName)
return nil return nil
} }
@ -2210,7 +2215,7 @@ func (t *DescribeResourceGroupTask) Execute(ctx context.Context) error {
} }
getCollectionNameFunc := func(value int32, key int64) string { getCollectionNameFunc := func(value int32, key int64) string {
name, err := globalMetaCache.GetCollectionName(ctx, key) _, name, err := globalMetaCache.GetDatabaseAndCollectionName(ctx, key)
if err != nil { if err != nil {
// unreachable logic path // unreachable logic path
return "unavailable_collection" return "unavailable_collection"
@ -2365,7 +2370,7 @@ func (t *TransferReplicaTask) PreExecute(ctx context.Context) error {
func (t *TransferReplicaTask) Execute(ctx context.Context) error { func (t *TransferReplicaTask) Execute(ctx context.Context) error {
var err error var err error
collID, err := globalMetaCache.GetCollectionID(ctx, t.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), t.CollectionName)
if err != nil { if err != nil {
return err return err
} }

View File

@ -0,0 +1,205 @@
package proxy
import (
"context"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type createDatabaseTask struct {
Condition
*milvuspb.CreateDatabaseRequest
ctx context.Context
rootCoord types.RootCoord
result *commonpb.Status
}
func (cdt *createDatabaseTask) TraceCtx() context.Context {
return cdt.ctx
}
func (cdt *createDatabaseTask) ID() UniqueID {
return cdt.Base.MsgID
}
func (cdt *createDatabaseTask) SetID(uid UniqueID) {
cdt.Base.MsgID = uid
}
func (cdt *createDatabaseTask) Name() string {
return CreateDatabaseTaskName
}
func (cdt *createDatabaseTask) Type() commonpb.MsgType {
return cdt.Base.MsgType
}
func (cdt *createDatabaseTask) BeginTs() Timestamp {
return cdt.Base.Timestamp
}
func (cdt *createDatabaseTask) EndTs() Timestamp {
return cdt.Base.Timestamp
}
func (cdt *createDatabaseTask) SetTs(ts Timestamp) {
cdt.Base.Timestamp = ts
}
func (cdt *createDatabaseTask) OnEnqueue() error {
cdt.Base = commonpbutil.NewMsgBase()
cdt.Base.MsgType = commonpb.MsgType_CreateDatabase
cdt.Base.SourceID = paramtable.GetNodeID()
return nil
}
func (cdt *createDatabaseTask) PreExecute(ctx context.Context) error {
return ValidateDatabaseName(cdt.GetDbName())
}
func (cdt *createDatabaseTask) Execute(ctx context.Context) error {
var err error
cdt.result = &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}
cdt.result, err = cdt.rootCoord.CreateDatabase(ctx, cdt.CreateDatabaseRequest)
return err
}
func (cdt *createDatabaseTask) PostExecute(ctx context.Context) error {
return nil
}
type dropDatabaseTask struct {
Condition
*milvuspb.DropDatabaseRequest
ctx context.Context
rootCoord types.RootCoord
result *commonpb.Status
}
func (ddt *dropDatabaseTask) TraceCtx() context.Context {
return ddt.ctx
}
func (ddt *dropDatabaseTask) ID() UniqueID {
return ddt.Base.MsgID
}
func (ddt *dropDatabaseTask) SetID(uid UniqueID) {
ddt.Base.MsgID = uid
}
func (ddt *dropDatabaseTask) Name() string {
return DropCollectionTaskName
}
func (ddt *dropDatabaseTask) Type() commonpb.MsgType {
return ddt.Base.MsgType
}
func (ddt *dropDatabaseTask) BeginTs() Timestamp {
return ddt.Base.Timestamp
}
func (ddt *dropDatabaseTask) EndTs() Timestamp {
return ddt.Base.Timestamp
}
func (ddt *dropDatabaseTask) SetTs(ts Timestamp) {
ddt.Base.Timestamp = ts
}
func (ddt *dropDatabaseTask) OnEnqueue() error {
ddt.Base = commonpbutil.NewMsgBase()
ddt.Base.MsgType = commonpb.MsgType_DropDatabase
ddt.Base.SourceID = paramtable.GetNodeID()
return nil
}
func (ddt *dropDatabaseTask) PreExecute(ctx context.Context) error {
return ValidateDatabaseName(ddt.GetDbName())
}
func (ddt *dropDatabaseTask) Execute(ctx context.Context) error {
var err error
ddt.result = &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}
ddt.result, err = ddt.rootCoord.DropDatabase(ctx, ddt.DropDatabaseRequest)
if ddt.result != nil && ddt.result.ErrorCode == commonpb.ErrorCode_Success {
globalMetaCache.RemoveDatabase(ctx, ddt.DbName)
}
return err
}
func (ddt *dropDatabaseTask) PostExecute(ctx context.Context) error {
return nil
}
type listDatabaseTask struct {
Condition
*milvuspb.ListDatabasesRequest
ctx context.Context
rootCoord types.RootCoord
result *milvuspb.ListDatabasesResponse
}
func (ldt *listDatabaseTask) TraceCtx() context.Context {
return ldt.ctx
}
func (ldt *listDatabaseTask) ID() UniqueID {
return ldt.Base.MsgID
}
func (ldt *listDatabaseTask) SetID(uid UniqueID) {
ldt.Base.MsgID = uid
}
func (ldt *listDatabaseTask) Name() string {
return ListDatabaseTaskName
}
func (ldt *listDatabaseTask) Type() commonpb.MsgType {
return ldt.Base.MsgType
}
func (ldt *listDatabaseTask) BeginTs() Timestamp {
return ldt.Base.Timestamp
}
func (ldt *listDatabaseTask) EndTs() Timestamp {
return ldt.Base.Timestamp
}
func (ldt *listDatabaseTask) SetTs(ts Timestamp) {
ldt.Base.Timestamp = ts
}
func (ldt *listDatabaseTask) OnEnqueue() error {
ldt.Base = commonpbutil.NewMsgBase()
ldt.Base.MsgType = commonpb.MsgType_ListDatabases
ldt.Base.SourceID = paramtable.GetNodeID()
return nil
}
func (ldt *listDatabaseTask) PreExecute(ctx context.Context) error {
return nil
}
func (ldt *listDatabaseTask) Execute(ctx context.Context) error {
var err error
ldt.result = &milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}
ldt.result, err = ldt.rootCoord.ListDatabases(ctx, ldt.ListDatabasesRequest)
return err
}
func (ldt *listDatabaseTask) PostExecute(ctx context.Context) error {
return nil
}

View File

@ -0,0 +1,154 @@
package proxy
import (
"context"
"testing"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
func TestCreateDatabaseTask(t *testing.T) {
paramtable.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
task := &createDatabaseTask{
Condition: NewTaskCondition(ctx),
CreateDatabaseRequest: &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateDatabase,
MsgID: 100,
Timestamp: 100,
},
DbName: "db",
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
t.Run("ok", func(t *testing.T) {
err := task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.MsgType_CreateDatabase, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
assert.Equal(t, "db", task.GetDbName())
err = task.Execute(ctx)
assert.NoError(t, err)
err = task.OnEnqueue()
assert.NoError(t, err)
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
assert.Equal(t, UniqueID(0), task.ID())
})
t.Run("pre execute fail", func(t *testing.T) {
task.DbName = "#0xc0de"
err := task.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestDropDatabaseTask(t *testing.T) {
paramtable.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
task := &dropDatabaseTask{
Condition: NewTaskCondition(ctx),
DropDatabaseRequest: &milvuspb.DropDatabaseRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropDatabase,
MsgID: 100,
Timestamp: 100,
},
DbName: "db",
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
cache := NewMockCache(t)
cache.On("RemoveDatabase",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
).Maybe()
globalMetaCache = cache
t.Run("ok", func(t *testing.T) {
err := task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.MsgType_DropDatabase, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
assert.Equal(t, "db", task.GetDbName())
err = task.Execute(ctx)
assert.NoError(t, err)
err = task.OnEnqueue()
assert.NoError(t, err)
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
assert.Equal(t, UniqueID(0), task.ID())
})
t.Run("pre execute fail", func(t *testing.T) {
task.DbName = "#0xc0de"
err := task.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestListDatabaseTask(t *testing.T) {
paramtable.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
task := &listDatabaseTask{
Condition: NewTaskCondition(ctx),
ListDatabasesRequest: &milvuspb.ListDatabasesRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ListDatabases,
MsgID: 100,
Timestamp: 100,
},
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
t.Run("ok", func(t *testing.T) {
err := task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.MsgType_ListDatabases, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
err = task.Execute(ctx)
assert.NoError(t, err)
assert.NotNil(t, task.result)
err = task.OnEnqueue()
assert.NoError(t, err)
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
assert.Equal(t, UniqueID(0), task.ID())
})
}

View File

@ -82,7 +82,7 @@ func (dt *deleteTask) OnEnqueue() error {
} }
func (dt *deleteTask) setChannels() error { func (dt *deleteTask) setChannels() error {
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.CollectionName) collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.GetDbName(), dt.deleteMsg.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -168,7 +168,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
log.Info("Invalid collection name", zap.String("collectionName", collName), zap.Error(err)) log.Info("Invalid collection name", zap.String("collectionName", collName), zap.Error(err))
return err return err
} }
collID, err := globalMetaCache.GetCollectionID(ctx, collName) collID, err := globalMetaCache.GetCollectionID(ctx, dt.deleteMsg.GetDbName(), collName)
if err != nil { if err != nil {
log.Info("Failed to get collection id", zap.String("collectionName", collName), zap.Error(err)) log.Info("Failed to get collection id", zap.String("collectionName", collName), zap.Error(err))
return err return err
@ -176,7 +176,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
dt.deleteMsg.CollectionID = collID dt.deleteMsg.CollectionID = collID
dt.collectionID = collID dt.collectionID = collID
partitionKeyMode, err := isPartitionKeyMode(ctx, collName) partitionKeyMode, err := isPartitionKeyMode(ctx, dt.deleteMsg.GetDbName(), dt.deleteMsg.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -191,7 +191,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
log.Info("Invalid partition name", zap.String("partitionName", partName), zap.Error(err)) log.Info("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
return err return err
} }
partID, err := globalMetaCache.GetPartitionID(ctx, collName, partName) partID, err := globalMetaCache.GetPartitionID(ctx, dt.deleteMsg.GetDbName(), collName, partName)
if err != nil { if err != nil {
log.Info("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err)) log.Info("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err))
return err return err
@ -201,7 +201,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
dt.deleteMsg.PartitionID = common.InvalidPartitionID dt.deleteMsg.PartitionID = common.InvalidPartitionID
} }
schema, err := globalMetaCache.GetCollectionSchema(ctx, collName) schema, err := globalMetaCache.GetCollectionSchema(ctx, dt.deleteMsg.GetDbName(), collName)
if err != nil { if err != nil {
log.Info("Failed to get collection schema", zap.String("collectionName", collName), zap.Error(err)) log.Info("Failed to get collection schema", zap.String("collectionName", collName), zap.Error(err))
return err return err

View File

@ -6,12 +6,12 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
func Test_getPrimaryKeysFromExpr(t *testing.T) { func Test_getPrimaryKeysFromExpr(t *testing.T) {
@ -48,10 +48,12 @@ func TestDeleteTask(t *testing.T) {
collectionID := UniqueID(0) collectionID := UniqueID(0)
collectionName := "col-0" collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"} channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache() cache := NewMockCache(t)
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { cache.On("GetCollectionID",
return collectionID, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
globalMetaCache = cache globalMetaCache = cache
chMgr := newMockChannelsMgr() chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {

View File

@ -261,7 +261,7 @@ func (cit *createIndexTask) parseIndexParams() error {
} }
func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.FieldSchema, error) { func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.FieldSchema, error) {
schema, err := globalMetaCache.GetCollectionSchema(ctx, cit.req.GetCollectionName()) schema, err := globalMetaCache.GetCollectionSchema(ctx, cit.req.GetDbName(), cit.req.GetCollectionName())
if err != nil { if err != nil {
log.Error("failed to get collection schema", zap.Error(err)) log.Error("failed to get collection schema", zap.Error(err))
return nil, fmt.Errorf("failed to get collection schema: %s", err) return nil, fmt.Errorf("failed to get collection schema: %s", err)
@ -345,7 +345,7 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
collName := cit.req.GetCollectionName() collName := cit.req.GetCollectionName()
collID, err := globalMetaCache.GetCollectionID(ctx, collName) collID, err := globalMetaCache.GetCollectionID(ctx, cit.req.GetDbName(), collName)
if err != nil { if err != nil {
return err return err
} }
@ -459,7 +459,7 @@ func (dit *describeIndexTask) PreExecute(ctx context.Context) error {
return err return err
} }
collID, err := globalMetaCache.GetCollectionID(ctx, dit.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, dit.GetDbName(), dit.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -468,7 +468,7 @@ func (dit *describeIndexTask) PreExecute(ctx context.Context) error {
} }
func (dit *describeIndexTask) Execute(ctx context.Context) error { func (dit *describeIndexTask) Execute(ctx context.Context) error {
schema, err := globalMetaCache.GetCollectionSchema(ctx, dit.GetCollectionName()) schema, err := globalMetaCache.GetCollectionSchema(ctx, dit.GetDbName(), dit.GetCollectionName())
if err != nil { if err != nil {
log.Error("failed to get collection schema", zap.Error(err)) log.Error("failed to get collection schema", zap.Error(err))
return fmt.Errorf("failed to get collection schema: %s", err) return fmt.Errorf("failed to get collection schema: %s", err)
@ -577,7 +577,7 @@ func (dit *getIndexStatisticsTask) PreExecute(ctx context.Context) error {
return err return err
} }
collID, err := globalMetaCache.GetCollectionID(ctx, dit.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, dit.GetDbName(), dit.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -586,7 +586,7 @@ func (dit *getIndexStatisticsTask) PreExecute(ctx context.Context) error {
} }
func (dit *getIndexStatisticsTask) Execute(ctx context.Context) error { func (dit *getIndexStatisticsTask) Execute(ctx context.Context) error {
schema, err := globalMetaCache.GetCollectionSchema(ctx, dit.GetCollectionName()) schema, err := globalMetaCache.GetCollectionSchema(ctx, dit.GetDbName(), dit.GetCollectionName())
if err != nil { if err != nil {
log.Error("failed to get collection schema", zap.String("collection_name", dit.GetCollectionName()), zap.Error(err)) log.Error("failed to get collection schema", zap.String("collection_name", dit.GetCollectionName()), zap.Error(err))
return fmt.Errorf("failed to get collection schema: %s", dit.GetCollectionName()) return fmt.Errorf("failed to get collection schema: %s", dit.GetCollectionName())
@ -700,7 +700,7 @@ func (dit *dropIndexTask) PreExecute(ctx context.Context) error {
} }
} }
collID, err := globalMetaCache.GetCollectionID(ctx, dit.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, dit.GetDbName(), dit.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -801,7 +801,7 @@ func (gibpt *getIndexBuildProgressTask) PreExecute(ctx context.Context) error {
func (gibpt *getIndexBuildProgressTask) Execute(ctx context.Context) error { func (gibpt *getIndexBuildProgressTask) Execute(ctx context.Context) error {
collectionName := gibpt.CollectionName collectionName := gibpt.CollectionName
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, gibpt.GetDbName(), collectionName)
if err != nil { // err is not nil if collection not exists if err != nil { // err is not nil if collection not exists
return err return err
} }
@ -893,11 +893,7 @@ func (gist *getIndexStateTask) PreExecute(ctx context.Context) error {
} }
func (gist *getIndexStateTask) Execute(ctx context.Context) error { func (gist *getIndexStateTask) Execute(ctx context.Context) error {
collectionID, err := globalMetaCache.GetCollectionID(ctx, gist.GetDbName(), gist.CollectionName)
if gist.IndexName == "" {
gist.IndexName = Params.CommonCfg.DefaultIndexName.GetValue()
}
collectionID, err := globalMetaCache.GetCollectionID(ctx, gist.CollectionName)
if err != nil { if err != nil {
return err return err
} }

View File

@ -22,8 +22,6 @@ import (
"os" "os"
"testing" "testing"
"github.com/milvus-io/milvus/pkg/config"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@ -35,9 +33,9 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -129,10 +127,12 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
dc := NewDataCoordMock() dc := NewDataCoordMock()
ctx := context.Background() ctx := context.Background()
mockCache := newMockCache() mockCache := NewMockCache(t)
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { mockCache.On("GetCollectionID",
return collectionID, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
globalMetaCache = mockCache globalMetaCache = mockCache
dit := dropIndexTask{ dit := dropIndexTask{
@ -161,18 +161,22 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
}) })
t.Run("get collectionID error", func(t *testing.T) { t.Run("get collectionID error", func(t *testing.T) {
mockCache := newMockCache() mockCache := NewMockCache(t)
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { mockCache.On("GetCollectionID",
return 0, errors.New("error") mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), errors.New("error"))
globalMetaCache = mockCache globalMetaCache = mockCache
err := dit.PreExecute(ctx) err := dit.PreExecute(ctx)
assert.Error(t, err) assert.Error(t, err)
}) })
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { mockCache.On("GetCollectionID",
return collectionID, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
globalMetaCache = mockCache globalMetaCache = mockCache
t.Run("coll has been loaded", func(t *testing.T) { t.Run("coll has been loaded", func(t *testing.T) {
@ -238,13 +242,18 @@ func TestCreateIndexTask_PreExecute(t *testing.T) {
dc := NewDataCoordMock() dc := NewDataCoordMock()
ctx := context.Background() ctx := context.Background()
mockCache := newMockCache() mockCache := NewMockCache(t)
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { mockCache.On("GetCollectionID",
return collectionID, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mockCache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { mock.AnythingOfType("string"),
return newTestSchema(), nil ).Return(collectionID, nil)
}) mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newTestSchema(), nil)
globalMetaCache = mockCache globalMetaCache = mockCache
cit := createIndexTask{ cit := createIndexTask{

View File

@ -74,7 +74,7 @@ func (it *insertTask) EndTs() Timestamp {
} }
func (it *insertTask) setChannels() error { func (it *insertTask) setChannels() error {
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.CollectionName) collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.GetDbName(), it.insertMsg.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -114,7 +114,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return err return err
} }
schema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName) schema, err := globalMetaCache.GetCollectionSchema(ctx, it.insertMsg.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err)) log.Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err))
return err return err
@ -174,7 +174,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return err return err
} }
partitionKeyMode, err := isPartitionKeyMode(ctx, collectionName) partitionKeyMode, err := isPartitionKeyMode(ctx, it.insertMsg.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Warn("check partition key mode failed", zap.String("collection name", collectionName), zap.Error(err)) log.Warn("check partition key mode failed", zap.String("collection name", collectionName), zap.Error(err))
return err return err
@ -218,7 +218,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID()))
collectionName := it.insertMsg.CollectionName collectionName := it.insertMsg.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.GetDbName(), it.insertMsg.CollectionName)
if err != nil { if err != nil {
return err return err
} }

View File

@ -7,12 +7,12 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
func TestInsertTask_CheckAligned(t *testing.T) { func TestInsertTask_CheckAligned(t *testing.T) {
@ -230,10 +230,12 @@ func TestInsertTask(t *testing.T) {
collectionID := UniqueID(0) collectionID := UniqueID(0)
collectionName := "col-0" collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"} channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache() cache := NewMockCache(t)
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { cache.On("GetCollectionID",
return collectionID, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
globalMetaCache = cache globalMetaCache = cache
chMgr := newMockChannelsMgr() chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {

View File

@ -6,32 +6,29 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/samber/lo"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/golang/protobuf/proto"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/samber/lo"
"github.com/milvus-io/milvus/pkg/common"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
const ( const (
@ -248,7 +245,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
} }
log.Debug("Validate collectionName.") log.Debug("Validate collectionName.")
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Warn("Failed to get collection id.", zap.String("collectionName", collectionName), zap.Error(err)) log.Warn("Failed to get collection id.", zap.String("collectionName", collectionName), zap.Error(err))
return err return err
@ -256,7 +253,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
t.CollectionID = collID t.CollectionID = collID
log.Debug("Get collection ID by name", zap.Int64("collectionID", t.CollectionID)) log.Debug("Get collection ID by name", zap.Int64("collectionID", t.CollectionID))
t.partitionKeyMode, err = isPartitionKeyMode(ctx, collectionName) t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Warn("check partition key mode failed", zap.Int64("collectionID", t.CollectionID), zap.Error(err)) log.Warn("check partition key mode failed", zap.Int64("collectionID", t.CollectionID), zap.Error(err))
return err return err
@ -308,7 +305,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
t.queryParams = queryParams t.queryParams = queryParams
t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset
schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName) schema, _ := globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), t.collectionName)
t.schema = schema t.schema = schema
if t.ids != nil { if t.ids != nil {
@ -332,7 +329,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
return err return err
} }
partitionKeys := ParsePartitionKeys(expr) partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.CollectionName, partitionKeys) hashedPartitionNames, err := assignPartitionKeys(ctx, "", t.request.CollectionName, partitionKeys)
if err != nil { if err != nil {
return err return err
} }
@ -394,6 +391,7 @@ func (t *queryTask) Execute(ctx context.Context) error {
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]() t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{ err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collection: t.collectionName, collection: t.collectionName,
nq: 1, nq: 1,
exec: t.queryShard, exec: t.queryShard,
@ -466,12 +464,12 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
result, err := qn.Query(ctx, req) result, err := qn.Query(ctx, req)
if err != nil { if err != nil {
log.Warn("QueryNode query return error", zap.Error(err)) log.Warn("QueryNode query return error", zap.Error(err))
globalMetaCache.DeprecateShardCache(t.collectionName) globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return err return err
} }
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader") log.Warn("QueryNode is not shardLeader")
globalMetaCache.DeprecateShardCache(t.collectionName) globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return errInvalidShardLeaders return errInvalidShardLeaders
} }
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {

View File

@ -120,7 +120,7 @@ func TestQueryTask_all(t *testing.T) {
require.NoError(t, createColT.Execute(ctx)) require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx)) require.NoError(t, createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err) assert.NoError(t, err)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{

View File

@ -25,12 +25,12 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
func TestBaseTaskQueue(t *testing.T) { func TestBaseTaskQueue(t *testing.T) {
@ -567,10 +567,12 @@ func TestTaskScheduler_concurrentPushAndPop(t *testing.T) {
collectionID := UniqueID(0) collectionID := UniqueID(0)
collectionName := "col-0" collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"} channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache() cache := NewMockCache(t)
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { cache.On("GetCollectionID",
return collectionID, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
globalMetaCache = cache globalMetaCache = cache
tsoAllocatorIns := newMockTsoAllocator() tsoAllocatorIns := newMockTsoAllocator()
factory := newSimpleMockMsgStreamFactory() factory := newSimpleMockMsgStreamFactory()

View File

@ -75,7 +75,7 @@ func getPartitionIDs(ctx context.Context, collectionName string, partitionNames
} }
} }
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName) partitionsMap, err := globalMetaCache.GetPartitions(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -213,16 +213,16 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
collectionName := t.request.CollectionName collectionName := t.request.CollectionName
t.collectionName = collectionName t.collectionName = collectionName
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
if err != nil { // err is not nil if collection not exists if err != nil { // err is not nil if collection not exists
return err return err
} }
t.SearchRequest.DbID = 0 // todo t.SearchRequest.DbID = 0 // todo
t.SearchRequest.CollectionID = collID t.SearchRequest.CollectionID = collID
t.schema, _ = globalMetaCache.GetCollectionSchema(ctx, collectionName) t.schema, _ = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
partitionKeyMode, err := isPartitionKeyMode(ctx, collectionName) partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil { if err != nil {
return err return err
} }
@ -305,7 +305,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
return err return err
} }
partitionKeys := ParsePartitionKeys(expr) partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(ctx, collectionName, partitionKeys) hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), collectionName, partitionKeys)
if err != nil { if err != nil {
return err return err
} }
@ -352,7 +352,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, collectionName) collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName)
if err2 != nil { if err2 != nil {
log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute failed to GetCollectionInfo from cache", log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute failed to GetCollectionInfo from cache",
zap.Any("collectionName", collectionName), zap.Error(err2)) zap.Any("collectionName", collectionName), zap.Error(err2))
@ -404,6 +404,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]() t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{ err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collection: t.collectionName, collection: t.collectionName,
nq: t.Nq, nq: t.Nq,
exec: t.searchShard, exec: t.searchShard,

View File

@ -1603,7 +1603,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
require.NoError(t, createColT.Execute(ctx)) require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx)) require.NoError(t, createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err) assert.NoError(t, err)
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}

View File

@ -110,7 +110,7 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
g.Base.MsgType = commonpb.MsgType_GetPartitionStatistics g.Base.MsgType = commonpb.MsgType_GetPartitionStatistics
g.Base.SourceID = paramtable.GetNodeID() g.Base.SourceID = paramtable.GetNodeID()
collID, err := globalMetaCache.GetCollectionID(ctx, g.collectionName) collID, err := globalMetaCache.GetCollectionID(ctx, g.request.GetDbName(), g.collectionName)
if err != nil { // err is not nil if collection not exists if err != nil { // err is not nil if collection not exists
return err return err
} }
@ -266,6 +266,7 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]() g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]()
} }
err := g.lb.Execute(ctx, CollectionWorkLoad{ err := g.lb.Execute(ctx, CollectionWorkLoad{
db: g.request.GetDbName(),
collection: g.collectionName, collection: g.collectionName,
nq: 1, nq: 1,
exec: g.getStatisticsShard, exec: g.getStatisticsShard,
@ -292,21 +293,21 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Strings("channels", channelIDs), zap.Strings("channels", channelIDs),
zap.Error(err)) zap.Error(err))
globalMetaCache.DeprecateShardCache(g.collectionName) globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
return err return err
} }
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader", log.Warn("QueryNode is not shardLeader",
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Strings("channels", channelIDs)) zap.Strings("channels", channelIDs))
globalMetaCache.DeprecateShardCache(g.collectionName) globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
return errInvalidShardLeaders return errInvalidShardLeaders
} }
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode statistic result error", log.Warn("QueryNode statistic result error",
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.String("reason", result.GetStatus().GetReason())) zap.String("reason", result.GetStatus().GetReason()))
globalMetaCache.DeprecateShardCache(g.collectionName) globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
return fmt.Errorf("fail to get statistic, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) return fmt.Errorf("fail to get statistic, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
} }
g.resultBuf.Insert(result) g.resultBuf.Insert(result)
@ -321,7 +322,7 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st
var unloadPartitionIDs []UniqueID var unloadPartitionIDs []UniqueID
// TODO: Consider to check if partition loaded from cache to save rpc. // TODO: Consider to check if partition loaded from cache to save rpc.
info, err := globalMetaCache.GetCollectionInfo(ctx, collectionName) info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err) return nil, nil, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err)
} }
@ -643,7 +644,7 @@ func (g *getCollectionStatisticsTask) PreExecute(ctx context.Context) error {
} }
func (g *getCollectionStatisticsTask) Execute(ctx context.Context) error { func (g *getCollectionStatisticsTask) Execute(ctx context.Context) error {
collID, err := globalMetaCache.GetCollectionID(ctx, g.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, g.GetDbName(), g.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -731,12 +732,12 @@ func (g *getPartitionStatisticsTask) PreExecute(ctx context.Context) error {
} }
func (g *getPartitionStatisticsTask) Execute(ctx context.Context) error { func (g *getPartitionStatisticsTask) Execute(ctx context.Context) error {
collID, err := globalMetaCache.GetCollectionID(ctx, g.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, g.GetDbName(), g.CollectionName)
if err != nil { if err != nil {
return err return err
} }
g.collectionID = collID g.collectionID = collID
partitionID, err := globalMetaCache.GetPartitionID(ctx, g.CollectionName, g.PartitionName) partitionID, err := globalMetaCache.GetPartitionID(ctx, g.GetDbName(), g.CollectionName, g.PartitionName)
if err != nil { if err != nil {
return err return err
} }

View File

@ -124,7 +124,7 @@ func (s *StatisticTaskSuite) loadCollection() {
s.NoError(createColT.Execute(ctx)) s.NoError(createColT.Execute(ctx))
s.NoError(createColT.PostExecute(ctx)) s.NoError(createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, s.collection) collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), s.collection)
s.NoError(err) s.NoError(err)
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{

View File

@ -21,7 +21,6 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"math/rand" "math/rand"
"strconv" "strconv"
"testing" "testing"
@ -932,7 +931,7 @@ func TestHasCollectionTask(t *testing.T) {
assert.Equal(t, false, task.result.Value) assert.Equal(t, false, task.result.Value)
// createCollection in RootCood and fill GlobalMetaCache // createCollection in RootCood and fill GlobalMetaCache
rc.CreateCollection(ctx, createColReq) rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, collectionName) globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
// success to drop collection // success to drop collection
err = task.Execute(ctx) err = task.Execute(ctx)
@ -1051,7 +1050,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
} }
rc.CreateCollection(ctx, createColReq) rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, collectionName) globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
//CreateCollection //CreateCollection
task := &describeCollectionTask{ task := &describeCollectionTask{
@ -1115,7 +1114,7 @@ func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) {
} }
rc.CreateCollection(ctx, createColReq) rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, collectionName) globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
//CreateCollection //CreateCollection
task := &describeCollectionTask{ task := &describeCollectionTask{
@ -1178,7 +1177,7 @@ func TestDescribeCollectionTask_ShardsNum2(t *testing.T) {
} }
rc.CreateCollection(ctx, createColReq) rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, collectionName) globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
//CreateCollection //CreateCollection
task := &describeCollectionTask{ task := &describeCollectionTask{
@ -1276,10 +1275,25 @@ func TestDropPartitionTask(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil) }, nil)
mockCache := newMockCache()
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) { mockCache := NewMockCache(t)
return 1, nil mockCache.On("GetCollectionID",
}) mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), nil)
mockCache.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), nil)
mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{}, nil)
globalMetaCache = mockCache globalMetaCache = mockCache
task := &dropPartitionTask{ task := &dropPartitionTask{
@ -1319,13 +1333,18 @@ func TestDropPartitionTask(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
t.Run("get collectionID error", func(t *testing.T) { t.Run("get collectionID error", func(t *testing.T) {
mockCache := newMockCache() mockCache := NewMockCache(t)
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) { mockCache.On("GetCollectionID",
return 1, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { mock.AnythingOfType("string"),
return 0, errors.New("error") ).Return(UniqueID(1), errors.New("error"))
}) mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{}, nil)
globalMetaCache = mockCache globalMetaCache = mockCache
task.PartitionName = "partition1" task.PartitionName = "partition1"
err = task.PreExecute(ctx) err = task.PreExecute(ctx)
@ -1335,13 +1354,24 @@ func TestDropPartitionTask(t *testing.T) {
t.Run("partition not exist", func(t *testing.T) { t.Run("partition not exist", func(t *testing.T) {
task.PartitionName = "partition2" task.PartitionName = "partition2"
mockCache := newMockCache() mockCache := NewMockCache(t)
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) { mockCache.On("GetPartitionID",
return 0, merr.WrapErrPartitionNotFound(partitionName) mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { mock.AnythingOfType("string"),
return 1, nil mock.AnythingOfType("string"),
}) ).Return(UniqueID(0), merr.WrapErrPartitionNotFound(partitionName))
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), nil)
mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{}, nil)
globalMetaCache = mockCache globalMetaCache = mockCache
err = task.PreExecute(ctx) err = task.PreExecute(ctx)
assert.NoError(t, err) assert.NoError(t, err)
@ -1350,13 +1380,24 @@ func TestDropPartitionTask(t *testing.T) {
t.Run("get partition error", func(t *testing.T) { t.Run("get partition error", func(t *testing.T) {
task.PartitionName = "partition3" task.PartitionName = "partition3"
mockCache := newMockCache() mockCache := NewMockCache(t)
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) { mockCache.On("GetPartitionID",
return 0, errors.New("error") mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { mock.AnythingOfType("string"),
return 1, nil mock.AnythingOfType("string"),
}) ).Return(UniqueID(0), errors.New("error"))
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), nil)
mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{}, nil)
globalMetaCache = mockCache globalMetaCache = mockCache
err = task.PreExecute(ctx) err = task.PreExecute(ctx)
assert.Error(t, err) assert.Error(t, err)
@ -1536,7 +1577,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
}) })
}) })
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err) assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
@ -1790,7 +1831,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
}) })
}) })
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err) assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
@ -2194,27 +2235,30 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
} }
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
cache := newMockCache() cache := NewMockCache(t)
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { cache.On("GetCollectionSchema",
return &schemapb.CollectionSchema{ mock.Anything, // context.Context
Fields: []*schemapb.FieldSchema{ mock.AnythingOfType("string"),
{ mock.AnythingOfType("string"),
FieldID: 100, ).Return(&schemapb.CollectionSchema{
Name: fieldName, Fields: []*schemapb.FieldSchema{
IsPrimaryKey: false, {
DataType: schemapb.DataType_FloatVector, FieldID: 100,
TypeParams: nil, Name: fieldName,
IndexParams: []*commonpb.KeyValuePair{ IsPrimaryKey: false,
{ DataType: schemapb.DataType_FloatVector,
Key: common.DimKey, TypeParams: nil,
Value: "128", IndexParams: []*commonpb.KeyValuePair{
}, {
Key: "dim",
Value: "128",
}, },
AutoID: false,
}, },
AutoID: false,
}, },
}, nil },
}) }, nil)
globalMetaCache = cache globalMetaCache = cache
field, err := cit.getIndexedField(context.Background()) field, err := cit.getIndexedField(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
@ -2222,45 +2266,51 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
}) })
t.Run("schema not found", func(t *testing.T) { t.Run("schema not found", func(t *testing.T) {
cache := newMockCache() cache := NewMockCache(t)
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { cache.On("GetCollectionSchema",
return nil, errors.New("mock") mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(nil, errors.New("mock"))
globalMetaCache = cache globalMetaCache = cache
_, err := cit.getIndexedField(context.Background()) _, err := cit.getIndexedField(context.Background())
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("invalid schema", func(t *testing.T) { t.Run("invalid schema", func(t *testing.T) {
cache := newMockCache() cache := NewMockCache(t)
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { cache.On("GetCollectionSchema",
return &schemapb.CollectionSchema{ mock.Anything, // context.Context
Fields: []*schemapb.FieldSchema{ mock.AnythingOfType("string"),
{ mock.AnythingOfType("string"),
Name: fieldName, ).Return(&schemapb.CollectionSchema{
}, Fields: []*schemapb.FieldSchema{
{ {
Name: fieldName, // duplicate Name: fieldName,
},
}, },
}, nil {
}) Name: fieldName, // duplicate
},
},
}, nil)
globalMetaCache = cache globalMetaCache = cache
_, err := cit.getIndexedField(context.Background()) _, err := cit.getIndexedField(context.Background())
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("field not found", func(t *testing.T) { t.Run("field not found", func(t *testing.T) {
cache := newMockCache() cache := NewMockCache(t)
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { cache.On("GetCollectionSchema",
return &schemapb.CollectionSchema{ mock.Anything, // context.Context
Fields: []*schemapb.FieldSchema{ mock.AnythingOfType("string"),
{ mock.AnythingOfType("string"),
Name: fieldName + fieldName, ).Return(&schemapb.CollectionSchema{
}, Fields: []*schemapb.FieldSchema{
{
Name: fieldName + fieldName,
}, },
}, nil },
}) }, nil)
globalMetaCache = cache globalMetaCache = cache
_, err := cit.getIndexedField(context.Background()) _, err := cit.getIndexedField(context.Background())
assert.Error(t, err) assert.Error(t, err)
@ -2392,30 +2442,34 @@ func Test_createIndexTask_PreExecute(t *testing.T) {
} }
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
cache := newMockCache() cache := NewMockCache(t)
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { cache.On("GetCollectionID",
return 100, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { mock.AnythingOfType("string"),
return &schemapb.CollectionSchema{ ).Return(UniqueID(100), nil)
Fields: []*schemapb.FieldSchema{ cache.On("GetCollectionSchema",
{ mock.Anything, // context.Context
FieldID: 100, mock.AnythingOfType("string"),
Name: fieldName, mock.AnythingOfType("string"),
IsPrimaryKey: false, ).Return(&schemapb.CollectionSchema{
DataType: schemapb.DataType_FloatVector, Fields: []*schemapb.FieldSchema{
TypeParams: nil, {
IndexParams: []*commonpb.KeyValuePair{ FieldID: 100,
{ Name: fieldName,
Key: common.DimKey, IsPrimaryKey: false,
Value: "128", DataType: schemapb.DataType_FloatVector,
}, TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
}, },
AutoID: false,
}, },
AutoID: false,
}, },
}, nil },
}) }, nil)
globalMetaCache = cache globalMetaCache = cache
cit.req.ExtraParams = []*commonpb.KeyValuePair{ cit.req.ExtraParams = []*commonpb.KeyValuePair{
{ {
@ -2435,19 +2489,23 @@ func Test_createIndexTask_PreExecute(t *testing.T) {
}) })
t.Run("collection not found", func(t *testing.T) { t.Run("collection not found", func(t *testing.T) {
cache := newMockCache() cache := NewMockCache(t)
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { cache.On("GetCollectionID",
return 0, errors.New("mock") mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), errors.New("mock"))
globalMetaCache = cache globalMetaCache = cache
assert.Error(t, cit.PreExecute(context.Background())) assert.Error(t, cit.PreExecute(context.Background()))
}) })
t.Run("index name length exceed 255", func(t *testing.T) { t.Run("index name length exceed 255", func(t *testing.T) {
cache := newMockCache() cache := NewMockCache(t)
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { cache.On("GetCollectionID",
return 100, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(100), nil)
globalMetaCache = cache globalMetaCache = cache
for i := 0; i < 256; i++ { for i := 0; i < 256; i++ {
@ -2871,7 +2929,7 @@ func TestTransferReplicaTask(t *testing.T) {
mgr := newShardClientMgr() mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr) InitMetaCache(ctx, rc, qc, mgr)
// make it avoid remote call on rc // make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), "collection1") globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
req := &milvuspb.TransferReplicaRequest{ req := &milvuspb.TransferReplicaRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
@ -2966,8 +3024,8 @@ func TestDescribeResourceGroupTask(t *testing.T) {
mgr := newShardClientMgr() mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr) InitMetaCache(ctx, rc, qc, mgr)
// make it avoid remote call on rc // make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), "collection1") globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
globalMetaCache.GetCollectionSchema(context.Background(), "collection2") globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection2")
req := &milvuspb.DescribeResourceGroupRequest{ req := &milvuspb.DescribeResourceGroupRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
@ -3014,8 +3072,8 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) {
mgr := newShardClientMgr() mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr) InitMetaCache(ctx, rc, qc, mgr)
// make it avoid remote call on rc // make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), "collection1") globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
globalMetaCache.GetCollectionSchema(context.Background(), "collection2") globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection2")
req := &milvuspb.DescribeResourceGroupRequest{ req := &milvuspb.DescribeResourceGroupRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
@ -3180,7 +3238,7 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
// check default partitions // check default partitions
err = InitMetaCache(ctx, rc, nil, nil) err = InitMetaCache(ctx, rc, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
partitionNames, err := getDefaultPartitionNames(ctx, task.CollectionName) partitionNames, err := getDefaultPartitionNames(ctx, "", task.CollectionName)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, task.GetNumPartitions(), int64(len(partitionNames))) assert.Equal(t, task.GetNumPartitions(), int64(len(partitionNames)))
@ -3310,7 +3368,7 @@ func TestPartitionKey(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}) })
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err) assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
@ -3341,7 +3399,7 @@ func TestPartitionKey(t *testing.T) {
_ = segAllocator.Start() _ = segAllocator.Start()
defer segAllocator.Close() defer segAllocator.Close()
partitionNames, err := getDefaultPartitionNames(ctx, collectionName) partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames))) assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames)))

View File

@ -115,7 +115,7 @@ func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
} }
func (it *upsertTask) setChannels() error { func (it *upsertTask) setChannels() error {
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.CollectionName) collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.GetDbName(), it.req.CollectionName)
if err != nil { if err != nil {
return err return err
} }
@ -224,7 +224,7 @@ func (it *upsertTask) deletePreExecute(ctx context.Context) error {
log.Info("Invalid collection name", zap.Error(err)) log.Info("Invalid collection name", zap.Error(err))
return err return err
} }
collID, err := globalMetaCache.GetCollectionID(ctx, collName) collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collName)
if err != nil { if err != nil {
log.Info("Failed to get collection id", zap.Error(err)) log.Info("Failed to get collection id", zap.Error(err))
return err return err
@ -244,7 +244,7 @@ func (it *upsertTask) deletePreExecute(ctx context.Context) error {
log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err)) log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
return err return err
} }
partID, err := globalMetaCache.GetPartitionID(ctx, collName, partName) partID, err := globalMetaCache.GetPartitionID(ctx, it.req.GetDbName(), collName, partName)
if err != nil { if err != nil {
log.Warn("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err)) log.Warn("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err))
return err return err
@ -277,7 +277,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
Timestamp: it.EndTs(), Timestamp: it.EndTs(),
} }
schema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName) schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Warn("Failed to get collection schema", log.Warn("Failed to get collection schema",
zap.String("collectionName", collectionName), zap.String("collectionName", collectionName),
@ -286,7 +286,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
} }
it.schema = schema it.schema = schema
it.partitionKeyMode, err = isPartitionKeyMode(ctx, collectionName) it.partitionKeyMode, err = isPartitionKeyMode(ctx, it.req.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Warn("check partition key mode failed", log.Warn("check partition key mode failed",
zap.String("collectionName", collectionName), zap.String("collectionName", collectionName),
@ -319,6 +319,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
FieldsData: it.req.FieldsData, FieldsData: it.req.FieldsData,
NumRows: uint64(it.req.NumRows), NumRows: uint64(it.req.NumRows),
Version: msgpb.InsertDataVersion_ColumnBased, Version: msgpb.InsertDataVersion_ColumnBased,
DbName: it.req.DbName,
}, },
}, },
DeleteMsg: &msgstream.DeleteMsg{ DeleteMsg: &msgstream.DeleteMsg{
@ -364,7 +365,7 @@ func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgP
defer tr.Elapse("insert execute done when insertExecute") defer tr.Elapse("insert execute done when insertExecute")
collectionName := it.upsertMsg.InsertMsg.CollectionName collectionName := it.upsertMsg.InsertMsg.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collectionName)
if err != nil { if err != nil {
return err return err
} }

View File

@ -21,6 +21,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -28,7 +29,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
func TestUpsertTask_CheckAligned(t *testing.T) { func TestUpsertTask_CheckAligned(t *testing.T) {
@ -299,11 +299,14 @@ func TestUpsertTask(t *testing.T) {
collectionID := UniqueID(0) collectionID := UniqueID(0)
collectionName := "col-0" collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"} channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache() cache := NewMockCache(t)
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { cache.On("GetCollectionID",
return collectionID, nil mock.Anything, // context.Context
}) mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
globalMetaCache = cache globalMetaCache = cache
chMgr := newMockChannelsMgr() chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return channels, nil return channels, nil

View File

@ -165,6 +165,32 @@ func ValidateResourceGroupName(entity string) error {
return nil return nil
} }
func ValidateDatabaseName(dbName string) error {
if dbName == "" {
return merr.WrapErrInvalidedDatabaseName(dbName, "database name couldn't be empty")
}
if len(dbName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
return merr.WrapErrInvalidedDatabaseName(dbName,
fmt.Sprintf("the length of a database name must be less than %d characters", Params.ProxyCfg.MaxNameLength.GetAsInt()))
}
firstChar := dbName[0]
if firstChar != '_' && !isAlpha(firstChar) {
return merr.WrapErrInvalidedDatabaseName(dbName,
"the first character of a database name must be an underscore or letter")
}
for i := 1; i < len(dbName); i++ {
c := dbName[i]
if c != '_' && !isAlpha(c) && !isNumber(c) {
return merr.WrapErrInvalidedDatabaseName(dbName,
"database name can only contain numbers, letters and underscores")
}
}
return nil
}
// ValidateCollectionAlias returns true if collAlias is a valid alias name for collection, otherwise returns false. // ValidateCollectionAlias returns true if collAlias is a valid alias name for collection, otherwise returns false.
func ValidateCollectionAlias(collAlias string) error { func ValidateCollectionAlias(collAlias string) error {
return validateCollectionNameOrAlias(collAlias, "alias") return validateCollectionNameOrAlias(collAlias, "alias")
@ -810,6 +836,18 @@ func GetCurUserFromContext(ctx context.Context) (string, error) {
return username, nil return username, nil
} }
func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return util.DefaultDBName
}
dbNameData := md[strings.ToLower(util.HeaderDBName)]
if len(dbNameData) < 1 || dbNameData[0] == "" {
return util.DefaultDBName
}
return dbNameData[0]
}
func GetRole(username string) ([]string, error) { func GetRole(username string) ([]string, error) {
if globalMetaCache == nil { if globalMetaCache == nil {
return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait") return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
@ -1141,7 +1179,7 @@ func getCollectionProgress(
resp, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{ resp, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
Base: commonpbutil.UpdateMsgBase( Base: commonpbutil.UpdateMsgBase(
msgBase, msgBase,
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections),
), ),
CollectionIDs: []int64{collectionID}, CollectionIDs: []int64{collectionID},
}) })
@ -1191,14 +1229,16 @@ func getPartitionProgress(
partitionIDs := make([]int64, 0) partitionIDs := make([]int64, 0)
for _, partitionName := range partitionNames { for _, partitionName := range partitionNames {
var partitionID int64 var partitionID int64
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, partitionName) partitionID, err = globalMetaCache.GetPartitionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName, partitionName)
if err != nil { if err != nil {
return return
} }
IDs2Names[partitionID] = partitionName IDs2Names[partitionID] = partitionName
partitionIDs = append(partitionIDs, partitionID) partitionIDs = append(partitionIDs, partitionID)
} }
resp, err := queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
var resp *querypb.ShowPartitionsResponse
resp, err = queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
Base: commonpbutil.UpdateMsgBase( Base: commonpbutil.UpdateMsgBase(
msgBase, msgBase,
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
@ -1253,8 +1293,8 @@ func getPartitionProgress(
return return
} }
func isPartitionKeyMode(ctx context.Context, colName string) (bool, error) { func isPartitionKeyMode(ctx context.Context, dbName string, colName string) (bool, error) {
colSchema, err := globalMetaCache.GetCollectionSchema(ctx, colName) colSchema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, colName)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -1269,8 +1309,34 @@ func isPartitionKeyMode(ctx context.Context, colName string) (bool, error) {
} }
// getDefaultPartitionNames only used in partition key mode // getDefaultPartitionNames only used in partition key mode
func getDefaultPartitionNames(ctx context.Context, collectionName string) ([]string, error) { func getDefaultPartitionsInPartitionKeyMode(ctx context.Context, dbName string, collectionName string) ([]string, error) {
partitions, err := globalMetaCache.GetPartitions(ctx, collectionName) partitions, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
if err != nil {
return nil, err
}
// Make sure the order of the partition names got every time is the same
partitionNames := make([]string, len(partitions))
for partitionName := range partitions {
splits := strings.Split(partitionName, "_")
if len(splits) < 2 {
err = fmt.Errorf("bad default partion name in partition ket mode: %s", partitionName)
return nil, err
}
index, err := strconv.ParseInt(splits[len(splits)-1], 10, 64)
if err != nil {
return nil, err
}
partitionNames[index] = partitionName
}
return partitionNames, nil
}
// getDefaultPartitionNames only used in partition key mode
func getDefaultPartitionNames(ctx context.Context, dbName string, collectionName string) ([]string, error) {
partitions, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1311,13 +1377,13 @@ func assignChannelsByPK(pks *schemapb.IDs, channelNames []string, insertMsg *msg
return channel2RowOffsets return channel2RowOffsets
} }
func assignPartitionKeys(ctx context.Context, collName string, keys []*planpb.GenericValue) ([]string, error) { func assignPartitionKeys(ctx context.Context, dbName string, collName string, keys []*planpb.GenericValue) ([]string, error) {
partitionNames, err := getDefaultPartitionNames(ctx, collName) partitionNames, err := getDefaultPartitionNames(ctx, dbName, collName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
schema, err := globalMetaCache.GetCollectionSchema(ctx, collName) schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, collName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -25,9 +25,6 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -40,9 +37,12 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
@ -98,6 +98,31 @@ func TestValidateResourceGroupName(t *testing.T) {
} }
} }
func TestValidateDatabaseName(t *testing.T) {
assert.Nil(t, ValidateDatabaseName("dbname"))
assert.Nil(t, ValidateDatabaseName("_123abc"))
assert.Nil(t, ValidateDatabaseName("abc123_"))
longName := make([]byte, 512)
for i := 0; i < len(longName); i++ {
longName[i] = 'a'
}
invalidNames := []string{
"123abc",
"$abc",
"abc$",
"_12 ac",
" ",
"",
string(longName),
"中文",
}
for _, name := range invalidNames {
assert.Error(t, ValidateDatabaseName(name))
}
}
func TestValidatePartitionTag(t *testing.T) { func TestValidatePartitionTag(t *testing.T) {
assert.Nil(t, validatePartitionTag("abc", true)) assert.Nil(t, validatePartitionTag("abc", true))
assert.Nil(t, validatePartitionTag("123abc", true)) assert.Nil(t, validatePartitionTag("123abc", true))
@ -725,6 +750,18 @@ func GetContext(ctx context.Context, originValue string) context.Context {
return metadata.NewIncomingContext(ctx, md) return metadata.NewIncomingContext(ctx, md)
} }
func GetContextWithDB(ctx context.Context, originValue string, dbName string) context.Context {
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
dbKey := strings.ToLower(util.HeaderDBName)
contextMap := map[string]string{
authKey: authValue,
dbKey: dbName,
}
md := metadata.New(contextMap)
return metadata.NewIncomingContext(ctx, md)
}
func TestGetCurUserFromContext(t *testing.T) { func TestGetCurUserFromContext(t *testing.T) {
_, err := GetCurUserFromContext(context.Background()) _, err := GetCurUserFromContext(context.Background())
assert.Error(t, err) assert.Error(t, err)
@ -742,18 +779,38 @@ func TestGetCurUserFromContext(t *testing.T) {
assert.Equal(t, "root", username) assert.Equal(t, "root", username)
} }
func TestGetCurDBNameFromContext(t *testing.T) {
dbName := GetCurDBNameFromContextOrDefault(context.Background())
assert.Equal(t, util.DefaultDBName, dbName)
dbName = GetCurDBNameFromContextOrDefault(metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{})))
assert.Equal(t, util.DefaultDBName, dbName)
dbNameKey := strings.ToLower(util.HeaderDBName)
dbNameValue := "foodb"
contextMap := map[string]string{
dbNameKey: dbNameValue,
}
md := metadata.New(contextMap)
dbName = GetCurDBNameFromContextOrDefault(metadata.NewIncomingContext(context.Background(), md))
assert.Equal(t, dbNameValue, dbName)
}
func TestGetRole(t *testing.T) { func TestGetRole(t *testing.T) {
globalMetaCache = nil globalMetaCache = nil
_, err := GetRole("foo") _, err := GetRole("foo")
assert.Error(t, err) assert.Error(t, err)
globalMetaCache = &mockCache{ mockCache := NewMockCache(t)
getUserRoleFunc: func(username string) []string { mockCache.On("GetUserRole",
if username == "root" { mock.AnythingOfType("string"),
return []string{"role1", "admin", "role2"} ).Return(func(username string) []string {
} if username == "root" {
return []string{"role1"} return []string{"role1", "admin", "role2"}
}, }
} return []string{"role1"}
})
globalMetaCache = mockCache
roles, err := GetRole("root") roles, err := GetRole("root")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 3, len(roles)) assert.Equal(t, 3, len(roles))

View File

@ -42,8 +42,8 @@ type MockBalancer_AssignChannel_Call struct {
} }
// AssignChannel is a helper method to define mock.On call // AssignChannel is a helper method to define mock.On call
// - channels []*meta.DmChannel // - channels []*meta.DmChannel
// - nodes []int64 // - nodes []int64
func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}) *MockBalancer_AssignChannel_Call { func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}) *MockBalancer_AssignChannel_Call {
return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes)} return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes)}
} }
@ -82,9 +82,9 @@ type MockBalancer_AssignSegment_Call struct {
} }
// AssignSegment is a helper method to define mock.On call // AssignSegment is a helper method to define mock.On call
// - collectionID int64 // - collectionID int64
// - segments []*meta.Segment // - segments []*meta.Segment
// - nodes []int64 // - nodes []int64
func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call { func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call {
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes)} return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes)}
} }
@ -132,7 +132,7 @@ type MockBalancer_BalanceReplica_Call struct {
} }
// BalanceReplica is a helper method to define mock.On call // BalanceReplica is a helper method to define mock.On call
// - replica *meta.Replica // - replica *meta.Replica
func (_e *MockBalancer_Expecter) BalanceReplica(replica interface{}) *MockBalancer_BalanceReplica_Call { func (_e *MockBalancer_Expecter) BalanceReplica(replica interface{}) *MockBalancer_BalanceReplica_Call {
return &MockBalancer_BalanceReplica_Call{Call: _e.mock.On("BalanceReplica", replica)} return &MockBalancer_BalanceReplica_Call{Call: _e.mock.On("BalanceReplica", replica)}
} }

View File

@ -32,7 +32,7 @@ type MockController_Remove_Call struct {
} }
// Remove is a helper method to define mock.On call // Remove is a helper method to define mock.On call
// - nodeID int64 // - nodeID int64
func (_e *MockController_Expecter) Remove(nodeID interface{}) *MockController_Remove_Call { func (_e *MockController_Expecter) Remove(nodeID interface{}) *MockController_Remove_Call {
return &MockController_Remove_Call{Call: _e.mock.On("Remove", nodeID)} return &MockController_Remove_Call{Call: _e.mock.On("Remove", nodeID)}
} }
@ -60,8 +60,8 @@ type MockController_StartDistInstance_Call struct {
} }
// StartDistInstance is a helper method to define mock.On call // StartDistInstance is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
func (_e *MockController_Expecter) StartDistInstance(ctx interface{}, nodeID interface{}) *MockController_StartDistInstance_Call { func (_e *MockController_Expecter) StartDistInstance(ctx interface{}, nodeID interface{}) *MockController_StartDistInstance_Call {
return &MockController_StartDistInstance_Call{Call: _e.mock.On("StartDistInstance", ctx, nodeID)} return &MockController_StartDistInstance_Call{Call: _e.mock.On("StartDistInstance", ctx, nodeID)}
} }
@ -116,7 +116,7 @@ type MockController_SyncAll_Call struct {
} }
// SyncAll is a helper method to define mock.On call // SyncAll is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
func (_e *MockController_Expecter) SyncAll(ctx interface{}) *MockController_SyncAll_Call { func (_e *MockController_Expecter) SyncAll(ctx interface{}) *MockController_SyncAll_Call {
return &MockController_SyncAll_Call{Call: _e.mock.On("SyncAll", ctx)} return &MockController_SyncAll_Call{Call: _e.mock.On("SyncAll", ctx)}
} }

View File

@ -74,6 +74,7 @@ func (broker *CoordinatorBroker) GetCollectionSchema(ctx context.Context, collec
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
), ),
// please do not specify the collection name alone after database feature.
CollectionID: collectionID, CollectionID: collectionID,
} }
resp, err := broker.rootCoord.DescribeCollection(ctx, req) resp, err := broker.rootCoord.DescribeCollection(ctx, req)
@ -101,6 +102,7 @@ func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
), ),
// please do not specify the collection name alone after database feature.
CollectionID: collectionID, CollectionID: collectionID,
} }
resp, err := broker.rootCoord.ShowPartitions(ctx, req) resp, err := broker.rootCoord.ShowPartitions(ctx, req)

View File

@ -4,9 +4,9 @@ package meta
import ( import (
context "context" context "context"
"github.com/milvus-io/milvus/internal/proto/indexpb"
datapb "github.com/milvus-io/milvus/internal/proto/datapb" datapb "github.com/milvus-io/milvus/internal/proto/datapb"
indexpb "github.com/milvus-io/milvus/internal/proto/indexpb"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
@ -28,102 +28,7 @@ func (_m *MockBroker) EXPECT() *MockBroker_Expecter {
return &MockBroker_Expecter{mock: &_m.Mock} return &MockBroker_Expecter{mock: &_m.Mock}
} }
// GetCollectionSchema provides a mock function with given fields: ctx, collectionID // DescribeIndex provides a mock function with given fields: ctx, collectionID
func (_m *MockBroker) GetCollectionSchema(ctx context.Context, collectionID int64) (*schemapb.CollectionSchema, error) {
ret := _m.Called(ctx, collectionID)
var r0 *schemapb.CollectionSchema
if rf, ok := ret.Get(0).(func(context.Context, int64) *schemapb.CollectionSchema); ok {
r0 = rf(ctx, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*schemapb.CollectionSchema)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_GetCollectionSchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionSchema'
type MockBroker_GetCollectionSchema_Call struct {
*mock.Call
}
// GetCollectionSchema is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockBroker_Expecter) GetCollectionSchema(ctx interface{}, collectionID interface{}) *MockBroker_GetCollectionSchema_Call {
return &MockBroker_GetCollectionSchema_Call{Call: _e.mock.On("GetCollectionSchema", ctx, collectionID)}
}
func (_c *MockBroker_GetCollectionSchema_Call) Run(run func(ctx context.Context, collectionID int64)) *MockBroker_GetCollectionSchema_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockBroker_GetCollectionSchema_Call) Return(_a0 *schemapb.CollectionSchema, _a1 error) *MockBroker_GetCollectionSchema_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetIndexInfo provides a mock function with given fields: ctx, collectionID, segmentID
func (_m *MockBroker) GetIndexInfo(ctx context.Context, collectionID int64, segmentID int64) ([]*querypb.FieldIndexInfo, error) {
ret := _m.Called(ctx, collectionID, segmentID)
var r0 []*querypb.FieldIndexInfo
if rf, ok := ret.Get(0).(func(context.Context, int64, int64) []*querypb.FieldIndexInfo); ok {
r0 = rf(ctx, collectionID, segmentID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*querypb.FieldIndexInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok {
r1 = rf(ctx, collectionID, segmentID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_GetIndexInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexInfo'
type MockBroker_GetIndexInfo_Call struct {
*mock.Call
}
// GetIndexInfo is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - segmentID int64
func (_e *MockBroker_Expecter) GetIndexInfo(ctx interface{}, collectionID interface{}, segmentID interface{}) *MockBroker_GetIndexInfo_Call {
return &MockBroker_GetIndexInfo_Call{Call: _e.mock.On("GetIndexInfo", ctx, collectionID, segmentID)}
}
func (_c *MockBroker_GetIndexInfo_Call) Run(run func(ctx context.Context, collectionID int64, segmentID int64)) *MockBroker_GetIndexInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(int64))
})
return _c
}
func (_c *MockBroker_GetIndexInfo_Call) Return(_a0 []*querypb.FieldIndexInfo, _a1 error) *MockBroker_GetIndexInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// DescribeIndex provides a mock function with given fields: ctx, collectionID, segmentID
func (_m *MockBroker) DescribeIndex(ctx context.Context, collectionID int64) ([]*indexpb.IndexInfo, error) { func (_m *MockBroker) DescribeIndex(ctx context.Context, collectionID int64) ([]*indexpb.IndexInfo, error) {
ret := _m.Called(ctx, collectionID) ret := _m.Called(ctx, collectionID)
@ -170,6 +75,101 @@ func (_c *MockBroker_DescribeIndex_Call) Return(_a0 []*indexpb.IndexInfo, _a1 er
return _c return _c
} }
// GetCollectionSchema provides a mock function with given fields: ctx, collectionID
func (_m *MockBroker) GetCollectionSchema(ctx context.Context, collectionID int64) (*schemapb.CollectionSchema, error) {
ret := _m.Called(ctx, collectionID)
var r0 *schemapb.CollectionSchema
if rf, ok := ret.Get(0).(func(context.Context, int64) *schemapb.CollectionSchema); ok {
r0 = rf(ctx, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*schemapb.CollectionSchema)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_GetCollectionSchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionSchema'
type MockBroker_GetCollectionSchema_Call struct {
*mock.Call
}
// GetCollectionSchema is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockBroker_Expecter) GetCollectionSchema(ctx interface{}, collectionID interface{}) *MockBroker_GetCollectionSchema_Call {
return &MockBroker_GetCollectionSchema_Call{Call: _e.mock.On("GetCollectionSchema", ctx, collectionID)}
}
func (_c *MockBroker_GetCollectionSchema_Call) Run(run func(ctx context.Context, collectionID int64)) *MockBroker_GetCollectionSchema_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockBroker_GetCollectionSchema_Call) Return(_a0 *schemapb.CollectionSchema, _a1 error) *MockBroker_GetCollectionSchema_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetIndexInfo provides a mock function with given fields: ctx, collectionID, segmentID
func (_m *MockBroker) GetIndexInfo(ctx context.Context, collectionID int64, segmentID int64) ([]*querypb.FieldIndexInfo, error) {
ret := _m.Called(ctx, collectionID, segmentID)
var r0 []*querypb.FieldIndexInfo
if rf, ok := ret.Get(0).(func(context.Context, int64, int64) []*querypb.FieldIndexInfo); ok {
r0 = rf(ctx, collectionID, segmentID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*querypb.FieldIndexInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok {
r1 = rf(ctx, collectionID, segmentID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_GetIndexInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexInfo'
type MockBroker_GetIndexInfo_Call struct {
*mock.Call
}
// GetIndexInfo is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - segmentID int64
func (_e *MockBroker_Expecter) GetIndexInfo(ctx interface{}, collectionID interface{}, segmentID interface{}) *MockBroker_GetIndexInfo_Call {
return &MockBroker_GetIndexInfo_Call{Call: _e.mock.On("GetIndexInfo", ctx, collectionID, segmentID)}
}
func (_c *MockBroker_GetIndexInfo_Call) Run(run func(ctx context.Context, collectionID int64, segmentID int64)) *MockBroker_GetIndexInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(int64))
})
return _c
}
func (_c *MockBroker_GetIndexInfo_Call) Return(_a0 []*querypb.FieldIndexInfo, _a1 error) *MockBroker_GetIndexInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetPartitions provides a mock function with given fields: ctx, collectionID // GetPartitions provides a mock function with given fields: ctx, collectionID
func (_m *MockBroker) GetPartitions(ctx context.Context, collectionID int64) ([]int64, error) { func (_m *MockBroker) GetPartitions(ctx context.Context, collectionID int64) ([]int64, error) {
ret := _m.Called(ctx, collectionID) ret := _m.Called(ctx, collectionID)
@ -199,8 +199,8 @@ type MockBroker_GetPartitions_Call struct {
} }
// GetPartitions is a helper method to define mock.On call // GetPartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - collectionID int64 // - collectionID int64
func (_e *MockBroker_Expecter) GetPartitions(ctx interface{}, collectionID interface{}) *MockBroker_GetPartitions_Call { func (_e *MockBroker_Expecter) GetPartitions(ctx interface{}, collectionID interface{}) *MockBroker_GetPartitions_Call {
return &MockBroker_GetPartitions_Call{Call: _e.mock.On("GetPartitions", ctx, collectionID)} return &MockBroker_GetPartitions_Call{Call: _e.mock.On("GetPartitions", ctx, collectionID)}
} }
@ -255,9 +255,9 @@ type MockBroker_GetRecoveryInfo_Call struct {
} }
// GetRecoveryInfo is a helper method to define mock.On call // GetRecoveryInfo is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - collectionID int64 // - collectionID int64
// - partitionID int64 // - partitionID int64
func (_e *MockBroker_Expecter) GetRecoveryInfo(ctx interface{}, collectionID interface{}, partitionID interface{}) *MockBroker_GetRecoveryInfo_Call { func (_e *MockBroker_Expecter) GetRecoveryInfo(ctx interface{}, collectionID interface{}, partitionID interface{}) *MockBroker_GetRecoveryInfo_Call {
return &MockBroker_GetRecoveryInfo_Call{Call: _e.mock.On("GetRecoveryInfo", ctx, collectionID, partitionID)} return &MockBroker_GetRecoveryInfo_Call{Call: _e.mock.On("GetRecoveryInfo", ctx, collectionID, partitionID)}
} }
@ -319,9 +319,9 @@ type MockBroker_GetRecoveryInfoV2_Call struct {
} }
// GetRecoveryInfoV2 is a helper method to define mock.On call // GetRecoveryInfoV2 is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - collectionID int64 // - collectionID int64
// - partitionIDs ...int64 // - partitionIDs ...int64
func (_e *MockBroker_Expecter) GetRecoveryInfoV2(ctx interface{}, collectionID interface{}, partitionIDs ...interface{}) *MockBroker_GetRecoveryInfoV2_Call { func (_e *MockBroker_Expecter) GetRecoveryInfoV2(ctx interface{}, collectionID interface{}, partitionIDs ...interface{}) *MockBroker_GetRecoveryInfoV2_Call {
return &MockBroker_GetRecoveryInfoV2_Call{Call: _e.mock.On("GetRecoveryInfoV2", return &MockBroker_GetRecoveryInfoV2_Call{Call: _e.mock.On("GetRecoveryInfoV2",
append([]interface{}{ctx, collectionID}, partitionIDs...)...)} append([]interface{}{ctx, collectionID}, partitionIDs...)...)}
@ -381,8 +381,8 @@ type MockBroker_GetSegmentInfo_Call struct {
} }
// GetSegmentInfo is a helper method to define mock.On call // GetSegmentInfo is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - segmentID ...int64 // - segmentID ...int64
func (_e *MockBroker_Expecter) GetSegmentInfo(ctx interface{}, segmentID ...interface{}) *MockBroker_GetSegmentInfo_Call { func (_e *MockBroker_Expecter) GetSegmentInfo(ctx interface{}, segmentID ...interface{}) *MockBroker_GetSegmentInfo_Call {
return &MockBroker_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", return &MockBroker_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo",
append([]interface{}{ctx}, segmentID...)...)} append([]interface{}{ctx}, segmentID...)...)}

View File

@ -220,7 +220,7 @@ type MockStore_ReleaseCollection_Call struct {
} }
// ReleaseCollection is a helper method to define mock.On call // ReleaseCollection is a helper method to define mock.On call
// - collection int64 // - collection int64
func (_e *MockStore_Expecter) ReleaseCollection(collection interface{}) *MockStore_ReleaseCollection_Call { func (_e *MockStore_Expecter) ReleaseCollection(collection interface{}) *MockStore_ReleaseCollection_Call {
return &MockStore_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", collection)} return &MockStore_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", collection)}
} }
@ -264,8 +264,8 @@ type MockStore_ReleasePartition_Call struct {
} }
// ReleasePartition is a helper method to define mock.On call // ReleasePartition is a helper method to define mock.On call
// - collection int64 // - collection int64
// - partitions ...int64 // - partitions ...int64
func (_e *MockStore_Expecter) ReleasePartition(collection interface{}, partitions ...interface{}) *MockStore_ReleasePartition_Call { func (_e *MockStore_Expecter) ReleasePartition(collection interface{}, partitions ...interface{}) *MockStore_ReleasePartition_Call {
return &MockStore_ReleasePartition_Call{Call: _e.mock.On("ReleasePartition", return &MockStore_ReleasePartition_Call{Call: _e.mock.On("ReleasePartition",
append([]interface{}{collection}, partitions...)...)} append([]interface{}{collection}, partitions...)...)}
@ -309,8 +309,8 @@ type MockStore_ReleaseReplica_Call struct {
} }
// ReleaseReplica is a helper method to define mock.On call // ReleaseReplica is a helper method to define mock.On call
// - collection int64 // - collection int64
// - replica int64 // - replica int64
func (_e *MockStore_Expecter) ReleaseReplica(collection interface{}, replica interface{}) *MockStore_ReleaseReplica_Call { func (_e *MockStore_Expecter) ReleaseReplica(collection interface{}, replica interface{}) *MockStore_ReleaseReplica_Call {
return &MockStore_ReleaseReplica_Call{Call: _e.mock.On("ReleaseReplica", collection, replica)} return &MockStore_ReleaseReplica_Call{Call: _e.mock.On("ReleaseReplica", collection, replica)}
} }
@ -347,7 +347,7 @@ type MockStore_ReleaseReplicas_Call struct {
} }
// ReleaseReplicas is a helper method to define mock.On call // ReleaseReplicas is a helper method to define mock.On call
// - collectionID int64 // - collectionID int64
func (_e *MockStore_Expecter) ReleaseReplicas(collectionID interface{}) *MockStore_ReleaseReplicas_Call { func (_e *MockStore_Expecter) ReleaseReplicas(collectionID interface{}) *MockStore_ReleaseReplicas_Call {
return &MockStore_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", collectionID)} return &MockStore_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", collectionID)}
} }
@ -384,7 +384,7 @@ type MockStore_RemoveResourceGroup_Call struct {
} }
// RemoveResourceGroup is a helper method to define mock.On call // RemoveResourceGroup is a helper method to define mock.On call
// - rgName string // - rgName string
func (_e *MockStore_Expecter) RemoveResourceGroup(rgName interface{}) *MockStore_RemoveResourceGroup_Call { func (_e *MockStore_Expecter) RemoveResourceGroup(rgName interface{}) *MockStore_RemoveResourceGroup_Call {
return &MockStore_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", rgName)} return &MockStore_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", rgName)}
} }
@ -428,8 +428,8 @@ type MockStore_SaveCollection_Call struct {
} }
// SaveCollection is a helper method to define mock.On call // SaveCollection is a helper method to define mock.On call
// - collection *querypb.CollectionLoadInfo // - collection *querypb.CollectionLoadInfo
// - partitions ...*querypb.PartitionLoadInfo // - partitions ...*querypb.PartitionLoadInfo
func (_e *MockStore_Expecter) SaveCollection(collection interface{}, partitions ...interface{}) *MockStore_SaveCollection_Call { func (_e *MockStore_Expecter) SaveCollection(collection interface{}, partitions ...interface{}) *MockStore_SaveCollection_Call {
return &MockStore_SaveCollection_Call{Call: _e.mock.On("SaveCollection", return &MockStore_SaveCollection_Call{Call: _e.mock.On("SaveCollection",
append([]interface{}{collection}, partitions...)...)} append([]interface{}{collection}, partitions...)...)}
@ -479,7 +479,7 @@ type MockStore_SavePartition_Call struct {
} }
// SavePartition is a helper method to define mock.On call // SavePartition is a helper method to define mock.On call
// - info ...*querypb.PartitionLoadInfo // - info ...*querypb.PartitionLoadInfo
func (_e *MockStore_Expecter) SavePartition(info ...interface{}) *MockStore_SavePartition_Call { func (_e *MockStore_Expecter) SavePartition(info ...interface{}) *MockStore_SavePartition_Call {
return &MockStore_SavePartition_Call{Call: _e.mock.On("SavePartition", return &MockStore_SavePartition_Call{Call: _e.mock.On("SavePartition",
append([]interface{}{}, info...)...)} append([]interface{}{}, info...)...)}
@ -523,7 +523,7 @@ type MockStore_SaveReplica_Call struct {
} }
// SaveReplica is a helper method to define mock.On call // SaveReplica is a helper method to define mock.On call
// - replica *querypb.Replica // - replica *querypb.Replica
func (_e *MockStore_Expecter) SaveReplica(replica interface{}) *MockStore_SaveReplica_Call { func (_e *MockStore_Expecter) SaveReplica(replica interface{}) *MockStore_SaveReplica_Call {
return &MockStore_SaveReplica_Call{Call: _e.mock.On("SaveReplica", replica)} return &MockStore_SaveReplica_Call{Call: _e.mock.On("SaveReplica", replica)}
} }
@ -566,7 +566,7 @@ type MockStore_SaveResourceGroup_Call struct {
} }
// SaveResourceGroup is a helper method to define mock.On call // SaveResourceGroup is a helper method to define mock.On call
// - rgs ...*querypb.ResourceGroup // - rgs ...*querypb.ResourceGroup
func (_e *MockStore_Expecter) SaveResourceGroup(rgs ...interface{}) *MockStore_SaveResourceGroup_Call { func (_e *MockStore_Expecter) SaveResourceGroup(rgs ...interface{}) *MockStore_SaveResourceGroup_Call {
return &MockStore_SaveResourceGroup_Call{Call: _e.mock.On("SaveResourceGroup", return &MockStore_SaveResourceGroup_Call{Call: _e.mock.On("SaveResourceGroup",
append([]interface{}{}, rgs...)...)} append([]interface{}{}, rgs...)...)}

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.21.1. DO NOT EDIT. // Code generated by mockery v2.16.0. DO NOT EDIT.
package mocks package mocks
@ -34,10 +34,6 @@ func (_m *MockQueryNodeServer) Delete(_a0 context.Context, _a1 *querypb.DeleteRe
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -46,6 +42,7 @@ func (_m *MockQueryNodeServer) Delete(_a0 context.Context, _a1 *querypb.DeleteRe
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.DeleteRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.DeleteRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -61,8 +58,8 @@ type MockQueryNodeServer_Delete_Call struct {
} }
// Delete is a helper method to define mock.On call // Delete is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.DeleteRequest // - _a1 *querypb.DeleteRequest
func (_e *MockQueryNodeServer_Expecter) Delete(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Delete_Call { func (_e *MockQueryNodeServer_Expecter) Delete(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Delete_Call {
return &MockQueryNodeServer_Delete_Call{Call: _e.mock.On("Delete", _a0, _a1)} return &MockQueryNodeServer_Delete_Call{Call: _e.mock.On("Delete", _a0, _a1)}
} }
@ -79,20 +76,11 @@ func (_c *MockQueryNodeServer_Delete_Call) Return(_a0 *commonpb.Status, _a1 erro
return _c return _c
} }
func (_c *MockQueryNodeServer_Delete_Call) RunAndReturn(run func(context.Context, *querypb.DeleteRequest) (*commonpb.Status, error)) *MockQueryNodeServer_Delete_Call {
_c.Call.Return(run)
return _c
}
// GetComponentStates provides a mock function with given fields: _a0, _a1 // GetComponentStates provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { func (_m *MockQueryNodeServer) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *milvuspb.ComponentStates var r0 *milvuspb.ComponentStates
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -101,6 +89,7 @@ func (_m *MockQueryNodeServer) GetComponentStates(_a0 context.Context, _a1 *milv
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -116,8 +105,8 @@ type MockQueryNodeServer_GetComponentStates_Call struct {
} }
// GetComponentStates is a helper method to define mock.On call // GetComponentStates is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *milvuspb.GetComponentStatesRequest // - _a1 *milvuspb.GetComponentStatesRequest
func (_e *MockQueryNodeServer_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetComponentStates_Call { func (_e *MockQueryNodeServer_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetComponentStates_Call {
return &MockQueryNodeServer_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} return &MockQueryNodeServer_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)}
} }
@ -134,20 +123,11 @@ func (_c *MockQueryNodeServer_GetComponentStates_Call) Return(_a0 *milvuspb.Comp
return _c return _c
} }
func (_c *MockQueryNodeServer_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)) *MockQueryNodeServer_GetComponentStates_Call {
_c.Call.Return(run)
return _c
}
// GetDataDistribution provides a mock function with given fields: _a0, _a1 // GetDataDistribution provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) GetDataDistribution(_a0 context.Context, _a1 *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { func (_m *MockQueryNodeServer) GetDataDistribution(_a0 context.Context, _a1 *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *querypb.GetDataDistributionResponse var r0 *querypb.GetDataDistributionResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetDataDistributionRequest) *querypb.GetDataDistributionResponse); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetDataDistributionRequest) *querypb.GetDataDistributionResponse); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -156,6 +136,7 @@ func (_m *MockQueryNodeServer) GetDataDistribution(_a0 context.Context, _a1 *que
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetDataDistributionRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetDataDistributionRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -171,8 +152,8 @@ type MockQueryNodeServer_GetDataDistribution_Call struct {
} }
// GetDataDistribution is a helper method to define mock.On call // GetDataDistribution is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetDataDistributionRequest // - _a1 *querypb.GetDataDistributionRequest
func (_e *MockQueryNodeServer_Expecter) GetDataDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetDataDistribution_Call { func (_e *MockQueryNodeServer_Expecter) GetDataDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetDataDistribution_Call {
return &MockQueryNodeServer_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", _a0, _a1)} return &MockQueryNodeServer_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", _a0, _a1)}
} }
@ -189,20 +170,11 @@ func (_c *MockQueryNodeServer_GetDataDistribution_Call) Return(_a0 *querypb.GetD
return _c return _c
} }
func (_c *MockQueryNodeServer_GetDataDistribution_Call) RunAndReturn(run func(context.Context, *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error)) *MockQueryNodeServer_GetDataDistribution_Call {
_c.Call.Return(run)
return _c
}
// GetMetrics provides a mock function with given fields: _a0, _a1 // GetMetrics provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) GetMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { func (_m *MockQueryNodeServer) GetMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *milvuspb.GetMetricsResponse var r0 *milvuspb.GetMetricsResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -211,6 +183,7 @@ func (_m *MockQueryNodeServer) GetMetrics(_a0 context.Context, _a1 *milvuspb.Get
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -226,8 +199,8 @@ type MockQueryNodeServer_GetMetrics_Call struct {
} }
// GetMetrics is a helper method to define mock.On call // GetMetrics is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *milvuspb.GetMetricsRequest // - _a1 *milvuspb.GetMetricsRequest
func (_e *MockQueryNodeServer_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetMetrics_Call { func (_e *MockQueryNodeServer_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetMetrics_Call {
return &MockQueryNodeServer_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} return &MockQueryNodeServer_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)}
} }
@ -244,20 +217,11 @@ func (_c *MockQueryNodeServer_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsRe
return _c return _c
} }
func (_c *MockQueryNodeServer_GetMetrics_Call) RunAndReturn(run func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)) *MockQueryNodeServer_GetMetrics_Call {
_c.Call.Return(run)
return _c
}
// GetSegmentInfo provides a mock function with given fields: _a0, _a1 // GetSegmentInfo provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) GetSegmentInfo(_a0 context.Context, _a1 *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { func (_m *MockQueryNodeServer) GetSegmentInfo(_a0 context.Context, _a1 *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *querypb.GetSegmentInfoResponse var r0 *querypb.GetSegmentInfoResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest) *querypb.GetSegmentInfoResponse); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest) *querypb.GetSegmentInfoResponse); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -266,6 +230,7 @@ func (_m *MockQueryNodeServer) GetSegmentInfo(_a0 context.Context, _a1 *querypb.
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetSegmentInfoRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetSegmentInfoRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -281,8 +246,8 @@ type MockQueryNodeServer_GetSegmentInfo_Call struct {
} }
// GetSegmentInfo is a helper method to define mock.On call // GetSegmentInfo is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetSegmentInfoRequest // - _a1 *querypb.GetSegmentInfoRequest
func (_e *MockQueryNodeServer_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetSegmentInfo_Call { func (_e *MockQueryNodeServer_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetSegmentInfo_Call {
return &MockQueryNodeServer_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)} return &MockQueryNodeServer_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)}
} }
@ -299,20 +264,11 @@ func (_c *MockQueryNodeServer_GetSegmentInfo_Call) Return(_a0 *querypb.GetSegmen
return _c return _c
} }
func (_c *MockQueryNodeServer_GetSegmentInfo_Call) RunAndReturn(run func(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error)) *MockQueryNodeServer_GetSegmentInfo_Call {
_c.Call.Return(run)
return _c
}
// GetStatistics provides a mock function with given fields: _a0, _a1 // GetStatistics provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) GetStatistics(_a0 context.Context, _a1 *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { func (_m *MockQueryNodeServer) GetStatistics(_a0 context.Context, _a1 *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *internalpb.GetStatisticsResponse var r0 *internalpb.GetStatisticsResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) *internalpb.GetStatisticsResponse); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) *internalpb.GetStatisticsResponse); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -321,6 +277,7 @@ func (_m *MockQueryNodeServer) GetStatistics(_a0 context.Context, _a1 *querypb.G
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -336,8 +293,8 @@ type MockQueryNodeServer_GetStatistics_Call struct {
} }
// GetStatistics is a helper method to define mock.On call // GetStatistics is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetStatisticsRequest // - _a1 *querypb.GetStatisticsRequest
func (_e *MockQueryNodeServer_Expecter) GetStatistics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatistics_Call { func (_e *MockQueryNodeServer_Expecter) GetStatistics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatistics_Call {
return &MockQueryNodeServer_GetStatistics_Call{Call: _e.mock.On("GetStatistics", _a0, _a1)} return &MockQueryNodeServer_GetStatistics_Call{Call: _e.mock.On("GetStatistics", _a0, _a1)}
} }
@ -354,20 +311,11 @@ func (_c *MockQueryNodeServer_GetStatistics_Call) Return(_a0 *internalpb.GetStat
return _c return _c
} }
func (_c *MockQueryNodeServer_GetStatistics_Call) RunAndReturn(run func(context.Context, *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error)) *MockQueryNodeServer_GetStatistics_Call {
_c.Call.Return(run)
return _c
}
// GetStatisticsChannel provides a mock function with given fields: _a0, _a1 // GetStatisticsChannel provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { func (_m *MockQueryNodeServer) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *milvuspb.StringResponse var r0 *milvuspb.StringResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok { if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -376,6 +324,7 @@ func (_m *MockQueryNodeServer) GetStatisticsChannel(_a0 context.Context, _a1 *in
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -391,8 +340,8 @@ type MockQueryNodeServer_GetStatisticsChannel_Call struct {
} }
// GetStatisticsChannel is a helper method to define mock.On call // GetStatisticsChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.GetStatisticsChannelRequest // - _a1 *internalpb.GetStatisticsChannelRequest
func (_e *MockQueryNodeServer_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatisticsChannel_Call { func (_e *MockQueryNodeServer_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatisticsChannel_Call {
return &MockQueryNodeServer_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} return &MockQueryNodeServer_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)}
} }
@ -409,20 +358,11 @@ func (_c *MockQueryNodeServer_GetStatisticsChannel_Call) Return(_a0 *milvuspb.St
return _c return _c
} }
func (_c *MockQueryNodeServer_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)) *MockQueryNodeServer_GetStatisticsChannel_Call {
_c.Call.Return(run)
return _c
}
// GetTimeTickChannel provides a mock function with given fields: _a0, _a1 // GetTimeTickChannel provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) GetTimeTickChannel(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { func (_m *MockQueryNodeServer) GetTimeTickChannel(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *milvuspb.StringResponse var r0 *milvuspb.StringResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) *milvuspb.StringResponse); ok { if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) *milvuspb.StringResponse); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -431,6 +371,7 @@ func (_m *MockQueryNodeServer) GetTimeTickChannel(_a0 context.Context, _a1 *inte
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -446,8 +387,8 @@ type MockQueryNodeServer_GetTimeTickChannel_Call struct {
} }
// GetTimeTickChannel is a helper method to define mock.On call // GetTimeTickChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.GetTimeTickChannelRequest // - _a1 *internalpb.GetTimeTickChannelRequest
func (_e *MockQueryNodeServer_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetTimeTickChannel_Call { func (_e *MockQueryNodeServer_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetTimeTickChannel_Call {
return &MockQueryNodeServer_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)} return &MockQueryNodeServer_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)}
} }
@ -464,20 +405,11 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) Return(_a0 *milvuspb.Stri
return _c return _c
} }
func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)) *MockQueryNodeServer_GetTimeTickChannel_Call {
_c.Call.Return(run)
return _c
}
// LoadPartitions provides a mock function with given fields: _a0, _a1 // LoadPartitions provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -486,6 +418,7 @@ func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadPartitionsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadPartitionsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -501,8 +434,8 @@ type MockQueryNodeServer_LoadPartitions_Call struct {
} }
// LoadPartitions is a helper method to define mock.On call // LoadPartitions is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.LoadPartitionsRequest // - _a1 *querypb.LoadPartitionsRequest
func (_e *MockQueryNodeServer_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadPartitions_Call { func (_e *MockQueryNodeServer_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadPartitions_Call {
return &MockQueryNodeServer_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)} return &MockQueryNodeServer_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)}
} }
@ -519,20 +452,11 @@ func (_c *MockQueryNodeServer_LoadPartitions_Call) Return(_a0 *commonpb.Status,
return _c return _c
} }
func (_c *MockQueryNodeServer_LoadPartitions_Call) RunAndReturn(run func(context.Context, *querypb.LoadPartitionsRequest) (*commonpb.Status, error)) *MockQueryNodeServer_LoadPartitions_Call {
_c.Call.Return(run)
return _c
}
// LoadSegments provides a mock function with given fields: _a0, _a1 // LoadSegments provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -541,6 +465,7 @@ func (_m *MockQueryNodeServer) LoadSegments(_a0 context.Context, _a1 *querypb.Lo
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadSegmentsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadSegmentsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -556,8 +481,8 @@ type MockQueryNodeServer_LoadSegments_Call struct {
} }
// LoadSegments is a helper method to define mock.On call // LoadSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.LoadSegmentsRequest // - _a1 *querypb.LoadSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadSegments_Call { func (_e *MockQueryNodeServer_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadSegments_Call {
return &MockQueryNodeServer_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)} return &MockQueryNodeServer_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)}
} }
@ -574,20 +499,11 @@ func (_c *MockQueryNodeServer_LoadSegments_Call) Return(_a0 *commonpb.Status, _a
return _c return _c
} }
func (_c *MockQueryNodeServer_LoadSegments_Call) RunAndReturn(run func(context.Context, *querypb.LoadSegmentsRequest) (*commonpb.Status, error)) *MockQueryNodeServer_LoadSegments_Call {
_c.Call.Return(run)
return _c
}
// Query provides a mock function with given fields: _a0, _a1 // Query provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) Query(_a0 context.Context, _a1 *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { func (_m *MockQueryNodeServer) Query(_a0 context.Context, _a1 *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *internalpb.RetrieveResults var r0 *internalpb.RetrieveResults
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -596,6 +512,7 @@ func (_m *MockQueryNodeServer) Query(_a0 context.Context, _a1 *querypb.QueryRequ
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -611,8 +528,8 @@ type MockQueryNodeServer_Query_Call struct {
} }
// Query is a helper method to define mock.On call // Query is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.QueryRequest // - _a1 *querypb.QueryRequest
func (_e *MockQueryNodeServer_Expecter) Query(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Query_Call { func (_e *MockQueryNodeServer_Expecter) Query(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Query_Call {
return &MockQueryNodeServer_Query_Call{Call: _e.mock.On("Query", _a0, _a1)} return &MockQueryNodeServer_Query_Call{Call: _e.mock.On("Query", _a0, _a1)}
} }
@ -629,20 +546,11 @@ func (_c *MockQueryNodeServer_Query_Call) Return(_a0 *internalpb.RetrieveResults
return _c return _c
} }
func (_c *MockQueryNodeServer_Query_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)) *MockQueryNodeServer_Query_Call {
_c.Call.Return(run)
return _c
}
// QuerySegments provides a mock function with given fields: _a0, _a1 // QuerySegments provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) QuerySegments(_a0 context.Context, _a1 *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { func (_m *MockQueryNodeServer) QuerySegments(_a0 context.Context, _a1 *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *internalpb.RetrieveResults var r0 *internalpb.RetrieveResults
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -651,6 +559,7 @@ func (_m *MockQueryNodeServer) QuerySegments(_a0 context.Context, _a1 *querypb.Q
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -666,8 +575,8 @@ type MockQueryNodeServer_QuerySegments_Call struct {
} }
// QuerySegments is a helper method to define mock.On call // QuerySegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.QueryRequest // - _a1 *querypb.QueryRequest
func (_e *MockQueryNodeServer_Expecter) QuerySegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_QuerySegments_Call { func (_e *MockQueryNodeServer_Expecter) QuerySegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_QuerySegments_Call {
return &MockQueryNodeServer_QuerySegments_Call{Call: _e.mock.On("QuerySegments", _a0, _a1)} return &MockQueryNodeServer_QuerySegments_Call{Call: _e.mock.On("QuerySegments", _a0, _a1)}
} }
@ -684,20 +593,11 @@ func (_c *MockQueryNodeServer_QuerySegments_Call) Return(_a0 *internalpb.Retriev
return _c return _c
} }
func (_c *MockQueryNodeServer_QuerySegments_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)) *MockQueryNodeServer_QuerySegments_Call {
_c.Call.Return(run)
return _c
}
// ReleaseCollection provides a mock function with given fields: _a0, _a1 // ReleaseCollection provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) ReleaseCollection(_a0 context.Context, _a1 *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) ReleaseCollection(_a0 context.Context, _a1 *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -706,6 +606,7 @@ func (_m *MockQueryNodeServer) ReleaseCollection(_a0 context.Context, _a1 *query
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseCollectionRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseCollectionRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -721,8 +622,8 @@ type MockQueryNodeServer_ReleaseCollection_Call struct {
} }
// ReleaseCollection is a helper method to define mock.On call // ReleaseCollection is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleaseCollectionRequest // - _a1 *querypb.ReleaseCollectionRequest
func (_e *MockQueryNodeServer_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseCollection_Call { func (_e *MockQueryNodeServer_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseCollection_Call {
return &MockQueryNodeServer_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)} return &MockQueryNodeServer_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)}
} }
@ -739,20 +640,11 @@ func (_c *MockQueryNodeServer_ReleaseCollection_Call) Return(_a0 *commonpb.Statu
return _c return _c
} }
func (_c *MockQueryNodeServer_ReleaseCollection_Call) RunAndReturn(run func(context.Context, *querypb.ReleaseCollectionRequest) (*commonpb.Status, error)) *MockQueryNodeServer_ReleaseCollection_Call {
_c.Call.Return(run)
return _c
}
// ReleasePartitions provides a mock function with given fields: _a0, _a1 // ReleasePartitions provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) ReleasePartitions(_a0 context.Context, _a1 *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) ReleasePartitions(_a0 context.Context, _a1 *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -761,6 +653,7 @@ func (_m *MockQueryNodeServer) ReleasePartitions(_a0 context.Context, _a1 *query
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleasePartitionsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleasePartitionsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -776,8 +669,8 @@ type MockQueryNodeServer_ReleasePartitions_Call struct {
} }
// ReleasePartitions is a helper method to define mock.On call // ReleasePartitions is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleasePartitionsRequest // - _a1 *querypb.ReleasePartitionsRequest
func (_e *MockQueryNodeServer_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleasePartitions_Call { func (_e *MockQueryNodeServer_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleasePartitions_Call {
return &MockQueryNodeServer_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)} return &MockQueryNodeServer_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)}
} }
@ -794,20 +687,11 @@ func (_c *MockQueryNodeServer_ReleasePartitions_Call) Return(_a0 *commonpb.Statu
return _c return _c
} }
func (_c *MockQueryNodeServer_ReleasePartitions_Call) RunAndReturn(run func(context.Context, *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)) *MockQueryNodeServer_ReleasePartitions_Call {
_c.Call.Return(run)
return _c
}
// ReleaseSegments provides a mock function with given fields: _a0, _a1 // ReleaseSegments provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) ReleaseSegments(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) ReleaseSegments(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -816,6 +700,7 @@ func (_m *MockQueryNodeServer) ReleaseSegments(_a0 context.Context, _a1 *querypb
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseSegmentsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseSegmentsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -831,8 +716,8 @@ type MockQueryNodeServer_ReleaseSegments_Call struct {
} }
// ReleaseSegments is a helper method to define mock.On call // ReleaseSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleaseSegmentsRequest // - _a1 *querypb.ReleaseSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseSegments_Call { func (_e *MockQueryNodeServer_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseSegments_Call {
return &MockQueryNodeServer_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)} return &MockQueryNodeServer_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)}
} }
@ -849,20 +734,11 @@ func (_c *MockQueryNodeServer_ReleaseSegments_Call) Return(_a0 *commonpb.Status,
return _c return _c
} }
func (_c *MockQueryNodeServer_ReleaseSegments_Call) RunAndReturn(run func(context.Context, *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)) *MockQueryNodeServer_ReleaseSegments_Call {
_c.Call.Return(run)
return _c
}
// Search provides a mock function with given fields: _a0, _a1 // Search provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) Search(_a0 context.Context, _a1 *querypb.SearchRequest) (*internalpb.SearchResults, error) { func (_m *MockQueryNodeServer) Search(_a0 context.Context, _a1 *querypb.SearchRequest) (*internalpb.SearchResults, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *internalpb.SearchResults var r0 *internalpb.SearchResults
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -871,6 +747,7 @@ func (_m *MockQueryNodeServer) Search(_a0 context.Context, _a1 *querypb.SearchRe
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -886,8 +763,8 @@ type MockQueryNodeServer_Search_Call struct {
} }
// Search is a helper method to define mock.On call // Search is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SearchRequest // - _a1 *querypb.SearchRequest
func (_e *MockQueryNodeServer_Expecter) Search(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Search_Call { func (_e *MockQueryNodeServer_Expecter) Search(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Search_Call {
return &MockQueryNodeServer_Search_Call{Call: _e.mock.On("Search", _a0, _a1)} return &MockQueryNodeServer_Search_Call{Call: _e.mock.On("Search", _a0, _a1)}
} }
@ -904,20 +781,11 @@ func (_c *MockQueryNodeServer_Search_Call) Return(_a0 *internalpb.SearchResults,
return _c return _c
} }
func (_c *MockQueryNodeServer_Search_Call) RunAndReturn(run func(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)) *MockQueryNodeServer_Search_Call {
_c.Call.Return(run)
return _c
}
// SearchSegments provides a mock function with given fields: _a0, _a1 // SearchSegments provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) SearchSegments(_a0 context.Context, _a1 *querypb.SearchRequest) (*internalpb.SearchResults, error) { func (_m *MockQueryNodeServer) SearchSegments(_a0 context.Context, _a1 *querypb.SearchRequest) (*internalpb.SearchResults, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *internalpb.SearchResults var r0 *internalpb.SearchResults
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -926,6 +794,7 @@ func (_m *MockQueryNodeServer) SearchSegments(_a0 context.Context, _a1 *querypb.
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -941,8 +810,8 @@ type MockQueryNodeServer_SearchSegments_Call struct {
} }
// SearchSegments is a helper method to define mock.On call // SearchSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SearchRequest // - _a1 *querypb.SearchRequest
func (_e *MockQueryNodeServer_Expecter) SearchSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SearchSegments_Call { func (_e *MockQueryNodeServer_Expecter) SearchSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SearchSegments_Call {
return &MockQueryNodeServer_SearchSegments_Call{Call: _e.mock.On("SearchSegments", _a0, _a1)} return &MockQueryNodeServer_SearchSegments_Call{Call: _e.mock.On("SearchSegments", _a0, _a1)}
} }
@ -959,20 +828,11 @@ func (_c *MockQueryNodeServer_SearchSegments_Call) Return(_a0 *internalpb.Search
return _c return _c
} }
func (_c *MockQueryNodeServer_SearchSegments_Call) RunAndReturn(run func(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)) *MockQueryNodeServer_SearchSegments_Call {
_c.Call.Return(run)
return _c
}
// ShowConfigurations provides a mock function with given fields: _a0, _a1 // ShowConfigurations provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { func (_m *MockQueryNodeServer) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *internalpb.ShowConfigurationsResponse var r0 *internalpb.ShowConfigurationsResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) *internalpb.ShowConfigurationsResponse); ok { if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) *internalpb.ShowConfigurationsResponse); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -981,6 +841,7 @@ func (_m *MockQueryNodeServer) ShowConfigurations(_a0 context.Context, _a1 *inte
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -996,8 +857,8 @@ type MockQueryNodeServer_ShowConfigurations_Call struct {
} }
// ShowConfigurations is a helper method to define mock.On call // ShowConfigurations is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.ShowConfigurationsRequest // - _a1 *internalpb.ShowConfigurationsRequest
func (_e *MockQueryNodeServer_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ShowConfigurations_Call { func (_e *MockQueryNodeServer_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ShowConfigurations_Call {
return &MockQueryNodeServer_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} return &MockQueryNodeServer_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)}
} }
@ -1014,20 +875,11 @@ func (_c *MockQueryNodeServer_ShowConfigurations_Call) Return(_a0 *internalpb.Sh
return _c return _c
} }
func (_c *MockQueryNodeServer_ShowConfigurations_Call) RunAndReturn(run func(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)) *MockQueryNodeServer_ShowConfigurations_Call {
_c.Call.Return(run)
return _c
}
// SyncDistribution provides a mock function with given fields: _a0, _a1 // SyncDistribution provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) SyncDistribution(_a0 context.Context, _a1 *querypb.SyncDistributionRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) SyncDistribution(_a0 context.Context, _a1 *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncDistributionRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncDistributionRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncDistributionRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -1036,6 +888,7 @@ func (_m *MockQueryNodeServer) SyncDistribution(_a0 context.Context, _a1 *queryp
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncDistributionRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncDistributionRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -1051,8 +904,8 @@ type MockQueryNodeServer_SyncDistribution_Call struct {
} }
// SyncDistribution is a helper method to define mock.On call // SyncDistribution is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SyncDistributionRequest // - _a1 *querypb.SyncDistributionRequest
func (_e *MockQueryNodeServer_Expecter) SyncDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncDistribution_Call { func (_e *MockQueryNodeServer_Expecter) SyncDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncDistribution_Call {
return &MockQueryNodeServer_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", _a0, _a1)} return &MockQueryNodeServer_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", _a0, _a1)}
} }
@ -1069,20 +922,11 @@ func (_c *MockQueryNodeServer_SyncDistribution_Call) Return(_a0 *commonpb.Status
return _c return _c
} }
func (_c *MockQueryNodeServer_SyncDistribution_Call) RunAndReturn(run func(context.Context, *querypb.SyncDistributionRequest) (*commonpb.Status, error)) *MockQueryNodeServer_SyncDistribution_Call {
_c.Call.Return(run)
return _c
}
// SyncReplicaSegments provides a mock function with given fields: _a0, _a1 // SyncReplicaSegments provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) SyncReplicaSegments(_a0 context.Context, _a1 *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) SyncReplicaSegments(_a0 context.Context, _a1 *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncReplicaSegmentsRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncReplicaSegmentsRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -1091,6 +935,7 @@ func (_m *MockQueryNodeServer) SyncReplicaSegments(_a0 context.Context, _a1 *que
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncReplicaSegmentsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncReplicaSegmentsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -1106,8 +951,8 @@ type MockQueryNodeServer_SyncReplicaSegments_Call struct {
} }
// SyncReplicaSegments is a helper method to define mock.On call // SyncReplicaSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SyncReplicaSegmentsRequest // - _a1 *querypb.SyncReplicaSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) SyncReplicaSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncReplicaSegments_Call { func (_e *MockQueryNodeServer_Expecter) SyncReplicaSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncReplicaSegments_Call {
return &MockQueryNodeServer_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", _a0, _a1)} return &MockQueryNodeServer_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", _a0, _a1)}
} }
@ -1124,20 +969,11 @@ func (_c *MockQueryNodeServer_SyncReplicaSegments_Call) Return(_a0 *commonpb.Sta
return _c return _c
} }
func (_c *MockQueryNodeServer_SyncReplicaSegments_Call) RunAndReturn(run func(context.Context, *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error)) *MockQueryNodeServer_SyncReplicaSegments_Call {
_c.Call.Return(run)
return _c
}
// UnsubDmChannel provides a mock function with given fields: _a0, _a1 // UnsubDmChannel provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) UnsubDmChannel(_a0 context.Context, _a1 *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) UnsubDmChannel(_a0 context.Context, _a1 *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.UnsubDmChannelRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.UnsubDmChannelRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.UnsubDmChannelRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -1146,6 +982,7 @@ func (_m *MockQueryNodeServer) UnsubDmChannel(_a0 context.Context, _a1 *querypb.
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.UnsubDmChannelRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.UnsubDmChannelRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -1161,8 +998,8 @@ type MockQueryNodeServer_UnsubDmChannel_Call struct {
} }
// UnsubDmChannel is a helper method to define mock.On call // UnsubDmChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.UnsubDmChannelRequest // - _a1 *querypb.UnsubDmChannelRequest
func (_e *MockQueryNodeServer_Expecter) UnsubDmChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_UnsubDmChannel_Call { func (_e *MockQueryNodeServer_Expecter) UnsubDmChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_UnsubDmChannel_Call {
return &MockQueryNodeServer_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", _a0, _a1)} return &MockQueryNodeServer_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", _a0, _a1)}
} }
@ -1179,20 +1016,11 @@ func (_c *MockQueryNodeServer_UnsubDmChannel_Call) Return(_a0 *commonpb.Status,
return _c return _c
} }
func (_c *MockQueryNodeServer_UnsubDmChannel_Call) RunAndReturn(run func(context.Context, *querypb.UnsubDmChannelRequest) (*commonpb.Status, error)) *MockQueryNodeServer_UnsubDmChannel_Call {
_c.Call.Return(run)
return _c
}
// WatchDmChannels provides a mock function with given fields: _a0, _a1 // WatchDmChannels provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) WatchDmChannels(_a0 context.Context, _a1 *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) WatchDmChannels(_a0 context.Context, _a1 *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.WatchDmChannelsRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.WatchDmChannelsRequest) *commonpb.Status); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.WatchDmChannelsRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1)
} else { } else {
@ -1201,6 +1029,7 @@ func (_m *MockQueryNodeServer) WatchDmChannels(_a0 context.Context, _a1 *querypb
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.WatchDmChannelsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.WatchDmChannelsRequest) error); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1)
} else { } else {
@ -1216,8 +1045,8 @@ type MockQueryNodeServer_WatchDmChannels_Call struct {
} }
// WatchDmChannels is a helper method to define mock.On call // WatchDmChannels is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.WatchDmChannelsRequest // - _a1 *querypb.WatchDmChannelsRequest
func (_e *MockQueryNodeServer_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_WatchDmChannels_Call { func (_e *MockQueryNodeServer_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_WatchDmChannels_Call {
return &MockQueryNodeServer_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)} return &MockQueryNodeServer_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)}
} }
@ -1234,11 +1063,6 @@ func (_c *MockQueryNodeServer_WatchDmChannels_Call) Return(_a0 *commonpb.Status,
return _c return _c
} }
func (_c *MockQueryNodeServer_WatchDmChannels_Call) RunAndReturn(run func(context.Context, *querypb.WatchDmChannelsRequest) (*commonpb.Status, error)) *MockQueryNodeServer_WatchDmChannels_Call {
_c.Call.Return(run)
return _c
}
type mockConstructorTestingTNewMockQueryNodeServer interface { type mockConstructorTestingTNewMockQueryNodeServer interface {
mock.TestingT mock.TestingT
Cleanup(func()) Cleanup(func())

View File

@ -23,6 +23,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/mock"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
@ -30,10 +35,6 @@ import (
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/stretchr/testify/mock"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"google.golang.org/grpc"
) )
type MockQueryNode struct { type MockQueryNode struct {

View File

@ -56,8 +56,8 @@ type MockCluster_GetComponentStates_Call struct {
} }
// GetComponentStates is a helper method to define mock.On call // GetComponentStates is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
func (_e *MockCluster_Expecter) GetComponentStates(ctx interface{}, nodeID interface{}) *MockCluster_GetComponentStates_Call { func (_e *MockCluster_Expecter) GetComponentStates(ctx interface{}, nodeID interface{}) *MockCluster_GetComponentStates_Call {
return &MockCluster_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx, nodeID)} return &MockCluster_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx, nodeID)}
} }
@ -103,9 +103,9 @@ type MockCluster_GetDataDistribution_Call struct {
} }
// GetDataDistribution is a helper method to define mock.On call // GetDataDistribution is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.GetDataDistributionRequest // - req *querypb.GetDataDistributionRequest
func (_e *MockCluster_Expecter) GetDataDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetDataDistribution_Call { func (_e *MockCluster_Expecter) GetDataDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetDataDistribution_Call {
return &MockCluster_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", ctx, nodeID, req)} return &MockCluster_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", ctx, nodeID, req)}
} }
@ -151,9 +151,9 @@ type MockCluster_GetMetrics_Call struct {
} }
// GetMetrics is a helper method to define mock.On call // GetMetrics is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *milvuspb.GetMetricsRequest // - req *milvuspb.GetMetricsRequest
func (_e *MockCluster_Expecter) GetMetrics(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetMetrics_Call { func (_e *MockCluster_Expecter) GetMetrics(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetMetrics_Call {
return &MockCluster_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, nodeID, req)} return &MockCluster_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, nodeID, req)}
} }
@ -199,9 +199,9 @@ type MockCluster_LoadPartitions_Call struct {
} }
// LoadPartitions is a helper method to define mock.On call // LoadPartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.LoadPartitionsRequest // - req *querypb.LoadPartitionsRequest
func (_e *MockCluster_Expecter) LoadPartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadPartitions_Call { func (_e *MockCluster_Expecter) LoadPartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadPartitions_Call {
return &MockCluster_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, nodeID, req)} return &MockCluster_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, nodeID, req)}
} }
@ -247,9 +247,9 @@ type MockCluster_LoadSegments_Call struct {
} }
// LoadSegments is a helper method to define mock.On call // LoadSegments is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.LoadSegmentsRequest // - req *querypb.LoadSegmentsRequest
func (_e *MockCluster_Expecter) LoadSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadSegments_Call { func (_e *MockCluster_Expecter) LoadSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadSegments_Call {
return &MockCluster_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, nodeID, req)} return &MockCluster_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, nodeID, req)}
} }
@ -295,9 +295,9 @@ type MockCluster_ReleasePartitions_Call struct {
} }
// ReleasePartitions is a helper method to define mock.On call // ReleasePartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.ReleasePartitionsRequest // - req *querypb.ReleasePartitionsRequest
func (_e *MockCluster_Expecter) ReleasePartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleasePartitions_Call { func (_e *MockCluster_Expecter) ReleasePartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleasePartitions_Call {
return &MockCluster_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, nodeID, req)} return &MockCluster_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, nodeID, req)}
} }
@ -343,9 +343,9 @@ type MockCluster_ReleaseSegments_Call struct {
} }
// ReleaseSegments is a helper method to define mock.On call // ReleaseSegments is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.ReleaseSegmentsRequest // - req *querypb.ReleaseSegmentsRequest
func (_e *MockCluster_Expecter) ReleaseSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleaseSegments_Call { func (_e *MockCluster_Expecter) ReleaseSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleaseSegments_Call {
return &MockCluster_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, nodeID, req)} return &MockCluster_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, nodeID, req)}
} }
@ -373,7 +373,7 @@ type MockCluster_Start_Call struct {
} }
// Start is a helper method to define mock.On call // Start is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
func (_e *MockCluster_Expecter) Start(ctx interface{}) *MockCluster_Start_Call { func (_e *MockCluster_Expecter) Start(ctx interface{}) *MockCluster_Start_Call {
return &MockCluster_Start_Call{Call: _e.mock.On("Start", ctx)} return &MockCluster_Start_Call{Call: _e.mock.On("Start", ctx)}
} }
@ -446,9 +446,9 @@ type MockCluster_SyncDistribution_Call struct {
} }
// SyncDistribution is a helper method to define mock.On call // SyncDistribution is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.SyncDistributionRequest // - req *querypb.SyncDistributionRequest
func (_e *MockCluster_Expecter) SyncDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_SyncDistribution_Call { func (_e *MockCluster_Expecter) SyncDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_SyncDistribution_Call {
return &MockCluster_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", ctx, nodeID, req)} return &MockCluster_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", ctx, nodeID, req)}
} }
@ -494,9 +494,9 @@ type MockCluster_UnsubDmChannel_Call struct {
} }
// UnsubDmChannel is a helper method to define mock.On call // UnsubDmChannel is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.UnsubDmChannelRequest // - req *querypb.UnsubDmChannelRequest
func (_e *MockCluster_Expecter) UnsubDmChannel(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_UnsubDmChannel_Call { func (_e *MockCluster_Expecter) UnsubDmChannel(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_UnsubDmChannel_Call {
return &MockCluster_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", ctx, nodeID, req)} return &MockCluster_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", ctx, nodeID, req)}
} }
@ -542,9 +542,9 @@ type MockCluster_WatchDmChannels_Call struct {
} }
// WatchDmChannels is a helper method to define mock.On call // WatchDmChannels is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.WatchDmChannelsRequest // - req *querypb.WatchDmChannelsRequest
func (_e *MockCluster_Expecter) WatchDmChannels(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_WatchDmChannels_Call { func (_e *MockCluster_Expecter) WatchDmChannels(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_WatchDmChannels_Call {
return &MockCluster_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", ctx, nodeID, req)} return &MockCluster_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", ctx, nodeID, req)}
} }

View File

@ -41,7 +41,7 @@ type MockScheduler_Add_Call struct {
} }
// Add is a helper method to define mock.On call // Add is a helper method to define mock.On call
// - task Task // - task Task
func (_e *MockScheduler_Expecter) Add(task interface{}) *MockScheduler_Add_Call { func (_e *MockScheduler_Expecter) Add(task interface{}) *MockScheduler_Add_Call {
return &MockScheduler_Add_Call{Call: _e.mock.On("Add", task)} return &MockScheduler_Add_Call{Call: _e.mock.On("Add", task)}
} }
@ -69,7 +69,7 @@ type MockScheduler_AddExecutor_Call struct {
} }
// AddExecutor is a helper method to define mock.On call // AddExecutor is a helper method to define mock.On call
// - nodeID int64 // - nodeID int64
func (_e *MockScheduler_Expecter) AddExecutor(nodeID interface{}) *MockScheduler_AddExecutor_Call { func (_e *MockScheduler_Expecter) AddExecutor(nodeID interface{}) *MockScheduler_AddExecutor_Call {
return &MockScheduler_AddExecutor_Call{Call: _e.mock.On("AddExecutor", nodeID)} return &MockScheduler_AddExecutor_Call{Call: _e.mock.On("AddExecutor", nodeID)}
} }
@ -97,7 +97,7 @@ type MockScheduler_Dispatch_Call struct {
} }
// Dispatch is a helper method to define mock.On call // Dispatch is a helper method to define mock.On call
// - node int64 // - node int64
func (_e *MockScheduler_Expecter) Dispatch(node interface{}) *MockScheduler_Dispatch_Call { func (_e *MockScheduler_Expecter) Dispatch(node interface{}) *MockScheduler_Dispatch_Call {
return &MockScheduler_Dispatch_Call{Call: _e.mock.On("Dispatch", node)} return &MockScheduler_Dispatch_Call{Call: _e.mock.On("Dispatch", node)}
} }
@ -170,7 +170,7 @@ type MockScheduler_GetNodeChannelDelta_Call struct {
} }
// GetNodeChannelDelta is a helper method to define mock.On call // GetNodeChannelDelta is a helper method to define mock.On call
// - nodeID int64 // - nodeID int64
func (_e *MockScheduler_Expecter) GetNodeChannelDelta(nodeID interface{}) *MockScheduler_GetNodeChannelDelta_Call { func (_e *MockScheduler_Expecter) GetNodeChannelDelta(nodeID interface{}) *MockScheduler_GetNodeChannelDelta_Call {
return &MockScheduler_GetNodeChannelDelta_Call{Call: _e.mock.On("GetNodeChannelDelta", nodeID)} return &MockScheduler_GetNodeChannelDelta_Call{Call: _e.mock.On("GetNodeChannelDelta", nodeID)}
} }
@ -207,7 +207,7 @@ type MockScheduler_GetNodeSegmentDelta_Call struct {
} }
// GetNodeSegmentDelta is a helper method to define mock.On call // GetNodeSegmentDelta is a helper method to define mock.On call
// - nodeID int64 // - nodeID int64
func (_e *MockScheduler_Expecter) GetNodeSegmentDelta(nodeID interface{}) *MockScheduler_GetNodeSegmentDelta_Call { func (_e *MockScheduler_Expecter) GetNodeSegmentDelta(nodeID interface{}) *MockScheduler_GetNodeSegmentDelta_Call {
return &MockScheduler_GetNodeSegmentDelta_Call{Call: _e.mock.On("GetNodeSegmentDelta", nodeID)} return &MockScheduler_GetNodeSegmentDelta_Call{Call: _e.mock.On("GetNodeSegmentDelta", nodeID)}
} }
@ -271,7 +271,7 @@ type MockScheduler_RemoveByNode_Call struct {
} }
// RemoveByNode is a helper method to define mock.On call // RemoveByNode is a helper method to define mock.On call
// - node int64 // - node int64
func (_e *MockScheduler_Expecter) RemoveByNode(node interface{}) *MockScheduler_RemoveByNode_Call { func (_e *MockScheduler_Expecter) RemoveByNode(node interface{}) *MockScheduler_RemoveByNode_Call {
return &MockScheduler_RemoveByNode_Call{Call: _e.mock.On("RemoveByNode", node)} return &MockScheduler_RemoveByNode_Call{Call: _e.mock.On("RemoveByNode", node)}
} }
@ -299,7 +299,7 @@ type MockScheduler_RemoveExecutor_Call struct {
} }
// RemoveExecutor is a helper method to define mock.On call // RemoveExecutor is a helper method to define mock.On call
// - nodeID int64 // - nodeID int64
func (_e *MockScheduler_Expecter) RemoveExecutor(nodeID interface{}) *MockScheduler_RemoveExecutor_Call { func (_e *MockScheduler_Expecter) RemoveExecutor(nodeID interface{}) *MockScheduler_RemoveExecutor_Call {
return &MockScheduler_RemoveExecutor_Call{Call: _e.mock.On("RemoveExecutor", nodeID)} return &MockScheduler_RemoveExecutor_Call{Call: _e.mock.On("RemoveExecutor", nodeID)}
} }
@ -327,7 +327,7 @@ type MockScheduler_Start_Call struct {
} }
// Start is a helper method to define mock.On call // Start is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
func (_e *MockScheduler_Expecter) Start(ctx interface{}) *MockScheduler_Start_Call { func (_e *MockScheduler_Expecter) Start(ctx interface{}) *MockScheduler_Start_Call {
return &MockScheduler_Start_Call{Call: _e.mock.On("Start", ctx)} return &MockScheduler_Start_Call{Call: _e.mock.On("Start", ctx)}
} }

View File

@ -46,7 +46,7 @@ type MockManager_GetWorker_Call struct {
} }
// GetWorker is a helper method to define mock.On call // GetWorker is a helper method to define mock.On call
// - nodeID int64 // - nodeID int64
func (_e *MockManager_Expecter) GetWorker(nodeID interface{}) *MockManager_GetWorker_Call { func (_e *MockManager_Expecter) GetWorker(nodeID interface{}) *MockManager_GetWorker_Call {
return &MockManager_GetWorker_Call{Call: _e.mock.On("GetWorker", nodeID)} return &MockManager_GetWorker_Call{Call: _e.mock.On("GetWorker", nodeID)}
} }

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.21.1. DO NOT EDIT. // Code generated by mockery v2.16.0. DO NOT EDIT.
package cluster package cluster
@ -44,8 +44,8 @@ type MockWorker_Delete_Call struct {
} }
// Delete is a helper method to define mock.On call // Delete is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.DeleteRequest // - req *querypb.DeleteRequest
func (_e *MockWorker_Expecter) Delete(ctx interface{}, req interface{}) *MockWorker_Delete_Call { func (_e *MockWorker_Expecter) Delete(ctx interface{}, req interface{}) *MockWorker_Delete_Call {
return &MockWorker_Delete_Call{Call: _e.mock.On("Delete", ctx, req)} return &MockWorker_Delete_Call{Call: _e.mock.On("Delete", ctx, req)}
} }
@ -62,20 +62,11 @@ func (_c *MockWorker_Delete_Call) Return(_a0 error) *MockWorker_Delete_Call {
return _c return _c
} }
func (_c *MockWorker_Delete_Call) RunAndReturn(run func(context.Context, *querypb.DeleteRequest) error) *MockWorker_Delete_Call {
_c.Call.Return(run)
return _c
}
// GetStatistics provides a mock function with given fields: ctx, req // GetStatistics provides a mock function with given fields: ctx, req
func (_m *MockWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { func (_m *MockWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
var r0 *internalpb.GetStatisticsResponse var r0 *internalpb.GetStatisticsResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) *internalpb.GetStatisticsResponse); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) *internalpb.GetStatisticsResponse); ok {
r0 = rf(ctx, req) r0 = rf(ctx, req)
} else { } else {
@ -84,6 +75,7 @@ func (_m *MockWorker) GetStatistics(ctx context.Context, req *querypb.GetStatist
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok {
r1 = rf(ctx, req) r1 = rf(ctx, req)
} else { } else {
@ -99,8 +91,8 @@ type MockWorker_GetStatistics_Call struct {
} }
// GetStatistics is a helper method to define mock.On call // GetStatistics is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.GetStatisticsRequest // - req *querypb.GetStatisticsRequest
func (_e *MockWorker_Expecter) GetStatistics(ctx interface{}, req interface{}) *MockWorker_GetStatistics_Call { func (_e *MockWorker_Expecter) GetStatistics(ctx interface{}, req interface{}) *MockWorker_GetStatistics_Call {
return &MockWorker_GetStatistics_Call{Call: _e.mock.On("GetStatistics", ctx, req)} return &MockWorker_GetStatistics_Call{Call: _e.mock.On("GetStatistics", ctx, req)}
} }
@ -117,11 +109,6 @@ func (_c *MockWorker_GetStatistics_Call) Return(_a0 *internalpb.GetStatisticsRes
return _c return _c
} }
func (_c *MockWorker_GetStatistics_Call) RunAndReturn(run func(context.Context, *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error)) *MockWorker_GetStatistics_Call {
_c.Call.Return(run)
return _c
}
// IsHealthy provides a mock function with given fields: // IsHealthy provides a mock function with given fields:
func (_m *MockWorker) IsHealthy() bool { func (_m *MockWorker) IsHealthy() bool {
ret := _m.Called() ret := _m.Called()
@ -158,11 +145,6 @@ func (_c *MockWorker_IsHealthy_Call) Return(_a0 bool) *MockWorker_IsHealthy_Call
return _c return _c
} }
func (_c *MockWorker_IsHealthy_Call) RunAndReturn(run func() bool) *MockWorker_IsHealthy_Call {
_c.Call.Return(run)
return _c
}
// LoadSegments provides a mock function with given fields: _a0, _a1 // LoadSegments provides a mock function with given fields: _a0, _a1
func (_m *MockWorker) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) error { func (_m *MockWorker) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) error {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
@ -183,8 +165,8 @@ type MockWorker_LoadSegments_Call struct {
} }
// LoadSegments is a helper method to define mock.On call // LoadSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.LoadSegmentsRequest // - _a1 *querypb.LoadSegmentsRequest
func (_e *MockWorker_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockWorker_LoadSegments_Call { func (_e *MockWorker_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockWorker_LoadSegments_Call {
return &MockWorker_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)} return &MockWorker_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)}
} }
@ -201,20 +183,11 @@ func (_c *MockWorker_LoadSegments_Call) Return(_a0 error) *MockWorker_LoadSegmen
return _c return _c
} }
func (_c *MockWorker_LoadSegments_Call) RunAndReturn(run func(context.Context, *querypb.LoadSegmentsRequest) error) *MockWorker_LoadSegments_Call {
_c.Call.Return(run)
return _c
}
// QuerySegments provides a mock function with given fields: ctx, req // QuerySegments provides a mock function with given fields: ctx, req
func (_m *MockWorker) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { func (_m *MockWorker) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
var r0 *internalpb.RetrieveResults var r0 *internalpb.RetrieveResults
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok {
r0 = rf(ctx, req) r0 = rf(ctx, req)
} else { } else {
@ -223,6 +196,7 @@ func (_m *MockWorker) QuerySegments(ctx context.Context, req *querypb.QueryReque
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok {
r1 = rf(ctx, req) r1 = rf(ctx, req)
} else { } else {
@ -238,8 +212,8 @@ type MockWorker_QuerySegments_Call struct {
} }
// QuerySegments is a helper method to define mock.On call // QuerySegments is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.QueryRequest // - req *querypb.QueryRequest
func (_e *MockWorker_Expecter) QuerySegments(ctx interface{}, req interface{}) *MockWorker_QuerySegments_Call { func (_e *MockWorker_Expecter) QuerySegments(ctx interface{}, req interface{}) *MockWorker_QuerySegments_Call {
return &MockWorker_QuerySegments_Call{Call: _e.mock.On("QuerySegments", ctx, req)} return &MockWorker_QuerySegments_Call{Call: _e.mock.On("QuerySegments", ctx, req)}
} }
@ -256,11 +230,6 @@ func (_c *MockWorker_QuerySegments_Call) Return(_a0 *internalpb.RetrieveResults,
return _c return _c
} }
func (_c *MockWorker_QuerySegments_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)) *MockWorker_QuerySegments_Call {
_c.Call.Return(run)
return _c
}
// ReleaseSegments provides a mock function with given fields: _a0, _a1 // ReleaseSegments provides a mock function with given fields: _a0, _a1
func (_m *MockWorker) ReleaseSegments(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest) error { func (_m *MockWorker) ReleaseSegments(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest) error {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
@ -281,8 +250,8 @@ type MockWorker_ReleaseSegments_Call struct {
} }
// ReleaseSegments is a helper method to define mock.On call // ReleaseSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleaseSegmentsRequest // - _a1 *querypb.ReleaseSegmentsRequest
func (_e *MockWorker_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockWorker_ReleaseSegments_Call { func (_e *MockWorker_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockWorker_ReleaseSegments_Call {
return &MockWorker_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)} return &MockWorker_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)}
} }
@ -299,20 +268,11 @@ func (_c *MockWorker_ReleaseSegments_Call) Return(_a0 error) *MockWorker_Release
return _c return _c
} }
func (_c *MockWorker_ReleaseSegments_Call) RunAndReturn(run func(context.Context, *querypb.ReleaseSegmentsRequest) error) *MockWorker_ReleaseSegments_Call {
_c.Call.Return(run)
return _c
}
// SearchSegments provides a mock function with given fields: ctx, req // SearchSegments provides a mock function with given fields: ctx, req
func (_m *MockWorker) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { func (_m *MockWorker) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
var r0 *internalpb.SearchResults var r0 *internalpb.SearchResults
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok { if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok {
r0 = rf(ctx, req) r0 = rf(ctx, req)
} else { } else {
@ -321,6 +281,7 @@ func (_m *MockWorker) SearchSegments(ctx context.Context, req *querypb.SearchReq
} }
} }
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok {
r1 = rf(ctx, req) r1 = rf(ctx, req)
} else { } else {
@ -336,8 +297,8 @@ type MockWorker_SearchSegments_Call struct {
} }
// SearchSegments is a helper method to define mock.On call // SearchSegments is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.SearchRequest // - req *querypb.SearchRequest
func (_e *MockWorker_Expecter) SearchSegments(ctx interface{}, req interface{}) *MockWorker_SearchSegments_Call { func (_e *MockWorker_Expecter) SearchSegments(ctx interface{}, req interface{}) *MockWorker_SearchSegments_Call {
return &MockWorker_SearchSegments_Call{Call: _e.mock.On("SearchSegments", ctx, req)} return &MockWorker_SearchSegments_Call{Call: _e.mock.On("SearchSegments", ctx, req)}
} }
@ -354,11 +315,6 @@ func (_c *MockWorker_SearchSegments_Call) Return(_a0 *internalpb.SearchResults,
return _c return _c
} }
func (_c *MockWorker_SearchSegments_Call) RunAndReturn(run func(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)) *MockWorker_SearchSegments_Call {
_c.Call.Return(run)
return _c
}
// Stop provides a mock function with given fields: // Stop provides a mock function with given fields:
func (_m *MockWorker) Stop() { func (_m *MockWorker) Stop() {
_m.Called() _m.Called()
@ -386,11 +342,6 @@ func (_c *MockWorker_Stop_Call) Return() *MockWorker_Stop_Call {
return _c return _c
} }
func (_c *MockWorker_Stop_Call) RunAndReturn(run func()) *MockWorker_Stop_Call {
_c.Call.Return(run)
return _c
}
type mockConstructorTestingTNewMockWorker interface { type mockConstructorTestingTNewMockWorker interface {
mock.TestingT mock.TestingT
Cleanup(func()) Cleanup(func())

View File

@ -29,11 +29,11 @@ func (_m *MockLoader) EXPECT() *MockLoader_Expecter {
return &MockLoader_Expecter{mock: &_m.Mock} return &MockLoader_Expecter{mock: &_m.Mock}
} }
// Load provides a mock function with given fields: ctx, collectionID, segmentType, version, infos // Load provides a mock function with given fields: ctx, collectionID, segmentType, version, segments
func (_m *MockLoader) Load(ctx context.Context, collectionID int64, segmentType commonpb.SegmentState, version int64, infos ...*querypb.SegmentLoadInfo) ([]Segment, error) { func (_m *MockLoader) Load(ctx context.Context, collectionID int64, segmentType commonpb.SegmentState, version int64, segments ...*querypb.SegmentLoadInfo) ([]Segment, error) {
_va := make([]interface{}, len(infos)) _va := make([]interface{}, len(segments))
for _i := range infos { for _i := range segments {
_va[_i] = infos[_i] _va[_i] = segments[_i]
} }
var _ca []interface{} var _ca []interface{}
_ca = append(_ca, ctx, collectionID, segmentType, version) _ca = append(_ca, ctx, collectionID, segmentType, version)
@ -42,7 +42,7 @@ func (_m *MockLoader) Load(ctx context.Context, collectionID int64, segmentType
var r0 []Segment var r0 []Segment
if rf, ok := ret.Get(0).(func(context.Context, int64, commonpb.SegmentState, int64, ...*querypb.SegmentLoadInfo) []Segment); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, commonpb.SegmentState, int64, ...*querypb.SegmentLoadInfo) []Segment); ok {
r0 = rf(ctx, collectionID, segmentType, version, infos...) r0 = rf(ctx, collectionID, segmentType, version, segments...)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).([]Segment) r0 = ret.Get(0).([]Segment)
@ -51,7 +51,7 @@ func (_m *MockLoader) Load(ctx context.Context, collectionID int64, segmentType
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, commonpb.SegmentState, int64, ...*querypb.SegmentLoadInfo) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int64, commonpb.SegmentState, int64, ...*querypb.SegmentLoadInfo) error); ok {
r1 = rf(ctx, collectionID, segmentType, version, infos...) r1 = rf(ctx, collectionID, segmentType, version, segments...)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -75,7 +75,7 @@ func (_e *MockLoader_Expecter) Load(ctx interface{}, collectionID interface{}, s
append([]interface{}{ctx, collectionID, segmentType, version}, infos...)...)} append([]interface{}{ctx, collectionID, segmentType, version}, infos...)...)}
} }
func (_c *MockLoader_Load_Call) Run(run func(ctx context.Context, collectionID int64, segmentType commonpb.SegmentState, version int64, infos ...*querypb.SegmentLoadInfo)) *MockLoader_Load_Call { func (_c *MockLoader_Load_Call) Run(run func(ctx context.Context, collectionID int64, segmentType commonpb.SegmentState, version int64, segments ...*querypb.SegmentLoadInfo)) *MockLoader_Load_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]*querypb.SegmentLoadInfo, len(args)-4) variadicArgs := make([]*querypb.SegmentLoadInfo, len(args)-4)
for i, a := range args[4:] { for i, a := range args[4:] {

View File

@ -36,9 +36,9 @@ func (t *alterAliasTask) Prepare(ctx context.Context) error {
} }
func (t *alterAliasTask) Execute(ctx context.Context) error { func (t *alterAliasTask) Execute(ctx context.Context) error {
if err := t.core.ExpireMetaCache(ctx, []string{t.Req.GetAlias()}, InvalidCollectionID, t.GetTs()); err != nil { if err := t.core.ExpireMetaCache(ctx, t.Req.GetDbName(), []string{t.Req.GetAlias()}, InvalidCollectionID, t.GetTs()); err != nil {
return err return err
} }
// alter alias is atomic enough. // alter alias is atomic enough.
return t.core.meta.AlterAlias(ctx, t.Req.GetAlias(), t.Req.GetCollectionName(), t.GetTs()) return t.core.meta.AlterAlias(ctx, t.Req.GetDbName(), t.Req.GetAlias(), t.Req.GetCollectionName(), t.GetTs())
} }

View File

@ -21,13 +21,12 @@ import (
"fmt" "fmt"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/log"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/log"
) )
type alterCollectionTask struct { type alterCollectionTask struct {
@ -49,7 +48,7 @@ func (a *alterCollectionTask) Execute(ctx context.Context) error {
return errors.New("only support alter collection properties, but collection properties is empty") return errors.New("only support alter collection properties, but collection properties is empty")
} }
oldColl, err := a.core.meta.GetCollectionByName(ctx, a.Req.GetCollectionName(), a.ts) oldColl, err := a.core.meta.GetCollectionByName(ctx, a.Req.GetDbName(), a.Req.GetCollectionName(), a.ts)
if err != nil { if err != nil {
log.Warn("get collection failed during changing collection state", log.Warn("get collection failed during changing collection state",
zap.String("collectionName", a.Req.GetCollectionName()), zap.Uint64("ts", a.ts)) zap.String("collectionName", a.Req.GetCollectionName()), zap.Uint64("ts", a.ts))
@ -70,6 +69,7 @@ func (a *alterCollectionTask) Execute(ctx context.Context) error {
redoTask.AddSyncStep(&expireCacheStep{ redoTask.AddSyncStep(&expireCacheStep{
baseStep: baseStep{core: a.core}, baseStep: baseStep{core: a.core},
dbName: a.Req.GetDbName(),
collectionNames: []string{oldColl.Name}, collectionNames: []string{oldColl.Name},
collectionID: oldColl.CollectionID, collectionID: oldColl.CollectionID,
ts: ts, ts: ts,

View File

@ -21,15 +21,14 @@ import (
"testing" "testing"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/common"
) )
func Test_alterCollectionTask_Prepare(t *testing.T) { func Test_alterCollectionTask_Prepare(t *testing.T) {
@ -80,13 +79,19 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
}) })
t.Run("alter step failed", func(t *testing.T) { t.Run("alter step failed", func(t *testing.T) {
meta := newMockMetaTable() meta := mockrootcoord.NewIMetaTable(t)
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.On("GetCollectionByName",
return &model.Collection{CollectionID: int64(1)}, nil mock.Anything,
} mock.Anything,
meta.AlterCollectionFunc = func(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error { mock.Anything,
return errors.New("err") mock.Anything,
} ).Return(&model.Collection{CollectionID: int64(1)}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(errors.New("err"))
core := newTestCore(withMeta(meta)) core := newTestCore(withMeta(meta))
task := &alterCollectionTask{ task := &alterCollectionTask{
@ -103,13 +108,19 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
}) })
t.Run("broadcast step failed", func(t *testing.T) { t.Run("broadcast step failed", func(t *testing.T) {
meta := newMockMetaTable() meta := mockrootcoord.NewIMetaTable(t)
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.On("GetCollectionByName",
return &model.Collection{CollectionID: int64(1)}, nil mock.Anything,
} mock.Anything,
meta.AlterCollectionFunc = func(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error { mock.Anything,
return nil mock.Anything,
} ).Return(&model.Collection{CollectionID: int64(1)}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
broker := newMockBroker() broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error { broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
@ -131,13 +142,19 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
}) })
t.Run("alter successfully", func(t *testing.T) { t.Run("alter successfully", func(t *testing.T) {
meta := newMockMetaTable() meta := mockrootcoord.NewIMetaTable(t)
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.On("GetCollectionByName",
return &model.Collection{CollectionID: int64(1)}, nil mock.Anything,
} mock.Anything,
meta.AlterCollectionFunc = func(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error { mock.Anything,
return nil mock.Anything,
} ).Return(&model.Collection{CollectionID: int64(1)}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
broker := newMockBroker() broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error { broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {

View File

@ -21,7 +21,6 @@ import (
"fmt" "fmt"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -260,7 +259,7 @@ func (b *ServerBroker) GetSegmentIndexState(ctx context.Context, collID UniqueID
func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milvuspb.AlterCollectionRequest) error { func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
log.Info("broadcasting request to alter collection", zap.String("collection name", req.GetCollectionName()), zap.Int64("collection id", req.GetCollectionID())) log.Info("broadcasting request to alter collection", zap.String("collection name", req.GetCollectionName()), zap.Int64("collection id", req.GetCollectionID()))
colMeta, err := b.s.meta.GetCollectionByID(ctx, req.GetCollectionID(), typeutil.MaxTimestamp, false) colMeta, err := b.s.meta.GetCollectionByID(ctx, req.GetDbName(), req.GetCollectionID(), typeutil.MaxTimestamp, false)
if err != nil { if err != nil {
return err return err
} }

View File

@ -29,6 +29,8 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
) )
@ -267,11 +269,15 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
t.Run("get meta fail", func(t *testing.T) { t.Run("get meta fail", func(t *testing.T) {
c := newTestCore(withInvalidDataCoord()) c := newTestCore(withInvalidDataCoord())
c.meta = &mockMetaTable{ meta := mockrootcoord.NewIMetaTable(t)
GetCollectionByIDFunc: func(ctx context.Context, collectionID UniqueID, ts Timestamp, allowUnavailable bool) (*model.Collection, error) { meta.On("GetCollectionByID",
return nil, errors.New("err") mock.Anything,
}, mock.Anything,
} mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil, errors.New("err"))
c.meta = meta
b := newServerBroker(c) b := newServerBroker(c)
ctx := context.Background() ctx := context.Background()
err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{}) err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{})
@ -280,11 +286,15 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
t.Run("failed to execute", func(t *testing.T) { t.Run("failed to execute", func(t *testing.T) {
c := newTestCore(withInvalidDataCoord()) c := newTestCore(withInvalidDataCoord())
c.meta = &mockMetaTable{ meta := mockrootcoord.NewIMetaTable(t)
GetCollectionByIDFunc: func(ctx context.Context, collectionID UniqueID, ts Timestamp, allowUnavailable bool) (*model.Collection, error) { meta.On("GetCollectionByID",
return collMeta, nil mock.Anything,
}, mock.Anything,
} mock.Anything,
mock.Anything,
mock.Anything,
).Return(collMeta, nil)
c.meta = meta
b := newServerBroker(c) b := newServerBroker(c)
ctx := context.Background() ctx := context.Background()
err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{}) err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{})
@ -293,11 +303,15 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
t.Run("non success error code on execute", func(t *testing.T) { t.Run("non success error code on execute", func(t *testing.T) {
c := newTestCore(withFailedDataCoord()) c := newTestCore(withFailedDataCoord())
c.meta = &mockMetaTable{ meta := mockrootcoord.NewIMetaTable(t)
GetCollectionByIDFunc: func(ctx context.Context, collectionID UniqueID, ts Timestamp, allowUnavailable bool) (*model.Collection, error) { meta.On("GetCollectionByID",
return collMeta, nil mock.Anything,
}, mock.Anything,
} mock.Anything,
mock.Anything,
mock.Anything,
).Return(collMeta, nil)
c.meta = meta
b := newServerBroker(c) b := newServerBroker(c)
ctx := context.Background() ctx := context.Background()
err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{}) err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{})
@ -306,11 +320,15 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
c := newTestCore(withValidDataCoord()) c := newTestCore(withValidDataCoord())
c.meta = &mockMetaTable{ meta := mockrootcoord.NewIMetaTable(t)
GetCollectionByIDFunc: func(ctx context.Context, collectionID UniqueID, ts Timestamp, allowUnavailable bool) (*model.Collection, error) { meta.On("GetCollectionByID",
return collMeta, nil mock.Anything,
}, mock.Anything,
} mock.Anything,
mock.Anything,
mock.Anything,
).Return(collMeta, nil)
c.meta = meta
b := newServerBroker(c) b := newServerBroker(c)
ctx := context.Background() ctx := context.Background()

View File

@ -37,9 +37,9 @@ func (t *createAliasTask) Prepare(ctx context.Context) error {
} }
func (t *createAliasTask) Execute(ctx context.Context) error { func (t *createAliasTask) Execute(ctx context.Context) error {
if err := t.core.ExpireMetaCache(ctx, []string{t.Req.GetAlias(), t.Req.GetCollectionName()}, InvalidCollectionID, t.GetTs()); err != nil { if err := t.core.ExpireMetaCache(ctx, t.Req.GetDbName(), []string{t.Req.GetAlias(), t.Req.GetCollectionName()}, InvalidCollectionID, t.GetTs()); err != nil {
return err return err
} }
// create alias is atomic enough. // create alias is atomic enough.
return t.core.meta.CreateAlias(ctx, t.Req.GetAlias(), t.Req.GetCollectionName(), t.GetTs()) return t.core.meta.CreateAlias(ctx, t.Req.GetDbName(), t.Req.GetAlias(), t.Req.GetCollectionName(), t.GetTs())
} }

View File

@ -23,6 +23,8 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
@ -35,10 +37,8 @@ import (
"github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/parameterutil.go" parameterutil "github.com/milvus-io/milvus/pkg/util/parameterutil.go"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/samber/lo"
"go.uber.org/zap"
) )
type collectionChannels struct { type collectionChannels struct {
@ -53,6 +53,7 @@ type createCollectionTask struct {
collID UniqueID collID UniqueID
partIDs []UniqueID partIDs []UniqueID
channels collectionChannels channels collectionChannels
dbID UniqueID
partitionNames []string partitionNames []string
} }
@ -77,6 +78,30 @@ func (t *createCollectionTask) validate() error {
return fmt.Errorf("shard num (%d) exceeds system limit (%d)", shardsNum, cfgShardLimit) return fmt.Errorf("shard num (%d) exceeds system limit (%d)", shardsNum, cfgShardLimit)
} }
db2CollIDs := t.core.meta.ListAllAvailCollections(t.ctx)
collIDs, ok := db2CollIDs[t.dbID]
if !ok {
log.Warn("can not found DB ID", zap.String("collection", t.Req.GetCollectionName()), zap.String("dbName", t.Req.GetDbName()))
return merr.WrapErrDatabaseNotFound(t.Req.GetDbName(), "failed to create collection")
}
maxColNumPerDB := Params.QuotaConfig.MaxCollectionNumPerDB.GetAsInt()
if len(collIDs) >= maxColNumPerDB {
log.Warn("unable to create collection because the number of collection has reached the limit in DB", zap.Int("maxCollectionNumPerDB", maxColNumPerDB))
return merr.WrapErrCollectionResourceLimitExceeded(fmt.Sprintf("Failed to create collection, maxCollectionNumPerDB={%d}", maxColNumPerDB))
}
totalCollections := 0
for _, collIDs := range db2CollIDs {
totalCollections += len(collIDs)
}
maxCollectionNum := Params.QuotaConfig.MaxCollectionNum.GetAsInt()
if totalCollections >= maxCollectionNum {
log.Warn("unable to create collection because the number of collection has reached the limit", zap.Int("max_collection_num", maxCollectionNum))
return merr.WrapErrCollectionResourceLimitExceeded(fmt.Sprintf("Failed to create collection, limit={%d}", maxCollectionNum))
}
return nil return nil
} }
@ -297,6 +322,12 @@ func (t *createCollectionTask) assignChannels() error {
} }
func (t *createCollectionTask) Prepare(ctx context.Context) error { func (t *createCollectionTask) Prepare(ctx context.Context) error {
db, err := t.core.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
t.dbID = db.ID
if err := t.validate(); err != nil { if err := t.validate(); err != nil {
return err return err
} }
@ -386,6 +417,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
collInfo := model.Collection{ collInfo := model.Collection{
CollectionID: collID, CollectionID: collID,
DBID: t.dbID,
Name: t.schema.Name, Name: t.schema.Name,
Description: t.schema.Description, Description: t.schema.Description,
AutoID: t.schema.AutoID, AutoID: t.schema.AutoID,
@ -407,7 +439,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
// are not promised idempotent. // are not promised idempotent.
clone := collInfo.Clone() clone := collInfo.Clone()
// need double check in meta table if we can't promise the sequence execution. // need double check in meta table if we can't promise the sequence execution.
existedCollInfo, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetCollectionName(), typeutil.MaxTimestamp) existedCollInfo, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp)
if err == nil { if err == nil {
equal := existedCollInfo.Equal(*clone) equal := existedCollInfo.Equal(*clone)
if !equal { if !equal {
@ -418,30 +450,10 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
return nil return nil
} }
// check collection number quota for the entire the instance
existedCollInfos, err := t.core.meta.ListCollections(ctx, typeutil.MaxTimestamp)
if err != nil {
log.Warn("fail to list collections for checking the collection count", zap.Error(err))
return fmt.Errorf("fail to list collections for checking the collection count")
}
maxCollectionNum := Params.QuotaConfig.MaxCollectionNum.GetAsInt()
if len(existedCollInfos) >= maxCollectionNum {
log.Warn("unable to create collection because the number of collection has reached the limit", zap.Int("max_collection_num", maxCollectionNum))
return merr.WrapErrCollectionResourceLimitExceeded(fmt.Sprintf("Failed to create collection, limit={%d}", maxCollectionNum))
}
// check collection number quota for DB
existedColsInDB := lo.Filter(existedCollInfos, func(collection *model.Collection, _ int) bool {
return t.Req.GetDbName() != "" && collection.DBName == t.Req.GetDbName()
})
maxColNumPerDB := Params.QuotaConfig.MaxCollectionNumPerDB.GetAsInt()
if len(existedColsInDB) >= maxColNumPerDB {
log.Warn("unable to create collection because the number of collection has reached the limit in DB", zap.Int("maxCollectionNumPerDB", maxColNumPerDB))
return merr.WrapErrCollectionResourceLimitExceeded(fmt.Sprintf("Failed to create collection, maxCollectionNumPerDB={%d}", maxColNumPerDB))
}
undoTask := newBaseUndoTask(t.core.stepExecutor) undoTask := newBaseUndoTask(t.core.stepExecutor)
undoTask.AddStep(&expireCacheStep{ undoTask.AddStep(&expireCacheStep{
baseStep: baseStep{core: t.core}, baseStep: baseStep{core: t.core},
dbName: t.Req.GetDbName(),
collectionNames: []string{t.Req.GetCollectionName()}, collectionNames: []string{t.Req.GetCollectionName()},
collectionID: InvalidCollectionID, collectionID: InvalidCollectionID,
ts: ts, ts: ts,

View File

@ -19,26 +19,30 @@ package rootcoord
import ( import (
"context" "context"
"math" "math"
"strconv"
"testing" "testing"
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/stretchr/testify/mock"
) )
func Test_createCollectionTask_validate(t *testing.T) { func Test_createCollectionTask_validate(t *testing.T) {
paramtable.Init()
t.Run("empty request", func(t *testing.T) { t.Run("empty request", func(t *testing.T) {
task := createCollectionTask{ task := createCollectionTask{
Req: nil, Req: nil,
@ -72,7 +76,7 @@ func Test_createCollectionTask_validate(t *testing.T) {
t.Run("shard num exceeds limit", func(t *testing.T) { t.Run("shard num exceeds limit", func(t *testing.T) {
// TODO: better to have a `Set` method for ParamItem. // TODO: better to have a `Set` method for ParamItem.
cfgShardLimit := Params.ProxyCfg.MaxShardNum.GetAsInt32() cfgShardLimit := paramtable.Get().ProxyCfg.MaxShardNum.GetAsInt32()
task := createCollectionTask{ task := createCollectionTask{
Req: &milvuspb.CreateCollectionRequest{ Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
@ -83,12 +87,91 @@ func Test_createCollectionTask_validate(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("normal case", func(t *testing.T) { t.Run("total collection num exceeds limit", func(t *testing.T) {
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(2))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("ListAllAvailCollections",
mock.Anything,
).Return(map[int64][]int64{
1: {1, 2},
}, nil)
core := newTestCore(withMeta(meta))
task := createCollectionTask{ task := createCollectionTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateCollectionRequest{ Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
}, },
} }
err := task.validate()
assert.Error(t, err)
task = createCollectionTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
},
dbID: util.DefaultDBID,
}
err = task.validate()
assert.Error(t, err)
})
t.Run("collection num per db exceeds limit", func(t *testing.T) {
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNumPerDB.Key, strconv.Itoa(2))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNumPerDB.Key)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("ListAllAvailCollections",
mock.Anything,
).Return(map[int64][]int64{
1: {1, 2},
}, nil)
core := newTestCore(withMeta(meta))
task := createCollectionTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
},
}
err := task.validate()
assert.Error(t, err)
task = createCollectionTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
},
dbID: util.DefaultDBID,
}
err = task.validate()
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("ListAllAvailCollections",
mock.Anything,
).Return(map[int64][]int64{
1: {1, 2},
}, nil)
core := newTestCore(withMeta(meta))
task := createCollectionTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
},
dbID: 1,
}
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNumPerDB.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNumPerDB.Key)
err := task.validate() err := task.validate()
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -377,8 +460,29 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) {
} }
func Test_createCollectionTask_Prepare(t *testing.T) { func Test_createCollectionTask_Prepare(t *testing.T) {
paramtable.Init()
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetDatabaseByName",
mock.Anything,
mock.Anything,
mock.Anything,
).Return(model.NewDefaultDatabase(), nil)
meta.On("ListAllAvailCollections",
mock.Anything,
).Return(map[int64][]int64{
util.DefaultDBID: {1, 2},
}, nil)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNumPerDB.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNumPerDB.Key)
t.Run("invalid msg type", func(t *testing.T) { t.Run("invalid msg type", func(t *testing.T) {
core := newTestCore(withMeta(meta))
task := &createCollectionTask{ task := &createCollectionTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateCollectionRequest{ Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection}, Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
}, },
@ -388,13 +492,16 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
}) })
t.Run("invalid schema", func(t *testing.T) { t.Run("invalid schema", func(t *testing.T) {
core := newTestCore(withMeta(meta))
collectionName := funcutil.GenRandomStr() collectionName := funcutil.GenRandomStr()
task := &createCollectionTask{ task := &createCollectionTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateCollectionRequest{ Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
CollectionName: collectionName, CollectionName: collectionName,
Schema: []byte("invalid schema"), Schema: []byte("invalid schema"),
}, },
dbID: 1,
} }
err := task.Prepare(context.Background()) err := task.Prepare(context.Background())
assert.Error(t, err) assert.Error(t, err)
@ -414,7 +521,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err) assert.NoError(t, err)
core := newTestCore(withInvalidIDAllocator()) core := newTestCore(withInvalidIDAllocator(), withMeta(meta))
task := createCollectionTask{ task := createCollectionTask{
baseTask: baseTask{core: core}, baseTask: baseTask{core: core},
@ -423,6 +530,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
CollectionName: collectionName, CollectionName: collectionName,
Schema: marshaledSchema, Schema: marshaledSchema,
}, },
dbID: 1,
} }
err = task.Prepare(context.Background()) err = task.Prepare(context.Background())
assert.Error(t, err) assert.Error(t, err)
@ -436,7 +544,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
ticker := newRocksMqTtSynchronizer() ticker := newRocksMqTtSynchronizer()
core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker)) core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta))
schema := &schemapb.CollectionSchema{ schema := &schemapb.CollectionSchema{
Name: collectionName, Name: collectionName,
@ -456,6 +564,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
CollectionName: collectionName, CollectionName: collectionName,
Schema: marshaledSchema, Schema: marshaledSchema,
}, },
dbID: 1,
} }
task.Req.ShardsNum = int32(Params.RootCoordCfg.DmlChannelNum.GetAsInt() + 1) // no enough channels. task.Req.ShardsNum = int32(Params.RootCoordCfg.DmlChannelNum.GetAsInt() + 1) // no enough channels.
err = task.Prepare(context.Background()) err = task.Prepare(context.Background())
@ -475,13 +584,13 @@ func Test_createCollectionTask_Execute(t *testing.T) {
field1 := funcutil.GenRandomStr() field1 := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName} coll := &model.Collection{Name: collectionName}
meta := newMockMetaTable() meta := mockrootcoord.NewIMetaTable(t)
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.On("GetCollectionByName",
return coll, nil mock.Anything,
} mock.Anything,
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) { mock.Anything,
return []*model.Collection{}, nil mock.Anything,
} ).Return(coll, nil)
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker)) core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
@ -522,13 +631,13 @@ func Test_createCollectionTask_Execute(t *testing.T) {
PhysicalChannelNames: channels.physicalChannels, PhysicalChannelNames: channels.physicalChannels,
} }
meta := newMockMetaTable() meta := mockrootcoord.NewIMetaTable(t)
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.On("GetCollectionByName",
return coll, nil mock.Anything,
} mock.Anything,
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) { mock.Anything,
return []*model.Collection{}, nil mock.Anything,
} ).Return(coll, nil)
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker)) core := newTestCore(withMeta(meta), withTtSynchronizer(ticker))
@ -573,16 +682,23 @@ func Test_createCollectionTask_Execute(t *testing.T) {
ticker := newRocksMqTtSynchronizer() ticker := newRocksMqTtSynchronizer()
pchans := ticker.getDmlChannelNames(shardNum) pchans := ticker.getDmlChannelNames(shardNum)
meta := newMockMetaTable() meta := mockrootcoord.NewIMetaTable(t)
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.On("GetCollectionByName",
return nil, errors.New("error mock GetCollectionByName") mock.Anything,
} mock.Anything,
meta.AddCollectionFunc = func(ctx context.Context, coll *model.Collection) error { mock.Anything,
return nil mock.Anything,
} ).Return(nil, errors.New("error mock GetCollectionByName"))
meta.ChangeCollectionStateFunc = func(ctx context.Context, collectionID UniqueID, state etcdpb.CollectionState, ts Timestamp) error { meta.On("AddCollection",
return nil mock.Anything,
} mock.Anything,
).Return(nil)
meta.On("ChangeCollectionState",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
dc := newMockDataCoord() dc := newMockDataCoord()
dc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) { dc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) {
@ -630,40 +746,6 @@ func Test_createCollectionTask_Execute(t *testing.T) {
schema: schema, schema: schema,
} }
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
return nil, errors.New("mock error")
}
err = task.Execute(context.Background())
assert.Error(t, err)
originFormatter := Params.QuotaConfig.MaxCollectionNum.Formatter
Params.QuotaConfig.MaxCollectionNum.Formatter = func(originValue string) string {
return "10"
}
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
maxNum := Params.QuotaConfig.MaxCollectionNum.GetAsInt()
return make([]*model.Collection, maxNum), nil
}
err = task.Execute(context.Background())
assert.Error(t, err)
Params.QuotaConfig.MaxCollectionNum.Formatter = originFormatter
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
maxNum := Params.QuotaConfig.MaxCollectionNumPerDB.GetAsInt()
collections := make([]*model.Collection, 0, maxNum)
for i := 0; i < maxNum; i++ {
collections = append(collections, &model.Collection{DBName: task.Req.GetDbName()})
}
return collections, nil
}
err = task.Execute(context.Background())
assert.Error(t, err)
assert.True(t, errors.Is(merr.ErrCollectionNumLimitExceeded, err))
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) {
return []*model.Collection{}, nil
}
err = task.Execute(context.Background()) err = task.Execute(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -678,27 +760,35 @@ func Test_createCollectionTask_Execute(t *testing.T) {
ticker := newRocksMqTtSynchronizer() ticker := newRocksMqTtSynchronizer()
pchans := ticker.getDmlChannelNames(shardNum) pchans := ticker.getDmlChannelNames(shardNum)
meta := newMockMetaTable() meta := mockrootcoord.NewIMetaTable(t)
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { meta.On("GetCollectionByName",
return nil, errors.New("error mock GetCollectionByName") mock.Anything,
} mock.Anything,
meta.AddCollectionFunc = func(ctx context.Context, coll *model.Collection) error { mock.Anything,
return nil mock.Anything,
} ).Return(nil, errors.New("error mock GetCollectionByName"))
meta.ListCollectionsFunc = func(ctx context.Context, ts Timestamp) ([]*model.Collection, error) { meta.On("AddCollection",
return []*model.Collection{}, nil mock.Anything,
} mock.Anything,
// inject error here. ).Return(nil)
meta.ChangeCollectionStateFunc = func(ctx context.Context, collectionID UniqueID, state etcdpb.CollectionState, ts Timestamp) error { meta.On("ChangeCollectionState",
return errors.New("error mock ChangeCollectionState") mock.Anything,
} mock.Anything,
mock.Anything,
mock.Anything,
).Return(errors.New("error mock ChangeCollectionState"))
removeCollectionCalled := false removeCollectionCalled := false
removeCollectionChan := make(chan struct{}, 1) removeCollectionChan := make(chan struct{}, 1)
meta.RemoveCollectionFunc = func(ctx context.Context, collectionID UniqueID, ts Timestamp) error { meta.On("RemoveCollection",
mock.Anything,
mock.Anything,
mock.Anything,
).Return(func(ctx context.Context, collID UniqueID, ts Timestamp) error {
removeCollectionCalled = true removeCollectionCalled = true
removeCollectionChan <- struct{}{} removeCollectionChan <- struct{}{}
return nil return nil
} })
broker := newMockBroker() broker := newMockBroker()
broker.WatchChannelsFunc = func(ctx context.Context, info *watchInfo) error { broker.WatchChannelsFunc = func(ctx context.Context, info *watchInfo) error {
@ -770,6 +860,7 @@ func Test_createCollectionTask_Execute(t *testing.T) {
} }
func Test_createCollectionTask_PartitionKey(t *testing.T) { func Test_createCollectionTask_PartitionKey(t *testing.T) {
paramtable.Init()
defer cleanTestEnv() defer cleanTestEnv()
collectionName := funcutil.GenRandomStr() collectionName := funcutil.GenRandomStr()
@ -777,6 +868,23 @@ func Test_createCollectionTask_PartitionKey(t *testing.T) {
ticker := newRocksMqTtSynchronizer() ticker := newRocksMqTtSynchronizer()
meta := mockrootcoord.NewIMetaTable(t) meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetDatabaseByName",
mock.Anything,
mock.Anything,
mock.Anything,
).Return(model.NewDefaultDatabase(), nil)
meta.On("ListAllAvailCollections",
mock.Anything,
).Return(map[int64][]int64{
util.DefaultDBID: {1, 2},
}, nil)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key)
paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNumPerDB.Key, strconv.Itoa(math.MaxInt64))
defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNumPerDB.Key)
core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta)) core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta))
partitionKeyField := &schemapb.FieldSchema{ partitionKeyField := &schemapb.FieldSchema{

View File

@ -0,0 +1,56 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type createDatabaseTask struct {
baseTask
Req *milvuspb.CreateDatabaseRequest
dbID UniqueID
}
func (t *createDatabaseTask) Prepare(ctx context.Context) error {
dbs, err := t.core.meta.ListDatabases(ctx, t.GetTs())
if err != nil {
return err
}
cfgMaxDatabaseNum := Params.RootCoordCfg.MaxDatabaseNum.GetAsInt()
if len(dbs) > cfgMaxDatabaseNum {
return merr.WrapErrDatabaseResourceLimitExceeded(fmt.Sprintf("Failed to create database, limit={%d}", cfgMaxDatabaseNum))
}
t.dbID, err = t.core.idAllocator.AllocOne()
if err != nil {
return err
}
return nil
}
func (t *createDatabaseTask) Execute(ctx context.Context) error {
db := model.NewDatabase(t.dbID, t.Req.GetDbName(), etcdpb.DatabaseState_DatabaseCreated)
return t.core.meta.CreateDatabase(ctx, db, t.GetTs())
}

View File

@ -0,0 +1,124 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func Test_CreateDBTask_Prepare(t *testing.T) {
paramtable.Init()
t.Run("list database fail", func(t *testing.T) {
core := newTestCore(withInvalidMeta())
task := &createDatabaseTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateDatabase,
},
DbName: "db",
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("check database number fail", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
cfgMaxDatabaseNum := Params.RootCoordCfg.MaxDatabaseNum.GetAsInt()
len := cfgMaxDatabaseNum + 1
dbs := make([]*model.Database, 0, len)
for i := 0; i < len; i++ {
dbs = append(dbs, model.NewDefaultDatabase())
}
meta.On("ListDatabases",
mock.Anything,
mock.Anything).
Return(dbs, nil)
core := newTestCore(withMeta(meta))
task := &createDatabaseTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateDatabase,
},
DbName: "db",
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("ok", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("ListDatabases",
mock.Anything,
mock.Anything).
Return([]*model.Database{model.NewDefaultDatabase()}, nil)
core := newTestCore(withMeta(meta), withValidIDAllocator())
paramtable.Get().Save(Params.RootCoordCfg.MaxDatabaseNum.Key, strconv.Itoa(10))
defer paramtable.Get().Reset(Params.RootCoordCfg.MaxDatabaseNum.Key)
task := &createDatabaseTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateDatabase,
},
DbName: "db",
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
})
}
func Test_CreateDBTask_Execute(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("CreateDatabase",
mock.Anything,
mock.Anything,
mock.Anything).
Return(nil)
core := newTestCore(withMeta(meta))
task := &createDatabaseTask{
baseTask: newBaseTask(context.TODO(), core),
Req: &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateDatabase,
},
DbName: "db",
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
}

View File

@ -20,15 +20,13 @@ import (
"context" "context"
"fmt" "fmt"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/log"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/log"
) )
type createPartitionTask struct { type createPartitionTask struct {
@ -41,7 +39,7 @@ func (t *createPartitionTask) Prepare(ctx context.Context) error {
if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_CreatePartition); err != nil { if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_CreatePartition); err != nil {
return err return err
} }
collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetCollectionName(), t.GetTs()) collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), t.GetTs())
if err != nil { if err != nil {
return err return err
} }
@ -80,6 +78,7 @@ func (t *createPartitionTask) Execute(ctx context.Context) error {
undoTask.AddStep(&expireCacheStep{ undoTask.AddStep(&expireCacheStep{
baseStep: baseStep{core: t.core}, baseStep: baseStep{core: t.core},
dbName: t.Req.GetDbName(),
collectionNames: []string{t.collMeta.Name}, collectionNames: []string{t.collMeta.Name},
collectionID: t.collMeta.CollectionID, collectionID: t.collMeta.CollectionID,
ts: t.GetTs(), ts: t.GetTs(),
@ -90,6 +89,7 @@ func (t *createPartitionTask) Execute(ctx context.Context) error {
partition: partition, partition: partition,
}, &removePartitionMetaStep{ }, &removePartitionMetaStep{
baseStep: baseStep{core: t.core}, baseStep: baseStep{core: t.core},
dbID: t.collMeta.DBID,
collectionID: partition.CollectionID, collectionID: partition.CollectionID,
partitionID: partition.PartitionID, partitionID: partition.PartitionID,
ts: t.GetTs(), ts: t.GetTs(),

View File

@ -20,15 +20,15 @@ import (
"context" "context"
"testing" "testing"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/etcdpb"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/util/funcutil"
) )
func Test_createPartitionTask_Prepare(t *testing.T) { func Test_createPartitionTask_Prepare(t *testing.T) {
@ -51,12 +51,17 @@ func Test_createPartitionTask_Prepare(t *testing.T) {
}) })
t.Run("normal case", func(t *testing.T) { t.Run("normal case", func(t *testing.T) {
meta := newMockMetaTable()
collectionName := funcutil.GenRandomStr() collectionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName} coll := &model.Collection{Name: collectionName}
meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) {
return coll.Clone(), nil meta := mockrootcoord.NewIMetaTable(t)
} meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(coll.Clone(), nil)
core := newTestCore(withMeta(meta)) core := newTestCore(withMeta(meta))
task := &createPartitionTask{ task := &createPartitionTask{
baseTask: baseTask{core: core}, baseTask: baseTask{core: core},

View File

@ -46,5 +46,6 @@ func (t *describeCollectionTask) Execute(ctx context.Context) (err error) {
} }
aliases := t.core.meta.ListAliasesByID(coll.CollectionID) aliases := t.core.meta.ListAliasesByID(coll.CollectionID)
t.Rsp = convertModelToDesc(coll, aliases) t.Rsp = convertModelToDesc(coll, aliases)
t.Rsp.DbName = t.Req.GetDbName()
return nil return nil
} }

Some files were not shown because too many files have changed in this diff Show More