fix: cmux graceful shutdown on proxy service (#28383)

issue #28305

Signed-off-by: jaime <yun.zhang@zilliz.com>
This commit is contained in:
jaime 2023-11-27 16:28:34 +08:00 committed by GitHub
parent 911a915798
commit c5f455dc6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 13 deletions

View File

@ -30,6 +30,7 @@ import (
"sync"
"time"
"github.com/cockroachdb/errors"
"github.com/gin-gonic/gin"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
@ -80,6 +81,10 @@ var (
errInvalidToken = status.Errorf(codes.Unauthenticated, "invalid token")
// registerHTTPHandlerOnce avoid register http handler multiple times
registerHTTPHandlerOnce sync.Once
// only for test
enableCustomInterceptor = true
// only for test, register internal interface to external service
enableRegisterProxyServer = false
)
const apiPathPrefix = "/api/v1"
@ -232,12 +237,10 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
log.Debug("Get proxy rate limiter done", zap.Int("port", grpcPort))
opts := tracer.GetInterceptorOpts()
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
var unaryServerOption grpc.ServerOption
if enableCustomInterceptor {
unaryServerOption = grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
otelgrpc.UnaryServerInterceptor(opts...),
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
proxy.DatabaseInterceptor(),
@ -248,7 +251,17 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
accesslog.UnaryAccessLoggerInterceptor,
proxy.TraceLogInterceptor,
proxy.KeepActiveInterceptor,
)),
))
} else {
unaryServerOption = grpc.EmptyServerOption{}
}
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()),
unaryServerOption,
}
if Params.TLSMode.GetAsInt() == 1 {
@ -290,6 +303,11 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(tlsConf)))
}
s.grpcExternalServer = grpc.NewServer(grpcOpts...)
if enableRegisterProxyServer {
proxypb.RegisterProxyServer(s.grpcExternalServer, s)
}
milvuspb.RegisterMilvusServiceServer(s.grpcExternalServer, s)
grpc_health_v1.RegisterHealthServer(s.grpcExternalServer, s)
errChan <- nil
@ -390,7 +408,7 @@ func (s *Server) Run() error {
s.wg.Add(1)
go func() {
defer s.wg.Done()
if err := s.tcpServer.Serve(); err != nil && err != cmux.ErrServerClosed {
if err := s.tcpServer.Serve(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Warn("Proxy server for tcp port failed", zap.Error(err))
return
}
@ -651,11 +669,8 @@ func (s *Server) Stop() error {
go func() {
defer gracefulWg.Done()
if s.tcpServer != nil {
log.Info("Proxy stop tcp server...")
s.tcpServer.Close()
}
// try to close grpc server firstly, it has the same root listener with cmux server and
// http listener that tls has not been enabled.
if s.grpcExternalServer != nil {
log.Info("Proxy stop external grpc server")
utils.GracefulStopGRPCServer(s.grpcExternalServer)
@ -666,6 +681,17 @@ func (s *Server) Stop() error {
s.httpServer.Close()
}
// close cmux server, it isn't a synchronized operation.
// Note that:
// 1. all listeners can be closed after closing cmux server that has the root listener, it will automatically
// propagate the closure to all the listeners derived from it, but it doesn't provide a graceful shutdown
// grpc server ideally.
// 2. avoid resource leak also need to close cmux after grpc and http listener closed.
if s.tcpServer != nil {
log.Info("Proxy stop tcp server...")
s.tcpServer.Close()
}
if s.grpcInternalServer != nil {
log.Info("Proxy stop internal grpc server")
utils.GracefulStopGRPCServer(s.grpcInternalServer)

View File

@ -25,6 +25,7 @@ import (
"net/http/httptest"
"os"
"strconv"
"sync/atomic"
"testing"
"time"
@ -34,6 +35,7 @@ import (
"github.com/stretchr/testify/mock"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
@ -42,6 +44,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/federpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client"
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
@ -50,6 +53,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
milvusmock "github.com/milvus-io/milvus/internal/util/mock"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/uniquegenerator"
@ -1380,3 +1384,72 @@ func TestHttpAuthenticate(t *testing.T) {
assert.Equal(t, "foo", ctxName)
}
}
func Test_Service_GracefulStop(t *testing.T) {
mockedProxy := mocks.NewMockProxy(t)
var count int32
mockedProxy.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Run(func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) {
fmt.Println("rpc start")
time.Sleep(10 * time.Second)
atomic.AddInt32(&count, 1)
fmt.Println("rpc done")
}).Return(&milvuspb.ComponentStates{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil)
mockedProxy.EXPECT().Init().Return(nil)
mockedProxy.EXPECT().Start().Return(nil)
mockedProxy.EXPECT().Stop().Return(nil)
mockedProxy.EXPECT().Register().Return(nil)
mockedProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockedProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockedProxy.EXPECT().SetDataCoordClient(mock.Anything).Return()
mockedProxy.EXPECT().SetRootCoordClient(mock.Anything).Return()
mockedProxy.EXPECT().SetQueryCoordClient(mock.Anything).Return()
mockedProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
mockedProxy.EXPECT().SetAddress(mock.Anything).Return()
Params := &paramtable.Get().ProxyGrpcServerCfg
paramtable.Get().Save(Params.TLSMode.Key, "0")
paramtable.Get().Save(Params.Port.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort()))
paramtable.Get().Save(Params.InternalPort.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort()))
paramtable.Get().Save(Params.ServerPemPath.Key, "../../../configs/cert/server.pem")
paramtable.Get().Save(Params.ServerKeyPath.Key, "../../../configs/cert/server.key")
paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true")
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "")
ctx := context.Background()
enableCustomInterceptor = false
enableRegisterProxyServer = true
defer func() {
enableCustomInterceptor = true
enableRegisterProxyServer = false
}()
server := getServer(t)
assert.NotNil(t, server)
server.proxy = mockedProxy
err := server.Run()
assert.Nil(t, err)
proxyClient, err := grpcproxyclient.NewClient(ctx, fmt.Sprintf("localhost:%s", Params.Port.GetValue()), 0)
assert.Nil(t, err)
group := &errgroup.Group{}
for i := 0; i < 3; i++ {
group.Go(func() error {
_, err := proxyClient.GetComponentStates(context.TODO(), &milvuspb.GetComponentStatesRequest{})
return err
})
}
// waiting for all requests have been launched
time.Sleep(1 * time.Second)
server.Stop()
err = group.Wait()
assert.Nil(t, err)
assert.Equal(t, count, int32(3))
}