fix: [cherry-pick] [restful v2] count(*) & hook (#34433)

issue: #31224 #34374
pr: #34369

for query api:

1. param filter is not requried
2. param limit is useless while outputFields = [count(*)]

add hook about grpc call

---------

Signed-off-by: PowderLi <min.li@zilliz.com>
This commit is contained in:
PowderLi 2024-07-05 09:52:10 +08:00 committed by GitHub
parent 190a2ca7b8
commit ba2c232331
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 136 additions and 108 deletions

View File

@ -7,6 +7,7 @@ import (
"io"
"net/http"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"github.com/gin-gonic/gin"
@ -245,7 +246,7 @@ func checkAuthorizationV2(ctx context.Context, c *gin.Context, ignoreErr bool, r
return nil
}
func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, fullMethod string, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
if baseGetter, ok := req.(BaseGetter); ok {
span := trace.SpanFromContext(ctx)
span.AddEvent(baseGetter.GetBase().GetMsgType().String())
@ -257,7 +258,11 @@ func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool,
}
}
log.Ctx(ctx).Debug("high level restful api, try to do a grpc call", zap.Any("grpcRequest", req))
response, err := handler(ctx, req)
username, ok := c.Get(ContextUsername)
if !ok {
username = ""
}
response, err := proxy.HookInterceptor(ctx, req, username.(string), fullMethod, handler)
if err == nil {
status, ok := requestutil.GetStatusFromResponse(response)
if ok {
@ -278,7 +283,7 @@ func (h *HandlersV2) wrapperCheckDatabase(v2 handlerFuncV2) handlerFuncV2 {
if dbName == DefaultDbName || proxy.CheckDatabase(ctx, dbName) {
return v2(ctx, c, req, dbName)
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListDatabases(reqCtx, &milvuspb.ListDatabasesRequest{})
})
if err != nil {
@ -308,7 +313,7 @@ func (h *HandlersV2) hasCollection(ctx context.Context, c *gin.Context, anyReq a
DbName: dbName,
CollectionName: collectionName,
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/HasCollection", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.HasCollection(reqCtx, req.(*milvuspb.HasCollectionRequest))
})
if err != nil {
@ -325,7 +330,7 @@ func (h *HandlersV2) listCollections(ctx context.Context, c *gin.Context, anyReq
DbName: dbName,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ShowCollections", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ShowCollections(reqCtx, req.(*milvuspb.ShowCollectionsRequest))
})
if err == nil {
@ -342,7 +347,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
CollectionName: collectionName,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeCollection", func(reqCtx context.Context, req any) (any, error) {
return h.proxy.DescribeCollection(reqCtx, req.(*milvuspb.DescribeCollectionRequest))
})
if err != nil {
@ -361,7 +366,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
DbName: dbName,
CollectionName: collectionName,
}
stateResp, err := wrapperProxy(ctx, c, loadStateReq, h.checkAuth, true, func(reqCtx context.Context, req any) (any, error) {
stateResp, err := wrapperProxy(ctx, c, loadStateReq, h.checkAuth, true, "/milvus.proto.milvus.MilvusService/GetLoadState", func(reqCtx context.Context, req any) (any, error) {
return h.proxy.GetLoadState(reqCtx, req.(*milvuspb.GetLoadStateRequest))
})
collLoadState := ""
@ -383,7 +388,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
CollectionName: collectionName,
FieldName: vectorField,
}
indexResp, err := wrapperProxy(ctx, c, descIndexReq, h.checkAuth, true, func(reqCtx context.Context, req any) (any, error) {
indexResp, err := wrapperProxy(ctx, c, descIndexReq, h.checkAuth, true, "/milvus.proto.milvus.MilvusService/DescribeIndex", func(reqCtx context.Context, req any) (any, error) {
return h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest))
})
if err == nil {
@ -396,7 +401,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
DbName: dbName,
CollectionName: collectionName,
}
aliasResp, err := wrapperProxy(ctx, c, aliasReq, h.checkAuth, true, func(reqCtx context.Context, req any) (interface{}, error) {
aliasResp, err := wrapperProxy(ctx, c, aliasReq, h.checkAuth, true, "/milvus.proto.milvus.MilvusService/ListAliases", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListAliases(reqCtx, req.(*milvuspb.ListAliasesRequest))
})
if err == nil {
@ -435,7 +440,7 @@ func (h *HandlersV2) getCollectionStats(ctx context.Context, c *gin.Context, any
CollectionName: collectionGetter.GetCollectionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/GetCollectionStatistics", func(reqCtx context.Context, req any) (any, error) {
return h.proxy.GetCollectionStatistics(reqCtx, req.(*milvuspb.GetCollectionStatisticsRequest))
})
if err == nil {
@ -451,7 +456,7 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context,
CollectionName: collectionGetter.GetCollectionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/GetLoadState", func(reqCtx context.Context, req any) (any, error) {
return h.proxy.GetLoadState(reqCtx, req.(*milvuspb.GetLoadStateRequest))
})
if err != nil {
@ -473,7 +478,7 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context,
PartitionNames: partitionsGetter.GetPartitionNames(),
DbName: dbName,
}
progressResp, err := wrapperProxy(ctx, c, progressReq, h.checkAuth, true, func(reqCtx context.Context, req any) (any, error) {
progressResp, err := wrapperProxy(ctx, c, progressReq, h.checkAuth, true, "/milvus.proto.milvus.MilvusService/GetLoadingProgress", func(reqCtx context.Context, req any) (any, error) {
return h.proxy.GetLoadingProgress(reqCtx, req.(*milvuspb.GetLoadingProgressRequest))
})
progress := int64(-1)
@ -501,7 +506,7 @@ func (h *HandlersV2) dropCollection(ctx context.Context, c *gin.Context, anyReq
CollectionName: getter.GetCollectionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropCollection", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropCollection(reqCtx, req.(*milvuspb.DropCollectionRequest))
})
if err == nil {
@ -522,7 +527,7 @@ func (h *HandlersV2) renameCollection(ctx context.Context, c *gin.Context, anyRe
if req.NewDBName == "" {
req.NewDBName = dbName
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/RenameCollection", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.RenameCollection(reqCtx, req.(*milvuspb.RenameCollectionRequest))
})
if err == nil {
@ -538,7 +543,7 @@ func (h *HandlersV2) loadCollection(ctx context.Context, c *gin.Context, anyReq
CollectionName: getter.GetCollectionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.LoadCollection(reqCtx, req.(*milvuspb.LoadCollectionRequest))
})
if err == nil {
@ -554,7 +559,7 @@ func (h *HandlersV2) releaseCollection(ctx context.Context, c *gin.Context, anyR
CollectionName: getter.GetCollectionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleaseCollection", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ReleaseCollection(reqCtx, req.(*milvuspb.ReleaseCollectionRequest))
})
if err == nil {
@ -563,6 +568,11 @@ func (h *HandlersV2) releaseCollection(ctx context.Context, c *gin.Context, anyR
return resp, err
}
// copy from internal/proxy/task_query.go
func matchCountRule(outputs []string) bool {
return len(outputs) == 1 && strings.ToLower(strings.TrimSpace(outputs[0])) == "count(*)"
}
func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*QueryReqV2)
req := &milvuspb.QueryRequest{
@ -578,10 +588,10 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa
if httpReq.Offset > 0 {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
}
if httpReq.Limit > 0 {
if httpReq.Limit > 0 && !matchCountRule(httpReq.OutputFields) {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest))
})
if err == nil {
@ -629,7 +639,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName
UseDefaultConsistency: true,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest))
})
if err == nil {
@ -678,7 +688,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN
}
req.Expr = filter
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Delete", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Delete(reqCtx, req.(*milvuspb.DeleteRequest))
})
if err == nil {
@ -724,7 +734,7 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN
})
return nil, err
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Insert", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Insert(reqCtx, req.(*milvuspb.InsertRequest))
})
if err == nil {
@ -802,7 +812,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
})
return nil, err
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Upsert", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Upsert(reqCtx, req.(*milvuspb.UpsertRequest))
})
if err == nil {
@ -947,7 +957,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
}
req.SearchParams = searchParams
req.PlaceholderGroup = placeholderGroup
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest))
})
if err == nil {
@ -1027,7 +1037,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
{Key: ParamRoundDecimal, Value: "-1"},
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
})
if err == nil {
@ -1236,7 +1246,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
Value: fmt.Sprintf("%v", httpReq.Params["ttlSeconds"]),
})
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateCollection", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateCollection(reqCtx, req.(*milvuspb.CreateCollectionRequest))
})
if err != nil {
@ -1253,7 +1263,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
IndexName: httpReq.VectorFieldName,
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricType}},
}
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))
})
if err != nil {
@ -1282,7 +1292,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
for key, value := range indexParam.Params {
createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))
})
if err != nil {
@ -1294,7 +1304,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
DbName: dbName,
CollectionName: httpReq.CollectionName,
}
statusResponse, err := wrapperProxy(ctx, c, loadReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
statusResponse, err := wrapperProxy(ctx, c, loadReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.LoadCollection(ctx, req.(*milvuspb.LoadCollectionRequest))
})
if err == nil {
@ -1311,7 +1321,7 @@ func (h *HandlersV2) listPartitions(ctx context.Context, c *gin.Context, anyReq
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ShowPartitions", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ShowPartitions(reqCtx, req.(*milvuspb.ShowPartitionsRequest))
})
if err == nil {
@ -1329,7 +1339,7 @@ func (h *HandlersV2) hasPartitions(ctx context.Context, c *gin.Context, anyReq a
PartitionName: partitionGetter.GetPartitionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HasPartition", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.HasPartition(reqCtx, req.(*milvuspb.HasPartitionRequest))
})
if err == nil {
@ -1349,7 +1359,7 @@ func (h *HandlersV2) statsPartition(ctx context.Context, c *gin.Context, anyReq
PartitionName: partitionGetter.GetPartitionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/GetPartitionStatistics", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.GetPartitionStatistics(reqCtx, req.(*milvuspb.GetPartitionStatisticsRequest))
})
if err == nil {
@ -1367,7 +1377,7 @@ func (h *HandlersV2) createPartition(ctx context.Context, c *gin.Context, anyReq
PartitionName: partitionGetter.GetPartitionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreatePartition", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreatePartition(reqCtx, req.(*milvuspb.CreatePartitionRequest))
})
if err == nil {
@ -1385,7 +1395,7 @@ func (h *HandlersV2) dropPartition(ctx context.Context, c *gin.Context, anyReq a
PartitionName: partitionGetter.GetPartitionName(),
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropPartition", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropPartition(reqCtx, req.(*milvuspb.DropPartitionRequest))
})
if err == nil {
@ -1402,7 +1412,7 @@ func (h *HandlersV2) loadPartitions(ctx context.Context, c *gin.Context, anyReq
PartitionNames: httpReq.PartitionNames,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadPartitions", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.LoadPartitions(reqCtx, req.(*milvuspb.LoadPartitionsRequest))
})
if err == nil {
@ -1419,7 +1429,7 @@ func (h *HandlersV2) releasePartitions(ctx context.Context, c *gin.Context, anyR
PartitionNames: httpReq.PartitionNames,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleasePartitions", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ReleasePartitions(reqCtx, req.(*milvuspb.ReleasePartitionsRequest))
})
if err == nil {
@ -1431,7 +1441,7 @@ func (h *HandlersV2) releasePartitions(ctx context.Context, c *gin.Context, anyR
func (h *HandlersV2) listUsers(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
req := &milvuspb.ListCredUsersRequest{}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ListCredUsers", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListCredUsers(reqCtx, req.(*milvuspb.ListCredUsersRequest))
})
if err == nil {
@ -1451,7 +1461,7 @@ func (h *HandlersV2) describeUser(ctx context.Context, c *gin.Context, anyReq an
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/SelectUser", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.SelectUser(reqCtx, req.(*milvuspb.SelectUserRequest))
})
if err == nil {
@ -1474,7 +1484,7 @@ func (h *HandlersV2) createUser(ctx context.Context, c *gin.Context, anyReq any,
Username: httpReq.UserName,
Password: crypto.Base64Encode(httpReq.Password),
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateCredential", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateCredential(reqCtx, req.(*milvuspb.CreateCredentialRequest))
})
if err == nil {
@ -1490,7 +1500,7 @@ func (h *HandlersV2) updateUser(ctx context.Context, c *gin.Context, anyReq any,
OldPassword: crypto.Base64Encode(httpReq.Password),
NewPassword: crypto.Base64Encode(httpReq.NewPassword),
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/UpdateCredential", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.UpdateCredential(reqCtx, req.(*milvuspb.UpdateCredentialRequest))
})
if err == nil {
@ -1504,7 +1514,7 @@ func (h *HandlersV2) dropUser(ctx context.Context, c *gin.Context, anyReq any, d
req := &milvuspb.DeleteCredentialRequest{
Username: getter.GetUserName(),
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DeleteCredential", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DeleteCredential(reqCtx, req.(*milvuspb.DeleteCredentialRequest))
})
if err == nil {
@ -1519,7 +1529,7 @@ func (h *HandlersV2) operateRoleToUser(ctx context.Context, c *gin.Context, user
RoleName: roleName,
Type: operateType,
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/OperateUserRole", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.OperateUserRole(reqCtx, req.(*milvuspb.OperateUserRoleRequest))
})
if err == nil {
@ -1538,7 +1548,7 @@ func (h *HandlersV2) removeRoleFromUser(ctx context.Context, c *gin.Context, any
func (h *HandlersV2) listRoles(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
req := &milvuspb.SelectRoleRequest{}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/SelectRole", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.SelectRole(reqCtx, req.(*milvuspb.SelectRoleRequest))
})
if err == nil {
@ -1556,7 +1566,7 @@ func (h *HandlersV2) describeRole(ctx context.Context, c *gin.Context, anyReq an
req := &milvuspb.SelectGrantRequest{
Entity: &milvuspb.GrantEntity{Role: &milvuspb.RoleEntity{Name: getter.GetRoleName()}, DbName: dbName},
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/SelectGrant", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.SelectGrant(reqCtx, req.(*milvuspb.SelectGrantRequest))
})
if err == nil {
@ -1581,7 +1591,7 @@ func (h *HandlersV2) createRole(ctx context.Context, c *gin.Context, anyReq any,
req := &milvuspb.CreateRoleRequest{
Entity: &milvuspb.RoleEntity{Name: getter.GetRoleName()},
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateRole", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateRole(reqCtx, req.(*milvuspb.CreateRoleRequest))
})
if err == nil {
@ -1595,7 +1605,7 @@ func (h *HandlersV2) dropRole(ctx context.Context, c *gin.Context, anyReq any, d
req := &milvuspb.DropRoleRequest{
RoleName: getter.GetRoleName(),
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropRole", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropRole(reqCtx, req.(*milvuspb.DropRoleRequest))
})
if err == nil {
@ -1617,7 +1627,7 @@ func (h *HandlersV2) operatePrivilegeToRole(ctx context.Context, c *gin.Context,
},
Type: operateType,
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/OperatePrivilege", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.OperatePrivilege(reqCtx, req.(*milvuspb.OperatePrivilegeRequest))
})
if err == nil {
@ -1643,7 +1653,7 @@ func (h *HandlersV2) listIndexes(ctx context.Context, c *gin.Context, anyReq any
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeIndex", func(reqCtx context.Context, req any) (any, error) {
resp, err := h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest))
if errors.Is(err, merr.ErrIndexNotFound) {
return &milvuspb.DescribeIndexResponse{
@ -1677,7 +1687,7 @@ func (h *HandlersV2) describeIndex(ctx context.Context, c *gin.Context, anyReq a
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeIndex", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest))
})
if err == nil {
@ -1727,7 +1737,7 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
for key, value := range indexParam.Params {
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
})
if err != nil {
@ -1748,7 +1758,7 @@ func (h *HandlersV2) dropIndex(ctx context.Context, c *gin.Context, anyReq any,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropIndex", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropIndex(reqCtx, req.(*milvuspb.DropIndexRequest))
})
if err == nil {
@ -1765,7 +1775,7 @@ func (h *HandlersV2) listAlias(ctx context.Context, c *gin.Context, anyReq any,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ListAliases", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListAliases(reqCtx, req.(*milvuspb.ListAliasesRequest))
})
if err == nil {
@ -1782,7 +1792,7 @@ func (h *HandlersV2) describeAlias(ctx context.Context, c *gin.Context, anyReq a
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeAlias", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DescribeAlias(reqCtx, req.(*milvuspb.DescribeAliasRequest))
})
if err == nil {
@ -1806,7 +1816,7 @@ func (h *HandlersV2) createAlias(ctx context.Context, c *gin.Context, anyReq any
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateAlias", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateAlias(reqCtx, req.(*milvuspb.CreateAliasRequest))
})
if err == nil {
@ -1823,7 +1833,7 @@ func (h *HandlersV2) dropAlias(ctx context.Context, c *gin.Context, anyReq any,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropAlias", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropAlias(reqCtx, req.(*milvuspb.DropAliasRequest))
})
if err == nil {
@ -1842,7 +1852,7 @@ func (h *HandlersV2) alterAlias(ctx context.Context, c *gin.Context, anyReq any,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterAlias", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.AlterAlias(reqCtx, req.(*milvuspb.AlterAliasRequest))
})
if err == nil {
@ -1871,7 +1881,7 @@ func (h *HandlersV2) listImportJob(ctx context.Context, c *gin.Context, anyReq a
return nil, err
}
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListImports", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListImports(reqCtx, req.(*internalpb.ListImportsRequest))
})
if err == nil {
@ -1924,7 +1934,7 @@ func (h *HandlersV2) createImportJob(ctx context.Context, c *gin.Context, anyReq
return nil, err
}
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ImportV2", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ImportV2(reqCtx, req.(*internalpb.ImportRequest))
})
if err == nil {
@ -1951,7 +1961,7 @@ func (h *HandlersV2) getImportJobProcess(ctx context.Context, c *gin.Context, an
return nil, err
}
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/GetImportProgress", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.GetImportProgress(reqCtx, req.(*internalpb.GetImportProgressRequest))
})
if err == nil {
@ -2002,7 +2012,7 @@ func (h *HandlersV2) GetCollectionSchema(ctx context.Context, c *gin.Context, db
DbName: dbName,
CollectionName: collectionName,
}
descResp, err := wrapperProxy(ctx, c, descReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
descResp, err := wrapperProxy(ctx, c, descReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeCollection", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DescribeCollection(reqCtx, req.(*milvuspb.DescribeCollectionRequest))
})
if err != nil {

View File

@ -114,7 +114,7 @@ func TestHTTPWrapper(t *testing.T) {
})
path = "/wrapper/post/trace/call"
app.POST(path, wrapperPost(func() any { return &DefaultReq{} }, wrapperTraceLog(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) {
return wrapperProxy(ctx, c, req, false, false, func(reqctx context.Context, req any) (any, error) {
return wrapperProxy(ctx, c, req, false, false, "", func(reqctx context.Context, req any) (any, error) {
return nil, nil
})
})))
@ -174,12 +174,12 @@ func TestGrpcWrapper(t *testing.T) {
}
app.GET(path, func(c *gin.Context) {
ctx := proxy.NewContextWithMetadata(c, "", DefaultDbName)
wrapperProxy(ctx, c, &DefaultReq{}, false, false, handle)
wrapperProxy(ctx, c, &DefaultReq{}, false, false, "", handle)
})
appNeedAuth.GET(path, func(c *gin.Context) {
username, _ := c.Get(ContextUsername)
ctx := proxy.NewContextWithMetadata(c, username.(string), DefaultDbName)
wrapperProxy(ctx, c, &milvuspb.DescribeCollectionRequest{}, true, false, handle)
wrapperProxy(ctx, c, &milvuspb.DescribeCollectionRequest{}, true, false, "", handle)
})
getTestCases = append(getTestCases, rawTestCase{
path: path,
@ -193,12 +193,12 @@ func TestGrpcWrapper(t *testing.T) {
}
app.GET(path, func(c *gin.Context) {
ctx := proxy.NewContextWithMetadata(c, "", DefaultDbName)
wrapperProxy(ctx, c, &DefaultReq{}, false, false, handle)
wrapperProxy(ctx, c, &DefaultReq{}, false, false, "", handle)
})
appNeedAuth.GET(path, func(c *gin.Context) {
username, _ := c.Get(ContextUsername)
ctx := proxy.NewContextWithMetadata(c, username.(string), DefaultDbName)
wrapperProxy(ctx, c, &milvuspb.DescribeCollectionRequest{}, true, false, handle)
wrapperProxy(ctx, c, &milvuspb.DescribeCollectionRequest{}, true, false, "", handle)
})
getTestCases = append(getTestCases, rawTestCase{
path: path,
@ -215,12 +215,12 @@ func TestGrpcWrapper(t *testing.T) {
}
app.GET(path, func(c *gin.Context) {
ctx := proxy.NewContextWithMetadata(c, "", DefaultDbName)
wrapperProxy(ctx, c, &DefaultReq{}, false, false, handle)
wrapperProxy(ctx, c, &DefaultReq{}, false, false, "", handle)
})
appNeedAuth.GET(path, func(c *gin.Context) {
username, _ := c.Get(ContextUsername)
ctx := proxy.NewContextWithMetadata(c, username.(string), DefaultDbName)
wrapperProxy(ctx, c, &milvuspb.DescribeCollectionRequest{}, true, false, handle)
wrapperProxy(ctx, c, &milvuspb.DescribeCollectionRequest{}, true, false, "", handle)
})
getTestCases = append(getTestCases, rawTestCase{
path: path,
@ -239,12 +239,12 @@ func TestGrpcWrapper(t *testing.T) {
}
app.GET(path, func(c *gin.Context) {
ctx := proxy.NewContextWithMetadata(c, "", DefaultDbName)
wrapperProxy(ctx, c, &DefaultReq{}, false, false, handle)
wrapperProxy(ctx, c, &DefaultReq{}, false, false, "", handle)
})
appNeedAuth.GET(path, func(c *gin.Context) {
username, _ := c.Get(ContextUsername)
ctx := proxy.NewContextWithMetadata(c, username.(string), DefaultDbName)
wrapperProxy(ctx, c, &milvuspb.DescribeCollectionRequest{}, true, false, handle)
wrapperProxy(ctx, c, &milvuspb.DescribeCollectionRequest{}, true, false, "", handle)
})
getTestCases = append(getTestCases, rawTestCase{
path: path,
@ -291,11 +291,11 @@ func TestGrpcWrapper(t *testing.T) {
path = "/wrapper/grpc/auth"
app.GET(path, func(c *gin.Context) {
wrapperProxy(context.Background(), c, &milvuspb.DescribeCollectionRequest{}, true, false, handle)
wrapperProxy(context.Background(), c, &milvuspb.DescribeCollectionRequest{}, true, false, "", handle)
})
appNeedAuth.GET(path, func(c *gin.Context) {
ctx := proxy.NewContextWithMetadata(c, "test", DefaultDbName)
wrapperProxy(ctx, c, &milvuspb.LoadCollectionRequest{}, true, false, handle)
wrapperProxy(ctx, c, &milvuspb.LoadCollectionRequest{}, true, false, "", handle)
})
t.Run("check authorization", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, path, nil)
@ -1159,7 +1159,16 @@ func TestDML(t *testing.T) {
Status: &StatusSuccess,
}, nil).Times(6)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Times(4)
mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Times(3)
mp.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
if matchCountRule(req.OutputFields) {
for _, pair := range req.QueryParams {
if pair.GetKey() == ParamLimit {
return nil, fmt.Errorf("mock error")
}
}
}
return &milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil
}).Times(4)
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once()
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, UpsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
@ -1181,6 +1190,10 @@ func TestDML(t *testing.T) {
path: QueryAction,
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [2, 4, 6, 8]"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: QueryAction,
requestBody: []byte(`{"collectionName": "book", "filter": "", "outputFields": ["count(*)"], "limit": 10}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: InsertAction,
requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`),

View File

@ -104,7 +104,7 @@ type QueryReqV2 struct {
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
OutputFields []string `json:"outputFields"`
Filter string `json:"filter" binding:"required"`
Filter string `json:"filter"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}

View File

@ -18,46 +18,51 @@ import (
var hoo hook.Hook
func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
hookutil.InitOnceHook()
hoo = hookutil.Hoo
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var (
fullMethod = info.FullMethod
newCtx context.Context
isMock bool
mockResp interface{}
realResp interface{}
realErr error
err error
)
if isMock, mockResp, err = hoo.Mock(ctx, req, fullMethod); isMock {
log.Info("hook mock", zap.String("user", getCurrentUser(ctx)),
zap.String("full method", fullMethod), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookMock, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return mockResp, err
}
if newCtx, err = hoo.Before(ctx, req, fullMethod); err != nil {
log.Warn("hook before error", zap.String("user", getCurrentUser(ctx)), zap.String("full method", fullMethod),
zap.Any("request", req), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookBefore, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return nil, err
}
realResp, realErr = handler(newCtx, req)
if err = hoo.After(newCtx, realResp, realErr, fullMethod); err != nil {
log.Warn("hook after error", zap.String("user", getCurrentUser(ctx)), zap.String("full method", fullMethod),
zap.Any("request", req), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookAfter, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return nil, err
}
return realResp, realErr
return HookInterceptor(ctx, req, getCurrentUser(ctx), info.FullMethod, handler)
}
}
func HookInterceptor(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) {
if hoo == nil {
hookutil.InitOnceHook()
hoo = hookutil.Hoo
}
var (
newCtx context.Context
isMock bool
mockResp interface{}
realResp interface{}
realErr error
err error
)
if isMock, mockResp, err = hoo.Mock(ctx, req, fullMethod); isMock {
log.Info("hook mock", zap.String("user", userName),
zap.String("full method", fullMethod), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookMock, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return mockResp, err
}
if newCtx, err = hoo.Before(ctx, req, fullMethod); err != nil {
log.Warn("hook before error", zap.String("user", userName), zap.String("full method", fullMethod),
zap.Any("request", req), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookBefore, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return nil, err
}
realResp, realErr = handler(newCtx, req)
if err = hoo.After(newCtx, realResp, realErr, fullMethod); err != nil {
log.Warn("hook after error", zap.String("user", userName), zap.String("full method", fullMethod),
zap.Any("request", req), zap.Error(err))
metrics.ProxyHookFunc.WithLabelValues(metrics.HookAfter, fullMethod).Inc()
updateProxyFunctionCallMetric(fullMethod)
return nil, err
}
return realResp, realErr
}
func updateProxyFunctionCallMetric(fullMethod string) {
strs := strings.Split(fullMethod, "/")
method := strs[len(strs)-1]