enhance: set database properties to restrict read access (#35745)

issue: #35744

Signed-off-by: jaime <yun.zhang@zilliz.com>
This commit is contained in:
jaime 2024-08-29 13:17:01 +08:00 committed by GitHub
parent b51b4a2838
commit b0ac04d104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 231 additions and 101 deletions

View File

@ -407,7 +407,7 @@ func (q *QuotaCenter) collectMetrics() error {
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID) coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
if getErr != nil { if getErr != nil {
// skip limit check if the collection meta has been removed from rootcoord meta // skip limit check if the collection meta has been removed from rootcoord meta
return false return true
} }
collIDToPartIDs, ok := q.readableCollections[coll.DBID] collIDToPartIDs, ok := q.readableCollections[coll.DBID]
if !ok { if !ok {
@ -463,7 +463,7 @@ func (q *QuotaCenter) collectMetrics() error {
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID) coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
if getErr != nil { if getErr != nil {
// skip limit check if the collection meta has been removed from rootcoord meta // skip limit check if the collection meta has been removed from rootcoord meta
return false return true
} }
collIDToPartIDs, ok := q.writableCollections[coll.DBID] collIDToPartIDs, ok := q.writableCollections[coll.DBID]
@ -562,7 +562,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
dbLimiters := q.rateLimiter.GetDatabaseLimiters(dbID) dbLimiters := q.rateLimiter.GetDatabaseLimiters(dbID)
if dbLimiters == nil { if dbLimiters == nil {
log.Warn("db limiter not found of db ID", zap.Int64("dbID", dbID)) log.Warn("db limiter not found of db ID", zap.Int64("dbID", dbID))
return fmt.Errorf("db limiter not found of db ID: %d", dbID) continue
} }
updateLimiter(dbLimiters, GetEarliestLimiter(), internalpb.RateScope_Database, dml) updateLimiter(dbLimiters, GetEarliestLimiter(), internalpb.RateScope_Database, dml)
dbLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode) dbLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
@ -578,7 +578,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
log.Warn("collection limiter not found of collection ID", log.Warn("collection limiter not found of collection ID",
zap.Int64("dbID", dbID), zap.Int64("dbID", dbID),
zap.Int64("collectionID", collectionID)) zap.Int64("collectionID", collectionID))
return fmt.Errorf("collection limiter not found of collection ID: %d", collectionID) continue
} }
updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dml) updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dml)
collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode) collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
@ -596,7 +596,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
zap.Int64("dbID", dbID), zap.Int64("dbID", dbID),
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64("partitionID", partitionID)) zap.Int64("partitionID", partitionID))
return fmt.Errorf("partition limiter not found of partition ID: %d", partitionID) continue
} }
updateLimiter(partitionLimiter, GetEarliestLimiter(), internalpb.RateScope_Partition, dml) updateLimiter(partitionLimiter, GetEarliestLimiter(), internalpb.RateScope_Partition, dml)
partitionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode) partitionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
@ -604,7 +604,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
} }
if cluster || len(dbIDs) > 0 || len(collectionIDs) > 0 || len(col2partitionIDs) > 0 { if cluster || len(dbIDs) > 0 || len(collectionIDs) > 0 || len(col2partitionIDs) > 0 {
log.RatedWarn(10, "QuotaCenter force to deny writing", log.RatedWarn(30, "QuotaCenter force to deny writing",
zap.Bool("cluster", cluster), zap.Bool("cluster", cluster),
zap.Int64s("dbIDs", dbIDs), zap.Int64s("dbIDs", dbIDs),
zap.Int64s("collectionIDs", collectionIDs), zap.Int64s("collectionIDs", collectionIDs),
@ -616,7 +616,8 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
} }
// forceDenyReading sets dql rates to 0 to reject all dql requests. // forceDenyReading sets dql rates to 0 to reject all dql requests.
func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode) { func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode, cluster bool, dbIDs []int64, mlog *log.MLogger) {
if cluster {
var collectionIDs []int64 var collectionIDs []int64
for dbID, collectionIDToPartIDs := range q.readableCollections { for dbID, collectionIDToPartIDs := range q.readableCollections {
for collectionID := range collectionIDToPartIDs { for collectionID := range collectionIDToPartIDs {
@ -627,11 +628,27 @@ func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode) {
} }
} }
log.Warn("QuotaCenter force to deny reading", mlog.RatedWarn(10, "QuotaCenter force to deny reading",
zap.Int64s("collectionIDs", collectionIDs), zap.Int64s("collectionIDs", collectionIDs),
zap.String("reason", errorCode.String())) zap.String("reason", errorCode.String()))
} }
if len(dbIDs) > 0 {
for _, dbID := range dbIDs {
dbLimiters := q.rateLimiter.GetDatabaseLimiters(dbID)
if dbLimiters == nil {
log.Warn("db limiter not found of db ID", zap.Int64("dbID", dbID))
continue
}
updateLimiter(dbLimiters, GetEarliestLimiter(), internalpb.RateScope_Database, dql)
dbLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, errorCode)
mlog.RatedWarn(10, "QuotaCenter force to deny reading",
zap.Int64s("dbIDs", dbIDs),
zap.String("reason", errorCode.String()))
}
}
}
// getRealTimeRate return real time rate in Proxy. // getRealTimeRate return real time rate in Proxy.
func (q *QuotaCenter) getRealTimeRate(label string) float64 { func (q *QuotaCenter) getRealTimeRate(label string) float64 {
var rate float64 var rate float64
@ -654,58 +671,26 @@ func (q *QuotaCenter) guaranteeMinRate(minRate float64, rt internalpb.RateType,
} }
} }
// calculateReadRates calculates and sets dql rates. func (q *QuotaCenter) getDenyReadingDBs() map[int64]struct{} {
func (q *QuotaCenter) calculateReadRates() error { dbIDs := make(map[int64]struct{})
log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0) for _, dbID := range lo.Uniq(q.collectionIDToDBID.Values()) {
if Params.QuotaConfig.ForceDenyReading.GetAsBool() { if db, err := q.meta.GetDatabaseByID(q.ctx, dbID, typeutil.MaxTimestamp); err == nil {
q.forceDenyReading(commonpb.ErrorCode_ForceDeny) if v := db.GetProperty(common.DatabaseForceDenyReadingKey); v != "" {
return nil if dbForceDenyReadingEnabled, _ := strconv.ParseBool(v); dbForceDenyReadingEnabled {
dbIDs[dbID] = struct{}{}
}
}
}
}
return dbIDs
} }
limitCollectionSet := typeutil.NewUniqueSet() // getReadRates get rate information of collections and databases from proxy metrics
limitDBNameSet := typeutil.NewSet[string]() func (q *QuotaCenter) getReadRates() (map[string]float64, map[string]map[string]map[string]float64) {
limitCollectionNameSet := typeutil.NewSet[string]() // label metric
clusterLimit := false metricMap := make(map[string]float64)
// sub label metric, label -> db -> collection -> value
formatCollctionRateKey := func(dbName, collectionName string) string { collectionMetricMap := make(map[string]map[string]map[string]float64)
return fmt.Sprintf("%s.%s", dbName, collectionName)
}
splitCollctionRateKey := func(key string) (string, string) {
parts := strings.Split(key, ".")
return parts[0], parts[1]
}
// query latency
queueLatencyThreshold := Params.QuotaConfig.QueueLatencyThreshold.GetAsDuration(time.Second)
// enableQueueProtection && queueLatencyThreshold >= 0 means enable queue latency protection
if queueLatencyThreshold >= 0 {
for _, metric := range q.queryNodeMetrics {
searchLatency := metric.SearchQueue.AvgQueueDuration
queryLatency := metric.QueryQueue.AvgQueueDuration
if searchLatency >= queueLatencyThreshold || queryLatency >= queueLatencyThreshold {
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
}
}
}
// queue length
enableQueueProtection := Params.QuotaConfig.QueueProtectionEnabled.GetAsBool()
nqInQueueThreshold := Params.QuotaConfig.NQInQueueThreshold.GetAsInt64()
if enableQueueProtection && nqInQueueThreshold >= 0 {
// >= 0 means enable queue length protection
sum := func(ri metricsinfo.ReadInfoInQueue) int64 {
return ri.UnsolvedQueue + ri.ReadyQueue + ri.ReceiveChan + ri.ExecuteChan
}
for _, metric := range q.queryNodeMetrics {
// We think of the NQ of query request as 1.
// search use same queue length counter with query
if sum(metric.SearchQueue) >= nqInQueueThreshold {
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
}
}
}
metricMap := make(map[string]float64) // label metric
collectionMetricMap := make(map[string]map[string]map[string]float64) // sub label metric, label -> db -> collection -> value
for _, metric := range q.proxyMetrics { for _, metric := range q.proxyMetrics {
for _, rm := range metric.Rms { for _, rm := range metric.Rms {
if !ratelimitutil.IsSubLabel(rm.Label) { if !ratelimitutil.IsSubLabel(rm.Label) {
@ -729,8 +714,20 @@ func (q *QuotaCenter) calculateReadRates() error {
databaseMetric[collection] += rm.Rate databaseMetric[collection] += rm.Rate
} }
} }
return metricMap, collectionMetricMap
}
func (q *QuotaCenter) getLimitedDBAndCollections(metricMap map[string]float64,
collectionMetricMap map[string]map[string]map[string]float64,
) (bool, *typeutil.Set[string], *typeutil.Set[string]) {
limitDBNameSet := typeutil.NewSet[string]()
limitCollectionNameSet := typeutil.NewSet[string]()
clusterLimit := false
formatCollctionRateKey := func(dbName, collectionName string) string {
return fmt.Sprintf("%s.%s", dbName, collectionName)
}
// read result
enableResultProtection := Params.QuotaConfig.ResultProtectionEnabled.GetAsBool() enableResultProtection := Params.QuotaConfig.ResultProtectionEnabled.GetAsBool()
if enableResultProtection { if enableResultProtection {
maxRate := Params.QuotaConfig.MaxReadResultRate.GetAsFloat() maxRate := Params.QuotaConfig.MaxReadResultRate.GetAsFloat()
@ -765,28 +762,12 @@ func (q *QuotaCenter) calculateReadRates() error {
} }
} }
} }
return clusterLimit, &limitDBNameSet, &limitCollectionNameSet
dbIDs := make(map[int64]string, q.dbs.Len())
collectionIDs := make(map[int64]string, q.collections.Len())
q.dbs.Range(func(name string, id int64) bool {
dbIDs[id] = name
return true
})
q.collections.Range(func(name string, id int64) bool {
_, collectionName := SplitCollectionKey(name)
collectionIDs[id] = collectionName
return true
})
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
if clusterLimit {
realTimeClusterSearchRate := metricMap[internalpb.RateType_DQLSearch.String()]
realTimeClusterQueryRate := metricMap[internalpb.RateType_DQLQuery.String()]
q.coolOffReading(realTimeClusterSearchRate, realTimeClusterQueryRate, coolOffSpeed, q.rateLimiter.GetRootLimiters(), log)
} }
var updateLimitErr error func (q *QuotaCenter) coolOffDatabaseReading(deniedDatabaseIDs map[int64]struct{}, limitDBNameSet *typeutil.Set[string],
collectionMetricMap map[string]map[string]map[string]float64, log *log.MLogger,
) error {
if limitDBNameSet.Len() > 0 { if limitDBNameSet.Len() > 0 {
databaseSearchRate := make(map[string]float64) databaseSearchRate := make(map[string]float64)
databaseQueryRate := make(map[string]float64) databaseQueryRate := make(map[string]float64)
@ -806,18 +787,24 @@ func (q *QuotaCenter) calculateReadRates() error {
} }
} }
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
limitDBNameSet.Range(func(name string) bool { limitDBNameSet.Range(func(name string) bool {
dbID, ok := q.dbs.Get(name) dbID, ok := q.dbs.Get(name)
if !ok { if !ok {
log.Warn("db not found", zap.String("dbName", name)) log.Warn("db not found", zap.String("dbName", name))
updateLimitErr = fmt.Errorf("db not found: %s", name) return true
return false
} }
// skip this database because it has been denied access for reading
_, ok = deniedDatabaseIDs[dbID]
if ok {
return true
}
dbLimiter := q.rateLimiter.GetDatabaseLimiters(dbID) dbLimiter := q.rateLimiter.GetDatabaseLimiters(dbID)
if dbLimiter == nil { if dbLimiter == nil {
log.Warn("database limiter not found", zap.Int64("dbID", dbID)) log.Warn("database limiter not found", zap.Int64("dbID", dbID))
updateLimitErr = fmt.Errorf("database limiter not found") return true
return false
} }
realTimeSearchRate := databaseSearchRate[name] realTimeSearchRate := databaseSearchRate[name]
@ -825,24 +812,46 @@ func (q *QuotaCenter) calculateReadRates() error {
q.coolOffReading(realTimeSearchRate, realTimeQueryRate, coolOffSpeed, dbLimiter, log) q.coolOffReading(realTimeSearchRate, realTimeQueryRate, coolOffSpeed, dbLimiter, log)
return true return true
}) })
if updateLimitErr != nil {
return updateLimitErr
} }
return nil
} }
func (q *QuotaCenter) coolOffCollectionReading(deniedDatabaseIDs map[int64]struct{}, limitCollectionSet *typeutil.UniqueSet, limitCollectionNameSet *typeutil.Set[string],
collectionMetricMap map[string]map[string]map[string]float64, log *log.MLogger,
) error {
var updateLimitErr error
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
splitCollctionRateKey := func(key string) (string, string) {
parts := strings.Split(key, ".")
return parts[0], parts[1]
}
dbIDs := make(map[int64]string, q.dbs.Len())
collectionIDs := make(map[int64]string, q.collections.Len())
q.dbs.Range(func(name string, id int64) bool {
dbIDs[id] = name
return true
})
q.collections.Range(func(name string, id int64) bool {
_, collectionName := SplitCollectionKey(name)
collectionIDs[id] = collectionName
return true
})
limitCollectionNameSet.Range(func(name string) bool { limitCollectionNameSet.Range(func(name string) bool {
dbName, collectionName := splitCollctionRateKey(name) dbName, collectionName := splitCollctionRateKey(name)
dbID, ok := q.dbs.Get(dbName) dbID, ok := q.dbs.Get(dbName)
if !ok { if !ok {
log.Warn("db not found", zap.String("dbName", dbName)) log.Warn("db not found", zap.String("dbName", dbName))
updateLimitErr = fmt.Errorf("db not found: %s", dbName) updateLimitErr = fmt.Errorf("db not found: %s", dbName)
return false return true
} }
collectionID, ok := q.collections.Get(FormatCollectionKey(dbID, collectionName)) collectionID, ok := q.collections.Get(FormatCollectionKey(dbID, collectionName))
if !ok { if !ok {
log.Warn("collection not found", zap.String("collectionName", name)) log.Warn("collection not found", zap.String("collectionName", name))
updateLimitErr = fmt.Errorf("collection not found: %s", name) updateLimitErr = fmt.Errorf("collection not found: %s", name)
return false return true
} }
limitCollectionSet.Insert(collectionID) limitCollectionSet.Insert(collectionID)
return true return true
@ -868,6 +877,12 @@ func (q *QuotaCenter) calculateReadRates() error {
if !ok { if !ok {
return fmt.Errorf("db ID not found of collection ID: %d", collection) return fmt.Errorf("db ID not found of collection ID: %d", collection)
} }
// skip this database because it has been denied access for reading
_, ok = deniedDatabaseIDs[dbID]
if ok {
continue
}
collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collection) collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collection)
if collectionLimiter == nil { if collectionLimiter == nil {
return fmt.Errorf("collection limiter not found: %d", collection) return fmt.Errorf("collection limiter not found: %d", collection)
@ -897,6 +912,73 @@ func (q *QuotaCenter) calculateReadRates() error {
if updateLimitErr = coolOffCollectionID(limitCollectionSet.Collect()...); updateLimitErr != nil { if updateLimitErr = coolOffCollectionID(limitCollectionSet.Collect()...); updateLimitErr != nil {
return updateLimitErr return updateLimitErr
} }
return nil
}
// calculateReadRates calculates and sets dql rates.
func (q *QuotaCenter) calculateReadRates() error {
log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0)
if Params.QuotaConfig.ForceDenyReading.GetAsBool() {
q.forceDenyReading(commonpb.ErrorCode_ForceDeny, true, []int64{}, log)
return nil
}
deniedDatabaseIDs := q.getDenyReadingDBs()
if len(deniedDatabaseIDs) != 0 {
q.forceDenyReading(commonpb.ErrorCode_ForceDeny, false, maps.Keys(deniedDatabaseIDs), log)
}
queueLatencyThreshold := Params.QuotaConfig.QueueLatencyThreshold.GetAsDuration(time.Second)
limitCollectionSet := typeutil.NewUniqueSet()
// enableQueueProtection && queueLatencyThreshold >= 0 means enable queue latency protection
if queueLatencyThreshold >= 0 {
for _, metric := range q.queryNodeMetrics {
searchLatency := metric.SearchQueue.AvgQueueDuration
queryLatency := metric.QueryQueue.AvgQueueDuration
if searchLatency >= queueLatencyThreshold || queryLatency >= queueLatencyThreshold {
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
}
}
}
// queue length
enableQueueProtection := Params.QuotaConfig.QueueProtectionEnabled.GetAsBool()
nqInQueueThreshold := Params.QuotaConfig.NQInQueueThreshold.GetAsInt64()
if enableQueueProtection && nqInQueueThreshold >= 0 {
// >= 0 means enable queue length protection
sum := func(ri metricsinfo.ReadInfoInQueue) int64 {
return ri.UnsolvedQueue + ri.ReadyQueue + ri.ReceiveChan + ri.ExecuteChan
}
for _, metric := range q.queryNodeMetrics {
// We think of the NQ of query request as 1.
// search use same queue length counter with query
if sum(metric.SearchQueue) >= nqInQueueThreshold {
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
}
}
}
metricMap, collectionMetricMap := q.getReadRates()
clusterLimit, limitDBNameSet, limitCollectionNameSet := q.getLimitedDBAndCollections(metricMap, collectionMetricMap)
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
if clusterLimit {
realTimeClusterSearchRate := metricMap[internalpb.RateType_DQLSearch.String()]
realTimeClusterQueryRate := metricMap[internalpb.RateType_DQLQuery.String()]
q.coolOffReading(realTimeClusterSearchRate, realTimeClusterQueryRate, coolOffSpeed, q.rateLimiter.GetRootLimiters(), log)
}
if updateLimitErr := q.coolOffDatabaseReading(deniedDatabaseIDs, limitDBNameSet, collectionMetricMap,
log); updateLimitErr != nil {
return updateLimitErr
}
if updateLimitErr := q.coolOffCollectionReading(deniedDatabaseIDs, &limitCollectionSet, limitCollectionNameSet,
collectionMetricMap, log); updateLimitErr != nil {
return updateLimitErr
}
return nil return nil
} }

View File

@ -563,17 +563,27 @@ func TestQuotaCenter(t *testing.T) {
ID: 0, ID: 0,
Name: "default", Name: "default",
}, },
{
ID: 1,
Name: "db1",
},
}, nil).Maybe() }, nil).Maybe()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrDatabaseNotFound).Maybe()
quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta)
quotaCenter.clearMetrics() quotaCenter.clearMetrics()
quotaCenter.collectionIDToDBID = collectionIDToDBID quotaCenter.collectionIDToDBID = collectionIDToDBID
quotaCenter.readableCollections = map[int64]map[int64][]int64{ quotaCenter.readableCollections = map[int64]map[int64][]int64{
0: collectionIDToPartitionIDs, 0: {1: {}, 2: {}, 3: {}},
1: {4: {}},
} }
quotaCenter.dbs.Insert("default", 0) quotaCenter.dbs.Insert("default", 0)
quotaCenter.dbs.Insert("db1", 1)
quotaCenter.collections.Insert("0.col1", 1) quotaCenter.collections.Insert("0.col1", 1)
quotaCenter.collections.Insert("0.col2", 2) quotaCenter.collections.Insert("0.col2", 2)
quotaCenter.collections.Insert("0.col3", 3) quotaCenter.collections.Insert("0.col3", 3)
quotaCenter.collections.Insert("1.col4", 4)
colSubLabel := ratelimitutil.GetCollectionSubLabel("default", "col1") colSubLabel := ratelimitutil.GetCollectionSubLabel("default", "col1")
quotaCenter.proxyMetrics = map[UniqueID]*metricsinfo.ProxyQuotaMetrics{ quotaCenter.proxyMetrics = map[UniqueID]*metricsinfo.ProxyQuotaMetrics{
1: {Rms: []metricsinfo.RateMetric{ 1: {Rms: []metricsinfo.RateMetric{
@ -652,6 +662,41 @@ func TestQuotaCenter(t *testing.T) {
err = quotaCenter.calculateReadRates() err = quotaCenter.calculateReadRates()
assert.NoError(t, err) assert.NoError(t, err)
checkLimiter() checkLimiter()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Unset()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, i int64, u uint64) (*model.Database, error) {
if i == 1 {
return &model.Database{
ID: 1,
Name: "db1",
Properties: []*commonpb.KeyValuePair{
{
Key: common.DatabaseForceDenyReadingKey,
Value: "true",
},
},
}, nil
}
return nil, errors.New("mock error")
}).Maybe()
quotaCenter.resetAllCurrentRates()
err = quotaCenter.calculateReadRates()
assert.NoError(t, err)
assert.NoError(t, err)
rln := quotaCenter.rateLimiter.GetDatabaseLimiters(0)
limiters := rln.GetLimiters()
a, _ := limiters.Get(internalpb.RateType_DQLSearch)
assert.NotEqual(t, Limit(0), a.Limit())
b, _ := limiters.Get(internalpb.RateType_DQLQuery)
assert.NotEqual(t, Limit(0), b.Limit())
rln = quotaCenter.rateLimiter.GetDatabaseLimiters(1)
limiters = rln.GetLimiters()
a, _ = limiters.Get(internalpb.RateType_DQLSearch)
assert.Equal(t, Limit(0), a.Limit())
b, _ = limiters.Get(internalpb.RateType_DQLQuery)
assert.Equal(t, Limit(0), b.Limit())
}) })
t.Run("test calculateWriteRates", func(t *testing.T) { t.Run("test calculateWriteRates", func(t *testing.T) {
@ -1544,6 +1589,8 @@ func TestCalculateReadRates(t *testing.T) {
t.Run("cool off db", func(t *testing.T) { t.Run("cool off db", func(t *testing.T) {
qc := mocks.NewMockQueryCoordClient(t) qc := mocks.NewMockQueryCoordClient(t)
meta := mockrootcoord.NewIMetaTable(t) meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrDatabaseNotFound).Maybe()
pcm := proxyutil.NewMockProxyClientManager(t) pcm := proxyutil.NewMockProxyClientManager(t)
dc := mocks.NewMockDataCoordClient(t) dc := mocks.NewMockDataCoordClient(t)
core, _ := NewCore(ctx, nil) core, _ := NewCore(ctx, nil)

View File

@ -130,7 +130,7 @@ func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) {
err := limitNode.GetQuotaExceededError(internalpb.RateType_DMLInsert) err := limitNode.GetQuotaExceededError(internalpb.RateType_DMLInsert)
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded)) assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
// reference: ratelimitutil.GetQuotaErrorString(errCode) // reference: ratelimitutil.GetQuotaErrorString(errCode)
assert.True(t, strings.Contains(err.Error(), "deactivated")) assert.True(t, strings.Contains(err.Error(), "disabled"))
}) })
t.Run("read", func(t *testing.T) { t.Run("read", func(t *testing.T) {
@ -139,7 +139,7 @@ func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) {
err := limitNode.GetQuotaExceededError(internalpb.RateType_DQLSearch) err := limitNode.GetQuotaExceededError(internalpb.RateType_DQLSearch)
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded)) assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
// reference: ratelimitutil.GetQuotaErrorString(errCode) // reference: ratelimitutil.GetQuotaErrorString(errCode)
assert.True(t, strings.Contains(err.Error(), "deactivated")) assert.True(t, strings.Contains(err.Error(), "disabled"))
}) })
t.Run("unknown", func(t *testing.T) { t.Run("unknown", func(t *testing.T) {

View File

@ -159,6 +159,7 @@ const (
DatabaseDiskQuotaKey = "database.diskQuota.mb" DatabaseDiskQuotaKey = "database.diskQuota.mb"
DatabaseMaxCollectionsKey = "database.max.collections" DatabaseMaxCollectionsKey = "database.max.collections"
DatabaseForceDenyWritingKey = "database.force.deny.writing" DatabaseForceDenyWritingKey = "database.force.deny.writing"
DatabaseForceDenyReadingKey = "database.force.deny.reading"
// collection level load properties // collection level load properties
CollectionReplicaNumber = "collection.replica.number" CollectionReplicaNumber = "collection.replica.number"

View File

@ -19,7 +19,7 @@ package ratelimitutil
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
var QuotaErrorString = map[commonpb.ErrorCode]string{ var QuotaErrorString = map[commonpb.ErrorCode]string{
commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator", commonpb.ErrorCode_ForceDeny: "access has been disabled by the administrator",
commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources", commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources",
commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources", commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources",
commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay", commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay",

View File

@ -15,7 +15,7 @@ func TestGetQuotaErrorString(t *testing.T) {
{ {
name: "Test ErrorCode_ForceDeny", name: "Test ErrorCode_ForceDeny",
args: commonpb.ErrorCode_ForceDeny, args: commonpb.ErrorCode_ForceDeny,
want: "the writing has been deactivated by the administrator", want: "access has been disabled by the administrator",
}, },
{ {
name: "Test ErrorCode_MemoryQuotaExhausted", name: "Test ErrorCode_MemoryQuotaExhausted",