enhance: refine proxy metacache for concurrent safe (#29872)

relate: https://github.com/milvus-io/milvus/issues/29675

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2024-01-22 14:28:55 +08:00 committed by GitHub
parent 4436effdc3
commit a81d2b4780
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 154 additions and 181 deletions

View File

@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -75,7 +76,6 @@ type Cache interface {
expireShardLeaderCache(ctx context.Context)
RemoveCollection(ctx context.Context, database, collectionName string)
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string
RemovePartition(ctx context.Context, database, collectionName string, partitionName string)
// GetCredentialInfo operate credential cache
GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
@ -256,6 +256,7 @@ type MetaCache struct {
credMut sync.RWMutex
privilegeMut sync.RWMutex
shardMgr shardClientMgr
sfGlobal conc.Singleflight[*collectionInfo]
}
// globalMetaCache is singleton instance of Cache
@ -294,33 +295,121 @@ func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordCl
}, nil
}
// GetCollectionID returns the corresponding collection id for provided collection name
func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error) {
func (m *MetaCache) getCollection(database, collectionName string, collectionID UniqueID) (*collectionInfo, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
var ok bool
var collInfo *collectionInfo
db, dbOk := m.collInfo[database]
if dbOk && db != nil {
collInfo, ok = db[collectionName]
db, ok := m.collInfo[database]
if !ok {
return nil, false
}
if collectionName == "" {
for _, collection := range db {
if collection.collID == collectionID {
return collection, collection.isCollectionCached()
}
}
} else {
if collection, ok := db[collectionName]; ok {
return collection, collection.isCollectionCached()
}
}
return nil, false
}
func (m *MetaCache) update(ctx context.Context, database, collectionName string, collectionID UniqueID) (*collectionInfo, error) {
if collInfo, ok := m.getCollection(database, collectionName, collectionID); ok {
return collInfo, nil
}
collection, err := m.describeCollection(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
partitions, err := m.showPartitions(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
// check partitionID, createdTimestamp and utcstamp has sam element numbers
if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) {
return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String())
}
infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo {
return &partitionInfo{
name: partitions.PartitionNames[idx],
partitionID: partitions.PartitionIDs[idx],
createdTimestamp: partitions.CreatedTimestamps[idx],
createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx],
}
})
collectionName = collection.Schema.GetName()
m.mu.Lock()
defer m.mu.Unlock()
_, dbOk := m.collInfo[database]
if !dbOk {
m.collInfo[database] = make(map[string]*collectionInfo)
}
_, ok := m.collInfo[database][collectionName]
if !ok {
m.collInfo[database][collectionName] = &collectionInfo{}
}
collInfo := m.collInfo[database][collectionName]
collInfo.schema = newSchemaInfo(collection.Schema)
collInfo.collID = collection.CollectionID
collInfo.createdTimestamp = collection.CreatedTimestamp
collInfo.createdUtcTimestamp = collection.CreatedUtcTimestamp
collInfo.consistencyLevel = collection.ConsistencyLevel
collInfo.partInfo = parsePartitionsInfo(infos)
log.Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName), zap.Int64("collectionID", collInfo.collID))
return m.collInfo[database][collectionName], nil
}
func buildSfKeyByName(database, collectionName string) string {
return database + "-" + collectionName
}
func buildSfKeyById(database string, collectionID UniqueID) string {
return database + "--" + fmt.Sprint(collectionID)
}
func (m *MetaCache) UpdateByName(ctx context.Context, database, collectionName string) (*collectionInfo, error) {
collection, err, _ := m.sfGlobal.Do(buildSfKeyByName(database, collectionName), func() (*collectionInfo, error) {
return m.update(ctx, database, collectionName, 0)
})
return collection, err
}
func (m *MetaCache) UpdateByID(ctx context.Context, database string, collectionID UniqueID) (*collectionInfo, error) {
collection, err, _ := m.sfGlobal.Do(buildSfKeyById(database, collectionID), func() (*collectionInfo, error) {
return m.update(ctx, database, "", collectionID)
})
return collection, err
}
// GetCollectionID returns the corresponding collection id for provided collection name
func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionName string) (UniqueID, error) {
method := "GetCollectionID"
if !ok || !collInfo.isCollectionCached() {
m.mu.RLock()
collInfo, ok := m.getCollection(database, collectionName, 0)
if !ok {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
tr := timerecord.NewTimeRecorder("UpdateCache")
m.mu.RUnlock()
coll, err := m.describeCollection(ctx, database, collectionName, 0)
if err != nil {
return 0, err
}
m.mu.Lock()
defer m.mu.Unlock()
m.updateCollection(coll, database, collectionName)
collInfo, err := m.UpdateByName(ctx, database, collectionName)
if err != nil {
return UniqueID(0), err
}
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
collInfo = m.collInfo[database][collectionName]
return collInfo.collID, nil
}
defer m.mu.RUnlock()
@ -331,32 +420,22 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam
// GetCollectionName returns the corresponding collection name for provided collection id
func (m *MetaCache) GetCollectionName(ctx context.Context, database string, collectionID int64) (string, error) {
m.mu.RLock()
var collInfo *collectionInfo
for _, db := range m.collInfo {
for _, coll := range db {
if coll.collID == collectionID {
collInfo = coll
break
}
}
}
method := "GetCollectionName"
if collInfo == nil || !collInfo.isCollectionCached() {
m.mu.RLock()
collInfo, ok := m.getCollection(database, "", collectionID)
if !ok {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
tr := timerecord.NewTimeRecorder("UpdateCache")
m.mu.RUnlock()
coll, err := m.describeCollection(ctx, database, "", collectionID)
collInfo, err := m.UpdateByID(ctx, database, collectionID)
if err != nil {
return "", err
}
m.mu.Lock()
defer m.mu.Unlock()
m.updateCollection(coll, coll.GetDbName(), coll.Schema.Name)
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return coll.Schema.Name, nil
return collInfo.schema.Name, nil
}
defer m.mu.RUnlock()
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
@ -366,29 +445,20 @@ func (m *MetaCache) GetCollectionName(ctx context.Context, database string, coll
func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionBasicInfo, error) {
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
collInfo, ok := m.getCollection(database, collectionName, 0)
method := "GetCollectionInfo"
// if collInfo.collID != collectionID, means that the cache is not trustable
// try to get collection according to collectionID
if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID {
if !ok || collInfo.collID != collectionID {
m.mu.RUnlock()
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
coll, err := m.describeCollection(ctx, database, "", collectionID)
collInfo, err := m.UpdateByID(ctx, database, collectionID)
if err != nil {
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
m.updateCollection(coll, database, collectionName)
collInfo = m.collInfo[database][collectionName]
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return collInfo.getBasicInfo(), nil
}
@ -403,34 +473,20 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll
// TODO: may cause data race of this implementation, should be refactored in future.
func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) {
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
collInfo, ok := m.getCollection(database, collectionName, collectionID)
method := "GetCollectionInfo"
// if collInfo.collID != collectionID, means that the cache is not trustable
// try to get collection according to collectionID
if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID {
if !ok || collInfo.collID != collectionID {
m.mu.RUnlock()
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
var coll *milvuspb.DescribeCollectionResponse
var err error
// collectionName maybe not trustable, get collection according to id
coll, err = m.describeCollection(ctx, database, "", collectionID)
collInfo, err := m.UpdateByID(ctx, database, collectionID)
if err != nil {
return nil, err
}
m.mu.Lock()
m.updateCollection(coll, database, collectionName)
collInfo = m.collInfo[database][collectionName]
m.mu.Unlock()
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return collInfo, nil
}
@ -442,31 +498,18 @@ func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collect
func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error) {
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
collInfo, ok := m.getCollection(database, collectionName, 0)
method := "GetCollectionSchema"
if !ok || !collInfo.isCollectionCached() {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
tr := timerecord.NewTimeRecorder("UpdateCache")
if !ok {
m.mu.RUnlock()
coll, err := m.describeCollection(ctx, database, collectionName, 0)
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
collInfo, err := m.UpdateByName(ctx, database, collectionName)
if err != nil {
log.Warn("Failed to load collection from rootcoord ",
zap.String("collection name ", collectionName),
zap.Error(err))
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
m.updateCollection(coll, database, collectionName)
collInfo = m.collInfo[database][collectionName]
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("Reload collection from root coordinator ",
zap.String("collectionName", collectionName),
@ -479,23 +522,6 @@ func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectio
return collInfo.schema, nil
}
func (m *MetaCache) updateCollection(coll *milvuspb.DescribeCollectionResponse, database, collectionName string) {
_, dbOk := m.collInfo[database]
if !dbOk {
m.collInfo[database] = make(map[string]*collectionInfo)
}
_, ok := m.collInfo[database][collectionName]
if !ok {
m.collInfo[database][collectionName] = &collectionInfo{}
}
m.collInfo[database][collectionName].schema = newSchemaInfo(coll.Schema)
m.collInfo[database][collectionName].collID = coll.CollectionID
m.collInfo[database][collectionName].createdTimestamp = coll.CreatedTimestamp
m.collInfo[database][collectionName].createdUtcTimestamp = coll.CreatedUtcTimestamp
m.collInfo[database][collectionName].consistencyLevel = coll.ConsistencyLevel
}
func (m *MetaCache) GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error) {
partInfo, err := m.GetPartitionInfo(ctx, database, collectionName, partitionName)
if err != nil {
@ -505,7 +531,7 @@ func (m *MetaCache) GetPartitionID(ctx context.Context, database, collectionName
}
func (m *MetaCache) GetPartitions(ctx context.Context, database, collectionName string) (map[string]typeutil.UniqueID, error) {
partitions, err := m.getPartitionInfos(ctx, database, collectionName)
partitions, err := m.GetPartitionInfos(ctx, database, collectionName)
if err != nil {
return nil, err
}
@ -514,7 +540,7 @@ func (m *MetaCache) GetPartitions(ctx context.Context, database, collectionName
}
func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error) {
partitions, err := m.getPartitionInfos(ctx, database, collectionName)
partitions, err := m.GetPartitionInfos(ctx, database, collectionName)
if err != nil {
return nil, err
}
@ -527,7 +553,7 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa
}
func (m *MetaCache) GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error) {
partitions, err := m.getPartitionInfos(ctx, database, collectionName)
partitions, err := m.GetPartitionInfos(ctx, database, collectionName)
if err != nil {
return nil, err
}
@ -539,49 +565,26 @@ func (m *MetaCache) GetPartitionsIndex(ctx context.Context, database, collection
return partitions.indexedPartitionNames, nil
}
func (m *MetaCache) getPartitionInfos(ctx context.Context, database, collectionName string) (*partitionInfos, error) {
_, err := m.GetCollectionID(ctx, database, collectionName)
if err != nil {
return nil, err
}
func (m *MetaCache) GetPartitionInfos(ctx context.Context, database, collectionName string) (*partitionInfos, error) {
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
method := "GetPartitionInfo"
collInfo, ok := m.getCollection(database, collectionName, 0)
if !ok {
m.mu.RUnlock()
return nil, fmt.Errorf("can't find collection name %s:%s", database, collectionName)
}
partitionInfos := collInfo.partInfo
m.mu.RUnlock()
method := "GetPartitionInfo"
if partitionInfos == nil {
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
partitions, err := m.showPartitions(ctx, database, collectionName)
collInfo, err := m.UpdateByName(ctx, database, collectionName)
if err != nil {
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
err = m.updatePartitions(partitions, database, collectionName)
if err != nil {
return nil, err
}
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
partitionInfos = m.collInfo[database][collectionName].partInfo
return partitionInfos, nil
return collInfo.partInfo, nil
}
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
return partitionInfos, nil
defer m.mu.RUnlock()
return collInfo.partInfo, nil
}
// Get the collection information from rootcoord.
@ -627,21 +630,23 @@ func (m *MetaCache) describeCollection(ctx context.Context, database, collection
return resp, nil
}
func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectionName string) (*milvuspb.ShowPartitionsResponse, error) {
func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectionName string, collectionID UniqueID) (*milvuspb.ShowPartitionsResponse, error) {
req := &milvuspb.ShowPartitionsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
),
DbName: dbName,
CollectionName: collectionName,
CollectionID: collectionID,
}
partitions, err := m.rootCoord.ShowPartitions(ctx, req)
if err != nil {
return nil, err
}
if partitions.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return nil, fmt.Errorf("%s", partitions.GetStatus().GetReason())
if err := merr.Error(partitions.GetStatus()); err != nil {
return nil, err
}
if len(partitions.PartitionIDs) != len(partitions.PartitionNames) {
@ -690,35 +695,6 @@ func parsePartitionsInfo(infos []*partitionInfo) *partitionInfos {
return result
}
func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse, database, collectionName string) error {
// check partitionID, createdTimestamp and utcstamp has sam element numbers
if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) {
return merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String())
}
_, dbOk := m.collInfo[database]
if !dbOk {
m.collInfo[database] = make(map[string]*collectionInfo)
}
_, ok := m.collInfo[database][collectionName]
if !ok {
m.collInfo[database][collectionName] = &collectionInfo{}
}
infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo {
return &partitionInfo{
name: partitions.PartitionNames[idx],
partitionID: partitions.PartitionIDs[idx],
createdTimestamp: partitions.CreatedTimestamps[idx],
createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx],
}
})
m.collInfo[database][collectionName].partInfo = parsePartitionsInfo(infos)
return nil
}
func (m *MetaCache) RemoveCollection(ctx context.Context, database, collectionName string) {
m.mu.Lock()
defer m.mu.Unlock()

View File

@ -69,7 +69,7 @@ func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *m
if m.Error {
return nil, errors.New("mocked error")
}
if in.CollectionName == "collection1" {
if in.CollectionName == "collection1" || in.CollectionID == 1 {
return &milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []typeutil.UniqueID{1, 2},
@ -78,7 +78,7 @@ func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *m
PartitionNames: []string{"par1", "par2"},
}, nil
}
if in.CollectionName == "collection2" {
if in.CollectionName == "collection2" || in.CollectionID == 2 {
return &milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []typeutil.UniqueID{3, 4},
@ -900,12 +900,6 @@ func TestMetaCache_Database(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: merr.Success(),
CollectionIDs: []UniqueID{1, 2},
InMemoryPercentages: []int64{100, 50},
}, nil)
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err)
_, err = GetCachedCollectionSchema(ctx, dbName, "collection1")

View File

@ -57,10 +57,6 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
rootCoord := newMockRootCoord()
queryCoord := getMockQueryCoord()
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: merr.Success(),
CollectionIDs: []int64{},
}, nil)
datacoord := NewDataCoordMock()
gist := &getIndexStateTask{
@ -75,7 +71,7 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
rootCoord: rootCoord,
dataCoord: datacoord,
result: &milvuspb.GetIndexStateResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mock"},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mock-1"},
State: commonpb.IndexState_Unissued,
},
collectionID: collectionID,
@ -83,7 +79,8 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
shardMgr := newShardClientMgr()
// failed to get collection id.
_ = InitMetaCache(ctx, rootCoord, queryCoord, shardMgr)
err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
assert.Error(t, gist.Execute(ctx))
rootCoord.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) {
@ -95,6 +92,12 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
}, nil
}
rootCoord.ShowPartitionsFunc = func(ctx context.Context, request *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) {
return &milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
}, nil
}
datacoord.GetIndexStateFunc = func(ctx context.Context, request *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) {
return &indexpb.GetIndexStateResponse{
Status: merr.Success(),