fix: Unify hook singleton implementation in proxy (#34887)

Related to #34885

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2024-07-26 18:07:53 +08:00 committed by GitHub
parent 6e9fbd1630
commit 783f9d9c33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 129 additions and 64 deletions

View File

@ -47,6 +47,7 @@ import (
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/util/hookutil"
milvusmock "github.com/milvus-io/milvus/internal/util/mock"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
@ -1165,8 +1166,8 @@ func TestHttpAuthenticate(t *testing.T) {
}
{
proxy.SetMockAPIHook("foo", nil)
defer proxy.SetMockAPIHook("", nil)
hookutil.SetMockAPIHook("foo", nil)
defer hookutil.SetMockAPIHook("", nil)
ctx.Request.Header.Set("Authorization", "Bearer 123456")
authenticate(ctx)
ctxName, _ := ctx.Get(httpserver.ContextUsername)

View File

@ -119,7 +119,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
{
// verify apikey error
SetMockAPIHook("", errors.New("err"))
hookutil.SetMockAPIHook("", errors.New("err"))
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
ctx = metadata.NewIncomingContext(ctx, md)
_, err = AuthenticationInterceptor(ctx)
@ -127,7 +127,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
}
{
SetMockAPIHook("mockUser", nil)
hookutil.SetMockAPIHook("mockUser", nil)
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
ctx = metadata.NewIncomingContext(ctx, md)
authCtx, err := AuthenticationInterceptor(ctx)
@ -141,5 +141,5 @@ func TestAuthenticationInterceptor(t *testing.T) {
user, _ := parseMD(rawToken)
assert.Equal(t, "mockUser", user)
}
hoo = hookutil.DefaultHook{}
hookutil.SetTestHook(hookutil.DefaultHook{})
}

View File

@ -8,15 +8,12 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
var hoo hook.Hook
func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return HookInterceptor(ctx, req, getCurrentUser(ctx), info.FullMethod, handler)
@ -24,10 +21,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
}
func HookInterceptor(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) {
if hoo == nil {
hookutil.InitOnceHook()
hoo = hookutil.Hoo
}
hoo := hookutil.GetHook()
var (
newCtx context.Context
isMock bool
@ -80,14 +74,3 @@ func getCurrentUser(ctx context.Context) string {
}
return username
}
func SetMockAPIHook(apiUser string, mockErr error) {
if apiUser == "" && mockErr == nil {
hoo = &hookutil.DefaultHook{}
return
}
hoo = &hookutil.MockAPIHook{
MockErr: mockErr,
User: apiUser,
}
}

View File

@ -83,7 +83,7 @@ func TestHookInterceptor(t *testing.T) {
err error
)
hoo = mockHoo
hookutil.SetTestHook(mockHoo)
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
@ -95,7 +95,7 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, res, mockHoo.mockRes)
assert.Equal(t, err, mockHoo.mockErr)
hoo = beforeHoo
hookutil.SetTestHook(beforeHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
@ -103,7 +103,7 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, err, beforeHoo.err)
beforeHoo.err = nil
hoo = beforeHoo
hookutil.SetTestHook(beforeHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
return nil, nil
@ -111,14 +111,14 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)
hoo = afterHoo
hookutil.SetTestHook(afterHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return re, nil
})
assert.Equal(t, re.method, afterHoo.method)
assert.Equal(t, err, afterHoo.err)
hoo = &hookutil.DefaultHook{}
hookutil.SetTestHook(&hookutil.DefaultHook{})
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return &resp{
method: r.(*req).method,

View File

@ -2592,7 +2592,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
dbName := request.DbName
collectionName := request.CollectionName
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeInsert,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
@ -2696,7 +2696,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
username := GetCurUserFromContextOrDefault(ctx)
collectionName := request.CollectionName
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeDelete,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
@ -2829,7 +2829,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
nodeID := paramtable.GetStringNodeID()
dbName := request.DbName
collectionName := request.CollectionName
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeUpsert,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: username,
@ -3072,7 +3072,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
if qt.result != nil {
username := GetCurUserFromContextOrDefault(ctx)
sentSize := proto.Size(qt.result)
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeSearch,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
@ -3269,7 +3269,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
if qt.result != nil {
sentSize := proto.Size(qt.result)
username := GetCurUserFromContextOrDefault(ctx)
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeHybridSearch,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
@ -3595,7 +3595,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
username := GetCurUserFromContextOrDefault(ctx)
nodeID := paramtable.GetStringNodeID()
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeQuery,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: username,

View File

@ -31,7 +31,6 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/proto/internalpb"
@ -67,9 +66,8 @@ type Timestamp = typeutil.Timestamp
var _ types.Proxy = (*Proxy)(nil)
var (
Params = paramtable.Get()
Extension hook.Extension
rateCol *ratelimitutil.RateCollector
Params = paramtable.Get()
rateCol *ratelimitutil.RateCollector
)
// Proxy of milvus
@ -157,7 +155,6 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
node.UpdateStateCode(commonpb.StateCode_Abnormal)
expr.Register("proxy", node)
hookutil.InitOnceHook()
Extension = hookutil.Extension
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
return node, nil
}
@ -422,7 +419,7 @@ func (node *Proxy) Start() error {
cb()
}
Extension.Report(map[string]any{
hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeNodeID,
hookutil.NodeIDKey: paramtable.GetNodeID(),
})

View File

@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/hookutil"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
@ -924,9 +925,7 @@ func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
}
func VerifyAPIKey(rawToken string) (string, error) {
if hoo == nil {
return "", merr.WrapErrServiceInternal("internal: Milvus Proxy is not ready yet. please wait")
}
hoo := hookutil.GetHook()
user, err := hoo.VerifyAPIKey(rawToken)
if err != nil {
log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err))

View File

@ -50,17 +50,6 @@ func (d DefaultHook) After(ctx context.Context, result interface{}, err error, f
return nil
}
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
type MockAPIHook struct {
DefaultHook
MockErr error
User string
}
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
return m.User, m.MockErr
}
func (d DefaultHook) Release() {}
type DefaultExtension struct{}

View File

@ -22,6 +22,7 @@ import (
"fmt"
"plugin"
"sync"
"sync/atomic"
"go.uber.org/zap"
@ -32,14 +33,37 @@ import (
)
var (
Hoo hook.Hook
Extension hook.Extension
hoo atomic.Value // hook.Hook
extension atomic.Value // hook.Extension
initOnce sync.Once
)
// hookContainer is Container to wrap hook.Hook interface
// this struct is used to be stored in atomic.Value
// since different type stored in it will cause panicking.
type hookContainer struct {
hook hook.Hook
}
// extensionContainer is Container to wrap hook.Extension interface
// this struct is used to be stored in atomic.Value
// since different type stored in it will cause panicking.
type extensionContainer struct {
extension hook.Extension
}
func storeHook(hook hook.Hook) {
hoo.Store(hookContainer{hook: hook})
}
func storeExtension(ext hook.Extension) {
extension.Store(extensionContainer{extension: ext})
}
func initHook() error {
Hoo = DefaultHook{}
Extension = DefaultExtension{}
// setup default hook & extension
storeHook(DefaultHook{})
storeExtension(DefaultExtension{})
path := paramtable.Get().ProxyCfg.SoPath.GetValue()
if path == "" {
@ -59,22 +83,26 @@ func initHook() error {
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
}
var hookVal hook.Hook
var ok bool
Hoo, ok = h.(hook.Hook)
hookVal, ok = h.(hook.Hook)
if !ok {
return fmt.Errorf("fail to convert the `Hook` interface")
}
if err = Hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
if err = hookVal.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
}
storeHook((hookVal))
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
log.Info("receive the hook refresh event", zap.Any("event", event))
go func() {
hookVal := GetHook()
soConfig := paramtable.GetHookParams().SoConfig.GetValue()
log.Info("refresh hook configs", zap.Any("config", soConfig))
if err = Hoo.Init(soConfig); err != nil {
if err = hookVal.Init(soConfig); err != nil {
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
}
storeHook(hookVal)
}()
})
@ -82,10 +110,12 @@ func initHook() error {
if err != nil {
return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error())
}
Extension, ok = e.(hook.Extension)
var extVal hook.Extension
extVal, ok = e.(hook.Extension)
if !ok {
return fmt.Errorf("fail to convert the `Extension` interface")
}
storeExtension(extVal)
return nil
}
@ -104,3 +134,15 @@ func InitOnceHook() {
}
})
}
// GetHook returns singleton hook.Hook instance.
func GetHook() hook.Hook {
InitOnceHook()
return hoo.Load().(hookContainer).hook
}
// GetHook returns singleton hook.Extension instance.
func GetExtension() hook.Extension {
InitOnceHook()
return extension.Load().(extensionContainer).extension
}

View File

@ -32,7 +32,7 @@ func TestInitHook(t *testing.T) {
Params := paramtable.Get()
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
initHook()
assert.IsType(t, DefaultHook{}, Hoo)
assert.IsType(t, DefaultHook{}, GetHook())
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
err := initHook()

View File

@ -0,0 +1,54 @@
//go:build test
// +build test
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hookutil
import "github.com/milvus-io/milvus-proto/go-api/v2/hook"
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
type MockAPIHook struct {
DefaultHook
MockErr error
User string
}
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
return m.User, m.MockErr
}
func SetMockAPIHook(apiUser string, mockErr error) {
if apiUser == "" && mockErr == nil {
storeHook(&DefaultHook{})
return
}
storeHook(&MockAPIHook{
MockErr: mockErr,
User: apiUser,
})
}
func SetTestHook(hookVal hook.Hook) {
storeHook(hookVal)
}
func SetTestExtension(extVal hook.Extension) {
storeExtension(extVal)
}

View File

@ -465,7 +465,7 @@ func (cluster *MiniClusterV2) GetAvailablePort() (int, error) {
func InitReportExtension() *ReportChanExtension {
e := NewReportChanExtension()
hookutil.InitOnceHook()
hookutil.Extension = e
hookutil.SetTestExtension(e)
return e
}