From b9a10a2f689afa0ee61011ac171bad7df4b3e16d Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Sat, 23 Nov 2024 21:36:33 +0800 Subject: [PATCH] enhance: remove the rpc layer of coordinator when enabling standalone or mixcoord (#37815) issue: #37764 - add a local client to call local server directly for querycoord/rootcoord/datacoord. - enable local client if milvus is running mixcoord or standalone mode. --------- Signed-off-by: chyezh --- cmd/milvus/util.go | 8 +- internal/coordinator/coordclient/registry.go | 158 ++++++++++++++++++ .../coordinator/coordclient/registry_test.go | 74 ++++++++ internal/datacoord/server.go | 4 +- internal/distributed/datacoord/service.go | 2 + internal/distributed/querycoord/service.go | 16 +- internal/distributed/rootcoord/service.go | 29 +--- .../distributed/rootcoord/service_test.go | 12 +- internal/util/grpcclient/local_grpc_client.go | 68 ++++++++ .../util/grpcclient/local_grpc_client_test.go | 46 +++++ pkg/util/syncutil/future.go | 15 ++ 11 files changed, 390 insertions(+), 42 deletions(-) create mode 100644 internal/coordinator/coordclient/registry.go create mode 100644 internal/coordinator/coordclient/registry_test.go create mode 100644 internal/util/grpcclient/local_grpc_client.go create mode 100644 internal/util/grpcclient/local_grpc_client_test.go diff --git a/cmd/milvus/util.go b/cmd/milvus/util.go index e7dcb20353..c21fe4ec28 100644 --- a/cmd/milvus/util.go +++ b/cmd/milvus/util.go @@ -20,6 +20,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/cmd/roles" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" @@ -179,7 +180,12 @@ func GetMilvusRoles(args []string, flags *flag.FlagSet) *roles.MilvusRoles { fmt.Fprintf(os.Stderr, "Unknown server type = %s\n%s", serverType, getHelp()) os.Exit(-1) } - + coordclient.EnableLocalClientRole(&coordclient.LocalClientRoleConfig{ + ServerType: serverType, + EnableQueryCoord: role.EnableQueryCoord, + EnableDataCoord: role.EnableDataCoord, + EnableRootCoord: role.EnableRootCoord, + }) return role } diff --git a/internal/coordinator/coordclient/registry.go b/internal/coordinator/coordclient/registry.go new file mode 100644 index 0000000000..f143a29a9f --- /dev/null +++ b/internal/coordinator/coordclient/registry.go @@ -0,0 +1,158 @@ +package coordclient + +import ( + "context" + "fmt" + + "go.uber.org/zap" + + dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" + qcc "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// localClient is a client that can access local server directly +type localClient struct { + queryCoordClient *syncutil.Future[types.QueryCoordClient] + dataCoordClient *syncutil.Future[types.DataCoordClient] + rootCoordClient *syncutil.Future[types.RootCoordClient] +} + +var ( + enableLocal *LocalClientRoleConfig // a global map to store all can be local accessible roles. + glocalClient *localClient // !!! WARNING: local client will ignore all interceptor of grpc client and server. +) + +func init() { + enableLocal = &LocalClientRoleConfig{} + glocalClient = &localClient{ + queryCoordClient: syncutil.NewFuture[types.QueryCoordClient](), + dataCoordClient: syncutil.NewFuture[types.DataCoordClient](), + rootCoordClient: syncutil.NewFuture[types.RootCoordClient](), + } +} + +type LocalClientRoleConfig struct { + ServerType string + EnableQueryCoord bool + EnableDataCoord bool + EnableRootCoord bool +} + +// EnableLocalClientRole init localable roles +func EnableLocalClientRole(cfg *LocalClientRoleConfig) { + if cfg.ServerType != typeutil.StandaloneRole && cfg.ServerType != typeutil.MixtureRole { + return + } + enableLocal = cfg +} + +// RegisterQueryCoordServer register query coord server +func RegisterQueryCoordServer(server querypb.QueryCoordServer) { + if !enableLocal.EnableQueryCoord { + return + } + newLocalClient := grpcclient.NewLocalGRPCClient(&querypb.QueryCoord_ServiceDesc, server, querypb.NewQueryCoordClient) + glocalClient.queryCoordClient.Set(&nopCloseQueryCoordClient{newLocalClient}) + log.Info("register query coord server", zap.Any("enableLocalClient", enableLocal)) +} + +// RegsterDataCoordServer register data coord server +func RegisterDataCoordServer(server datapb.DataCoordServer) { + if !enableLocal.EnableDataCoord { + return + } + newLocalClient := grpcclient.NewLocalGRPCClient(&datapb.DataCoord_ServiceDesc, server, datapb.NewDataCoordClient) + glocalClient.dataCoordClient.Set(&nopCloseDataCoordClient{newLocalClient}) + log.Info("register data coord server", zap.Any("enableLocalClient", enableLocal)) +} + +// RegisterRootCoordServer register root coord server +func RegisterRootCoordServer(server rootcoordpb.RootCoordServer) { + if !enableLocal.EnableRootCoord { + return + } + newLocalClient := grpcclient.NewLocalGRPCClient(&rootcoordpb.RootCoord_ServiceDesc, server, rootcoordpb.NewRootCoordClient) + glocalClient.rootCoordClient.Set(&nopCloseRootCoordClient{newLocalClient}) + log.Info("register root coord server", zap.Any("enableLocalClient", enableLocal)) +} + +// GetQueryCoordClient return query coord client +func GetQueryCoordClient(ctx context.Context) types.QueryCoordClient { + var client types.QueryCoordClient + var err error + if enableLocal.EnableQueryCoord { + client, err = glocalClient.queryCoordClient.GetWithContext(ctx) + } else { + // TODO: we should make a singleton here. but most unittest rely on a dedicated client. + client, err = qcc.NewClient(ctx) + } + if err != nil { + panic(fmt.Sprintf("get query coord client failed: %v", err)) + } + return client +} + +// GetDataCoordClient return data coord client +func GetDataCoordClient(ctx context.Context) types.DataCoordClient { + var client types.DataCoordClient + var err error + if enableLocal.EnableDataCoord { + client, err = glocalClient.dataCoordClient.GetWithContext(ctx) + } else { + // TODO: we should make a singleton here. but most unittest rely on a dedicated client. + client, err = dcc.NewClient(ctx) + } + if err != nil { + panic(fmt.Sprintf("get data coord client failed: %v", err)) + } + return client +} + +// GetRootCoordClient return root coord client +func GetRootCoordClient(ctx context.Context) types.RootCoordClient { + var client types.RootCoordClient + var err error + if enableLocal.EnableRootCoord { + client, err = glocalClient.rootCoordClient.GetWithContext(ctx) + } else { + // TODO: we should make a singleton here. but most unittest rely on a dedicated client. + client, err = rcc.NewClient(ctx) + } + if err != nil { + panic(fmt.Sprintf("get root coord client failed: %v", err)) + } + return client +} + +type nopCloseQueryCoordClient struct { + querypb.QueryCoordClient +} + +func (n *nopCloseQueryCoordClient) Close() error { + return nil +} + +type nopCloseDataCoordClient struct { + datapb.DataCoordClient +} + +func (n *nopCloseDataCoordClient) Close() error { + return nil +} + +type nopCloseRootCoordClient struct { + rootcoordpb.RootCoordClient +} + +func (n *nopCloseRootCoordClient) Close() error { + return nil +} diff --git a/internal/coordinator/coordclient/registry_test.go b/internal/coordinator/coordclient/registry_test.go new file mode 100644 index 0000000000..8ed97ac3d5 --- /dev/null +++ b/internal/coordinator/coordclient/registry_test.go @@ -0,0 +1,74 @@ +package coordclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestRegistry(t *testing.T) { + assert.False(t, enableLocal.EnableQueryCoord) + assert.False(t, enableLocal.EnableDataCoord) + assert.False(t, enableLocal.EnableRootCoord) + + EnableLocalClientRole(&LocalClientRoleConfig{ + ServerType: typeutil.RootCoordRole, + EnableQueryCoord: true, + EnableDataCoord: true, + EnableRootCoord: true, + }) + assert.False(t, enableLocal.EnableQueryCoord) + assert.False(t, enableLocal.EnableDataCoord) + assert.False(t, enableLocal.EnableRootCoord) + + RegisterRootCoordServer(&rootcoordpb.UnimplementedRootCoordServer{}) + RegisterDataCoordServer(&datapb.UnimplementedDataCoordServer{}) + RegisterQueryCoordServer(&querypb.UnimplementedQueryCoordServer{}) + assert.False(t, glocalClient.dataCoordClient.Ready()) + assert.False(t, glocalClient.queryCoordClient.Ready()) + assert.False(t, glocalClient.rootCoordClient.Ready()) + + enableLocal = &LocalClientRoleConfig{} + + EnableLocalClientRole(&LocalClientRoleConfig{ + ServerType: typeutil.StandaloneRole, + EnableQueryCoord: true, + EnableDataCoord: true, + EnableRootCoord: true, + }) + assert.True(t, enableLocal.EnableDataCoord) + assert.True(t, enableLocal.EnableQueryCoord) + assert.True(t, enableLocal.EnableRootCoord) + + RegisterRootCoordServer(&rootcoordpb.UnimplementedRootCoordServer{}) + RegisterDataCoordServer(&datapb.UnimplementedDataCoordServer{}) + RegisterQueryCoordServer(&querypb.UnimplementedQueryCoordServer{}) + assert.True(t, glocalClient.dataCoordClient.Ready()) + assert.True(t, glocalClient.queryCoordClient.Ready()) + assert.True(t, glocalClient.rootCoordClient.Ready()) + + enableLocal = &LocalClientRoleConfig{} + + EnableLocalClientRole(&LocalClientRoleConfig{ + ServerType: typeutil.MixtureRole, + EnableQueryCoord: true, + EnableDataCoord: true, + EnableRootCoord: true, + }) + assert.True(t, enableLocal.EnableDataCoord) + assert.True(t, enableLocal.EnableQueryCoord) + assert.True(t, enableLocal.EnableRootCoord) + + assert.NotNil(t, GetQueryCoordClient(context.Background())) + assert.NotNil(t, GetDataCoordClient(context.Background())) + assert.NotNil(t, GetRootCoordClient(context.Background())) + GetQueryCoordClient(context.Background()).Close() + GetDataCoordClient(context.Background()).Close() + GetRootCoordClient(context.Background()).Close() +} diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index e1ccd37c70..0d6f572018 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -38,12 +38,12 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" globalIDAllocator "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/broker" "github.com/milvus-io/milvus/internal/datacoord/session" datanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client" - rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/kv/tikv" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" @@ -237,7 +237,7 @@ func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) } func defaultRootCoordCreatorFunc(ctx context.Context) (types.RootCoordClient, error) { - return rootcoordclient.NewClient(ctx) + return coordclient.GetRootCoordClient(ctx), nil } // QuitSignal returns signal when server quits diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index ee17f8c0d3..40a43841cd 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/datacoord" "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -212,6 +213,7 @@ func (s *Server) startGrpcLoop() { if streamingutil.IsStreamingServiceEnabled() { s.dataCoord.RegisterStreamingCoordGRPCService(s.grpcServer) } + coordclient.RegisterDataCoordServer(s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(s.listener); err != nil { s.grpcErrChan <- err diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index 3dbca5bc69..3273e6de28 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -31,8 +31,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" - rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -169,11 +168,7 @@ func (s *Server) init() error { // --- Master Server Client --- if s.rootCoord == nil { - s.rootCoord, err = rcc.NewClient(s.loopCtx) - if err != nil { - log.Error("QueryCoord try to new RootCoord client failed", zap.Error(err)) - panic(err) - } + s.rootCoord = coordclient.GetRootCoordClient(s.loopCtx) } // wait for master init or healthy @@ -191,11 +186,7 @@ func (s *Server) init() error { // --- Data service client --- if s.dataCoord == nil { - s.dataCoord, err = dcc.NewClient(s.loopCtx) - if err != nil { - log.Error("QueryCoord try to new DataCoord client failed", zap.Error(err)) - panic(err) - } + s.dataCoord = coordclient.GetDataCoordClient(s.loopCtx) } log.Info("QueryCoord try to wait for DataCoord ready") @@ -261,6 +252,7 @@ func (s *Server) startGrpcLoop() { grpcOpts = append(grpcOpts, utils.EnableInternalTLS("QueryCoord")) s.grpcServer = grpc.NewServer(grpcOpts...) querypb.RegisterQueryCoordServer(s.grpcServer, s) + coordclient.RegisterQueryCoordServer(s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(s.listener); err != nil { diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index d329a1bc1f..c1e0241a5e 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -31,8 +31,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" - qcc "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" @@ -72,8 +71,8 @@ type Server struct { dataCoord types.DataCoordClient queryCoord types.QueryCoordClient - newDataCoordClient func() types.DataCoordClient - newQueryCoordClient func() types.QueryCoordClient + newDataCoordClient func(ctx context.Context) types.DataCoordClient + newQueryCoordClient func(ctx context.Context) types.QueryCoordClient } func (s *Server) DescribeDatabase(ctx context.Context, request *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) { @@ -157,21 +156,8 @@ func (s *Server) Prepare() error { } func (s *Server) setClient() { - s.newDataCoordClient = func() types.DataCoordClient { - dsClient, err := dcc.NewClient(s.ctx) - if err != nil { - panic(err) - } - return dsClient - } - - s.newQueryCoordClient = func() types.QueryCoordClient { - qsClient, err := qcc.NewClient(s.ctx) - if err != nil { - panic(err) - } - return qsClient - } + s.newDataCoordClient = coordclient.GetDataCoordClient + s.newQueryCoordClient = coordclient.GetQueryCoordClient } // Run initializes and starts RootCoord's grpc service. @@ -234,7 +220,7 @@ func (s *Server) init() error { if s.newDataCoordClient != nil { log.Info("RootCoord start to create DataCoord client") - dataCoord := s.newDataCoordClient() + dataCoord := s.newDataCoordClient(s.ctx) s.dataCoord = dataCoord if err := s.rootCoord.SetDataCoordClient(dataCoord); err != nil { panic(err) @@ -243,7 +229,7 @@ func (s *Server) init() error { if s.newQueryCoordClient != nil { log.Info("RootCoord start to create QueryCoord client") - queryCoord := s.newQueryCoordClient() + queryCoord := s.newQueryCoordClient(s.ctx) s.queryCoord = queryCoord if err := s.rootCoord.SetQueryCoordClient(queryCoord); err != nil { panic(err) @@ -309,6 +295,7 @@ func (s *Server) startGrpcLoop() { grpcOpts = append(grpcOpts, utils.EnableInternalTLS("RootCoord")) s.grpcServer = grpc.NewServer(grpcOpts...) rootcoordpb.RegisterRootCoordServer(s.grpcServer, s) + coordclient.RegisterRootCoordServer(s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(s.listener); err != nil { diff --git a/internal/distributed/rootcoord/service_test.go b/internal/distributed/rootcoord/service_test.go index 43c9ba6ba0..5965b47883 100644 --- a/internal/distributed/rootcoord/service_test.go +++ b/internal/distributed/rootcoord/service_test.go @@ -142,13 +142,13 @@ func TestRun(t *testing.T) { mockDataCoord := mocks.NewMockDataCoordClient(t) mockDataCoord.EXPECT().Close().Return(nil) - svr.newDataCoordClient = func() types.DataCoordClient { + svr.newDataCoordClient = func(_ context.Context) types.DataCoordClient { return mockDataCoord } mockQueryCoord := mocks.NewMockQueryCoordClient(t) mockQueryCoord.EXPECT().Close().Return(nil) - svr.newQueryCoordClient = func() types.QueryCoordClient { + svr.newQueryCoordClient = func(_ context.Context) types.QueryCoordClient { return mockQueryCoord } @@ -238,7 +238,7 @@ func TestServerRun_DataCoordClientInitErr(t *testing.T) { mockDataCoord := mocks.NewMockDataCoordClient(t) mockDataCoord.EXPECT().Close().Return(nil) - server.newDataCoordClient = func() types.DataCoordClient { + server.newDataCoordClient = func(_ context.Context) types.DataCoordClient { return mockDataCoord } err = server.Prepare() @@ -268,7 +268,7 @@ func TestServerRun_DataCoordClientStartErr(t *testing.T) { mockDataCoord := mocks.NewMockDataCoordClient(t) mockDataCoord.EXPECT().Close().Return(nil) - server.newDataCoordClient = func() types.DataCoordClient { + server.newDataCoordClient = func(_ context.Context) types.DataCoordClient { return mockDataCoord } err = server.Prepare() @@ -298,7 +298,7 @@ func TestServerRun_QueryCoordClientInitErr(t *testing.T) { mockQueryCoord := mocks.NewMockQueryCoordClient(t) mockQueryCoord.EXPECT().Close().Return(nil) - server.newQueryCoordClient = func() types.QueryCoordClient { + server.newQueryCoordClient = func(_ context.Context) types.QueryCoordClient { return mockQueryCoord } err = server.Prepare() @@ -328,7 +328,7 @@ func TestServer_QueryCoordClientStartErr(t *testing.T) { mockQueryCoord := mocks.NewMockQueryCoordClient(t) mockQueryCoord.EXPECT().Close().Return(nil) - server.newQueryCoordClient = func() types.QueryCoordClient { + server.newQueryCoordClient = func(_ context.Context) types.QueryCoordClient { return mockQueryCoord } err = server.Prepare() diff --git a/internal/util/grpcclient/local_grpc_client.go b/internal/util/grpcclient/local_grpc_client.go new file mode 100644 index 0000000000..21afecf061 --- /dev/null +++ b/internal/util/grpcclient/local_grpc_client.go @@ -0,0 +1,68 @@ +package grpcclient + +import ( + "context" + "fmt" + "reflect" + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var _ grpc.ClientConnInterface = &localConn{} + +// NewLocalGRPCClient creates a grpc client that calls the server directly. +// !!! Warning: it didn't make any network or serialization/deserialization, so it's not promise concurrent safe. +// and there's no interceptor for client and server like the common grpc client/server. +func NewLocalGRPCClient[C any, S any](desc *grpc.ServiceDesc, server S, clientCreator func(grpc.ClientConnInterface) C) C { + return clientCreator(&localConn{ + serviceDesc: desc, + server: server, + }) +} + +// localConn is a grpc.ClientConnInterface implementation that calls the server directly. +type localConn struct { + serviceDesc *grpc.ServiceDesc // ServiceDesc is the descriptor for this service. + server interface{} // the server object. +} + +// Invoke calls the server method directly. +func (c *localConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + methodDesc := c.findMethod(method) + if methodDesc == nil { + return status.Errorf(codes.Unimplemented, fmt.Sprintf("method %s not implemented", method)) + } + resp, err := methodDesc.Handler(c.server, ctx, func(in any) error { + reflect.ValueOf(in).Elem().Set(reflect.ValueOf(args).Elem()) + return nil + }, nil) + if err != nil { + return err + } + reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(resp).Elem()) + return nil +} + +// NewStream is not supported by now, wait for implementation. +func (c *localConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + panic("we don't support local stream rpc by now") +} + +// findMethod finds the method descriptor by the full method name. +func (c *localConn) findMethod(fullMethodName string) *grpc.MethodDesc { + strs := strings.SplitN(fullMethodName[1:], "/", 2) + serviceName := strs[0] + methodName := strs[1] + if c.serviceDesc.ServiceName != serviceName { + return nil + } + for i := range c.serviceDesc.Methods { + if c.serviceDesc.Methods[i].MethodName == methodName { + return &c.serviceDesc.Methods[i] + } + } + return nil +} diff --git a/internal/util/grpcclient/local_grpc_client_test.go b/internal/util/grpcclient/local_grpc_client_test.go new file mode 100644 index 0000000000..bcd59e62a9 --- /dev/null +++ b/internal/util/grpcclient/local_grpc_client_test.go @@ -0,0 +1,46 @@ +package grpcclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" +) + +type mockRootCoordServer struct { + t *testing.T + *rootcoordpb.UnimplementedRootCoordServer +} + +func (s *mockRootCoordServer) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { + assert.NotNil(s.t, req) + assert.Equal(s.t, uint32(100), req.Count) + return &rootcoordpb.AllocIDResponse{ + ID: 1, + Count: 2, + }, nil +} + +func TestLocalGRPCClient(t *testing.T) { + localClient := NewLocalGRPCClient( + &rootcoordpb.RootCoord_ServiceDesc, + &mockRootCoordServer{ + t: t, + UnimplementedRootCoordServer: &rootcoordpb.UnimplementedRootCoordServer{}, + }, + rootcoordpb.NewRootCoordClient, + ) + result, err := localClient.AllocTimestamp(context.Background(), &rootcoordpb.AllocTimestampRequest{}) + assert.Error(t, err) + assert.Nil(t, result) + + result2, err := localClient.AllocID(context.Background(), &rootcoordpb.AllocIDRequest{ + Count: 100, + }) + assert.NoError(t, err) + assert.NotNil(t, result2) + assert.Equal(t, int64(1), result2.ID) + assert.Equal(t, uint32(2), result2.Count) +} diff --git a/pkg/util/syncutil/future.go b/pkg/util/syncutil/future.go index cbeac95bec..f13c40f58e 100644 --- a/pkg/util/syncutil/future.go +++ b/pkg/util/syncutil/future.go @@ -1,5 +1,9 @@ package syncutil +import ( + "context" +) + // Future is a future value that can be set and retrieved. type Future[T any] struct { ch chan struct{} @@ -19,6 +23,17 @@ func (f *Future[T]) Set(value T) { close(f.ch) } +// GetWithContext retrieves the value of the future if set, otherwise block until set or the context is done. +func (f *Future[T]) GetWithContext(ctx context.Context) (T, error) { + select { + case <-ctx.Done(): + var val T + return val, ctx.Err() + case <-f.ch: + return f.value, nil + } +} + // Get retrieves the value of the future if set, otherwise block until set. func (f *Future[T]) Get() T { <-f.ch