mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-29 18:38:44 +08:00
enhance: set database properties to restrict read access (#35745)
issue: #35744 Signed-off-by: jaime <yun.zhang@zilliz.com>
This commit is contained in:
parent
b51b4a2838
commit
b0ac04d104
@ -407,7 +407,7 @@ func (q *QuotaCenter) collectMetrics() error {
|
||||
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
|
||||
if getErr != nil {
|
||||
// skip limit check if the collection meta has been removed from rootcoord meta
|
||||
return false
|
||||
return true
|
||||
}
|
||||
collIDToPartIDs, ok := q.readableCollections[coll.DBID]
|
||||
if !ok {
|
||||
@ -463,7 +463,7 @@ func (q *QuotaCenter) collectMetrics() error {
|
||||
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
|
||||
if getErr != nil {
|
||||
// skip limit check if the collection meta has been removed from rootcoord meta
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
collIDToPartIDs, ok := q.writableCollections[coll.DBID]
|
||||
@ -562,7 +562,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
|
||||
dbLimiters := q.rateLimiter.GetDatabaseLimiters(dbID)
|
||||
if dbLimiters == nil {
|
||||
log.Warn("db limiter not found of db ID", zap.Int64("dbID", dbID))
|
||||
return fmt.Errorf("db limiter not found of db ID: %d", dbID)
|
||||
continue
|
||||
}
|
||||
updateLimiter(dbLimiters, GetEarliestLimiter(), internalpb.RateScope_Database, dml)
|
||||
dbLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
|
||||
@ -578,7 +578,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
|
||||
log.Warn("collection limiter not found of collection ID",
|
||||
zap.Int64("dbID", dbID),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
return fmt.Errorf("collection limiter not found of collection ID: %d", collectionID)
|
||||
continue
|
||||
}
|
||||
updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dml)
|
||||
collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
|
||||
@ -596,7 +596,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
|
||||
zap.Int64("dbID", dbID),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("partitionID", partitionID))
|
||||
return fmt.Errorf("partition limiter not found of partition ID: %d", partitionID)
|
||||
continue
|
||||
}
|
||||
updateLimiter(partitionLimiter, GetEarliestLimiter(), internalpb.RateScope_Partition, dml)
|
||||
partitionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode)
|
||||
@ -604,7 +604,7 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
|
||||
}
|
||||
|
||||
if cluster || len(dbIDs) > 0 || len(collectionIDs) > 0 || len(col2partitionIDs) > 0 {
|
||||
log.RatedWarn(10, "QuotaCenter force to deny writing",
|
||||
log.RatedWarn(30, "QuotaCenter force to deny writing",
|
||||
zap.Bool("cluster", cluster),
|
||||
zap.Int64s("dbIDs", dbIDs),
|
||||
zap.Int64s("collectionIDs", collectionIDs),
|
||||
@ -616,20 +616,37 @@ func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster boo
|
||||
}
|
||||
|
||||
// forceDenyReading sets dql rates to 0 to reject all dql requests.
|
||||
func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode) {
|
||||
var collectionIDs []int64
|
||||
for dbID, collectionIDToPartIDs := range q.readableCollections {
|
||||
for collectionID := range collectionIDToPartIDs {
|
||||
collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collectionID)
|
||||
updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dql)
|
||||
collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, errorCode)
|
||||
collectionIDs = append(collectionIDs, collectionID)
|
||||
func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode, cluster bool, dbIDs []int64, mlog *log.MLogger) {
|
||||
if cluster {
|
||||
var collectionIDs []int64
|
||||
for dbID, collectionIDToPartIDs := range q.readableCollections {
|
||||
for collectionID := range collectionIDToPartIDs {
|
||||
collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collectionID)
|
||||
updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dql)
|
||||
collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, errorCode)
|
||||
collectionIDs = append(collectionIDs, collectionID)
|
||||
}
|
||||
}
|
||||
|
||||
mlog.RatedWarn(10, "QuotaCenter force to deny reading",
|
||||
zap.Int64s("collectionIDs", collectionIDs),
|
||||
zap.String("reason", errorCode.String()))
|
||||
}
|
||||
|
||||
log.Warn("QuotaCenter force to deny reading",
|
||||
zap.Int64s("collectionIDs", collectionIDs),
|
||||
zap.String("reason", errorCode.String()))
|
||||
if len(dbIDs) > 0 {
|
||||
for _, dbID := range dbIDs {
|
||||
dbLimiters := q.rateLimiter.GetDatabaseLimiters(dbID)
|
||||
if dbLimiters == nil {
|
||||
log.Warn("db limiter not found of db ID", zap.Int64("dbID", dbID))
|
||||
continue
|
||||
}
|
||||
updateLimiter(dbLimiters, GetEarliestLimiter(), internalpb.RateScope_Database, dql)
|
||||
dbLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, errorCode)
|
||||
mlog.RatedWarn(10, "QuotaCenter force to deny reading",
|
||||
zap.Int64s("dbIDs", dbIDs),
|
||||
zap.String("reason", errorCode.String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRealTimeRate return real time rate in Proxy.
|
||||
@ -654,58 +671,26 @@ func (q *QuotaCenter) guaranteeMinRate(minRate float64, rt internalpb.RateType,
|
||||
}
|
||||
}
|
||||
|
||||
// calculateReadRates calculates and sets dql rates.
|
||||
func (q *QuotaCenter) calculateReadRates() error {
|
||||
log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0)
|
||||
if Params.QuotaConfig.ForceDenyReading.GetAsBool() {
|
||||
q.forceDenyReading(commonpb.ErrorCode_ForceDeny)
|
||||
return nil
|
||||
}
|
||||
|
||||
limitCollectionSet := typeutil.NewUniqueSet()
|
||||
limitDBNameSet := typeutil.NewSet[string]()
|
||||
limitCollectionNameSet := typeutil.NewSet[string]()
|
||||
clusterLimit := false
|
||||
|
||||
formatCollctionRateKey := func(dbName, collectionName string) string {
|
||||
return fmt.Sprintf("%s.%s", dbName, collectionName)
|
||||
}
|
||||
splitCollctionRateKey := func(key string) (string, string) {
|
||||
parts := strings.Split(key, ".")
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
// query latency
|
||||
queueLatencyThreshold := Params.QuotaConfig.QueueLatencyThreshold.GetAsDuration(time.Second)
|
||||
// enableQueueProtection && queueLatencyThreshold >= 0 means enable queue latency protection
|
||||
if queueLatencyThreshold >= 0 {
|
||||
for _, metric := range q.queryNodeMetrics {
|
||||
searchLatency := metric.SearchQueue.AvgQueueDuration
|
||||
queryLatency := metric.QueryQueue.AvgQueueDuration
|
||||
if searchLatency >= queueLatencyThreshold || queryLatency >= queueLatencyThreshold {
|
||||
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
|
||||
func (q *QuotaCenter) getDenyReadingDBs() map[int64]struct{} {
|
||||
dbIDs := make(map[int64]struct{})
|
||||
for _, dbID := range lo.Uniq(q.collectionIDToDBID.Values()) {
|
||||
if db, err := q.meta.GetDatabaseByID(q.ctx, dbID, typeutil.MaxTimestamp); err == nil {
|
||||
if v := db.GetProperty(common.DatabaseForceDenyReadingKey); v != "" {
|
||||
if dbForceDenyReadingEnabled, _ := strconv.ParseBool(v); dbForceDenyReadingEnabled {
|
||||
dbIDs[dbID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return dbIDs
|
||||
}
|
||||
|
||||
// queue length
|
||||
enableQueueProtection := Params.QuotaConfig.QueueProtectionEnabled.GetAsBool()
|
||||
nqInQueueThreshold := Params.QuotaConfig.NQInQueueThreshold.GetAsInt64()
|
||||
if enableQueueProtection && nqInQueueThreshold >= 0 {
|
||||
// >= 0 means enable queue length protection
|
||||
sum := func(ri metricsinfo.ReadInfoInQueue) int64 {
|
||||
return ri.UnsolvedQueue + ri.ReadyQueue + ri.ReceiveChan + ri.ExecuteChan
|
||||
}
|
||||
for _, metric := range q.queryNodeMetrics {
|
||||
// We think of the NQ of query request as 1.
|
||||
// search use same queue length counter with query
|
||||
if sum(metric.SearchQueue) >= nqInQueueThreshold {
|
||||
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metricMap := make(map[string]float64) // label metric
|
||||
collectionMetricMap := make(map[string]map[string]map[string]float64) // sub label metric, label -> db -> collection -> value
|
||||
// getReadRates get rate information of collections and databases from proxy metrics
|
||||
func (q *QuotaCenter) getReadRates() (map[string]float64, map[string]map[string]map[string]float64) {
|
||||
// label metric
|
||||
metricMap := make(map[string]float64)
|
||||
// sub label metric, label -> db -> collection -> value
|
||||
collectionMetricMap := make(map[string]map[string]map[string]float64)
|
||||
for _, metric := range q.proxyMetrics {
|
||||
for _, rm := range metric.Rms {
|
||||
if !ratelimitutil.IsSubLabel(rm.Label) {
|
||||
@ -729,8 +714,20 @@ func (q *QuotaCenter) calculateReadRates() error {
|
||||
databaseMetric[collection] += rm.Rate
|
||||
}
|
||||
}
|
||||
return metricMap, collectionMetricMap
|
||||
}
|
||||
|
||||
func (q *QuotaCenter) getLimitedDBAndCollections(metricMap map[string]float64,
|
||||
collectionMetricMap map[string]map[string]map[string]float64,
|
||||
) (bool, *typeutil.Set[string], *typeutil.Set[string]) {
|
||||
limitDBNameSet := typeutil.NewSet[string]()
|
||||
limitCollectionNameSet := typeutil.NewSet[string]()
|
||||
clusterLimit := false
|
||||
|
||||
formatCollctionRateKey := func(dbName, collectionName string) string {
|
||||
return fmt.Sprintf("%s.%s", dbName, collectionName)
|
||||
}
|
||||
|
||||
// read result
|
||||
enableResultProtection := Params.QuotaConfig.ResultProtectionEnabled.GetAsBool()
|
||||
if enableResultProtection {
|
||||
maxRate := Params.QuotaConfig.MaxReadResultRate.GetAsFloat()
|
||||
@ -765,28 +762,12 @@ func (q *QuotaCenter) calculateReadRates() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
return clusterLimit, &limitDBNameSet, &limitCollectionNameSet
|
||||
}
|
||||
|
||||
dbIDs := make(map[int64]string, q.dbs.Len())
|
||||
collectionIDs := make(map[int64]string, q.collections.Len())
|
||||
q.dbs.Range(func(name string, id int64) bool {
|
||||
dbIDs[id] = name
|
||||
return true
|
||||
})
|
||||
q.collections.Range(func(name string, id int64) bool {
|
||||
_, collectionName := SplitCollectionKey(name)
|
||||
collectionIDs[id] = collectionName
|
||||
return true
|
||||
})
|
||||
|
||||
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
|
||||
|
||||
if clusterLimit {
|
||||
realTimeClusterSearchRate := metricMap[internalpb.RateType_DQLSearch.String()]
|
||||
realTimeClusterQueryRate := metricMap[internalpb.RateType_DQLQuery.String()]
|
||||
q.coolOffReading(realTimeClusterSearchRate, realTimeClusterQueryRate, coolOffSpeed, q.rateLimiter.GetRootLimiters(), log)
|
||||
}
|
||||
|
||||
var updateLimitErr error
|
||||
func (q *QuotaCenter) coolOffDatabaseReading(deniedDatabaseIDs map[int64]struct{}, limitDBNameSet *typeutil.Set[string],
|
||||
collectionMetricMap map[string]map[string]map[string]float64, log *log.MLogger,
|
||||
) error {
|
||||
if limitDBNameSet.Len() > 0 {
|
||||
databaseSearchRate := make(map[string]float64)
|
||||
databaseQueryRate := make(map[string]float64)
|
||||
@ -806,18 +787,24 @@ func (q *QuotaCenter) calculateReadRates() error {
|
||||
}
|
||||
}
|
||||
|
||||
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
|
||||
limitDBNameSet.Range(func(name string) bool {
|
||||
dbID, ok := q.dbs.Get(name)
|
||||
if !ok {
|
||||
log.Warn("db not found", zap.String("dbName", name))
|
||||
updateLimitErr = fmt.Errorf("db not found: %s", name)
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
// skip this database because it has been denied access for reading
|
||||
_, ok = deniedDatabaseIDs[dbID]
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
dbLimiter := q.rateLimiter.GetDatabaseLimiters(dbID)
|
||||
if dbLimiter == nil {
|
||||
log.Warn("database limiter not found", zap.Int64("dbID", dbID))
|
||||
updateLimitErr = fmt.Errorf("database limiter not found")
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
realTimeSearchRate := databaseSearchRate[name]
|
||||
@ -825,10 +812,32 @@ func (q *QuotaCenter) calculateReadRates() error {
|
||||
q.coolOffReading(realTimeSearchRate, realTimeQueryRate, coolOffSpeed, dbLimiter, log)
|
||||
return true
|
||||
})
|
||||
if updateLimitErr != nil {
|
||||
return updateLimitErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *QuotaCenter) coolOffCollectionReading(deniedDatabaseIDs map[int64]struct{}, limitCollectionSet *typeutil.UniqueSet, limitCollectionNameSet *typeutil.Set[string],
|
||||
collectionMetricMap map[string]map[string]map[string]float64, log *log.MLogger,
|
||||
) error {
|
||||
var updateLimitErr error
|
||||
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
|
||||
|
||||
splitCollctionRateKey := func(key string) (string, string) {
|
||||
parts := strings.Split(key, ".")
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
|
||||
dbIDs := make(map[int64]string, q.dbs.Len())
|
||||
collectionIDs := make(map[int64]string, q.collections.Len())
|
||||
q.dbs.Range(func(name string, id int64) bool {
|
||||
dbIDs[id] = name
|
||||
return true
|
||||
})
|
||||
q.collections.Range(func(name string, id int64) bool {
|
||||
_, collectionName := SplitCollectionKey(name)
|
||||
collectionIDs[id] = collectionName
|
||||
return true
|
||||
})
|
||||
|
||||
limitCollectionNameSet.Range(func(name string) bool {
|
||||
dbName, collectionName := splitCollctionRateKey(name)
|
||||
@ -836,13 +845,13 @@ func (q *QuotaCenter) calculateReadRates() error {
|
||||
if !ok {
|
||||
log.Warn("db not found", zap.String("dbName", dbName))
|
||||
updateLimitErr = fmt.Errorf("db not found: %s", dbName)
|
||||
return false
|
||||
return true
|
||||
}
|
||||
collectionID, ok := q.collections.Get(FormatCollectionKey(dbID, collectionName))
|
||||
if !ok {
|
||||
log.Warn("collection not found", zap.String("collectionName", name))
|
||||
updateLimitErr = fmt.Errorf("collection not found: %s", name)
|
||||
return false
|
||||
return true
|
||||
}
|
||||
limitCollectionSet.Insert(collectionID)
|
||||
return true
|
||||
@ -868,6 +877,12 @@ func (q *QuotaCenter) calculateReadRates() error {
|
||||
if !ok {
|
||||
return fmt.Errorf("db ID not found of collection ID: %d", collection)
|
||||
}
|
||||
// skip this database because it has been denied access for reading
|
||||
_, ok = deniedDatabaseIDs[dbID]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collection)
|
||||
if collectionLimiter == nil {
|
||||
return fmt.Errorf("collection limiter not found: %d", collection)
|
||||
@ -897,6 +912,73 @@ func (q *QuotaCenter) calculateReadRates() error {
|
||||
if updateLimitErr = coolOffCollectionID(limitCollectionSet.Collect()...); updateLimitErr != nil {
|
||||
return updateLimitErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateReadRates calculates and sets dql rates.
|
||||
func (q *QuotaCenter) calculateReadRates() error {
|
||||
log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0)
|
||||
if Params.QuotaConfig.ForceDenyReading.GetAsBool() {
|
||||
q.forceDenyReading(commonpb.ErrorCode_ForceDeny, true, []int64{}, log)
|
||||
return nil
|
||||
}
|
||||
|
||||
deniedDatabaseIDs := q.getDenyReadingDBs()
|
||||
if len(deniedDatabaseIDs) != 0 {
|
||||
q.forceDenyReading(commonpb.ErrorCode_ForceDeny, false, maps.Keys(deniedDatabaseIDs), log)
|
||||
}
|
||||
|
||||
queueLatencyThreshold := Params.QuotaConfig.QueueLatencyThreshold.GetAsDuration(time.Second)
|
||||
limitCollectionSet := typeutil.NewUniqueSet()
|
||||
|
||||
// enableQueueProtection && queueLatencyThreshold >= 0 means enable queue latency protection
|
||||
if queueLatencyThreshold >= 0 {
|
||||
for _, metric := range q.queryNodeMetrics {
|
||||
searchLatency := metric.SearchQueue.AvgQueueDuration
|
||||
queryLatency := metric.QueryQueue.AvgQueueDuration
|
||||
if searchLatency >= queueLatencyThreshold || queryLatency >= queueLatencyThreshold {
|
||||
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// queue length
|
||||
enableQueueProtection := Params.QuotaConfig.QueueProtectionEnabled.GetAsBool()
|
||||
nqInQueueThreshold := Params.QuotaConfig.NQInQueueThreshold.GetAsInt64()
|
||||
if enableQueueProtection && nqInQueueThreshold >= 0 {
|
||||
// >= 0 means enable queue length protection
|
||||
sum := func(ri metricsinfo.ReadInfoInQueue) int64 {
|
||||
return ri.UnsolvedQueue + ri.ReadyQueue + ri.ReceiveChan + ri.ExecuteChan
|
||||
}
|
||||
for _, metric := range q.queryNodeMetrics {
|
||||
// We think of the NQ of query request as 1.
|
||||
// search use same queue length counter with query
|
||||
if sum(metric.SearchQueue) >= nqInQueueThreshold {
|
||||
limitCollectionSet.Insert(metric.Effect.CollectionIDs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metricMap, collectionMetricMap := q.getReadRates()
|
||||
clusterLimit, limitDBNameSet, limitCollectionNameSet := q.getLimitedDBAndCollections(metricMap, collectionMetricMap)
|
||||
|
||||
coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat()
|
||||
|
||||
if clusterLimit {
|
||||
realTimeClusterSearchRate := metricMap[internalpb.RateType_DQLSearch.String()]
|
||||
realTimeClusterQueryRate := metricMap[internalpb.RateType_DQLQuery.String()]
|
||||
q.coolOffReading(realTimeClusterSearchRate, realTimeClusterQueryRate, coolOffSpeed, q.rateLimiter.GetRootLimiters(), log)
|
||||
}
|
||||
|
||||
if updateLimitErr := q.coolOffDatabaseReading(deniedDatabaseIDs, limitDBNameSet, collectionMetricMap,
|
||||
log); updateLimitErr != nil {
|
||||
return updateLimitErr
|
||||
}
|
||||
|
||||
if updateLimitErr := q.coolOffCollectionReading(deniedDatabaseIDs, &limitCollectionSet, limitCollectionNameSet,
|
||||
collectionMetricMap, log); updateLimitErr != nil {
|
||||
return updateLimitErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -563,17 +563,27 @@ func TestQuotaCenter(t *testing.T) {
|
||||
ID: 0,
|
||||
Name: "default",
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
Name: "db1",
|
||||
},
|
||||
}, nil).Maybe()
|
||||
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrDatabaseNotFound).Maybe()
|
||||
|
||||
quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta)
|
||||
quotaCenter.clearMetrics()
|
||||
quotaCenter.collectionIDToDBID = collectionIDToDBID
|
||||
quotaCenter.readableCollections = map[int64]map[int64][]int64{
|
||||
0: collectionIDToPartitionIDs,
|
||||
0: {1: {}, 2: {}, 3: {}},
|
||||
1: {4: {}},
|
||||
}
|
||||
quotaCenter.dbs.Insert("default", 0)
|
||||
quotaCenter.dbs.Insert("db1", 1)
|
||||
quotaCenter.collections.Insert("0.col1", 1)
|
||||
quotaCenter.collections.Insert("0.col2", 2)
|
||||
quotaCenter.collections.Insert("0.col3", 3)
|
||||
quotaCenter.collections.Insert("1.col4", 4)
|
||||
|
||||
colSubLabel := ratelimitutil.GetCollectionSubLabel("default", "col1")
|
||||
quotaCenter.proxyMetrics = map[UniqueID]*metricsinfo.ProxyQuotaMetrics{
|
||||
1: {Rms: []metricsinfo.RateMetric{
|
||||
@ -652,6 +662,41 @@ func TestQuotaCenter(t *testing.T) {
|
||||
err = quotaCenter.calculateReadRates()
|
||||
assert.NoError(t, err)
|
||||
checkLimiter()
|
||||
|
||||
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Unset()
|
||||
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).
|
||||
RunAndReturn(func(ctx context.Context, i int64, u uint64) (*model.Database, error) {
|
||||
if i == 1 {
|
||||
return &model.Database{
|
||||
ID: 1,
|
||||
Name: "db1",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DatabaseForceDenyReadingKey,
|
||||
Value: "true",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("mock error")
|
||||
}).Maybe()
|
||||
quotaCenter.resetAllCurrentRates()
|
||||
err = quotaCenter.calculateReadRates()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
rln := quotaCenter.rateLimiter.GetDatabaseLimiters(0)
|
||||
limiters := rln.GetLimiters()
|
||||
a, _ := limiters.Get(internalpb.RateType_DQLSearch)
|
||||
assert.NotEqual(t, Limit(0), a.Limit())
|
||||
b, _ := limiters.Get(internalpb.RateType_DQLQuery)
|
||||
assert.NotEqual(t, Limit(0), b.Limit())
|
||||
|
||||
rln = quotaCenter.rateLimiter.GetDatabaseLimiters(1)
|
||||
limiters = rln.GetLimiters()
|
||||
a, _ = limiters.Get(internalpb.RateType_DQLSearch)
|
||||
assert.Equal(t, Limit(0), a.Limit())
|
||||
b, _ = limiters.Get(internalpb.RateType_DQLQuery)
|
||||
assert.Equal(t, Limit(0), b.Limit())
|
||||
})
|
||||
|
||||
t.Run("test calculateWriteRates", func(t *testing.T) {
|
||||
@ -1544,6 +1589,8 @@ func TestCalculateReadRates(t *testing.T) {
|
||||
t.Run("cool off db", func(t *testing.T) {
|
||||
qc := mocks.NewMockQueryCoordClient(t)
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrDatabaseNotFound).Maybe()
|
||||
|
||||
pcm := proxyutil.NewMockProxyClientManager(t)
|
||||
dc := mocks.NewMockDataCoordClient(t)
|
||||
core, _ := NewCore(ctx, nil)
|
||||
|
@ -130,7 +130,7 @@ func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) {
|
||||
err := limitNode.GetQuotaExceededError(internalpb.RateType_DMLInsert)
|
||||
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
|
||||
// reference: ratelimitutil.GetQuotaErrorString(errCode)
|
||||
assert.True(t, strings.Contains(err.Error(), "deactivated"))
|
||||
assert.True(t, strings.Contains(err.Error(), "disabled"))
|
||||
})
|
||||
|
||||
t.Run("read", func(t *testing.T) {
|
||||
@ -139,7 +139,7 @@ func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) {
|
||||
err := limitNode.GetQuotaExceededError(internalpb.RateType_DQLSearch)
|
||||
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
|
||||
// reference: ratelimitutil.GetQuotaErrorString(errCode)
|
||||
assert.True(t, strings.Contains(err.Error(), "deactivated"))
|
||||
assert.True(t, strings.Contains(err.Error(), "disabled"))
|
||||
})
|
||||
|
||||
t.Run("unknown", func(t *testing.T) {
|
||||
|
@ -159,6 +159,7 @@ const (
|
||||
DatabaseDiskQuotaKey = "database.diskQuota.mb"
|
||||
DatabaseMaxCollectionsKey = "database.max.collections"
|
||||
DatabaseForceDenyWritingKey = "database.force.deny.writing"
|
||||
DatabaseForceDenyReadingKey = "database.force.deny.reading"
|
||||
|
||||
// collection level load properties
|
||||
CollectionReplicaNumber = "collection.replica.number"
|
||||
|
@ -19,7 +19,7 @@ package ratelimitutil
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
||||
var QuotaErrorString = map[commonpb.ErrorCode]string{
|
||||
commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator",
|
||||
commonpb.ErrorCode_ForceDeny: "access has been disabled by the administrator",
|
||||
commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources",
|
||||
commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources",
|
||||
commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay",
|
||||
|
@ -15,7 +15,7 @@ func TestGetQuotaErrorString(t *testing.T) {
|
||||
{
|
||||
name: "Test ErrorCode_ForceDeny",
|
||||
args: commonpb.ErrorCode_ForceDeny,
|
||||
want: "the writing has been deactivated by the administrator",
|
||||
want: "access has been disabled by the administrator",
|
||||
},
|
||||
{
|
||||
name: "Test ErrorCode_MemoryQuotaExhausted",
|
||||
|
Loading…
Reference in New Issue
Block a user