mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
related: #35350 pr: https://github.com/milvus-io/milvus/pull/35470 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
parent
fc344d1eae
commit
cf8494ef45
@ -247,6 +247,10 @@ func checkAuthorizationV2(ctx context.Context, c *gin.Context, ignoreErr bool, r
|
||||
}
|
||||
|
||||
func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, fullMethod string, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
|
||||
return wrapperProxyWithLimit(ctx, c, req, checkAuth, ignoreErr, fullMethod, false, nil, handler)
|
||||
}
|
||||
|
||||
func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, fullMethod string, checkLimit bool, pxy types.ProxyComponent, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
|
||||
if baseGetter, ok := req.(BaseGetter); ok {
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.AddEvent(baseGetter.GetBase().GetMsgType().String())
|
||||
@ -257,6 +261,17 @@ func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool,
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if checkLimit {
|
||||
_, err := CheckLimiter(ctx, req, pxy)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to check limiter", zap.Error(err), zap.String("method", fullMethod))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
|
||||
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
|
||||
})
|
||||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
}
|
||||
log.Ctx(ctx).Debug("high level restful api, try to do a grpc call", zap.Any("grpcRequest", req))
|
||||
username, ok := c.Get(ContextUsername)
|
||||
if !ok {
|
||||
@ -506,7 +521,7 @@ func (h *HandlersV2) dropCollection(ctx context.Context, c *gin.Context, anyReq
|
||||
CollectionName: getter.GetCollectionName(),
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropCollection", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.DropCollection(reqCtx, req.(*milvuspb.DropCollectionRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -527,7 +542,7 @@ func (h *HandlersV2) renameCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
if req.NewDBName == "" {
|
||||
req.NewDBName = dbName
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/RenameCollection", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/RenameCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.RenameCollection(reqCtx, req.(*milvuspb.RenameCollectionRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -543,7 +558,7 @@ func (h *HandlersV2) loadCollection(ctx context.Context, c *gin.Context, anyReq
|
||||
CollectionName: getter.GetCollectionName(),
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.LoadCollection(reqCtx, req.(*milvuspb.LoadCollectionRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -559,7 +574,7 @@ func (h *HandlersV2) releaseCollection(ctx context.Context, c *gin.Context, anyR
|
||||
CollectionName: getter.GetCollectionName(),
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleaseCollection", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleaseCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.ReleaseCollection(reqCtx, req.(*milvuspb.ReleaseCollectionRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -591,7 +606,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa
|
||||
if httpReq.Limit > 0 && !matchCountRule(httpReq.OutputFields) {
|
||||
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -639,7 +654,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName
|
||||
UseDefaultConsistency: true,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -688,7 +703,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
}
|
||||
req.Expr = filter
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Delete", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Delete", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.Delete(reqCtx, req.(*milvuspb.DeleteRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -734,7 +749,7 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Insert", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Insert", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.Insert(reqCtx, req.(*milvuspb.InsertRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -812,7 +827,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Upsert", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Upsert", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.Upsert(reqCtx, req.(*milvuspb.UpsertRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -957,7 +972,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
}
|
||||
req.SearchParams = searchParams
|
||||
req.PlaceholderGroup = placeholderGroup
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1037,7 +1052,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
|
||||
{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
|
||||
{Key: ParamRoundDecimal, Value: "-1"},
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1252,7 +1267,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
Value: fmt.Sprintf("%v", httpReq.Params["partitionKeyIsolation"]),
|
||||
})
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateCollection", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreateCollection(reqCtx, req.(*milvuspb.CreateCollectionRequest))
|
||||
})
|
||||
if err != nil {
|
||||
@ -1269,7 +1284,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
IndexName: httpReq.VectorFieldName,
|
||||
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricType}},
|
||||
}
|
||||
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
statusResponse, err := wrapperProxyWithLimit(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))
|
||||
})
|
||||
if err != nil {
|
||||
@ -1298,7 +1313,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
for key, value := range indexParam.Params {
|
||||
createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
|
||||
}
|
||||
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
statusResponse, err := wrapperProxyWithLimit(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))
|
||||
})
|
||||
if err != nil {
|
||||
@ -1310,7 +1325,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
}
|
||||
statusResponse, err := wrapperProxy(ctx, c, loadReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
statusResponse, err := wrapperProxyWithLimit(ctx, c, loadReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.LoadCollection(ctx, req.(*milvuspb.LoadCollectionRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1383,7 +1398,7 @@ func (h *HandlersV2) createPartition(ctx context.Context, c *gin.Context, anyReq
|
||||
PartitionName: partitionGetter.GetPartitionName(),
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreatePartition", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreatePartition", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreatePartition(reqCtx, req.(*milvuspb.CreatePartitionRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1401,7 +1416,7 @@ func (h *HandlersV2) dropPartition(ctx context.Context, c *gin.Context, anyReq a
|
||||
PartitionName: partitionGetter.GetPartitionName(),
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropPartition", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropPartition", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.DropPartition(reqCtx, req.(*milvuspb.DropPartitionRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1418,7 +1433,7 @@ func (h *HandlersV2) loadPartitions(ctx context.Context, c *gin.Context, anyReq
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadPartitions", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadPartitions", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.LoadPartitions(reqCtx, req.(*milvuspb.LoadPartitionsRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1435,7 +1450,7 @@ func (h *HandlersV2) releasePartitions(ctx context.Context, c *gin.Context, anyR
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleasePartitions", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleasePartitions", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.ReleasePartitions(reqCtx, req.(*milvuspb.ReleasePartitionsRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1743,7 +1758,7 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
|
||||
for key, value := range indexParam.Params {
|
||||
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
|
||||
})
|
||||
if err != nil {
|
||||
@ -1764,7 +1779,7 @@ func (h *HandlersV2) dropIndex(ctx context.Context, c *gin.Context, anyReq any,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropIndex", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.DropIndex(reqCtx, req.(*milvuspb.DropIndexRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1822,7 +1837,7 @@ func (h *HandlersV2) createAlias(ctx context.Context, c *gin.Context, anyReq any
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateAlias", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateAlias", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreateAlias(reqCtx, req.(*milvuspb.CreateAliasRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1839,7 +1854,7 @@ func (h *HandlersV2) dropAlias(ctx context.Context, c *gin.Context, anyReq any,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropAlias", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropAlias", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.DropAlias(reqCtx, req.(*milvuspb.DropAliasRequest))
|
||||
})
|
||||
if err == nil {
|
||||
@ -1858,7 +1873,7 @@ func (h *HandlersV2) alterAlias(ctx context.Context, c *gin.Context, anyReq any,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterAlias", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterAlias", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.AlterAlias(reqCtx, req.(*milvuspb.AlterAliasRequest))
|
||||
})
|
||||
if err == nil {
|
||||
|
@ -469,6 +469,11 @@ func TestDatabaseWrapper(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateCollection(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
|
||||
postTestCases := []requestBodyTestCase{}
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(13)
|
||||
@ -974,6 +979,9 @@ var commonErrorStatus = &commonpb.Status{
|
||||
|
||||
func TestMethodDelete(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
|
||||
mp.EXPECT().DropPartition(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
|
||||
@ -1023,6 +1031,9 @@ func TestMethodDelete(t *testing.T) {
|
||||
|
||||
func TestMethodPost(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
|
||||
mp.EXPECT().RenameCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
|
||||
@ -1161,6 +1172,9 @@ func TestMethodPost(t *testing.T) {
|
||||
|
||||
func TestDML(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
CollectionName: DefaultCollectionName,
|
||||
@ -1280,6 +1294,9 @@ func TestDML(t *testing.T) {
|
||||
|
||||
func TestAllowInt64(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
mp := mocks.NewMockProxy(t)
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
queryTestCases := []requestBodyTestCase{}
|
||||
@ -1322,6 +1339,9 @@ func TestAllowInt64(t *testing.T) {
|
||||
|
||||
func TestSearchV2(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
outputFields := []string{FieldBookID, FieldWordCount, "author", "date"}
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
|
@ -2,6 +2,7 @@ package httpserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
@ -19,12 +20,16 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/proxy"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/parameterutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
@ -1256,3 +1261,29 @@ func formatInt64(intArray []int64) []string {
|
||||
}
|
||||
return stringArray
|
||||
}
|
||||
|
||||
func CheckLimiter(ctx context.Context, req interface{}, pxy types.ProxyComponent) (any, error) {
|
||||
if !paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() {
|
||||
return nil, nil
|
||||
}
|
||||
// apply limiter for http/http2 server
|
||||
limiter, err := pxy.GetRateLimiter()
|
||||
if err != nil {
|
||||
log.Error("Get proxy rate limiter for httpV1/V2 server failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbID, collectionIDToPartIDs, rt, n, err := proxy.GetRequestInfo(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = limiter.Check(dbID, collectionIDToPartIDs, rt, n)
|
||||
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)
|
||||
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc()
|
||||
if err != nil {
|
||||
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.FailLabel).Inc()
|
||||
return proxy.GetFailedResponse(req, err), err
|
||||
}
|
||||
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.SuccessLabel).Inc()
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -448,19 +448,19 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
|
||||
|
||||
// Start start the Proxy Server
|
||||
func (s *Server) Run() error {
|
||||
log.Debug("init Proxy server")
|
||||
log.Info("init Proxy server")
|
||||
if err := s.init(); err != nil {
|
||||
log.Warn("init Proxy server failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("init Proxy server done")
|
||||
log.Info("init Proxy server done")
|
||||
|
||||
log.Debug("start Proxy server")
|
||||
log.Info("start Proxy server")
|
||||
if err := s.start(); err != nil {
|
||||
log.Warn("start Proxy server failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("start Proxy server done")
|
||||
log.Info("start Proxy server done")
|
||||
|
||||
if s.tcpServer != nil {
|
||||
s.wg.Add(1)
|
||||
@ -479,23 +479,23 @@ func (s *Server) Run() error {
|
||||
func (s *Server) init() error {
|
||||
etcdConfig := ¶mtable.Get().EtcdCfg
|
||||
Params := ¶mtable.Get().ProxyGrpcServerCfg
|
||||
log.Debug("Proxy init service's parameter table done")
|
||||
log.Info("Proxy init service's parameter table done")
|
||||
HTTPParams := ¶mtable.Get().HTTPCfg
|
||||
log.Debug("Proxy init http server's parameter table done")
|
||||
log.Info("Proxy init http server's parameter table done")
|
||||
|
||||
if !funcutil.CheckPortAvailable(Params.Port.GetAsInt()) {
|
||||
paramtable.Get().Save(Params.Port.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort()))
|
||||
log.Warn("Proxy get available port when init", zap.Int("Port", Params.Port.GetAsInt()))
|
||||
}
|
||||
|
||||
log.Debug("init Proxy's parameter table done",
|
||||
log.Info("init Proxy's parameter table done",
|
||||
zap.String("internalAddress", Params.GetInternalAddress()),
|
||||
zap.String("externalAddress", Params.GetAddress()),
|
||||
)
|
||||
|
||||
accesslog.InitAccessLogger(paramtable.Get())
|
||||
serviceName := fmt.Sprintf("Proxy ip: %s, port: %d", Params.IP, Params.Port.GetAsInt())
|
||||
log.Debug("init Proxy's tracer done", zap.String("service name", serviceName))
|
||||
log.Info("init Proxy's tracer done", zap.String("service name", serviceName))
|
||||
|
||||
etcdCli, err := etcd.CreateEtcdClient(
|
||||
etcdConfig.UseEmbedEtcd.GetAsBool(),
|
||||
@ -530,7 +530,7 @@ func (s *Server) init() error {
|
||||
log.Info("Proxy server listen on tcp", zap.Int("port", port))
|
||||
var lis net.Listener
|
||||
|
||||
log.Info("Proxy server already listen on tcp", zap.Int("port", port))
|
||||
log.Info("Proxy server already listen on tcp", zap.Int("port", httpPort))
|
||||
lis, err = net.Listen("tcp", ":"+strconv.Itoa(port))
|
||||
if err != nil {
|
||||
log.Error("Proxy server(grpc/http) failed to listen on", zap.Int("port", port), zap.Error(err))
|
||||
|
@ -18,15 +18,12 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
@ -39,7 +36,7 @@ import (
|
||||
// RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting.
|
||||
func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
dbID, collectionIDToPartIDs, rt, n, err := getRequestInfo(ctx, req)
|
||||
dbID, collectionIDToPartIDs, rt, n, err := GetRequestInfo(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to get request info", zap.Error(err))
|
||||
return handler(ctx, req)
|
||||
@ -50,7 +47,7 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
|
||||
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc()
|
||||
if err != nil {
|
||||
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.FailLabel).Inc()
|
||||
rsp := getFailedResponse(req, err)
|
||||
rsp := GetFailedResponse(req, err)
|
||||
if rsp != nil {
|
||||
return rsp, nil
|
||||
}
|
||||
@ -126,127 +123,9 @@ func getCollectionID(r reqCollName) (int64, map[int64][]int64) {
|
||||
return db.dbID, map[int64][]int64{collectionID: {}}
|
||||
}
|
||||
|
||||
// getRequestInfo returns collection name and rateType of request and return tokens needed.
|
||||
func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) {
|
||||
switch r := req.(type) {
|
||||
case *milvuspb.InsertRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
|
||||
case *milvuspb.UpsertRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
|
||||
case *milvuspb.DeleteRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err
|
||||
case *milvuspb.ImportRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err
|
||||
case *milvuspb.SearchRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err
|
||||
case *milvuspb.QueryRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1
|
||||
case *milvuspb.CreateCollectionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.DropCollectionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.LoadCollectionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.ReleaseCollectionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.CreatePartitionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.DropPartitionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.LoadPartitionsRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.ReleasePartitionsRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.CreateIndexRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
|
||||
case *milvuspb.DropIndexRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
|
||||
case *milvuspb.FlushRequest:
|
||||
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
|
||||
if err != nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
|
||||
collToPartIDs := make(map[int64][]int64, 0)
|
||||
for _, collectionName := range r.GetCollectionNames() {
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
collToPartIDs[collectionID] = []int64{}
|
||||
}
|
||||
return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil
|
||||
case *milvuspb.ManualCompactionRequest:
|
||||
dbName := GetCurDBNameFromContextOrDefault(ctx)
|
||||
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
|
||||
if err != nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
return dbInfo.dbID, map[int64][]int64{
|
||||
r.GetCollectionID(): {},
|
||||
}, internalpb.RateType_DDLCompaction, 1, nil
|
||||
default: // TODO: support more request
|
||||
if req == nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
|
||||
}
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
// failedMutationResult returns failed mutation result.
|
||||
func failedMutationResult(err error) *milvuspb.MutationResult {
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
}
|
||||
|
||||
// getFailedResponse returns failed response.
|
||||
func getFailedResponse(req any, err error) any {
|
||||
switch req.(type) {
|
||||
case *milvuspb.InsertRequest, *milvuspb.DeleteRequest, *milvuspb.UpsertRequest:
|
||||
return failedMutationResult(err)
|
||||
case *milvuspb.ImportRequest:
|
||||
return &milvuspb.ImportResponse{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
case *milvuspb.SearchRequest:
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
case *milvuspb.QueryRequest:
|
||||
return &milvuspb.QueryResults{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
case *milvuspb.CreateCollectionRequest, *milvuspb.DropCollectionRequest,
|
||||
*milvuspb.LoadCollectionRequest, *milvuspb.ReleaseCollectionRequest,
|
||||
*milvuspb.CreatePartitionRequest, *milvuspb.DropPartitionRequest,
|
||||
*milvuspb.LoadPartitionsRequest, *milvuspb.ReleasePartitionsRequest,
|
||||
*milvuspb.CreateIndexRequest, *milvuspb.DropIndexRequest:
|
||||
return merr.Status(err)
|
||||
case *milvuspb.FlushRequest:
|
||||
return &milvuspb.FlushResponse{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
case *milvuspb.ManualCompactionRequest:
|
||||
return &milvuspb.ManualCompactionResponse{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
createdTimestamp: 1,
|
||||
}, nil)
|
||||
globalMetaCache = mockCache
|
||||
database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{
|
||||
database, col2part, rt, size, err := GetRequestInfo(context.Background(), &milvuspb.InsertRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
@ -85,7 +85,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.UpsertRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
@ -101,7 +101,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.DeleteRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
@ -117,7 +117,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.ImportRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
@ -133,7 +133,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.SearchRequest{
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.SearchRequest{
|
||||
Nq: 5,
|
||||
PartitionNames: []string{
|
||||
"p1",
|
||||
@ -146,7 +146,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 1, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.QueryRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionNames: []string{
|
||||
"p1",
|
||||
@ -160,7 +160,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 1, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
|
||||
@ -168,7 +168,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadCollectionRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.LoadCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
|
||||
@ -176,7 +176,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleaseCollectionRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.ReleaseCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
|
||||
@ -184,7 +184,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropCollectionRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.DropCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
|
||||
@ -192,7 +192,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreatePartitionRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.CreatePartitionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
|
||||
@ -200,7 +200,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadPartitionsRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.LoadPartitionsRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
|
||||
@ -208,7 +208,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleasePartitionsRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.ReleasePartitionsRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
|
||||
@ -216,7 +216,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropPartitionRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.DropPartitionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
|
||||
@ -224,7 +224,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateIndexRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.CreateIndexRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLIndex, rt)
|
||||
@ -232,7 +232,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropIndexRequest{})
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.DropIndexRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLIndex, rt)
|
||||
@ -240,7 +240,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.FlushRequest{
|
||||
database, col2part, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.FlushRequest{
|
||||
CollectionNames: []string{
|
||||
"col1",
|
||||
},
|
||||
@ -251,22 +251,22 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
|
||||
database, _, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ManualCompactionRequest{})
|
||||
database, _, rt, size, err = GetRequestInfo(context.Background(), &milvuspb.ManualCompactionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCompaction, rt)
|
||||
assert.Equal(t, database, int64(100))
|
||||
|
||||
_, _, _, _, err = getRequestInfo(context.Background(), nil)
|
||||
_, _, _, _, err = GetRequestInfo(context.Background(), nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, _, _, _, err = getRequestInfo(context.Background(), &milvuspb.CalcDistanceRequest{})
|
||||
_, _, _, _, err = GetRequestInfo(context.Background(), &milvuspb.CalcDistanceRequest{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test getFailedResponse", func(t *testing.T) {
|
||||
t.Run("test GetFailedResponse", func(t *testing.T) {
|
||||
testGetFailedResponse := func(req interface{}, rt internalpb.RateType, err error, fullMethod string) {
|
||||
rsp := getFailedResponse(req, err)
|
||||
rsp := GetFailedResponse(req, err)
|
||||
assert.NotNil(t, rsp)
|
||||
}
|
||||
|
||||
@ -280,9 +280,9 @@ func TestRateLimitInterceptor(t *testing.T) {
|
||||
testGetFailedResponse(&milvuspb.ManualCompactionRequest{}, internalpb.RateType_DDLCompaction, merr.ErrServiceRateLimit, "compaction")
|
||||
|
||||
// test illegal
|
||||
rsp := getFailedResponse(&milvuspb.SearchResults{}, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError))
|
||||
rsp := GetFailedResponse(&milvuspb.SearchResults{}, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError))
|
||||
assert.Nil(t, rsp)
|
||||
rsp = getFailedResponse(nil, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError))
|
||||
rsp = GetFailedResponse(nil, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError))
|
||||
assert.Nil(t, rsp)
|
||||
})
|
||||
|
||||
@ -390,13 +390,13 @@ func TestGetInfo(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{
|
||||
_, _, _, _, err := GetRequestInfo(ctx, &milvuspb.FlushRequest{
|
||||
DbName: "foo",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
|
||||
_, _, _, _, err := GetRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
@ -429,7 +429,7 @@ func TestGetInfo(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{
|
||||
_, _, _, _, err := GetRequestInfo(ctx, &milvuspb.FlushRequest{
|
||||
DbName: "foo",
|
||||
CollectionNames: []string{"coo"},
|
||||
})
|
||||
|
@ -25,6 +25,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@ -33,6 +34,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
@ -1687,3 +1689,121 @@ func GetRequestLabelFromContext(ctx context.Context) bool {
|
||||
}
|
||||
return v.(bool)
|
||||
}
|
||||
|
||||
// GetRequestInfo returns collection name and rateType of request and return tokens needed.
|
||||
func GetRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) {
|
||||
switch r := req.(type) {
|
||||
case *milvuspb.InsertRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
|
||||
case *milvuspb.UpsertRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
|
||||
case *milvuspb.DeleteRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err
|
||||
case *milvuspb.ImportRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err
|
||||
case *milvuspb.SearchRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err
|
||||
case *milvuspb.QueryRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1
|
||||
case *milvuspb.CreateCollectionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.DropCollectionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.LoadCollectionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.ReleaseCollectionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.CreatePartitionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.DropPartitionRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.LoadPartitionsRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.ReleasePartitionsRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.CreateIndexRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
|
||||
case *milvuspb.DropIndexRequest:
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
|
||||
case *milvuspb.FlushRequest:
|
||||
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
|
||||
if err != nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
|
||||
collToPartIDs := make(map[int64][]int64, 0)
|
||||
for _, collectionName := range r.GetCollectionNames() {
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
collToPartIDs[collectionID] = []int64{}
|
||||
}
|
||||
return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil
|
||||
case *milvuspb.ManualCompactionRequest:
|
||||
dbName := GetCurDBNameFromContextOrDefault(ctx)
|
||||
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
|
||||
if err != nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
return dbInfo.dbID, map[int64][]int64{
|
||||
r.GetCollectionID(): {},
|
||||
}, internalpb.RateType_DDLCompaction, 1, nil
|
||||
default: // TODO: support more request
|
||||
if req == nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
|
||||
}
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetFailedResponse returns failed response.
|
||||
func GetFailedResponse(req any, err error) any {
|
||||
switch req.(type) {
|
||||
case *milvuspb.InsertRequest, *milvuspb.DeleteRequest, *milvuspb.UpsertRequest:
|
||||
return failedMutationResult(err)
|
||||
case *milvuspb.ImportRequest:
|
||||
return &milvuspb.ImportResponse{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
case *milvuspb.SearchRequest:
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
case *milvuspb.QueryRequest:
|
||||
return &milvuspb.QueryResults{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
case *milvuspb.CreateCollectionRequest, *milvuspb.DropCollectionRequest,
|
||||
*milvuspb.LoadCollectionRequest, *milvuspb.ReleaseCollectionRequest,
|
||||
*milvuspb.CreatePartitionRequest, *milvuspb.DropPartitionRequest,
|
||||
*milvuspb.LoadPartitionsRequest, *milvuspb.ReleasePartitionsRequest,
|
||||
*milvuspb.CreateIndexRequest, *milvuspb.DropIndexRequest:
|
||||
return merr.Status(err)
|
||||
case *milvuspb.FlushRequest:
|
||||
return &milvuspb.FlushResponse{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
case *milvuspb.ManualCompactionRequest:
|
||||
return &milvuspb.ManualCompactionResponse{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -167,6 +167,7 @@ var (
|
||||
ErrInvalidInsertData = newMilvusError("fail to deal the insert data", 1804, false)
|
||||
ErrInvalidSearchResult = newMilvusError("fail to parse search result", 1805, false)
|
||||
ErrCheckPrimaryKey = newMilvusError("please check the primary key and its' type can only in [int, string]", 1806, false)
|
||||
ErrHTTPRateLimit = newMilvusError("request is rejected by limiter", 1807, true)
|
||||
|
||||
// replicate related
|
||||
ErrDenyReplicateMessage = newMilvusError("deny to use the replicate message in the normal instance", 1900, false)
|
||||
|
142
tests/integration/httpserver/httpserver_test.go
Normal file
142
tests/integration/httpserver/httpserver_test.go
Normal file
@ -0,0 +1,142 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
type HTTPServerSuite struct {
|
||||
integration.MiniClusterSuite
|
||||
}
|
||||
|
||||
const (
|
||||
PORT = 40001
|
||||
CollectionName = "collectionName"
|
||||
Data = "data"
|
||||
Dimension = "dimension"
|
||||
IDType = "idType"
|
||||
PrimaryFieldName = "primaryFieldName"
|
||||
VectorFieldName = "vectorFieldName"
|
||||
Schema = "schema"
|
||||
)
|
||||
|
||||
func (s *HTTPServerSuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
paramtable.Get().Save(paramtable.Get().HTTPCfg.Port.Key, fmt.Sprintf("%d", PORT))
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLLimitEnabled.Key, "true")
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLMaxInsertRate.Key, "1")
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerDB.Key, "1")
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerCollection.Key, "1")
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerPartition.Key, "1")
|
||||
s.MiniClusterSuite.SetupSuite()
|
||||
}
|
||||
|
||||
func (s *HTTPServerSuite) TearDownSuite() {
|
||||
paramtable.Get().Reset(paramtable.Get().HTTPCfg.Port.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLLimitEnabled.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLMaxInsertRate.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerDB.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerCollection.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().QuotaConfig.DMLMaxInsertRatePerPartition.Key)
|
||||
s.MiniClusterSuite.TearDownSuite()
|
||||
}
|
||||
|
||||
func (s *HTTPServerSuite) TestInsertThrottle() {
|
||||
collectionName := "test_collection"
|
||||
dim := 768
|
||||
client := http.Client{}
|
||||
pkFieldName := "pk"
|
||||
vecFieldName := "vector"
|
||||
// create collection
|
||||
{
|
||||
dataMap := make(map[string]any, 0)
|
||||
dataMap[CollectionName] = collectionName
|
||||
dataMap[Dimension] = dim
|
||||
dataMap[IDType] = "Int64"
|
||||
dataMap[PrimaryFieldName] = "pk"
|
||||
dataMap[VectorFieldName] = "vector"
|
||||
|
||||
pkField := httpserver.FieldSchema{FieldName: pkFieldName, DataType: "Int64", IsPrimary: true}
|
||||
vecParams := map[string]interface{}{"dim": dim}
|
||||
vecField := httpserver.FieldSchema{FieldName: vecFieldName, DataType: "FloatVector", ElementTypeParams: vecParams}
|
||||
schema := httpserver.CollectionSchema{Fields: []httpserver.FieldSchema{pkField, vecField}}
|
||||
dataMap[Schema] = schema
|
||||
payload, _ := json.Marshal(dataMap)
|
||||
url := "http://localhost:" + strconv.Itoa(PORT) + "/v2/vectordb" + httpserver.CollectionCategory + httpserver.CreateAction
|
||||
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
s.NoError(err)
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// insert data
|
||||
{
|
||||
url := "http://localhost:" + strconv.Itoa(PORT) + "/v2/vectordb" + httpserver.EntityCategory + httpserver.InsertAction
|
||||
prepareData := func() []byte {
|
||||
dataMap := make(map[string]any, 0)
|
||||
dataMap[CollectionName] = collectionName
|
||||
vectorData := make([]float32, dim)
|
||||
for i := 0; i < dim; i++ {
|
||||
vectorData[i] = 1.0
|
||||
}
|
||||
count := 500
|
||||
dataMap[Data] = make([]map[string]interface{}, count)
|
||||
for i := 0; i < count; i++ {
|
||||
data := map[string]interface{}{pkFieldName: i, vecFieldName: vectorData}
|
||||
dataMap[Data].([]map[string]interface{})[i] = data
|
||||
}
|
||||
payload, _ := json.Marshal(dataMap)
|
||||
return payload
|
||||
}
|
||||
|
||||
threadCount := 3
|
||||
wg := &sync.WaitGroup{}
|
||||
limitedThreadCount := atomic.Int32{}
|
||||
for i := 0; i < threadCount; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
payload := prepareData()
|
||||
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
s.NoError(err)
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
bodyStr := string(body)
|
||||
if strings.Contains(bodyStr, strconv.Itoa(int(merr.Code(merr.ErrHTTPRateLimit)))) {
|
||||
s.True(strings.Contains(bodyStr, "request is rejected by limiter"))
|
||||
limitedThreadCount.Inc()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
// it's expected at least one insert request is rejected for throttle
|
||||
log.Info("limited thread count", zap.Int32("limitedThreadCount", limitedThreadCount.Load()))
|
||||
s.True(limitedThreadCount.Load() > 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpSearch(t *testing.T) {
|
||||
suite.Run(t, new(HTTPServerSuite))
|
||||
}
|
@ -14,7 +14,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package balance
|
||||
package target
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
Loading…
Reference in New Issue
Block a user