enhance: support httpv1/v2 throttle and add it for httpV2(#35350) (#35504)

related: #35350
pr: https://github.com/milvus-io/milvus/pull/35470

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
Chun Han 2024-08-20 16:32:56 +08:00 committed by GitHub
parent fc344d1eae
commit cf8494ef45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 392 additions and 184 deletions

View File

@ -247,6 +247,10 @@ func checkAuthorizationV2(ctx context.Context, c *gin.Context, ignoreErr bool, r
}
func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, fullMethod string, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
return wrapperProxyWithLimit(ctx, c, req, checkAuth, ignoreErr, fullMethod, false, nil, handler)
}
func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, fullMethod string, checkLimit bool, pxy types.ProxyComponent, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
if baseGetter, ok := req.(BaseGetter); ok {
span := trace.SpanFromContext(ctx)
span.AddEvent(baseGetter.GetBase().GetMsgType().String())
@ -257,6 +261,17 @@ func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool,
return nil, err
}
}
if checkLimit {
_, err := CheckLimiter(ctx, req, pxy)
if err != nil {
log.Warn("high level restful api, fail to check limiter", zap.Error(err), zap.String("method", fullMethod))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
})
return nil, RestRequestInterceptorErr
}
}
log.Ctx(ctx).Debug("high level restful api, try to do a grpc call", zap.Any("grpcRequest", req))
username, ok := c.Get(ContextUsername)
if !ok {
@ -506,7 +521,7 @@ func (h *HandlersV2) dropCollection(ctx context.Context, c *gin.Context, anyReq
CollectionName: getter.GetCollectionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropCollection", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropCollection(reqCtx, req.(*milvuspb.DropCollectionRequest))
})
if err == nil {
@ -527,7 +542,7 @@ func (h *HandlersV2) renameCollection(ctx context.Context, c *gin.Context, anyRe
if req.NewDBName == "" {
req.NewDBName = dbName
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/RenameCollection", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/RenameCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.RenameCollection(reqCtx, req.(*milvuspb.RenameCollectionRequest))
})
if err == nil {
@ -543,7 +558,7 @@ func (h *HandlersV2) loadCollection(ctx context.Context, c *gin.Context, anyReq
CollectionName: getter.GetCollectionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.LoadCollection(reqCtx, req.(*milvuspb.LoadCollectionRequest))
})
if err == nil {
@ -559,7 +574,7 @@ func (h *HandlersV2) releaseCollection(ctx context.Context, c *gin.Context, anyR
CollectionName: getter.GetCollectionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleaseCollection", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleaseCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ReleaseCollection(reqCtx, req.(*milvuspb.ReleaseCollectionRequest))
})
if err == nil {
@ -591,7 +606,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa
if httpReq.Limit > 0 && !matchCountRule(httpReq.OutputFields) {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest))
})
if err == nil {
@ -639,7 +654,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName
UseDefaultConsistency: true,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest))
})
if err == nil {
@ -688,7 +703,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN
}
req.Expr = filter
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Delete", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Delete", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Delete(reqCtx, req.(*milvuspb.DeleteRequest))
})
if err == nil {
@ -734,7 +749,7 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN
})
return nil, err
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Insert", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Insert", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Insert(reqCtx, req.(*milvuspb.InsertRequest))
})
if err == nil {
@ -812,7 +827,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
})
return nil, err
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Upsert", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Upsert", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Upsert(reqCtx, req.(*milvuspb.UpsertRequest))
})
if err == nil {
@ -957,7 +972,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
}
req.SearchParams = searchParams
req.PlaceholderGroup = placeholderGroup
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest))
})
if err == nil {
@ -1037,7 +1052,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
{Key: ParamRoundDecimal, Value: "-1"},
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
})
if err == nil {
@ -1252,7 +1267,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
Value: fmt.Sprintf("%v", httpReq.Params["partitionKeyIsolation"]),
})
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateCollection", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateCollection(reqCtx, req.(*milvuspb.CreateCollectionRequest))
})
if err != nil {
@ -1269,7 +1284,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
IndexName: httpReq.VectorFieldName,
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricType}},
}
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
statusResponse, err := wrapperProxyWithLimit(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))
})
if err != nil {
@ -1298,7 +1313,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
for key, value := range indexParam.Params {
createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
statusResponse, err := wrapperProxyWithLimit(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))
})
if err != nil {
@ -1310,7 +1325,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
DbName: dbName,
CollectionName: httpReq.CollectionName,
}
statusResponse, err := wrapperProxy(ctx, c, loadReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", func(reqCtx context.Context, req any) (interface{}, error) {
statusResponse, err := wrapperProxyWithLimit(ctx, c, loadReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.LoadCollection(ctx, req.(*milvuspb.LoadCollectionRequest))
})
if err == nil {
@ -1383,7 +1398,7 @@ func (h *HandlersV2) createPartition(ctx context.Context, c *gin.Context, anyReq
PartitionName: partitionGetter.GetPartitionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreatePartition", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreatePartition", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreatePartition(reqCtx, req.(*milvuspb.CreatePartitionRequest))
})
if err == nil {
@ -1401,7 +1416,7 @@ func (h *HandlersV2) dropPartition(ctx context.Context, c *gin.Context, anyReq a
PartitionName: partitionGetter.GetPartitionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropPartition", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropPartition", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropPartition(reqCtx, req.(*milvuspb.DropPartitionRequest))
})
if err == nil {
@ -1418,7 +1433,7 @@ func (h *HandlersV2) loadPartitions(ctx context.Context, c *gin.Context, anyReq
PartitionNames: httpReq.PartitionNames,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadPartitions", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadPartitions", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.LoadPartitions(reqCtx, req.(*milvuspb.LoadPartitionsRequest))
})
if err == nil {
@ -1435,7 +1450,7 @@ func (h *HandlersV2) releasePartitions(ctx context.Context, c *gin.Context, anyR
PartitionNames: httpReq.PartitionNames,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleasePartitions", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleasePartitions", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ReleasePartitions(reqCtx, req.(*milvuspb.ReleasePartitionsRequest))
})
if err == nil {
@ -1743,7 +1758,7 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
for key, value := range indexParam.Params {
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
})
if err != nil {
@ -1764,7 +1779,7 @@ func (h *HandlersV2) dropIndex(ctx context.Context, c *gin.Context, anyReq any,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropIndex", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropIndex(reqCtx, req.(*milvuspb.DropIndexRequest))
})
if err == nil {
@ -1822,7 +1837,7 @@ func (h *HandlersV2) createAlias(ctx context.Context, c *gin.Context, anyReq any
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateAlias", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateAlias", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateAlias(reqCtx, req.(*milvuspb.CreateAliasRequest))
})
if err == nil {
@ -1839,7 +1854,7 @@ func (h *HandlersV2) dropAlias(ctx context.Context, c *gin.Context, anyReq any,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropAlias", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropAlias", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropAlias(reqCtx, req.(*milvuspb.DropAliasRequest))
})
if err == nil {
@ -1858,7 +1873,7 @@ func (h *HandlersV2) alterAlias(ctx context.Context, c *gin.Context, anyReq any,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterAlias", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterAlias", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.AlterAlias(reqCtx, req.(*milvuspb.AlterAliasRequest))
})
if err == nil {

View File

@ -469,6 +469,11 @@ func TestDatabaseWrapper(t *testing.T) {
}
func TestCreateCollection(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(13)
@ -974,6 +979,9 @@ var commonErrorStatus = &commonpb.Status{
func TestMethodDelete(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
mp := mocks.NewMockProxy(t)
mp.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().DropPartition(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
@ -1023,6 +1031,9 @@ func TestMethodDelete(t *testing.T) {
func TestMethodPost(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().RenameCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
@ -1161,6 +1172,9 @@ func TestMethodPost(t *testing.T) {
func TestDML(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
mp := mocks.NewMockProxy(t)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionName: DefaultCollectionName,
@ -1280,6 +1294,9 @@ func TestDML(t *testing.T) {
func TestAllowInt64(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
mp := mocks.NewMockProxy(t)
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
@ -1322,6 +1339,9 @@ func TestAllowInt64(t *testing.T) {
func TestSearchV2(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
outputFields := []string{FieldBookID, FieldWordCount, "author", "date"}
mp := mocks.NewMockProxy(t)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{

View File

@ -2,6 +2,7 @@ package httpserver
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"math"
@ -19,12 +20,16 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/parameterutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -1256,3 +1261,29 @@ func formatInt64(intArray []int64) []string {
}
return stringArray
}
func CheckLimiter(ctx context.Context, req interface{}, pxy types.ProxyComponent) (any, error) {
if !paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() {
return nil, nil
}
// apply limiter for http/http2 server
limiter, err := pxy.GetRateLimiter()
if err != nil {
log.Error("Get proxy rate limiter for httpV1/V2 server failed", zap.Error(err))
return nil, err
}
dbID, collectionIDToPartIDs, rt, n, err := proxy.GetRequestInfo(ctx, req)
if err != nil {
return nil, err
}
err = limiter.Check(dbID, collectionIDToPartIDs, rt, n)
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc()
if err != nil {
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.FailLabel).Inc()
return proxy.GetFailedResponse(req, err), err
}
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.SuccessLabel).Inc()
return nil, nil
}

View File

@ -448,19 +448,19 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
// Start start the Proxy Server
func (s *Server) Run() error {
log.Debug("init Proxy server")
log.Info("init Proxy server")
if err := s.init(); err != nil {
log.Warn("init Proxy server failed", zap.Error(err))
return err
}
log.Debug("init Proxy server done")
log.Info("init Proxy server done")
log.Debug("start Proxy server")
log.Info("start Proxy server")
if err := s.start(); err != nil {
log.Warn("start Proxy server failed", zap.Error(err))
return err
}
log.Debug("start Proxy server done")
log.Info("start Proxy server done")
if s.tcpServer != nil {
s.wg.Add(1)
@ -479,23 +479,23 @@ func (s *Server) Run() error {
func (s *Server) init() error {
etcdConfig := &paramtable.Get().EtcdCfg
Params := &paramtable.Get().ProxyGrpcServerCfg
log.Debug("Proxy init service's parameter table done")
log.Info("Proxy init service's parameter table done")
HTTPParams := &paramtable.Get().HTTPCfg
log.Debug("Proxy init http server's parameter table done")
log.Info("Proxy init http server's parameter table done")
if !funcutil.CheckPortAvailable(Params.Port.GetAsInt()) {
paramtable.Get().Save(Params.Port.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort()))
log.Warn("Proxy get available port when init", zap.Int("Port", Params.Port.GetAsInt()))
}
log.Debug("init Proxy's parameter table done",
log.Info("init Proxy's parameter table done",
zap.String("internalAddress", Params.GetInternalAddress()),
zap.String("externalAddress", Params.GetAddress()),
)
accesslog.InitAccessLogger(paramtable.Get())
serviceName := fmt.Sprintf("Proxy ip: %s, port: %d", Params.IP, Params.Port.GetAsInt())
log.Debug("init Proxy's tracer done", zap.String("service name", serviceName))
log.Info("init Proxy's tracer done", zap.String("service name", serviceName))
etcdCli, err := etcd.CreateEtcdClient(
etcdConfig.UseEmbedEtcd.GetAsBool(),
@ -530,7 +530,7 @@ func (s *Server) init() error {
log.Info("Proxy server listen on tcp", zap.Int("port", port))
var lis net.Listener
log.Info("Proxy server already listen on tcp", zap.Int("port", port))
log.Info("Proxy server already listen on tcp", zap.Int("port", httpPort))
lis, err = net.Listen("tcp", ":"+strconv.Itoa(port))
if err != nil {
log.Error("Proxy server(grpc/http) failed to listen on", zap.Int("port", port), zap.Error(err))

View File

@ -18,15 +18,12 @@ package proxy
import (
"context"
"fmt"
"strconv"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
@ -39,7 +36,7 @@ import (
// RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting.
func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
dbID, collectionIDToPartIDs, rt, n, err := getRequestInfo(ctx, req)
dbID, collectionIDToPartIDs, rt, n, err := GetRequestInfo(ctx, req)
if err != nil {
log.Warn("failed to get request info", zap.Error(err))
return handler(ctx, req)
@ -50,7 +47,7 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc()
if err != nil {
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.FailLabel).Inc()
rsp := getFailedResponse(req, err)
rsp := GetFailedResponse(req, err)
if rsp != nil {
return rsp, nil
}
@ -126,127 +123,9 @@ func getCollectionID(r reqCollName) (int64, map[int64][]int64) {
return db.dbID, map[int64][]int64{collectionID: {}}
}
// getRequestInfo returns collection name and rateType of request and return tokens needed.
func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) {
switch r := req.(type) {
case *milvuspb.InsertRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
case *milvuspb.UpsertRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
case *milvuspb.DeleteRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err
case *milvuspb.ImportRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err
case *milvuspb.SearchRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err
case *milvuspb.QueryRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1
case *milvuspb.CreateCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.DropCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.LoadCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.ReleaseCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.CreatePartitionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.DropPartitionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.LoadPartitionsRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.ReleasePartitionsRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.CreateIndexRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.DropIndexRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.FlushRequest:
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs := make(map[int64][]int64, 0)
for _, collectionName := range r.GetCollectionNames() {
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs[collectionID] = []int64{}
}
return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil
case *milvuspb.ManualCompactionRequest:
dbName := GetCurDBNameFromContextOrDefault(ctx)
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
return dbInfo.dbID, map[int64][]int64{
r.GetCollectionID(): {},
}, internalpb.RateType_DDLCompaction, 1, nil
default: // TODO: support more request
if req == nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
}
return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil
}
}
// failedMutationResult returns failed mutation result.
func failedMutationResult(err error) *milvuspb.MutationResult {
return &milvuspb.MutationResult{
Status: merr.Status(err),
}
}
// getFailedResponse returns failed response.
func getFailedResponse(req any, err error) any {
switch req.(type) {
case *milvuspb.InsertRequest, *milvuspb.DeleteRequest, *milvuspb.UpsertRequest:
return failedMutationResult(err)
case *milvuspb.ImportRequest:
return &milvuspb.ImportResponse{
Status: merr.Status(err),
}
case *milvuspb.SearchRequest:
return &milvuspb.SearchResults{
Status: merr.Status(err),
}
case *milvuspb.QueryRequest:
return &milvuspb.QueryResults{
Status: merr.Status(err),
}
case *milvuspb.CreateCollectionRequest, *milvuspb.DropCollectionRequest,
*milvuspb.LoadCollectionRequest, *milvuspb.ReleaseCollectionRequest,
*milvuspb.CreatePartitionRequest, *milvuspb.DropPartitionRequest,
*milvuspb.LoadPartitionsRequest, *milvuspb.ReleasePartitionsRequest,
*milvuspb.CreateIndexRequest, *milvuspb.DropIndexRequest:
return merr.Status(err)
case *milvuspb.FlushRequest:
return &milvuspb.FlushResponse{
Status: merr.Status(err),
}
case *milvuspb.ManualCompactionRequest:
return &milvuspb.ManualCompactionResponse{
Status: merr.Status(err),
}
}
return nil
}

View File

@ -69,7 +69,7 @@ func TestRateLimitInterceptor(t *testing.T) {
createdTimestamp: 1,
}, nil)
globalMetaCache = mockCache
database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{
database, col2part, rt, size, err := GetRequestInfo(context.Background(), &milvuspb.InsertRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
@ -85,7 +85,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.UpsertRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
@ -101,7 +101,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.DeleteRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
@ -117,7 +117,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.ImportRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
@ -133,7 +133,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.SearchRequest{
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.SearchRequest{
Nq: 5,
PartitionNames: []string{
"p1",
@ -146,7 +146,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 1, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.QueryRequest{
CollectionName: "foo",
PartitionNames: []string{
"p1",
@ -160,7 +160,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 1, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
@ -168,7 +168,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadCollectionRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.LoadCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
@ -176,7 +176,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleaseCollectionRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.ReleaseCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
@ -184,7 +184,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropCollectionRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.DropCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
@ -192,7 +192,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreatePartitionRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.CreatePartitionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
@ -200,7 +200,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadPartitionsRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.LoadPartitionsRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
@ -208,7 +208,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleasePartitionsRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.ReleasePartitionsRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
@ -216,7 +216,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropPartitionRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.DropPartitionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
@ -224,7 +224,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateIndexRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.CreateIndexRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLIndex, rt)
@ -232,7 +232,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropIndexRequest{})
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.DropIndexRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLIndex, rt)
@ -240,7 +240,7 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.FlushRequest{
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.FlushRequest{
CollectionNames: []string{
"col1",
},
@ -251,22 +251,22 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
database, _, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ManualCompactionRequest{})
database, _, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.ManualCompactionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCompaction, rt)
assert.Equal(t, database, int64(100))
_, _, _, _, err = getRequestInfo(context.Background(), nil)
_, _, _, _, err = GetRequestInfo(context.Background(), nil)
assert.Error(t, err)
_, _, _, _, err = getRequestInfo(context.Background(), &milvuspb.CalcDistanceRequest{})
_, _, _, _, err = GetRequestInfo(context.Background(), &milvuspb.CalcDistanceRequest{})
assert.NoError(t, err)
})
t.Run("test getFailedResponse", func(t *testing.T) {
t.Run("test GetFailedResponse", func(t *testing.T) {
testGetFailedResponse := func(req interface{}, rt internalpb.RateType, err error, fullMethod string) {
rsp := getFailedResponse(req, err)
rsp := GetFailedResponse(req, err)
assert.NotNil(t, rsp)
}
@ -280,9 +280,9 @@ func TestRateLimitInterceptor(t *testing.T) {
testGetFailedResponse(&milvuspb.ManualCompactionRequest{}, internalpb.RateType_DDLCompaction, merr.ErrServiceRateLimit, "compaction")
// test illegal
rsp := getFailedResponse(&milvuspb.SearchResults{}, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError))
rsp := GetFailedResponse(&milvuspb.SearchResults{}, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError))
assert.Nil(t, rsp)
rsp = getFailedResponse(nil, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError))
rsp = GetFailedResponse(nil, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError))
assert.Nil(t, rsp)
})
@ -390,13 +390,13 @@ func TestGetInfo(t *testing.T) {
assert.Error(t, err)
}
{
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{
_, _, _, _, err := GetRequestInfo(ctx, &milvuspb.FlushRequest{
DbName: "foo",
})
assert.Error(t, err)
}
{
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
_, _, _, _, err := GetRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
assert.Error(t, err)
}
{
@ -429,7 +429,7 @@ func TestGetInfo(t *testing.T) {
assert.Error(t, err)
}
{
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{
_, _, _, _, err := GetRequestInfo(ctx, &milvuspb.FlushRequest{
DbName: "foo",
CollectionNames: []string{"coo"},
})

View File

@ -25,6 +25,7 @@ import (
"time"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc/metadata"
@ -33,6 +34,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
@ -1687,3 +1689,121 @@ func GetRequestLabelFromContext(ctx context.Context) bool {
}
return v.(bool)
}
// GetRequestInfo returns collection name and rateType of request and return tokens needed.
func GetRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) {
switch r := req.(type) {
case *milvuspb.InsertRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
case *milvuspb.UpsertRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
case *milvuspb.DeleteRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err
case *milvuspb.ImportRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err
case *milvuspb.SearchRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err
case *milvuspb.QueryRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1
case *milvuspb.CreateCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.DropCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.LoadCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.ReleaseCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.CreatePartitionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.DropPartitionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.LoadPartitionsRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.ReleasePartitionsRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.CreateIndexRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.DropIndexRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.FlushRequest:
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs := make(map[int64][]int64, 0)
for _, collectionName := range r.GetCollectionNames() {
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs[collectionID] = []int64{}
}
return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil
case *milvuspb.ManualCompactionRequest:
dbName := GetCurDBNameFromContextOrDefault(ctx)
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
return dbInfo.dbID, map[int64][]int64{
r.GetCollectionID(): {},
}, internalpb.RateType_DDLCompaction, 1, nil
default: // TODO: support more request
if req == nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
}
return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil
}
}
// GetFailedResponse returns failed response.
func GetFailedResponse(req any, err error) any {
switch req.(type) {
case *milvuspb.InsertRequest, *milvuspb.DeleteRequest, *milvuspb.UpsertRequest:
return failedMutationResult(err)
case *milvuspb.ImportRequest:
return &milvuspb.ImportResponse{
Status: merr.Status(err),
}
case *milvuspb.SearchRequest:
return &milvuspb.SearchResults{
Status: merr.Status(err),
}
case *milvuspb.QueryRequest:
return &milvuspb.QueryResults{
Status: merr.Status(err),
}
case *milvuspb.CreateCollectionRequest, *milvuspb.DropCollectionRequest,
*milvuspb.LoadCollectionRequest, *milvuspb.ReleaseCollectionRequest,
*milvuspb.CreatePartitionRequest, *milvuspb.DropPartitionRequest,
*milvuspb.LoadPartitionsRequest, *milvuspb.ReleasePartitionsRequest,
*milvuspb.CreateIndexRequest, *milvuspb.DropIndexRequest:
return merr.Status(err)
case *milvuspb.FlushRequest:
return &milvuspb.FlushResponse{
Status: merr.Status(err),
}
case *milvuspb.ManualCompactionRequest:
return &milvuspb.ManualCompactionResponse{
Status: merr.Status(err),
}
}
return nil
}

View File

@ -167,6 +167,7 @@ var (
ErrInvalidInsertData = newMilvusError("fail to deal the insert data", 1804, false)
ErrInvalidSearchResult = newMilvusError("fail to parse search result", 1805, false)
ErrCheckPrimaryKey = newMilvusError("please check the primary key and its' type can only in [int, string]", 1806, false)
ErrHTTPRateLimit = newMilvusError("request is rejected by limiter", 1807, true)
// replicate related
ErrDenyReplicateMessage = newMilvusError("deny to use the replicate message in the normal instance", 1900, false)

View File

@ -0,0 +1,142 @@
package httpserver
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
)
type HTTPServerSuite struct {
integration.MiniClusterSuite
}
const (
PORT = 40001
CollectionName = "collectionName"
Data = "data"
Dimension = "dimension"
IDType = "idType"
PrimaryFieldName = "primaryFieldName"
VectorFieldName = "vectorFieldName"
Schema = "schema"
)
func (s *HTTPServerSuite) SetupSuite() {
paramtable.Init()
paramtable.Get().Save(paramtable.Get().HTTPCfg.Port.Key, fmt.Sprintf("%d", PORT))
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLLimitEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLMaxInsertRate.Key, "1")
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerDB.Key, "1")
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerCollection.Key, "1")
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerPartition.Key, "1")
s.MiniClusterSuite.SetupSuite()
}
func (s *HTTPServerSuite) TearDownSuite() {
paramtable.Get().Reset(paramtable.Get().HTTPCfg.Port.Key)
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLLimitEnabled.Key)
paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLMaxInsertRate.Key)
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerDB.Key)
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerCollection.Key)
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerPartition.Key)
s.MiniClusterSuite.TearDownSuite()
}
func (s *HTTPServerSuite) TestInsertThrottle() {
collectionName := "test_collection"
dim := 768
client := http.Client{}
pkFieldName := "pk"
vecFieldName := "vector"
// create collection
{
dataMap := make(map[string]any, 0)
dataMap[CollectionName] = collectionName
dataMap[Dimension] = dim
dataMap[IDType] = "Int64"
dataMap[PrimaryFieldName] = "pk"
dataMap[VectorFieldName] = "vector"
pkField := httpserver.FieldSchema{FieldName: pkFieldName, DataType: "Int64", IsPrimary: true}
vecParams := map[string]interface{}{"dim": dim}
vecField := httpserver.FieldSchema{FieldName: vecFieldName, DataType: "FloatVector", ElementTypeParams: vecParams}
schema := httpserver.CollectionSchema{Fields: []httpserver.FieldSchema{pkField, vecField}}
dataMap[Schema] = schema
payload, _ := json.Marshal(dataMap)
url := "http://localhost:" + strconv.Itoa(PORT) + "/v2/vectordb" + httpserver.CollectionCategory + httpserver.CreateAction
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload))
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
s.NoError(err)
defer resp.Body.Close()
}
// insert data
{
url := "http://localhost:" + strconv.Itoa(PORT) + "/v2/vectordb" + httpserver.EntityCategory + httpserver.InsertAction
prepareData := func() []byte {
dataMap := make(map[string]any, 0)
dataMap[CollectionName] = collectionName
vectorData := make([]float32, dim)
for i := 0; i < dim; i++ {
vectorData[i] = 1.0
}
count := 500
dataMap[Data] = make([]map[string]interface{}, count)
for i := 0; i < count; i++ {
data := map[string]interface{}{pkFieldName: i, vecFieldName: vectorData}
dataMap[Data].([]map[string]interface{})[i] = data
}
payload, _ := json.Marshal(dataMap)
return payload
}
threadCount := 3
wg := &sync.WaitGroup{}
limitedThreadCount := atomic.Int32{}
for i := 0; i < threadCount; i++ {
wg.Add(1)
go func() {
defer wg.Done()
payload := prepareData()
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload))
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
s.NoError(err)
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
bodyStr := string(body)
if strings.Contains(bodyStr, strconv.Itoa(int(merr.Code(merr.ErrHTTPRateLimit)))) {
s.True(strings.Contains(bodyStr, "request is rejected by limiter"))
limitedThreadCount.Inc()
}
}()
}
wg.Wait()
// it's expected at least one insert request is rejected for throttle
log.Info("limited thread count", zap.Int32("limitedThreadCount", limitedThreadCount.Load()))
s.True(limitedThreadCount.Load() > 0)
}
}
func TestHttpSearch(t *testing.T) {
suite.Run(t, new(HTTPServerSuite))
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package balance
package target
import (
"context"