mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-29 18:38:44 +08:00
fix: Unify hook singleton implementation in proxy (#34887)
Related to #34885 --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
6e9fbd1630
commit
783f9d9c33
@ -47,6 +47,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proxy"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
milvusmock "github.com/milvus-io/milvus/internal/util/mock"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
@ -1165,8 +1166,8 @@ func TestHttpAuthenticate(t *testing.T) {
|
||||
}
|
||||
|
||||
{
|
||||
proxy.SetMockAPIHook("foo", nil)
|
||||
defer proxy.SetMockAPIHook("", nil)
|
||||
hookutil.SetMockAPIHook("foo", nil)
|
||||
defer hookutil.SetMockAPIHook("", nil)
|
||||
ctx.Request.Header.Set("Authorization", "Bearer 123456")
|
||||
authenticate(ctx)
|
||||
ctxName, _ := ctx.Get(httpserver.ContextUsername)
|
||||
|
@ -119,7 +119,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
|
||||
|
||||
{
|
||||
// verify apikey error
|
||||
SetMockAPIHook("", errors.New("err"))
|
||||
hookutil.SetMockAPIHook("", errors.New("err"))
|
||||
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
|
||||
ctx = metadata.NewIncomingContext(ctx, md)
|
||||
_, err = AuthenticationInterceptor(ctx)
|
||||
@ -127,7 +127,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
|
||||
}
|
||||
|
||||
{
|
||||
SetMockAPIHook("mockUser", nil)
|
||||
hookutil.SetMockAPIHook("mockUser", nil)
|
||||
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
|
||||
ctx = metadata.NewIncomingContext(ctx, md)
|
||||
authCtx, err := AuthenticationInterceptor(ctx)
|
||||
@ -141,5 +141,5 @@ func TestAuthenticationInterceptor(t *testing.T) {
|
||||
user, _ := parseMD(rawToken)
|
||||
assert.Equal(t, "mockUser", user)
|
||||
}
|
||||
hoo = hookutil.DefaultHook{}
|
||||
hookutil.SetTestHook(hookutil.DefaultHook{})
|
||||
}
|
||||
|
@ -8,15 +8,12 @@ import (
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
var hoo hook.Hook
|
||||
|
||||
func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
return HookInterceptor(ctx, req, getCurrentUser(ctx), info.FullMethod, handler)
|
||||
@ -24,10 +21,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
|
||||
}
|
||||
|
||||
func HookInterceptor(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if hoo == nil {
|
||||
hookutil.InitOnceHook()
|
||||
hoo = hookutil.Hoo
|
||||
}
|
||||
hoo := hookutil.GetHook()
|
||||
var (
|
||||
newCtx context.Context
|
||||
isMock bool
|
||||
@ -80,14 +74,3 @@ func getCurrentUser(ctx context.Context) string {
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
||||
func SetMockAPIHook(apiUser string, mockErr error) {
|
||||
if apiUser == "" && mockErr == nil {
|
||||
hoo = &hookutil.DefaultHook{}
|
||||
return
|
||||
}
|
||||
hoo = &hookutil.MockAPIHook{
|
||||
MockErr: mockErr,
|
||||
User: apiUser,
|
||||
}
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ func TestHookInterceptor(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
hoo = mockHoo
|
||||
hookutil.SetTestHook(mockHoo)
|
||||
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
@ -95,7 +95,7 @@ func TestHookInterceptor(t *testing.T) {
|
||||
assert.Equal(t, res, mockHoo.mockRes)
|
||||
assert.Equal(t, err, mockHoo.mockErr)
|
||||
|
||||
hoo = beforeHoo
|
||||
hookutil.SetTestHook(beforeHoo)
|
||||
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
@ -103,7 +103,7 @@ func TestHookInterceptor(t *testing.T) {
|
||||
assert.Equal(t, err, beforeHoo.err)
|
||||
|
||||
beforeHoo.err = nil
|
||||
hoo = beforeHoo
|
||||
hookutil.SetTestHook(beforeHoo)
|
||||
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
|
||||
return nil, nil
|
||||
@ -111,14 +111,14 @@ func TestHookInterceptor(t *testing.T) {
|
||||
assert.Equal(t, r.method, beforeHoo.method)
|
||||
assert.Equal(t, err, beforeHoo.err)
|
||||
|
||||
hoo = afterHoo
|
||||
hookutil.SetTestHook(afterHoo)
|
||||
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
||||
return re, nil
|
||||
})
|
||||
assert.Equal(t, re.method, afterHoo.method)
|
||||
assert.Equal(t, err, afterHoo.err)
|
||||
|
||||
hoo = &hookutil.DefaultHook{}
|
||||
hookutil.SetTestHook(&hookutil.DefaultHook{})
|
||||
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
||||
return &resp{
|
||||
method: r.(*req).method,
|
||||
|
@ -2592,7 +2592,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
||||
dbName := request.DbName
|
||||
collectionName := request.CollectionName
|
||||
|
||||
v := Extension.Report(map[string]any{
|
||||
v := hookutil.GetExtension().Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeInsert,
|
||||
hookutil.DatabaseKey: dbName,
|
||||
hookutil.UsernameKey: username,
|
||||
@ -2696,7 +2696,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
||||
|
||||
username := GetCurUserFromContextOrDefault(ctx)
|
||||
collectionName := request.CollectionName
|
||||
v := Extension.Report(map[string]any{
|
||||
v := hookutil.GetExtension().Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeDelete,
|
||||
hookutil.DatabaseKey: dbName,
|
||||
hookutil.UsernameKey: username,
|
||||
@ -2829,7 +2829,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
|
||||
nodeID := paramtable.GetStringNodeID()
|
||||
dbName := request.DbName
|
||||
collectionName := request.CollectionName
|
||||
v := Extension.Report(map[string]any{
|
||||
v := hookutil.GetExtension().Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeUpsert,
|
||||
hookutil.DatabaseKey: request.DbName,
|
||||
hookutil.UsernameKey: username,
|
||||
@ -3072,7 +3072,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
|
||||
if qt.result != nil {
|
||||
username := GetCurUserFromContextOrDefault(ctx)
|
||||
sentSize := proto.Size(qt.result)
|
||||
v := Extension.Report(map[string]any{
|
||||
v := hookutil.GetExtension().Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeSearch,
|
||||
hookutil.DatabaseKey: dbName,
|
||||
hookutil.UsernameKey: username,
|
||||
@ -3269,7 +3269,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
||||
if qt.result != nil {
|
||||
sentSize := proto.Size(qt.result)
|
||||
username := GetCurUserFromContextOrDefault(ctx)
|
||||
v := Extension.Report(map[string]any{
|
||||
v := hookutil.GetExtension().Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeHybridSearch,
|
||||
hookutil.DatabaseKey: dbName,
|
||||
hookutil.UsernameKey: username,
|
||||
@ -3595,7 +3595,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
||||
|
||||
username := GetCurUserFromContextOrDefault(ctx)
|
||||
nodeID := paramtable.GetStringNodeID()
|
||||
v := Extension.Report(map[string]any{
|
||||
v := hookutil.GetExtension().Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeQuery,
|
||||
hookutil.DatabaseKey: request.DbName,
|
||||
hookutil.UsernameKey: username,
|
||||
|
@ -31,7 +31,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
@ -67,9 +66,8 @@ type Timestamp = typeutil.Timestamp
|
||||
var _ types.Proxy = (*Proxy)(nil)
|
||||
|
||||
var (
|
||||
Params = paramtable.Get()
|
||||
Extension hook.Extension
|
||||
rateCol *ratelimitutil.RateCollector
|
||||
Params = paramtable.Get()
|
||||
rateCol *ratelimitutil.RateCollector
|
||||
)
|
||||
|
||||
// Proxy of milvus
|
||||
@ -157,7 +155,6 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
expr.Register("proxy", node)
|
||||
hookutil.InitOnceHook()
|
||||
Extension = hookutil.Extension
|
||||
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
|
||||
return node, nil
|
||||
}
|
||||
@ -422,7 +419,7 @@ func (node *Proxy) Start() error {
|
||||
cb()
|
||||
}
|
||||
|
||||
Extension.Report(map[string]any{
|
||||
hookutil.GetExtension().Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeNodeID,
|
||||
hookutil.NodeIDKey: paramtable.GetNodeID(),
|
||||
})
|
||||
|
@ -36,6 +36,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
@ -924,9 +925,7 @@ func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
|
||||
}
|
||||
|
||||
func VerifyAPIKey(rawToken string) (string, error) {
|
||||
if hoo == nil {
|
||||
return "", merr.WrapErrServiceInternal("internal: Milvus Proxy is not ready yet. please wait")
|
||||
}
|
||||
hoo := hookutil.GetHook()
|
||||
user, err := hoo.VerifyAPIKey(rawToken)
|
||||
if err != nil {
|
||||
log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err))
|
||||
|
@ -50,17 +50,6 @@ func (d DefaultHook) After(ctx context.Context, result interface{}, err error, f
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
|
||||
type MockAPIHook struct {
|
||||
DefaultHook
|
||||
MockErr error
|
||||
User string
|
||||
}
|
||||
|
||||
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
|
||||
return m.User, m.MockErr
|
||||
}
|
||||
|
||||
func (d DefaultHook) Release() {}
|
||||
|
||||
type DefaultExtension struct{}
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"fmt"
|
||||
"plugin"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
@ -32,14 +33,37 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
Hoo hook.Hook
|
||||
Extension hook.Extension
|
||||
hoo atomic.Value // hook.Hook
|
||||
extension atomic.Value // hook.Extension
|
||||
initOnce sync.Once
|
||||
)
|
||||
|
||||
// hookContainer is Container to wrap hook.Hook interface
|
||||
// this struct is used to be stored in atomic.Value
|
||||
// since different type stored in it will cause panicking.
|
||||
type hookContainer struct {
|
||||
hook hook.Hook
|
||||
}
|
||||
|
||||
// extensionContainer is Container to wrap hook.Extension interface
|
||||
// this struct is used to be stored in atomic.Value
|
||||
// since different type stored in it will cause panicking.
|
||||
type extensionContainer struct {
|
||||
extension hook.Extension
|
||||
}
|
||||
|
||||
func storeHook(hook hook.Hook) {
|
||||
hoo.Store(hookContainer{hook: hook})
|
||||
}
|
||||
|
||||
func storeExtension(ext hook.Extension) {
|
||||
extension.Store(extensionContainer{extension: ext})
|
||||
}
|
||||
|
||||
func initHook() error {
|
||||
Hoo = DefaultHook{}
|
||||
Extension = DefaultExtension{}
|
||||
// setup default hook & extension
|
||||
storeHook(DefaultHook{})
|
||||
storeExtension(DefaultExtension{})
|
||||
|
||||
path := paramtable.Get().ProxyCfg.SoPath.GetValue()
|
||||
if path == "" {
|
||||
@ -59,22 +83,26 @@ func initHook() error {
|
||||
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
|
||||
}
|
||||
|
||||
var hookVal hook.Hook
|
||||
var ok bool
|
||||
Hoo, ok = h.(hook.Hook)
|
||||
hookVal, ok = h.(hook.Hook)
|
||||
if !ok {
|
||||
return fmt.Errorf("fail to convert the `Hook` interface")
|
||||
}
|
||||
if err = Hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
|
||||
if err = hookVal.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
|
||||
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
|
||||
}
|
||||
storeHook((hookVal))
|
||||
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
|
||||
log.Info("receive the hook refresh event", zap.Any("event", event))
|
||||
go func() {
|
||||
hookVal := GetHook()
|
||||
soConfig := paramtable.GetHookParams().SoConfig.GetValue()
|
||||
log.Info("refresh hook configs", zap.Any("config", soConfig))
|
||||
if err = Hoo.Init(soConfig); err != nil {
|
||||
if err = hookVal.Init(soConfig); err != nil {
|
||||
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
|
||||
}
|
||||
storeHook(hookVal)
|
||||
}()
|
||||
})
|
||||
|
||||
@ -82,10 +110,12 @@ func initHook() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error())
|
||||
}
|
||||
Extension, ok = e.(hook.Extension)
|
||||
var extVal hook.Extension
|
||||
extVal, ok = e.(hook.Extension)
|
||||
if !ok {
|
||||
return fmt.Errorf("fail to convert the `Extension` interface")
|
||||
}
|
||||
storeExtension(extVal)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -104,3 +134,15 @@ func InitOnceHook() {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GetHook returns singleton hook.Hook instance.
|
||||
func GetHook() hook.Hook {
|
||||
InitOnceHook()
|
||||
return hoo.Load().(hookContainer).hook
|
||||
}
|
||||
|
||||
// GetHook returns singleton hook.Extension instance.
|
||||
func GetExtension() hook.Extension {
|
||||
InitOnceHook()
|
||||
return extension.Load().(extensionContainer).extension
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func TestInitHook(t *testing.T) {
|
||||
Params := paramtable.Get()
|
||||
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
|
||||
initHook()
|
||||
assert.IsType(t, DefaultHook{}, Hoo)
|
||||
assert.IsType(t, DefaultHook{}, GetHook())
|
||||
|
||||
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
|
||||
err := initHook()
|
||||
|
54
internal/util/hookutil/mock_hook.go
Normal file
54
internal/util/hookutil/mock_hook.go
Normal file
@ -0,0 +1,54 @@
|
||||
//go:build test
|
||||
// +build test
|
||||
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package hookutil
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/hook"
|
||||
|
||||
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
|
||||
type MockAPIHook struct {
|
||||
DefaultHook
|
||||
MockErr error
|
||||
User string
|
||||
}
|
||||
|
||||
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
|
||||
return m.User, m.MockErr
|
||||
}
|
||||
|
||||
func SetMockAPIHook(apiUser string, mockErr error) {
|
||||
if apiUser == "" && mockErr == nil {
|
||||
storeHook(&DefaultHook{})
|
||||
return
|
||||
}
|
||||
storeHook(&MockAPIHook{
|
||||
MockErr: mockErr,
|
||||
User: apiUser,
|
||||
})
|
||||
}
|
||||
|
||||
func SetTestHook(hookVal hook.Hook) {
|
||||
storeHook(hookVal)
|
||||
}
|
||||
|
||||
func SetTestExtension(extVal hook.Extension) {
|
||||
storeExtension(extVal)
|
||||
}
|
@ -465,7 +465,7 @@ func (cluster *MiniClusterV2) GetAvailablePort() (int, error) {
|
||||
func InitReportExtension() *ReportChanExtension {
|
||||
e := NewReportChanExtension()
|
||||
hookutil.InitOnceHook()
|
||||
hookutil.Extension = e
|
||||
hookutil.SetTestExtension(e)
|
||||
return e
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user