mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
Implement the hook
interceptor (#19294)
Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
parent
901d3fb367
commit
68a257458b
9
api/hook/hook.go
Normal file
9
api/hook/hook.go
Normal file
@ -0,0 +1,9 @@
|
||||
package hook
|
||||
|
||||
type Hook interface {
|
||||
Init(params map[string]string) error
|
||||
Mock(req interface{}, fullMethod string) (bool, interface{}, error)
|
||||
Before(req interface{}, fullMethod string) error
|
||||
After(result interface{}, err error, fullMethod string) error
|
||||
Release()
|
||||
}
|
0
configs/hook.yaml
Normal file
0
configs/hook.yaml
Normal file
@ -22,11 +22,7 @@ macro( build_rocksdb )
|
||||
message( STATUS "Building ROCKSDB-${ROCKSDB_VERSION} from source" )
|
||||
|
||||
set ( ROCKSDB_MD5 "e4a0625f0cec82060e62c81b787a1124" )
|
||||
|
||||
if ( EMBEDDED_MILVUS )
|
||||
message ( STATUS "Turning on fPIC while building embedded Milvus" )
|
||||
set( FPIC_ARG "-DCMAKE_POSITION_INDEPENDENT_CODE=ON" )
|
||||
endif()
|
||||
set( FPIC_ARG "-DCMAKE_POSITION_INDEPENDENT_CODE=ON" )
|
||||
set( ROCKSDB_CMAKE_ARGS
|
||||
"-DWITH_GFLAGS=OFF"
|
||||
"-DROCKSDB_BUILD_SHARED=OFF"
|
||||
|
@ -179,6 +179,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
|
||||
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
|
||||
ot.UnaryServerInterceptor(opts...),
|
||||
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
|
||||
proxy.UnaryServerHookInterceptor(),
|
||||
proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor),
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
proxy.RateLimitInterceptor(limiter),
|
||||
|
95
internal/proxy/hook_interceptor.go
Normal file
95
internal/proxy/hook_interceptor.go
Normal file
@ -0,0 +1,95 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"plugin"
|
||||
|
||||
"github.com/milvus-io/milvus/api/hook"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type defaultHook struct {
|
||||
}
|
||||
|
||||
func (d defaultHook) Init(params map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d defaultHook) Mock(req interface{}, fullMethod string) (bool, interface{}, error) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (d defaultHook) Before(req interface{}, fullMethod string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d defaultHook) After(result interface{}, err error, fullMethod string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d defaultHook) Release() {}
|
||||
|
||||
var hoo hook.Hook
|
||||
|
||||
func initHook() {
|
||||
path := Params.ProxyCfg.SoPath
|
||||
if path == "" {
|
||||
hoo = defaultHook{}
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("start to load plugin", zap.String("path", path))
|
||||
p, err := plugin.Open(path)
|
||||
if err != nil {
|
||||
exit("fail to open the plugin", err)
|
||||
}
|
||||
logger.Debug("plugin open")
|
||||
|
||||
h, err := p.Lookup("MilvusHook")
|
||||
if err != nil {
|
||||
exit("fail to the 'MilvusHook' object in the plugin", err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
hoo, ok = h.(hook.Hook)
|
||||
if !ok {
|
||||
exit("fail to convert the `Hook` interface", nil)
|
||||
}
|
||||
if err = hoo.Init(Params.HookCfg.SoConfig); err != nil {
|
||||
exit("fail to init configs for the hoo", err)
|
||||
}
|
||||
}
|
||||
|
||||
func exit(errMsg string, err error) {
|
||||
logger.Panic("hoo error", zap.String("path", Params.ProxyCfg.SoPath), zap.String("msg", errMsg), zap.Error(err))
|
||||
}
|
||||
|
||||
func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
|
||||
initHook()
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
var (
|
||||
fullMethod = info.FullMethod
|
||||
isMock bool
|
||||
mockResp interface{}
|
||||
realResp interface{}
|
||||
realErr error
|
||||
err error
|
||||
)
|
||||
|
||||
if isMock, mockResp, err = hoo.Mock(req, fullMethod); isMock {
|
||||
return mockResp, err
|
||||
}
|
||||
|
||||
if err = hoo.Before(req, fullMethod); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
realResp, realErr = handler(ctx, req)
|
||||
if err = hoo.After(realResp, realErr, fullMethod); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return realResp, realErr
|
||||
}
|
||||
}
|
127
internal/proxy/hook_interceptor_test.go
Normal file
127
internal/proxy/hook_interceptor_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInitHook(t *testing.T) {
|
||||
Params.ProxyCfg.SoPath = ""
|
||||
initHook()
|
||||
assert.IsType(t, defaultHook{}, hoo)
|
||||
|
||||
Params.ProxyCfg.SoPath = "/a/b/hook.so"
|
||||
assert.Panics(t, func() {
|
||||
initHook()
|
||||
})
|
||||
Params.ProxyCfg.SoPath = ""
|
||||
}
|
||||
|
||||
type mockHook struct {
|
||||
defaultHook
|
||||
mockRes interface{}
|
||||
mockErr error
|
||||
}
|
||||
|
||||
func (m mockHook) Mock(req interface{}, fullMethod string) (bool, interface{}, error) {
|
||||
return true, m.mockRes, m.mockErr
|
||||
}
|
||||
|
||||
type req struct {
|
||||
method string
|
||||
}
|
||||
|
||||
type beforeMock struct {
|
||||
defaultHook
|
||||
method string
|
||||
err error
|
||||
}
|
||||
|
||||
func (b beforeMock) Before(r interface{}, fullMethod string) error {
|
||||
re, ok := r.(*req)
|
||||
if !ok {
|
||||
return errors.New("r is invalid type")
|
||||
}
|
||||
re.method = b.method
|
||||
return b.err
|
||||
}
|
||||
|
||||
type resp struct {
|
||||
method string
|
||||
}
|
||||
|
||||
type afterMock struct {
|
||||
defaultHook
|
||||
method string
|
||||
err error
|
||||
}
|
||||
|
||||
func (a afterMock) After(r interface{}, err error, fullMethod string) error {
|
||||
re, ok := r.(*resp)
|
||||
if !ok {
|
||||
return errors.New("r is invalid type")
|
||||
}
|
||||
re.method = a.method
|
||||
return a.err
|
||||
}
|
||||
|
||||
func TestHookInterceptor(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
info = &grpc.UnaryServerInfo{
|
||||
FullMethod: "test",
|
||||
}
|
||||
interceptor = UnaryServerHookInterceptor()
|
||||
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
|
||||
r = &req{method: "req"}
|
||||
re = &resp{method: "resp"}
|
||||
beforeHoo = beforeMock{method: "before", err: errors.New("before")}
|
||||
afterHoo = afterMock{method: "after", err: errors.New("after")}
|
||||
|
||||
res interface{}
|
||||
err error
|
||||
)
|
||||
|
||||
hoo = mockHoo
|
||||
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Equal(t, res, mockHoo.mockRes)
|
||||
assert.Equal(t, err, mockHoo.mockErr)
|
||||
|
||||
hoo = beforeHoo
|
||||
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Equal(t, r.method, beforeHoo.method)
|
||||
assert.Equal(t, err, beforeHoo.err)
|
||||
|
||||
hoo = afterHoo
|
||||
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
||||
return re, nil
|
||||
})
|
||||
assert.Equal(t, re.method, afterHoo.method)
|
||||
assert.Equal(t, err, afterHoo.err)
|
||||
|
||||
hoo = defaultHook{}
|
||||
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
||||
return &resp{
|
||||
method: r.(*req).method,
|
||||
}, nil
|
||||
})
|
||||
assert.Equal(t, res.(*resp).method, r.method)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDefaultHook(t *testing.T) {
|
||||
d := defaultHook{}
|
||||
assert.NoError(t, d.Init(nil))
|
||||
assert.NotPanics(t, func() {
|
||||
d.Release()
|
||||
})
|
||||
}
|
@ -45,6 +45,13 @@ const (
|
||||
DefaultEtcdEndpoints = "localhost:2379"
|
||||
DefaultInsertBufferSize = "16777216"
|
||||
DefaultEnvPrefix = "milvus"
|
||||
|
||||
DefaultLogFormat = "text"
|
||||
DefaultLogLevelForBase = "debug"
|
||||
DefaultRootPath = ""
|
||||
DefaultMaxSize = 300
|
||||
DefaultMaxAge = 10
|
||||
DefaultMaxBackups = 20
|
||||
)
|
||||
|
||||
var defaultYaml = DefaultMilvusYaml
|
||||
@ -71,6 +78,8 @@ type BaseTable struct {
|
||||
RoleName string
|
||||
Log log.Config
|
||||
LogCfgFunc func(log.Config)
|
||||
|
||||
YamlFile string
|
||||
}
|
||||
|
||||
// GlobalInitWithYaml initializes the param table with the given yaml.
|
||||
@ -94,6 +103,9 @@ func (gp *BaseTable) Init() {
|
||||
ret = strings.ReplaceAll(ret, ".", "")
|
||||
return ret
|
||||
}
|
||||
if gp.YamlFile == "" {
|
||||
gp.YamlFile = defaultYaml
|
||||
}
|
||||
gp.initConfigsFromLocal(formatter)
|
||||
gp.initConfigsFromRemote(formatter)
|
||||
gp.InitLogCfg()
|
||||
@ -107,7 +119,7 @@ func (gp *BaseTable) initConfigsFromLocal(formatter func(key string) string) {
|
||||
}
|
||||
|
||||
gp.configDir = gp.initConfPath()
|
||||
configFilePath := gp.configDir + "/" + defaultYaml
|
||||
configFilePath := gp.configDir + "/" + gp.YamlFile
|
||||
gp.mgr, err = config.Init(config.WithEnvSource(formatter), config.WithFilesSource(configFilePath))
|
||||
if err != nil {
|
||||
log.Warn("init baseTable with file failed", zap.String("configFile", configFilePath), zap.Error(err))
|
||||
@ -127,7 +139,7 @@ func (gp *BaseTable) initConfigsFromRemote(formatter func(key string) string) {
|
||||
return
|
||||
}
|
||||
|
||||
configFilePath := gp.configDir + "/" + defaultYaml
|
||||
configFilePath := gp.configDir + "/" + gp.YamlFile
|
||||
gp.mgr, err = config.Init(config.WithEnvSource(formatter),
|
||||
config.WithFilesSource(configFilePath),
|
||||
config.WithEtcdSource(&config.EtcdInfo{
|
||||
@ -164,6 +176,10 @@ func (gp *BaseTable) initConfPath() string {
|
||||
return configDir
|
||||
}
|
||||
|
||||
func (gp *BaseTable) Configs() map[string]string {
|
||||
return gp.mgr.Configs()
|
||||
}
|
||||
|
||||
// Load loads an object with @key.
|
||||
func (gp *BaseTable) Load(key string) (string, error) {
|
||||
return gp.mgr.GetConfig(key)
|
||||
@ -366,19 +382,13 @@ func ConvertRangeToIntSlice(rangeStr, sep string) []int {
|
||||
// InitLogCfg init log of the base table
|
||||
func (gp *BaseTable) InitLogCfg() {
|
||||
gp.Log = log.Config{}
|
||||
format, err := gp.Load("log.format")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
format := gp.LoadWithDefault("log.format", DefaultLogFormat)
|
||||
gp.Log.Format = format
|
||||
level, err := gp.Load("log.level")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
level := gp.LoadWithDefault("log.level", DefaultLogLevelForBase)
|
||||
gp.Log.Level = level
|
||||
gp.Log.File.MaxSize = gp.ParseInt("log.file.maxSize")
|
||||
gp.Log.File.MaxBackups = gp.ParseInt("log.file.maxBackups")
|
||||
gp.Log.File.MaxDays = gp.ParseInt("log.file.maxAge")
|
||||
gp.Log.File.MaxSize = gp.ParseIntWithDefault("log.file.maxSize", DefaultMaxSize)
|
||||
gp.Log.File.MaxBackups = gp.ParseIntWithDefault("log.file.maxBackups", DefaultMaxBackups)
|
||||
gp.Log.File.MaxDays = gp.ParseIntWithDefault("log.file.maxAge", DefaultMaxAge)
|
||||
}
|
||||
|
||||
// SetLogConfig set log config of the base table
|
||||
@ -398,10 +408,7 @@ func (gp *BaseTable) SetLogConfig() {
|
||||
|
||||
// SetLogger sets the logger file by given id
|
||||
func (gp *BaseTable) SetLogger(id UniqueID) {
|
||||
rootPath, err := gp.Load("log.file.rootPath")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rootPath := gp.LoadWithDefault("log.file.rootPath", DefaultRootPath)
|
||||
if rootPath != "" {
|
||||
if id < 0 {
|
||||
gp.Log.File.Filename = path.Join(rootPath, gp.RoleName+".log")
|
||||
|
@ -52,6 +52,7 @@ type ComponentParam struct {
|
||||
DataNodeCfg dataNodeConfig
|
||||
IndexCoordCfg indexCoordConfig
|
||||
IndexNodeCfg indexNodeConfig
|
||||
HookCfg HookConfig
|
||||
}
|
||||
|
||||
// InitOnce initialize once
|
||||
@ -76,6 +77,7 @@ func (p *ComponentParam) Init() {
|
||||
p.DataNodeCfg.init(&p.BaseTable)
|
||||
p.IndexCoordCfg.init(&p.BaseTable)
|
||||
p.IndexNodeCfg.init(&p.BaseTable)
|
||||
p.HookCfg.init()
|
||||
}
|
||||
|
||||
// SetLogConfig set log config with given role
|
||||
@ -431,7 +433,8 @@ type proxyConfig struct {
|
||||
IP string
|
||||
NetworkAddress string
|
||||
|
||||
Alias string
|
||||
Alias string
|
||||
SoPath string
|
||||
|
||||
NodeID atomic.Value
|
||||
TimeTickInterval time.Duration
|
||||
@ -475,6 +478,8 @@ func (p *proxyConfig) init(base *BaseTable) {
|
||||
p.initGinLogging()
|
||||
p.initMaxUserNum()
|
||||
p.initMaxRoleNum()
|
||||
|
||||
p.initSoPath()
|
||||
}
|
||||
|
||||
// InitAlias initialize Alias member.
|
||||
@ -482,6 +487,10 @@ func (p *proxyConfig) InitAlias(alias string) {
|
||||
p.Alias = alias
|
||||
}
|
||||
|
||||
func (p *proxyConfig) initSoPath() {
|
||||
p.SoPath = p.Base.LoadWithDefault("proxy.soPath", "")
|
||||
}
|
||||
|
||||
func (p *proxyConfig) initTimeTickInterval() {
|
||||
interval := p.Base.ParseIntWithDefault("proxy.timeTickInterval", 200)
|
||||
p.TimeTickInterval = time.Duration(interval) * time.Millisecond
|
||||
|
26
internal/util/paramtable/hook_config.go
Normal file
26
internal/util/paramtable/hook_config.go
Normal file
@ -0,0 +1,26 @@
|
||||
package paramtable
|
||||
|
||||
const hookYamlFile = "hook.yaml"
|
||||
|
||||
type HookConfig struct {
|
||||
Base *BaseTable
|
||||
SoPath string
|
||||
SoConfig map[string]string
|
||||
}
|
||||
|
||||
func (h *HookConfig) init() {
|
||||
h.Base = &BaseTable{YamlFile: hookYamlFile}
|
||||
h.Base.Init()
|
||||
|
||||
h.initSoPath()
|
||||
h.initSoConfig()
|
||||
}
|
||||
|
||||
func (h *HookConfig) initSoPath() {
|
||||
h.SoPath = h.Base.LoadWithDefault("soPath", "")
|
||||
}
|
||||
|
||||
func (h *HookConfig) initSoConfig() {
|
||||
// all keys have been set lower
|
||||
h.SoConfig = h.Base.Configs()
|
||||
}
|
Loading…
Reference in New Issue
Block a user