Implement the hook interceptor (#19294)

Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
SimFG 2022-09-23 09:50:52 +08:00 committed by GitHub
parent 901d3fb367
commit 68a257458b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 293 additions and 23 deletions

9
api/hook/hook.go Normal file
View File

@ -0,0 +1,9 @@
package hook
type Hook interface {
Init(params map[string]string) error
Mock(req interface{}, fullMethod string) (bool, interface{}, error)
Before(req interface{}, fullMethod string) error
After(result interface{}, err error, fullMethod string) error
Release()
}

0
configs/hook.yaml Normal file
View File

View File

@ -22,11 +22,7 @@ macro( build_rocksdb )
message( STATUS "Building ROCKSDB-${ROCKSDB_VERSION} from source" )
set ( ROCKSDB_MD5 "e4a0625f0cec82060e62c81b787a1124" )
if ( EMBEDDED_MILVUS )
message ( STATUS "Turning on fPIC while building embedded Milvus" )
set( FPIC_ARG "-DCMAKE_POSITION_INDEPENDENT_CODE=ON" )
endif()
set( FPIC_ARG "-DCMAKE_POSITION_INDEPENDENT_CODE=ON" )
set( ROCKSDB_CMAKE_ARGS
"-DWITH_GFLAGS=OFF"
"-DROCKSDB_BUILD_SHARED=OFF"

View File

@ -179,6 +179,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
proxy.UnaryServerHookInterceptor(),
proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor),
logutil.UnaryTraceLoggerInterceptor,
proxy.RateLimitInterceptor(limiter),

View File

@ -0,0 +1,95 @@
package proxy
import (
"context"
"plugin"
"github.com/milvus-io/milvus/api/hook"
"go.uber.org/zap"
"google.golang.org/grpc"
)
type defaultHook struct {
}
func (d defaultHook) Init(params map[string]string) error {
return nil
}
func (d defaultHook) Mock(req interface{}, fullMethod string) (bool, interface{}, error) {
return false, nil, nil
}
func (d defaultHook) Before(req interface{}, fullMethod string) error {
return nil
}
func (d defaultHook) After(result interface{}, err error, fullMethod string) error {
return nil
}
func (d defaultHook) Release() {}
var hoo hook.Hook
func initHook() {
path := Params.ProxyCfg.SoPath
if path == "" {
hoo = defaultHook{}
return
}
logger.Debug("start to load plugin", zap.String("path", path))
p, err := plugin.Open(path)
if err != nil {
exit("fail to open the plugin", err)
}
logger.Debug("plugin open")
h, err := p.Lookup("MilvusHook")
if err != nil {
exit("fail to the 'MilvusHook' object in the plugin", err)
}
var ok bool
hoo, ok = h.(hook.Hook)
if !ok {
exit("fail to convert the `Hook` interface", nil)
}
if err = hoo.Init(Params.HookCfg.SoConfig); err != nil {
exit("fail to init configs for the hoo", err)
}
}
func exit(errMsg string, err error) {
logger.Panic("hoo error", zap.String("path", Params.ProxyCfg.SoPath), zap.String("msg", errMsg), zap.Error(err))
}
func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
initHook()
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var (
fullMethod = info.FullMethod
isMock bool
mockResp interface{}
realResp interface{}
realErr error
err error
)
if isMock, mockResp, err = hoo.Mock(req, fullMethod); isMock {
return mockResp, err
}
if err = hoo.Before(req, fullMethod); err != nil {
return nil, err
}
realResp, realErr = handler(ctx, req)
if err = hoo.After(realResp, realErr, fullMethod); err != nil {
return nil, err
}
return realResp, realErr
}
}

View File

@ -0,0 +1,127 @@
package proxy
import (
"context"
"errors"
"testing"
"google.golang.org/grpc"
"github.com/stretchr/testify/assert"
)
func TestInitHook(t *testing.T) {
Params.ProxyCfg.SoPath = ""
initHook()
assert.IsType(t, defaultHook{}, hoo)
Params.ProxyCfg.SoPath = "/a/b/hook.so"
assert.Panics(t, func() {
initHook()
})
Params.ProxyCfg.SoPath = ""
}
type mockHook struct {
defaultHook
mockRes interface{}
mockErr error
}
func (m mockHook) Mock(req interface{}, fullMethod string) (bool, interface{}, error) {
return true, m.mockRes, m.mockErr
}
type req struct {
method string
}
type beforeMock struct {
defaultHook
method string
err error
}
func (b beforeMock) Before(r interface{}, fullMethod string) error {
re, ok := r.(*req)
if !ok {
return errors.New("r is invalid type")
}
re.method = b.method
return b.err
}
type resp struct {
method string
}
type afterMock struct {
defaultHook
method string
err error
}
func (a afterMock) After(r interface{}, err error, fullMethod string) error {
re, ok := r.(*resp)
if !ok {
return errors.New("r is invalid type")
}
re.method = a.method
return a.err
}
func TestHookInterceptor(t *testing.T) {
var (
ctx = context.Background()
info = &grpc.UnaryServerInfo{
FullMethod: "test",
}
interceptor = UnaryServerHookInterceptor()
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
r = &req{method: "req"}
re = &resp{method: "resp"}
beforeHoo = beforeMock{method: "before", err: errors.New("before")}
afterHoo = afterMock{method: "after", err: errors.New("after")}
res interface{}
err error
)
hoo = mockHoo
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, res, mockHoo.mockRes)
assert.Equal(t, err, mockHoo.mockErr)
hoo = beforeHoo
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)
hoo = afterHoo
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return re, nil
})
assert.Equal(t, re.method, afterHoo.method)
assert.Equal(t, err, afterHoo.err)
hoo = defaultHook{}
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return &resp{
method: r.(*req).method,
}, nil
})
assert.Equal(t, res.(*resp).method, r.method)
assert.NoError(t, err)
}
func TestDefaultHook(t *testing.T) {
d := defaultHook{}
assert.NoError(t, d.Init(nil))
assert.NotPanics(t, func() {
d.Release()
})
}

View File

@ -45,6 +45,13 @@ const (
DefaultEtcdEndpoints = "localhost:2379"
DefaultInsertBufferSize = "16777216"
DefaultEnvPrefix = "milvus"
DefaultLogFormat = "text"
DefaultLogLevelForBase = "debug"
DefaultRootPath = ""
DefaultMaxSize = 300
DefaultMaxAge = 10
DefaultMaxBackups = 20
)
var defaultYaml = DefaultMilvusYaml
@ -71,6 +78,8 @@ type BaseTable struct {
RoleName string
Log log.Config
LogCfgFunc func(log.Config)
YamlFile string
}
// GlobalInitWithYaml initializes the param table with the given yaml.
@ -94,6 +103,9 @@ func (gp *BaseTable) Init() {
ret = strings.ReplaceAll(ret, ".", "")
return ret
}
if gp.YamlFile == "" {
gp.YamlFile = defaultYaml
}
gp.initConfigsFromLocal(formatter)
gp.initConfigsFromRemote(formatter)
gp.InitLogCfg()
@ -107,7 +119,7 @@ func (gp *BaseTable) initConfigsFromLocal(formatter func(key string) string) {
}
gp.configDir = gp.initConfPath()
configFilePath := gp.configDir + "/" + defaultYaml
configFilePath := gp.configDir + "/" + gp.YamlFile
gp.mgr, err = config.Init(config.WithEnvSource(formatter), config.WithFilesSource(configFilePath))
if err != nil {
log.Warn("init baseTable with file failed", zap.String("configFile", configFilePath), zap.Error(err))
@ -127,7 +139,7 @@ func (gp *BaseTable) initConfigsFromRemote(formatter func(key string) string) {
return
}
configFilePath := gp.configDir + "/" + defaultYaml
configFilePath := gp.configDir + "/" + gp.YamlFile
gp.mgr, err = config.Init(config.WithEnvSource(formatter),
config.WithFilesSource(configFilePath),
config.WithEtcdSource(&config.EtcdInfo{
@ -164,6 +176,10 @@ func (gp *BaseTable) initConfPath() string {
return configDir
}
func (gp *BaseTable) Configs() map[string]string {
return gp.mgr.Configs()
}
// Load loads an object with @key.
func (gp *BaseTable) Load(key string) (string, error) {
return gp.mgr.GetConfig(key)
@ -366,19 +382,13 @@ func ConvertRangeToIntSlice(rangeStr, sep string) []int {
// InitLogCfg init log of the base table
func (gp *BaseTable) InitLogCfg() {
gp.Log = log.Config{}
format, err := gp.Load("log.format")
if err != nil {
panic(err)
}
format := gp.LoadWithDefault("log.format", DefaultLogFormat)
gp.Log.Format = format
level, err := gp.Load("log.level")
if err != nil {
panic(err)
}
level := gp.LoadWithDefault("log.level", DefaultLogLevelForBase)
gp.Log.Level = level
gp.Log.File.MaxSize = gp.ParseInt("log.file.maxSize")
gp.Log.File.MaxBackups = gp.ParseInt("log.file.maxBackups")
gp.Log.File.MaxDays = gp.ParseInt("log.file.maxAge")
gp.Log.File.MaxSize = gp.ParseIntWithDefault("log.file.maxSize", DefaultMaxSize)
gp.Log.File.MaxBackups = gp.ParseIntWithDefault("log.file.maxBackups", DefaultMaxBackups)
gp.Log.File.MaxDays = gp.ParseIntWithDefault("log.file.maxAge", DefaultMaxAge)
}
// SetLogConfig set log config of the base table
@ -398,10 +408,7 @@ func (gp *BaseTable) SetLogConfig() {
// SetLogger sets the logger file by given id
func (gp *BaseTable) SetLogger(id UniqueID) {
rootPath, err := gp.Load("log.file.rootPath")
if err != nil {
panic(err)
}
rootPath := gp.LoadWithDefault("log.file.rootPath", DefaultRootPath)
if rootPath != "" {
if id < 0 {
gp.Log.File.Filename = path.Join(rootPath, gp.RoleName+".log")

View File

@ -52,6 +52,7 @@ type ComponentParam struct {
DataNodeCfg dataNodeConfig
IndexCoordCfg indexCoordConfig
IndexNodeCfg indexNodeConfig
HookCfg HookConfig
}
// InitOnce initialize once
@ -76,6 +77,7 @@ func (p *ComponentParam) Init() {
p.DataNodeCfg.init(&p.BaseTable)
p.IndexCoordCfg.init(&p.BaseTable)
p.IndexNodeCfg.init(&p.BaseTable)
p.HookCfg.init()
}
// SetLogConfig set log config with given role
@ -431,7 +433,8 @@ type proxyConfig struct {
IP string
NetworkAddress string
Alias string
Alias string
SoPath string
NodeID atomic.Value
TimeTickInterval time.Duration
@ -475,6 +478,8 @@ func (p *proxyConfig) init(base *BaseTable) {
p.initGinLogging()
p.initMaxUserNum()
p.initMaxRoleNum()
p.initSoPath()
}
// InitAlias initialize Alias member.
@ -482,6 +487,10 @@ func (p *proxyConfig) InitAlias(alias string) {
p.Alias = alias
}
func (p *proxyConfig) initSoPath() {
p.SoPath = p.Base.LoadWithDefault("proxy.soPath", "")
}
func (p *proxyConfig) initTimeTickInterval() {
interval := p.Base.ParseIntWithDefault("proxy.timeTickInterval", 200)
p.TimeTickInterval = time.Duration(interval) * time.Millisecond

View File

@ -0,0 +1,26 @@
package paramtable
const hookYamlFile = "hook.yaml"
type HookConfig struct {
Base *BaseTable
SoPath string
SoConfig map[string]string
}
func (h *HookConfig) init() {
h.Base = &BaseTable{YamlFile: hookYamlFile}
h.Base.Init()
h.initSoPath()
h.initSoConfig()
}
func (h *HookConfig) initSoPath() {
h.SoPath = h.Base.LoadWithDefault("soPath", "")
}
func (h *HookConfig) initSoConfig() {
// all keys have been set lower
h.SoConfig = h.Base.Configs()
}