enhance: set database properties to restrict read access (#35745)

issue: #35744

Signed-off-by: jaime <yun.zhang@zilliz.com>
This commit is contained in:
jaime 2024-08-29 13:17:01 +08:00 committed by GitHub
parent b51b4a2838
commit b0ac04d104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 231 additions and 101 deletions

View File

@ -407,7 +407,7 @@ func (q *QuotaCenter) collectMetrics() error {
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
if getErr != nil {
// skip limit check if the collection meta has been removed from rootcoord meta
return false
return true
}
collIDToPartIDs, ok := q.readableCollections[coll.DBID]
if !ok {
@ -463,7 +463,7 @@ func (q *QuotaCenter) collectMetrics() error {
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
if getErr != nil {
// skip limit check if the collection meta has been removed from rootcoord meta
return false
return true
}
collIDToPartIDs, ok := q.writableCollections[coll.DBID]
@ -562,7 +562,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
dbLimiters := q.rateLimiter.GetDatabaseLimiters(dbID)
if dbLimiters == nil {
log.Warn("db limiter not found of db ID", zap.Int64("dbID", dbID))
return fmt.Errorf("db limiter not found of db ID: %d", dbID)
continue
}
updateLimiter(dbLimiters, GetEarliestLimiter(), internalpb.RateScope_Database, dml)
dbLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
@ -578,7 +578,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
log.Warn("collection limiter not found of collection ID",
zap.Int64("dbID", dbID),
zap.Int64("collectionID", collectionID))
return fmt.Errorf("collection limiter not found of collection ID: %d", collectionID)
continue
}
updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dml)
collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
@ -596,7 +596,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
zap.Int64("dbID", dbID),
zap.Int64("collectionID", collectionID),
zap.Int64("partitionID", partitionID))
return fmt.Errorf("partition limiter not found of partition ID: %d", partitionID)
continue
}
updateLimiter(partitionLimiter, GetEarliestLimiter(), internalpb.RateScope_Partition, dml)
partitionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
@ -604,7 +604,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
}
if cluster || len(dbIDs) > 0 || len(collectionIDs) > 0 || len(col2partitionIDs) > 0 {
log.RatedWarn(10, "QuotaCenter force to deny writing",
log.RatedWarn(30, "QuotaCenter force to deny writing",
zap.Bool("cluster", cluster),
zap.Int64s("dbIDs", dbIDs),
zap.Int64s("collectionIDs", collectionIDs),
@ -616,20 +616,37 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
}
// forceDenyReading sets dql rates to 0 to reject all dql requests.
func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode) {
var collectionIDs []int64
for dbID, collectionIDToPartIDs := range q.readableCollections {
for collectionID := range collectionIDToPartIDs {
collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collectionID)
updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dql)
collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, errorCode)
collectionIDs = append(collectionIDs, collectionID)
func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode, cluster bool, dbIDs []int64, mlog *log.MLogger) {
if cluster {
var collectionIDs []int64
for dbID, collectionIDToPartIDs := range q.readableCollections {
for collectionID := range collectionIDToPartIDs {
collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collectionID)
updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dql)
collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, errorCode)
collectionIDs = append(collectionIDs, collectionID)
}
}
mlog.RatedWarn(10, "QuotaCenter force to deny reading",
zap.Int64s("collectionIDs", collectionIDs),
zap.String("reason", errorCode.String()))
}
log.Warn("QuotaCenter force to deny reading",
zap.Int64s("collectionIDs", collectionIDs),
zap.String("reason", errorCode.String()))
if len(dbIDs) > 0 {
for _, dbID := range dbIDs {
dbLimiters := q.rateLimiter.GetDatabaseLimiters(dbID)
if dbLimiters == nil {
log.Warn("db limiter not found of db ID", zap.Int64("dbID", dbID))
continue
}
updateLimiter(dbLimiters, GetEarliestLimiter(), internalpb.RateScope_Database, dql)
dbLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, errorCode)
mlog.RatedWarn(10, "QuotaCenter force to deny reading",
zap.Int64s("dbIDs", dbIDs),
zap.String("reason", errorCode.String()))
}
}
}
// getRealTimeRate return real time rate in Proxy.
@ -654,58 +671,26 @@ func (q *QuotaCenter) guaranteeMinRate(minRate float64, rt internalpb.RateType,
}
}
// calculateReadRates calculates and sets dql rates.
func (q *QuotaCenter) calculateReadRates() error {
log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0)
if Params.QuotaConfig.ForceDenyReading.GetAsBool() {
q.forceDenyReading(commonpb.ErrorCode_ForceDeny)
return nil
}
limitCollectionSet := typeutil.NewUniqueSet()
limitDBNameSet := typeutil.NewSet[string]()
limitCollectionNameSet := typeutil.NewSet[string]()
clusterLimit := false
formatCollctionRateKey := func(dbName, collectionName string) string {
return fmt.Sprintf("%s.%s", dbName, collectionName)
}
splitCollctionRateKey := func(key string) (string, string) {
parts := strings.Split(key, ".")
return parts[0], parts[1]
}
// query latency
queueLatencyThreshold := Params.QuotaConfig.QueueLatencyThreshold.GetAsDuration(time.Second)
// enableQueueProtection && queueLatencyThreshold >= 0 means enable queue latency protection
if queueLatencyThreshold >= 0 {
for _, metric := range q.queryNodeMetrics {
searchLatency := metric.SearchQueue.AvgQueueDuration
queryLatency := metric.QueryQueue.AvgQueueDuration
if searchLatency >= queueLatencyThreshold || queryLatency >= queueLatencyThreshold {
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
func (q *QuotaCenter) getDenyReadingDBs() map[int64]struct{} {
dbIDs := make(map[int64]struct{})
for _, dbID := range lo.Uniq(q.collectionIDToDBID.Values()) {
if db, err := q.meta.GetDatabaseByID(q.ctx, dbID, typeutil.MaxTimestamp); err == nil {
if v := db.GetProperty(common.DatabaseForceDenyReadingKey); v != "" {
if dbForceDenyReadingEnabled, _ := strconv.ParseBool(v); dbForceDenyReadingEnabled {
dbIDs[dbID] = struct{}{}
}
}
}
}
return dbIDs
}
// queue length
enableQueueProtection := Params.QuotaConfig.QueueProtectionEnabled.GetAsBool()
nqInQueueThreshold := Params.QuotaConfig.NQInQueueThreshold.GetAsInt64()
if enableQueueProtection && nqInQueueThreshold >= 0 {
// >= 0 means enable queue length protection
sum := func(ri metricsinfo.ReadInfoInQueue) int64 {
return ri.UnsolvedQueue + ri.ReadyQueue + ri.ReceiveChan + ri.ExecuteChan
}
for _, metric := range q.queryNodeMetrics {
// We think of the NQ of query request as 1.
// search use same queue length counter with query
if sum(metric.SearchQueue) >= nqInQueueThreshold {
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
}
}
}
metricMap := make(map[string]float64) // label metric
collectionMetricMap := make(map[string]map[string]map[string]float64) // sub label metric, label -> db -> collection -> value
// getReadRates get rate information of collections and databases from proxy metrics
func (q *QuotaCenter) getReadRates() (map[string]float64, map[string]map[string]map[string]float64) {
// label metric
metricMap := make(map[string]float64)
// sub label metric, label -> db -> collection -> value
collectionMetricMap := make(map[string]map[string]map[string]float64)
for _, metric := range q.proxyMetrics {
for _, rm := range metric.Rms {
if !ratelimitutil.IsSubLabel(rm.Label) {
@ -729,8 +714,20 @@ func (q *QuotaCenter) calculateReadRates() error {
databaseMetric[collection] += rm.Rate
}
}
return metricMap, collectionMetricMap
}
func (q *QuotaCenter) getLimitedDBAndCollections(metricMap map[string]float64,
collectionMetricMap map[string]map[string]map[string]float64,
) (bool, *typeutil.Set[string], *typeutil.Set[string]) {
limitDBNameSet := typeutil.NewSet[string]()
limitCollectionNameSet := typeutil.NewSet[string]()
clusterLimit := false
formatCollctionRateKey := func(dbName, collectionName string) string {
return fmt.Sprintf("%s.%s", dbName, collectionName)
}
// read result
enableResultProtection := Params.QuotaConfig.ResultProtectionEnabled.GetAsBool()
if enableResultProtection {
maxRate := Params.QuotaConfig.MaxReadResultRate.GetAsFloat()
@ -765,28 +762,12 @@ func (q *QuotaCenter) calculateReadRates() error {
}
}
}
return clusterLimit, &limitDBNameSet, &limitCollectionNameSet
}
dbIDs := make(map[int64]string, q.dbs.Len())
collectionIDs := make(map[int64]string, q.collections.Len())
q.dbs.Range(func(name string, id int64) bool {
dbIDs[id] = name
return true
})
q.collections.Range(func(name string, id int64) bool {
_, collectionName := SplitCollectionKey(name)
collectionIDs[id] = collectionName
return true
})
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
if clusterLimit {
realTimeClusterSearchRate := metricMap[internalpb.RateType_DQLSearch.String()]
realTimeClusterQueryRate := metricMap[internalpb.RateType_DQLQuery.String()]
q.coolOffReading(realTimeClusterSearchRate, realTimeClusterQueryRate, coolOffSpeed, q.rateLimiter.GetRootLimiters(), log)
}
var updateLimitErr error
func (q *QuotaCenter) coolOffDatabaseReading(deniedDatabaseIDs map[int64]struct{}, limitDBNameSet *typeutil.Set[string],
collectionMetricMap map[string]map[string]map[string]float64, log *log.MLogger,
) error {
if limitDBNameSet.Len() > 0 {
databaseSearchRate := make(map[string]float64)
databaseQueryRate := make(map[string]float64)
@ -806,18 +787,24 @@ func (q *QuotaCenter) calculateReadRates() error {
}
}
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
limitDBNameSet.Range(func(name string) bool {
dbID, ok := q.dbs.Get(name)
if !ok {
log.Warn("db not found", zap.String("dbName", name))
updateLimitErr = fmt.Errorf("db not found: %s", name)
return false
return true
}
// skip this database because it has been denied access for reading
_, ok = deniedDatabaseIDs[dbID]
if ok {
return true
}
dbLimiter := q.rateLimiter.GetDatabaseLimiters(dbID)
if dbLimiter == nil {
log.Warn("database limiter not found", zap.Int64("dbID", dbID))
updateLimitErr = fmt.Errorf("database limiter not found")
return false
return true
}
realTimeSearchRate := databaseSearchRate[name]
@ -825,10 +812,32 @@ func (q *QuotaCenter) calculateReadRates() error {
q.coolOffReading(realTimeSearchRate, realTimeQueryRate, coolOffSpeed, dbLimiter, log)
return true
})
if updateLimitErr != nil {
return updateLimitErr
}
}
return nil
}
func (q *QuotaCenter) coolOffCollectionReading(deniedDatabaseIDs map[int64]struct{}, limitCollectionSet *typeutil.UniqueSet, limitCollectionNameSet *typeutil.Set[string],
collectionMetricMap map[string]map[string]map[string]float64, log *log.MLogger,
) error {
var updateLimitErr error
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
splitCollctionRateKey := func(key string) (string, string) {
parts := strings.Split(key, ".")
return parts[0], parts[1]
}
dbIDs := make(map[int64]string, q.dbs.Len())
collectionIDs := make(map[int64]string, q.collections.Len())
q.dbs.Range(func(name string, id int64) bool {
dbIDs[id] = name
return true
})
q.collections.Range(func(name string, id int64) bool {
_, collectionName := SplitCollectionKey(name)
collectionIDs[id] = collectionName
return true
})
limitCollectionNameSet.Range(func(name string) bool {
dbName, collectionName := splitCollctionRateKey(name)
@ -836,13 +845,13 @@ func (q *QuotaCenter) calculateReadRates() error {
if !ok {
log.Warn("db not found", zap.String("dbName", dbName))
updateLimitErr = fmt.Errorf("db not found: %s", dbName)
return false
return true
}
collectionID, ok := q.collections.Get(FormatCollectionKey(dbID, collectionName))
if !ok {
log.Warn("collection not found", zap.String("collectionName", name))
updateLimitErr = fmt.Errorf("collection not found: %s", name)
return false
return true
}
limitCollectionSet.Insert(collectionID)
return true
@ -868,6 +877,12 @@ func (q *QuotaCenter) calculateReadRates() error {
if !ok {
return fmt.Errorf("db ID not found of collection ID: %d", collection)
}
// skip this database because it has been denied access for reading
_, ok = deniedDatabaseIDs[dbID]
if ok {
continue
}
collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collection)
if collectionLimiter == nil {
return fmt.Errorf("collection limiter not found: %d", collection)
@ -897,6 +912,73 @@ func (q *QuotaCenter) calculateReadRates() error {
if updateLimitErr = coolOffCollectionID(limitCollectionSet.Collect()...); updateLimitErr != nil {
return updateLimitErr
}
return nil
}
// calculateReadRates calculates and sets dql rates.
func (q *QuotaCenter) calculateReadRates() error {
log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0)
if Params.QuotaConfig.ForceDenyReading.GetAsBool() {
q.forceDenyReading(commonpb.ErrorCode_ForceDeny, true, []int64{}, log)
return nil
}
deniedDatabaseIDs := q.getDenyReadingDBs()
if len(deniedDatabaseIDs) != 0 {
q.forceDenyReading(commonpb.ErrorCode_ForceDeny, false, maps.Keys(deniedDatabaseIDs), log)
}
queueLatencyThreshold := Params.QuotaConfig.QueueLatencyThreshold.GetAsDuration(time.Second)
limitCollectionSet := typeutil.NewUniqueSet()
// enableQueueProtection && queueLatencyThreshold >= 0 means enable queue latency protection
if queueLatencyThreshold >= 0 {
for _, metric := range q.queryNodeMetrics {
searchLatency := metric.SearchQueue.AvgQueueDuration
queryLatency := metric.QueryQueue.AvgQueueDuration
if searchLatency >= queueLatencyThreshold || queryLatency >= queueLatencyThreshold {
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
}
}
}
// queue length
enableQueueProtection := Params.QuotaConfig.QueueProtectionEnabled.GetAsBool()
nqInQueueThreshold := Params.QuotaConfig.NQInQueueThreshold.GetAsInt64()
if enableQueueProtection && nqInQueueThreshold >= 0 {
// >= 0 means enable queue length protection
sum := func(ri metricsinfo.ReadInfoInQueue) int64 {
return ri.UnsolvedQueue + ri.ReadyQueue + ri.ReceiveChan + ri.ExecuteChan
}
for _, metric := range q.queryNodeMetrics {
// We think of the NQ of query request as 1.
// search use same queue length counter with query
if sum(metric.SearchQueue) >= nqInQueueThreshold {
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
}
}
}
metricMap, collectionMetricMap := q.getReadRates()
clusterLimit, limitDBNameSet, limitCollectionNameSet := q.getLimitedDBAndCollections(metricMap, collectionMetricMap)
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
if clusterLimit {
realTimeClusterSearchRate := metricMap[internalpb.RateType_DQLSearch.String()]
realTimeClusterQueryRate := metricMap[internalpb.RateType_DQLQuery.String()]
q.coolOffReading(realTimeClusterSearchRate, realTimeClusterQueryRate, coolOffSpeed, q.rateLimiter.GetRootLimiters(), log)
}
if updateLimitErr := q.coolOffDatabaseReading(deniedDatabaseIDs, limitDBNameSet, collectionMetricMap,
log); updateLimitErr != nil {
return updateLimitErr
}
if updateLimitErr := q.coolOffCollectionReading(deniedDatabaseIDs, &limitCollectionSet, limitCollectionNameSet,
collectionMetricMap, log); updateLimitErr != nil {
return updateLimitErr
}
return nil
}

View File

@ -563,17 +563,27 @@ func TestQuotaCenter(t *testing.T) {
ID: 0,
Name: "default",
},
{
ID: 1,
Name: "db1",
},
}, nil).Maybe()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrDatabaseNotFound).Maybe()
quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta)
quotaCenter.clearMetrics()
quotaCenter.collectionIDToDBID = collectionIDToDBID
quotaCenter.readableCollections = map[int64]map[int64][]int64{
0: collectionIDToPartitionIDs,
0: {1: {}, 2: {}, 3: {}},
1: {4: {}},
}
quotaCenter.dbs.Insert("default", 0)
quotaCenter.dbs.Insert("db1", 1)
quotaCenter.collections.Insert("0.col1", 1)
quotaCenter.collections.Insert("0.col2", 2)
quotaCenter.collections.Insert("0.col3", 3)
quotaCenter.collections.Insert("1.col4", 4)
colSubLabel := ratelimitutil.GetCollectionSubLabel("default", "col1")
quotaCenter.proxyMetrics = map[UniqueID]*metricsinfo.ProxyQuotaMetrics{
1: {Rms: []metricsinfo.RateMetric{
@ -652,6 +662,41 @@ func TestQuotaCenter(t *testing.T) {
err = quotaCenter.calculateReadRates()
assert.NoError(t, err)
checkLimiter()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Unset()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, i int64, u uint64) (*model.Database, error) {
if i == 1 {
return &model.Database{
ID: 1,
Name: "db1",
Properties: []*commonpb.KeyValuePair{
{
Key: common.DatabaseForceDenyReadingKey,
Value: "true",
},
},
}, nil
}
return nil, errors.New("mock error")
}).Maybe()
quotaCenter.resetAllCurrentRates()
err = quotaCenter.calculateReadRates()
assert.NoError(t, err)
assert.NoError(t, err)
rln := quotaCenter.rateLimiter.GetDatabaseLimiters(0)
limiters := rln.GetLimiters()
a, _ := limiters.Get(internalpb.RateType_DQLSearch)
assert.NotEqual(t, Limit(0), a.Limit())
b, _ := limiters.Get(internalpb.RateType_DQLQuery)
assert.NotEqual(t, Limit(0), b.Limit())
rln = quotaCenter.rateLimiter.GetDatabaseLimiters(1)
limiters = rln.GetLimiters()
a, _ = limiters.Get(internalpb.RateType_DQLSearch)
assert.Equal(t, Limit(0), a.Limit())
b, _ = limiters.Get(internalpb.RateType_DQLQuery)
assert.Equal(t, Limit(0), b.Limit())
})
t.Run("test calculateWriteRates", func(t *testing.T) {
@ -1544,6 +1589,8 @@ func TestCalculateReadRates(t *testing.T) {
t.Run("cool off db", func(t *testing.T) {
qc := mocks.NewMockQueryCoordClient(t)
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrDatabaseNotFound).Maybe()
pcm := proxyutil.NewMockProxyClientManager(t)
dc := mocks.NewMockDataCoordClient(t)
core, _ := NewCore(ctx, nil)

View File

@ -130,7 +130,7 @@ func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) {
err := limitNode.GetQuotaExceededError(internalpb.RateType_DMLInsert)
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
// reference: ratelimitutil.GetQuotaErrorString(errCode)
assert.True(t, strings.Contains(err.Error(), "deactivated"))
assert.True(t, strings.Contains(err.Error(), "disabled"))
})
t.Run("read", func(t *testing.T) {
@ -139,7 +139,7 @@ func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) {
err := limitNode.GetQuotaExceededError(internalpb.RateType_DQLSearch)
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
// reference: ratelimitutil.GetQuotaErrorString(errCode)
assert.True(t, strings.Contains(err.Error(), "deactivated"))
assert.True(t, strings.Contains(err.Error(), "disabled"))
})
t.Run("unknown", func(t *testing.T) {

View File

@ -159,6 +159,7 @@ const (
DatabaseDiskQuotaKey = "database.diskQuota.mb"
DatabaseMaxCollectionsKey = "database.max.collections"
DatabaseForceDenyWritingKey = "database.force.deny.writing"
DatabaseForceDenyReadingKey = "database.force.deny.reading"
// collection level load properties
CollectionReplicaNumber = "collection.replica.number"

View File

@ -19,7 +19,7 @@ package ratelimitutil
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
var QuotaErrorString = map[commonpb.ErrorCode]string{
commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator",
commonpb.ErrorCode_ForceDeny: "access has been disabled by the administrator",
commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources",
commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources",
commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay",

View File

@ -15,7 +15,7 @@ func TestGetQuotaErrorString(t *testing.T) {
{
name: "Test ErrorCode_ForceDeny",
args: commonpb.ErrorCode_ForceDeny,
want: "the writing has been deactivated by the administrator",
want: "access has been disabled by the administrator",
},
{
name: "Test ErrorCode_MemoryQuotaExhausted",