mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
enhance: update shard leader cache when leader location changed (#32470)
issue: #32466 this PR enhance that when shard location changed, update proxy's shard leader cache. in case of query node failover case, proxy can find replica recover --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
5038036ece
commit
ba02d54a30
@ -216,3 +216,9 @@ func (c *Client) ListImports(ctx context.Context, req *internalpb.ListImportsReq
|
||||
return client.ListImports(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) InvalidateShardLeaderCache(ctx context.Context, req *proxypb.InvalidateShardLeaderCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return wrapGrpcCall(ctx, c, func(client proxypb.ProxyClient) (*commonpb.Status, error) {
|
||||
return client.InvalidateShardLeaderCache(ctx, req)
|
||||
})
|
||||
}
|
||||
|
@ -462,3 +462,40 @@ func Test_ImportV2(t *testing.T) {
|
||||
_, err = client.ListImports(ctx, &internalpb.ListImportsRequest{})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func Test_InvalidateShardLeaderCache(t *testing.T) {
|
||||
paramtable.Init()
|
||||
|
||||
ctx := context.Background()
|
||||
client, err := NewClient(ctx, "test", 1)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
defer client.Close()
|
||||
|
||||
mockProxy := mocks.NewMockProxyClient(t)
|
||||
mockGrpcClient := mocks.NewMockGrpcClient[proxypb.ProxyClient](t)
|
||||
mockGrpcClient.EXPECT().Close().Return(nil)
|
||||
mockGrpcClient.EXPECT().ReCall(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, f func(proxypb.ProxyClient) (interface{}, error)) (interface{}, error) {
|
||||
return f(mockProxy)
|
||||
})
|
||||
client.(*Client).grpcClient = mockGrpcClient
|
||||
|
||||
// test success
|
||||
mockProxy.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
_, err = client.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// test return error code
|
||||
mockProxy.ExpectedCalls = nil
|
||||
mockProxy.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
|
||||
|
||||
_, err = client.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// test ctx done
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
|
||||
defer cancel()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
_, err = client.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
@ -1229,3 +1229,7 @@ func (s *Server) ListImports(ctx context.Context, req *internalpb.ListImportsReq
|
||||
func (s *Server) AlterDatabase(ctx context.Context, req *milvuspb.AlterDatabaseRequest) (*commonpb.Status, error) {
|
||||
return s.proxy.AlterDatabase(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) InvalidateShardLeaderCache(ctx context.Context, req *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error) {
|
||||
return s.proxy.InvalidateShardLeaderCache(ctx, req)
|
||||
}
|
||||
|
@ -229,6 +229,12 @@ func Test_NewServer(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("InvalidateShardLeaderCache", func(t *testing.T) {
|
||||
mockProxy.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(nil, nil)
|
||||
_, err := server.InvalidateShardLeaderCache(ctx, nil)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateCollection", func(t *testing.T) {
|
||||
mockProxy.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(nil, nil)
|
||||
_, err := server.CreateCollection(ctx, nil)
|
||||
|
@ -3634,6 +3634,61 @@ func (_c *MockProxy_InvalidateCredentialCache_Call) RunAndReturn(run func(contex
|
||||
return _c
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockProxy) InvalidateShardLeaderCache(_a0 context.Context, _a1 *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockProxy_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
|
||||
type MockProxy_InvalidateShardLeaderCache_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *proxypb.InvalidateShardLeaderCacheRequest
|
||||
func (_e *MockProxy_Expecter) InvalidateShardLeaderCache(_a0 interface{}, _a1 interface{}) *MockProxy_InvalidateShardLeaderCache_Call {
|
||||
return &MockProxy_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockProxy_InvalidateShardLeaderCache_Call) Run(run func(_a0 context.Context, _a1 *proxypb.InvalidateShardLeaderCacheRequest)) *MockProxy_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*proxypb.InvalidateShardLeaderCacheRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProxy_InvalidateShardLeaderCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxy_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProxy_InvalidateShardLeaderCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error)) *MockProxy_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListAliases provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockProxy) ListAliases(_a0 context.Context, _a1 *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
@ -632,6 +632,76 @@ func (_c *MockProxyClient_InvalidateCredentialCache_Call) RunAndReturn(run func(
|
||||
return _c
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockProxyClient) InvalidateShardLeaderCache(ctx context.Context, in *proxypb.InvalidateShardLeaderCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest, ...grpc.CallOption) *commonpb.Status); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockProxyClient_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
|
||||
type MockProxyClient_InvalidateShardLeaderCache_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *proxypb.InvalidateShardLeaderCacheRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockProxyClient_Expecter) InvalidateShardLeaderCache(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_InvalidateShardLeaderCache_Call {
|
||||
return &MockProxyClient_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockProxyClient_InvalidateShardLeaderCache_Call) Run(run func(ctx context.Context, in *proxypb.InvalidateShardLeaderCacheRequest, opts ...grpc.CallOption)) *MockProxyClient_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*proxypb.InvalidateShardLeaderCacheRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProxyClient_InvalidateShardLeaderCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxyClient_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProxyClient_InvalidateShardLeaderCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockProxyClient_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListClientInfos provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockProxyClient) ListClientInfos(ctx context.Context, in *proxypb.ListClientInfosRequest, opts ...grpc.CallOption) (*proxypb.ListClientInfosResponse, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
@ -27,6 +27,8 @@ service Proxy {
|
||||
rpc ImportV2(internal.ImportRequest) returns(internal.ImportResponse){}
|
||||
rpc GetImportProgress(internal.GetImportProgressRequest) returns(internal.GetImportProgressResponse){}
|
||||
rpc ListImports(internal.ListImportsRequest) returns(internal.ListImportsResponse){}
|
||||
|
||||
rpc InvalidateShardLeaderCache(InvalidateShardLeaderCacheRequest) returns (common.Status) {}
|
||||
}
|
||||
|
||||
message InvalidateCollMetaCacheRequest {
|
||||
@ -40,6 +42,11 @@ message InvalidateCollMetaCacheRequest {
|
||||
string partition_name = 5;
|
||||
}
|
||||
|
||||
message InvalidateShardLeaderCacheRequest {
|
||||
common.MsgBase base = 1;
|
||||
repeated int64 collectionIDs = 2;
|
||||
}
|
||||
|
||||
message InvalidateCredCacheRequest {
|
||||
common.MsgBase base = 1;
|
||||
string username = 2;
|
||||
|
@ -172,6 +172,29 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
// InvalidateCollectionMetaCache invalidate the meta cache of specific collection.
|
||||
func (node *Proxy) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error) {
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
ctx = logutil.WithModule(ctx, moduleName)
|
||||
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-InvalidateShardLeaderCache")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
)
|
||||
|
||||
log.Info("received request to invalidate shard leader cache", zap.Int64s("collectionIDs", request.GetCollectionIDs()))
|
||||
|
||||
if globalMetaCache != nil {
|
||||
globalMetaCache.InvalidateShardLeaderCache(request.GetCollectionIDs())
|
||||
}
|
||||
log.Info("complete to invalidate shard leader cache", zap.Int64s("collectionIDs", request.GetCollectionIDs()))
|
||||
|
||||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return merr.Status(err), nil
|
||||
|
@ -1720,3 +1720,30 @@ func TestGetCollectionRateSubLabel(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxy_InvalidateShardLeaderCache(t *testing.T) {
|
||||
t.Run("proxy unhealthy", func(t *testing.T) {
|
||||
node := &Proxy{}
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
|
||||
resp, err := node.InvalidateShardLeaderCache(context.TODO(), nil)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, merr.Ok(resp))
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
node := &Proxy{}
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
cacheBak := globalMetaCache
|
||||
defer func() { globalMetaCache = cacheBak }()
|
||||
// set expectations
|
||||
cache := NewMockCache(t)
|
||||
cache.EXPECT().InvalidateShardLeaderCache(mock.Anything)
|
||||
globalMetaCache = cache
|
||||
|
||||
resp, err := node.InvalidateShardLeaderCache(context.TODO(), &proxypb.InvalidateShardLeaderCacheRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, merr.Ok(resp))
|
||||
})
|
||||
}
|
||||
|
@ -73,6 +73,7 @@ type Cache interface {
|
||||
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error)
|
||||
GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
|
||||
DeprecateShardCache(database, collectionName string)
|
||||
InvalidateShardLeaderCache(collections []int64)
|
||||
RemoveCollection(ctx context.Context, database, collectionName string)
|
||||
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string
|
||||
RemovePartition(ctx context.Context, database, collectionName string, partitionName string)
|
||||
@ -201,6 +202,7 @@ type shardLeaders struct {
|
||||
idx *atomic.Int64
|
||||
deprecated *atomic.Bool
|
||||
|
||||
collectionID int64
|
||||
shardLeaders map[string][]nodeInfo
|
||||
}
|
||||
|
||||
@ -944,6 +946,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
|
||||
|
||||
shards := parseShardLeaderList2QueryNode(resp.GetShards())
|
||||
newShardLeaders := &shardLeaders{
|
||||
collectionID: info.collID,
|
||||
shardLeaders: shards,
|
||||
deprecated: atomic.NewBool(false),
|
||||
idx: atomic.NewInt64(0),
|
||||
@ -997,6 +1000,21 @@ func (m *MetaCache) DeprecateShardCache(database, collectionName string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) {
|
||||
log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections))
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
collectionSet := typeutil.NewUniqueSet(collections...)
|
||||
for _, db := range m.collLeader {
|
||||
for _, shardLeaders := range db {
|
||||
if collectionSet.Contain(shardLeaders.collectionID) {
|
||||
shardLeaders.deprecated.Store(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MetaCache) InitPolicyInfo(info []string, userRoles []string) {
|
||||
defer func() {
|
||||
err := getEnforcer().LoadPolicy()
|
||||
|
@ -42,6 +42,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
@ -1070,3 +1071,51 @@ func TestGlobalMetaCache_GetCollectionNamesByID(t *testing.T) {
|
||||
assert.Equal(t, []string{"db1", "db1"}, dbNames)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) {
|
||||
paramtable.Init()
|
||||
paramtable.Get().Save(Params.ProxyCfg.ShardLeaderCacheInterval.Key, "1")
|
||||
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &mocks.MockQueryCoordClient{}
|
||||
shardMgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
||||
Status: merr.Success(),
|
||||
CollectionIDs: []UniqueID{1},
|
||||
InMemoryPercentages: []int64{100},
|
||||
}, nil)
|
||||
|
||||
called := uatomic.NewInt32(0)
|
||||
queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context,
|
||||
gslr *querypb.GetShardLeadersRequest, co ...grpc.CallOption,
|
||||
) (*querypb.GetShardLeadersResponse, error) {
|
||||
called.Inc()
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: merr.Success(),
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodeInfos["channel-1"], 3)
|
||||
assert.Equal(t, called.Load(), int32(1))
|
||||
|
||||
globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
|
||||
assert.Equal(t, called.Load(), int32(1))
|
||||
|
||||
globalMetaCache.InvalidateShardLeaderCache([]int64{1})
|
||||
nodeInfos, err = globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodeInfos["channel-1"], 3)
|
||||
assert.Equal(t, called.Load(), int32(2))
|
||||
}
|
||||
|
@ -952,6 +952,39 @@ func (_c *MockCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []strin
|
||||
return _c
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache provides a mock function with given fields: collections
|
||||
func (_m *MockCache) InvalidateShardLeaderCache(collections []int64) {
|
||||
_m.Called(collections)
|
||||
}
|
||||
|
||||
// MockCache_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
|
||||
type MockCache_InvalidateShardLeaderCache_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache is a helper method to define mock.On call
|
||||
// - collections []int64
|
||||
func (_e *MockCache_Expecter) InvalidateShardLeaderCache(collections interface{}) *MockCache_InvalidateShardLeaderCache_Call {
|
||||
return &MockCache_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", collections)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_InvalidateShardLeaderCache_Call) Run(run func(collections []int64)) *MockCache_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_InvalidateShardLeaderCache_Call) Return() *MockCache_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int64)) *MockCache_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RefreshPolicyInfo provides a mock function with given fields: op
|
||||
func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
|
||||
ret := _m.Called(op)
|
||||
|
@ -191,7 +191,7 @@ func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int
|
||||
continue
|
||||
}
|
||||
|
||||
if err := CheckLeaderAvailable(c.nodeMgr, leaderView, targets); err != nil {
|
||||
if err := utils.CheckLeaderAvailable(c.nodeMgr, leaderView, targets); err != nil {
|
||||
log.RatedInfo(10, "replica has unavailable shard leader",
|
||||
zap.Int64("collectionID", replica.GetCollectionID()),
|
||||
zap.Int64("replicaID", replicaID),
|
||||
|
@ -1,81 +0,0 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package checkers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func CheckNodeAvailable(nodeID int64, info *session.NodeInfo) error {
|
||||
if info == nil {
|
||||
return merr.WrapErrNodeOffline(nodeID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// In a replica, a shard is available, if and only if:
|
||||
// 1. The leader is online
|
||||
// 2. All QueryNodes in the distribution are online
|
||||
// 3. The last heartbeat response time is within HeartbeatAvailableInterval for all QueryNodes(include leader) in the distribution
|
||||
// 4. All segments of the shard in target should be in the distribution
|
||||
func CheckLeaderAvailable(nodeMgr *session.NodeManager, leader *meta.LeaderView, currentTargets map[int64]*datapb.SegmentInfo) error {
|
||||
log := log.Ctx(context.TODO()).
|
||||
WithRateGroup("checkers.CheckLeaderAvailable", 1, 60).
|
||||
With(zap.Int64("leaderID", leader.ID))
|
||||
info := nodeMgr.Get(leader.ID)
|
||||
|
||||
// Check whether leader is online
|
||||
err := CheckNodeAvailable(leader.ID, info)
|
||||
if err != nil {
|
||||
log.Info("leader is not available", zap.Error(err))
|
||||
return fmt.Errorf("leader not available: %w", err)
|
||||
}
|
||||
|
||||
for id, version := range leader.Segments {
|
||||
info := nodeMgr.Get(version.GetNodeID())
|
||||
err = CheckNodeAvailable(version.GetNodeID(), info)
|
||||
if err != nil {
|
||||
log.Info("leader is not available due to QueryNode unavailable",
|
||||
zap.Int64("segmentID", id),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether segments are fully loaded
|
||||
for segmentID, info := range currentTargets {
|
||||
if info.GetInsertChannel() != leader.Channel {
|
||||
continue
|
||||
}
|
||||
|
||||
_, exist := leader.Segments[segmentID]
|
||||
if !exist {
|
||||
log.RatedInfo(10, "leader is not available due to lack of segment", zap.Int64("segmentID", segmentID))
|
||||
return merr.WrapErrSegmentLack(segmentID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -426,31 +426,3 @@ func (s *Server) fillReplicaInfo(replica *meta.Replica, withShardNodes bool) *mi
|
||||
info.ShardReplicas = shardReplicas
|
||||
return info
|
||||
}
|
||||
|
||||
func filterDupLeaders(replicaManager *meta.ReplicaManager, leaders map[int64]*meta.LeaderView) map[int64]*meta.LeaderView {
|
||||
type leaderID struct {
|
||||
ReplicaID int64
|
||||
Shard string
|
||||
}
|
||||
|
||||
newLeaders := make(map[leaderID]*meta.LeaderView)
|
||||
for _, view := range leaders {
|
||||
replica := replicaManager.GetByCollectionAndNode(view.CollectionID, view.ID)
|
||||
if replica == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
id := leaderID{replica.GetID(), view.Channel}
|
||||
if old, ok := newLeaders[id]; ok && old.Version > view.Version {
|
||||
continue
|
||||
}
|
||||
|
||||
newLeaders[id] = view
|
||||
}
|
||||
|
||||
result := make(map[int64]*meta.LeaderView)
|
||||
for _, v := range newLeaders {
|
||||
result[v.ID] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type lvCriterion struct {
|
||||
@ -192,9 +193,12 @@ func composeNodeViews(views ...*LeaderView) nodeViews {
|
||||
}
|
||||
}
|
||||
|
||||
type NotifyDelegatorChanges = func(collectionID ...int64)
|
||||
|
||||
type LeaderViewManager struct {
|
||||
rwmutex sync.RWMutex
|
||||
views map[int64]nodeViews // LeaderID -> Views (one per shard)
|
||||
rwmutex sync.RWMutex
|
||||
views map[int64]nodeViews // LeaderID -> Views (one per shard)
|
||||
notifyFunc NotifyDelegatorChanges
|
||||
}
|
||||
|
||||
func NewLeaderViewManager() *LeaderViewManager {
|
||||
@ -203,11 +207,45 @@ func NewLeaderViewManager() *LeaderViewManager {
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *LeaderViewManager) SetNotifyFunc(notifyFunc NotifyDelegatorChanges) {
|
||||
mgr.notifyFunc = notifyFunc
|
||||
}
|
||||
|
||||
// Update updates the leader's views, all views have to be with the same leader ID
|
||||
func (mgr *LeaderViewManager) Update(leaderID int64, views ...*LeaderView) {
|
||||
mgr.rwmutex.Lock()
|
||||
defer mgr.rwmutex.Unlock()
|
||||
|
||||
oldViews := make(map[string]*LeaderView, 0)
|
||||
if _, ok := mgr.views[leaderID]; ok {
|
||||
oldViews = mgr.views[leaderID].channelView
|
||||
}
|
||||
|
||||
newViews := lo.SliceToMap(views, func(v *LeaderView) (string, *LeaderView) {
|
||||
return v.Channel, v
|
||||
})
|
||||
|
||||
// update leader views
|
||||
mgr.views[leaderID] = composeNodeViews(views...)
|
||||
|
||||
// compute leader location change, find it's correspond collection
|
||||
if mgr.notifyFunc != nil {
|
||||
viewChanges := typeutil.NewUniqueSet()
|
||||
for channel, oldView := range oldViews {
|
||||
// if channel released from current node
|
||||
if _, ok := newViews[channel]; !ok {
|
||||
viewChanges.Insert(oldView.CollectionID)
|
||||
}
|
||||
}
|
||||
|
||||
for channel, newView := range newViews {
|
||||
// if channel loaded to current node
|
||||
if _, ok := oldViews[channel]; !ok {
|
||||
viewChanges.Insert(newView.CollectionID)
|
||||
}
|
||||
}
|
||||
mgr.notifyFunc(viewChanges.Collect()...)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *LeaderViewManager) GetLeaderShardView(id int64, shard string) *LeaderView {
|
||||
|
@ -208,6 +208,58 @@ func (suite *LeaderViewManagerSuite) TestClone() {
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *LeaderViewManagerSuite) TestNotifyDelegatorChanges() {
|
||||
mgr := NewLeaderViewManager()
|
||||
|
||||
oldViews := []*LeaderView{
|
||||
{
|
||||
ID: 1,
|
||||
CollectionID: 100,
|
||||
Channel: "test-channel-1",
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CollectionID: 101,
|
||||
Channel: "test-channel-2",
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CollectionID: 102,
|
||||
Channel: "test-channel-3",
|
||||
},
|
||||
}
|
||||
mgr.Update(1, oldViews...)
|
||||
|
||||
newViews := []*LeaderView{
|
||||
{
|
||||
ID: 1,
|
||||
CollectionID: 101,
|
||||
Channel: "test-channel-2",
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CollectionID: 102,
|
||||
Channel: "test-channel-3",
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CollectionID: 103,
|
||||
Channel: "test-channel-4",
|
||||
},
|
||||
}
|
||||
|
||||
updateCollections := make([]int64, 0)
|
||||
mgr.SetNotifyFunc(func(collectionIDs ...int64) {
|
||||
updateCollections = append(updateCollections, collectionIDs...)
|
||||
})
|
||||
|
||||
mgr.Update(1, newViews...)
|
||||
|
||||
suite.Equal(2, len(updateCollections))
|
||||
suite.Contains(updateCollections, int64(100))
|
||||
suite.Contains(updateCollections, int64(103))
|
||||
}
|
||||
|
||||
func TestLeaderViewManager(t *testing.T) {
|
||||
suite.Run(t, new(LeaderViewManagerSuite))
|
||||
}
|
||||
|
114
internal/querycoordv2/observers/leader_cache_observer.go
Normal file
114
internal/querycoordv2/observers/leader_cache_observer.go
Normal file
@ -0,0 +1,114 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package observers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/proxyutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
type CollectionShardLeaderCache = map[string]*querypb.ShardLeadersList
|
||||
|
||||
// LeaderCacheObserver is to invalidate shard leader cache when leader location changes
|
||||
type LeaderCacheObserver struct {
|
||||
wg sync.WaitGroup
|
||||
proxyManager proxyutil.ProxyClientManagerInterface
|
||||
stopOnce sync.Once
|
||||
closeCh chan struct{}
|
||||
|
||||
// collections which need to update event
|
||||
eventCh chan int64
|
||||
}
|
||||
|
||||
func (o *LeaderCacheObserver) Start(ctx context.Context) {
|
||||
o.wg.Add(1)
|
||||
go o.schedule(ctx)
|
||||
}
|
||||
|
||||
func (o *LeaderCacheObserver) Stop() {
|
||||
o.stopOnce.Do(func() {
|
||||
close(o.closeCh)
|
||||
o.wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func (o *LeaderCacheObserver) RegisterEvent(events ...int64) {
|
||||
for _, event := range events {
|
||||
o.eventCh <- event
|
||||
}
|
||||
}
|
||||
|
||||
func (o *LeaderCacheObserver) schedule(ctx context.Context) {
|
||||
defer o.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("stop leader cache observer due to context done")
|
||||
return
|
||||
case <-o.closeCh:
|
||||
log.Info("stop leader cache observer")
|
||||
return
|
||||
|
||||
case event := <-o.eventCh:
|
||||
log.Info("receive event, trigger leader cache update", zap.Int64("event", event))
|
||||
ret := make([]int64, 0)
|
||||
ret = append(ret, event)
|
||||
|
||||
// try batch submit events
|
||||
eventNum := len(o.eventCh)
|
||||
if eventNum > 0 {
|
||||
for eventNum > 0 {
|
||||
event := <-o.eventCh
|
||||
ret = append(ret, event)
|
||||
eventNum--
|
||||
}
|
||||
}
|
||||
o.HandleEvent(ctx, ret...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o *LeaderCacheObserver) HandleEvent(ctx context.Context, collectionIDs ...int64) {
|
||||
ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Second))
|
||||
defer cancel()
|
||||
err := o.proxyManager.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{
|
||||
CollectionIDs: collectionIDs,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("failed to invalidate proxy's shard leader cache", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func NewLeaderCacheObserver(
|
||||
proxyManager proxyutil.ProxyClientManagerInterface,
|
||||
) *LeaderCacheObserver {
|
||||
return &LeaderCacheObserver{
|
||||
proxyManager: proxyManager,
|
||||
closeCh: make(chan struct{}),
|
||||
eventCh: make(chan int64, 1024),
|
||||
}
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package observers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/util/proxyutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type LeaderCacheObserverTestSuite struct {
|
||||
suite.Suite
|
||||
|
||||
mockProxyManager *proxyutil.MockProxyClientManager
|
||||
|
||||
observer *LeaderCacheObserver
|
||||
}
|
||||
|
||||
func (suite *LeaderCacheObserverTestSuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
suite.mockProxyManager = proxyutil.NewMockProxyClientManager(suite.T())
|
||||
suite.observer = NewLeaderCacheObserver(suite.mockProxyManager)
|
||||
}
|
||||
|
||||
func (suite *LeaderCacheObserverTestSuite) TestInvalidateShardLeaderCache() {
|
||||
suite.observer.Start(context.TODO())
|
||||
defer suite.observer.Stop()
|
||||
|
||||
ret := atomic.NewBool(false)
|
||||
collectionIDs := typeutil.NewConcurrentSet[int64]()
|
||||
suite.mockProxyManager.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).RunAndReturn(
|
||||
func(ctx context.Context, req *proxypb.InvalidateShardLeaderCacheRequest) error {
|
||||
collectionIDs.Upsert(req.GetCollectionIDs()...)
|
||||
collectionIDs := req.GetCollectionIDs()
|
||||
|
||||
if len(collectionIDs) == 1 && lo.Contains(collectionIDs, 1) {
|
||||
ret.Store(true)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
suite.observer.RegisterEvent(1)
|
||||
suite.Eventually(func() bool {
|
||||
return ret.Load()
|
||||
}, 3*time.Second, 1*time.Second)
|
||||
|
||||
// test batch submit events
|
||||
ret.Store(false)
|
||||
suite.mockProxyManager.ExpectedCalls = nil
|
||||
suite.mockProxyManager.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).RunAndReturn(
|
||||
func(ctx context.Context, req *proxypb.InvalidateShardLeaderCacheRequest) error {
|
||||
collectionIDs.Upsert(req.GetCollectionIDs()...)
|
||||
collectionIDs := req.GetCollectionIDs()
|
||||
|
||||
if len(collectionIDs) == 3 && lo.Contains(collectionIDs, 1) && lo.Contains(collectionIDs, 2) && lo.Contains(collectionIDs, 3) {
|
||||
ret.Store(true)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
suite.observer.RegisterEvent(1)
|
||||
suite.observer.RegisterEvent(2)
|
||||
suite.observer.RegisterEvent(3)
|
||||
suite.Eventually(func() bool {
|
||||
return ret.Load()
|
||||
}, 3*time.Second, 1*time.Second)
|
||||
}
|
||||
|
||||
func TestLeaderCacheObserverTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(LeaderCacheObserverTestSuite))
|
||||
}
|
@ -51,6 +51,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/task"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/proxyutil"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
@ -108,10 +109,11 @@ type Server struct {
|
||||
checkerController *checkers.CheckerController
|
||||
|
||||
// Observers
|
||||
collectionObserver *observers.CollectionObserver
|
||||
targetObserver *observers.TargetObserver
|
||||
replicaObserver *observers.ReplicaObserver
|
||||
resourceObserver *observers.ResourceObserver
|
||||
collectionObserver *observers.CollectionObserver
|
||||
targetObserver *observers.TargetObserver
|
||||
replicaObserver *observers.ReplicaObserver
|
||||
resourceObserver *observers.ResourceObserver
|
||||
leaderCacheObserver *observers.LeaderCacheObserver
|
||||
|
||||
balancer balance.Balance
|
||||
balancerMap map[string]balance.Balance
|
||||
@ -122,6 +124,11 @@ type Server struct {
|
||||
|
||||
nodeUpEventChan chan int64
|
||||
notifyNodeUp chan struct{}
|
||||
|
||||
// proxy client manager
|
||||
proxyCreator proxyutil.ProxyCreator
|
||||
proxyWatcher proxyutil.ProxyWatcherInterface
|
||||
proxyClientManager proxyutil.ProxyClientManagerInterface
|
||||
}
|
||||
|
||||
func NewQueryCoord(ctx context.Context) (*Server, error) {
|
||||
@ -261,6 +268,16 @@ func (s *Server) initQueryCoord() error {
|
||||
s.nodeMgr,
|
||||
)
|
||||
|
||||
// init proxy client manager
|
||||
s.proxyClientManager = proxyutil.NewProxyClientManager(proxyutil.DefaultProxyCreator)
|
||||
s.proxyWatcher = proxyutil.NewProxyWatcher(
|
||||
s.etcdCli,
|
||||
s.proxyClientManager.AddProxyClients,
|
||||
)
|
||||
s.proxyWatcher.AddSessionFunc(s.proxyClientManager.AddProxyClient)
|
||||
s.proxyWatcher.DelSessionFunc(s.proxyClientManager.DelProxyClient)
|
||||
log.Info("init proxy manager done")
|
||||
|
||||
// Init heartbeat
|
||||
log.Info("init dist controller")
|
||||
s.distController = dist.NewDistController(
|
||||
@ -387,6 +404,11 @@ func (s *Server) initObserver() {
|
||||
)
|
||||
|
||||
s.resourceObserver = observers.NewResourceObserver(s.meta)
|
||||
|
||||
s.leaderCacheObserver = observers.NewLeaderCacheObserver(
|
||||
s.proxyClientManager,
|
||||
)
|
||||
s.dist.LeaderViewManager.SetNotifyFunc(s.leaderCacheObserver.RegisterEvent)
|
||||
}
|
||||
|
||||
func (s *Server) afterStart() {}
|
||||
@ -432,6 +454,10 @@ func (s *Server) startQueryCoord() error {
|
||||
// check whether old node exist, if yes suspend auto balance until all old nodes down
|
||||
s.updateBalanceConfigLoop(s.ctx)
|
||||
|
||||
if err := s.proxyWatcher.WatchProxy(s.ctx); err != nil {
|
||||
log.Warn("querycoord failed to watch proxy", zap.Error(err))
|
||||
}
|
||||
|
||||
// Recover dist, to avoid generate too much task when dist not ready after restart
|
||||
s.distController.SyncAll(s.ctx)
|
||||
|
||||
@ -453,6 +479,7 @@ func (s *Server) startServerLoop() {
|
||||
s.targetObserver.Start()
|
||||
s.replicaObserver.Start()
|
||||
s.resourceObserver.Start()
|
||||
s.leaderCacheObserver.Start(s.ctx)
|
||||
|
||||
log.Info("start task scheduler...")
|
||||
s.taskScheduler.Start()
|
||||
@ -504,6 +531,9 @@ func (s *Server) Stop() error {
|
||||
if s.resourceObserver != nil {
|
||||
s.resourceObserver.Stop()
|
||||
}
|
||||
if s.leaderCacheObserver != nil {
|
||||
s.leaderCacheObserver.Stop()
|
||||
}
|
||||
|
||||
if s.distController != nil {
|
||||
log.Info("stop dist controller...")
|
||||
|
@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
@ -31,7 +30,6 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/checkers"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/job"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
|
||||
@ -867,99 +865,11 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade
|
||||
}, nil
|
||||
}
|
||||
|
||||
resp := &querypb.GetShardLeadersResponse{
|
||||
Status: merr.Success(),
|
||||
}
|
||||
|
||||
percentage := s.meta.CollectionManager.CalculateLoadPercentage(req.GetCollectionID())
|
||||
if percentage < 0 {
|
||||
err := merr.WrapErrCollectionNotLoaded(req.GetCollectionID())
|
||||
log.Warn("failed to GetShardLeaders", zap.Error(err))
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
collection := s.meta.CollectionManager.GetCollection(req.GetCollectionID())
|
||||
if collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded {
|
||||
// when collection is loaded, regard collection as readable, set percentage == 100
|
||||
percentage = 100
|
||||
}
|
||||
if percentage < 100 {
|
||||
err := merr.WrapErrCollectionNotFullyLoaded(req.GetCollectionID())
|
||||
msg := fmt.Sprintf("collection %v is not fully loaded", req.GetCollectionID())
|
||||
log.Warn(msg)
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
channels := s.targetMgr.GetDmChannelsByCollection(req.GetCollectionID(), meta.CurrentTarget)
|
||||
if len(channels) == 0 {
|
||||
err := merr.WrapErrCollectionOnRecovering(req.GetCollectionID(),
|
||||
"loaded collection do not found any channel in target, may be in recovery")
|
||||
log.Warn("failed to get channels", zap.Error(err))
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
currentTargets := s.targetMgr.GetSealedSegmentsByCollection(req.GetCollectionID(), meta.CurrentTarget)
|
||||
for _, channel := range channels {
|
||||
log := log.With(zap.String("channel", channel.GetChannelName()))
|
||||
|
||||
leaders := s.dist.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(channel.GetChannelName()))
|
||||
|
||||
readableLeaders := make(map[int64]*meta.LeaderView)
|
||||
|
||||
var channelErr error
|
||||
if len(leaders) == 0 {
|
||||
channelErr = merr.WrapErrChannelLack(channel.GetChannelName(), "channel not subscribed")
|
||||
}
|
||||
|
||||
for _, leader := range leaders {
|
||||
if err := checkers.CheckLeaderAvailable(s.nodeMgr, leader, currentTargets); err != nil {
|
||||
multierr.AppendInto(&channelErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
readableLeaders[leader.ID] = leader
|
||||
}
|
||||
|
||||
if len(readableLeaders) == 0 {
|
||||
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
|
||||
log.Warn(msg, zap.Error(channelErr))
|
||||
resp.Status = merr.Status(
|
||||
errors.Wrap(merr.WrapErrChannelNotAvailable(channel.GetChannelName()), channelErr.Error()))
|
||||
resp.Shards = nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
readableLeaders = filterDupLeaders(s.meta.ReplicaManager, readableLeaders)
|
||||
ids := make([]int64, 0, len(leaders))
|
||||
addrs := make([]string, 0, len(leaders))
|
||||
for _, leader := range readableLeaders {
|
||||
info := s.nodeMgr.Get(leader.ID)
|
||||
if info != nil {
|
||||
ids = append(ids, info.ID())
|
||||
addrs = append(addrs, info.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
// to avoid node down during GetShardLeaders
|
||||
if len(ids) == 0 {
|
||||
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
|
||||
log.Warn(msg, zap.Error(channelErr))
|
||||
resp.Status = merr.Status(
|
||||
errors.Wrap(merr.WrapErrChannelNotAvailable(channel.GetChannelName()), channelErr.Error()))
|
||||
resp.Shards = nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp.Shards = append(resp.Shards, &querypb.ShardLeadersList{
|
||||
ChannelName: channel.GetChannelName(),
|
||||
NodeIds: ids,
|
||||
NodeAddrs: addrs,
|
||||
})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
leaders, err := utils.GetShardLeaders(s.meta, s.targetMgr, s.dist, s.nodeMgr, req.GetCollectionID())
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: merr.Status(err),
|
||||
Shards: leaders,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
|
||||
|
195
internal/querycoordv2/utils/util.go
Normal file
195
internal/querycoordv2/utils/util.go
Normal file
@ -0,0 +1,195 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func CheckNodeAvailable(nodeID int64, info *session.NodeInfo) error {
|
||||
if info == nil {
|
||||
return merr.WrapErrNodeOffline(nodeID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// In a replica, a shard is available, if and only if:
|
||||
// 1. The leader is online
|
||||
// 2. All QueryNodes in the distribution are online
|
||||
// 3. The last heartbeat response time is within HeartbeatAvailableInterval for all QueryNodes(include leader) in the distribution
|
||||
// 4. All segments of the shard in target should be in the distribution
|
||||
func CheckLeaderAvailable(nodeMgr *session.NodeManager, leader *meta.LeaderView, currentTargets map[int64]*datapb.SegmentInfo) error {
|
||||
log := log.Ctx(context.TODO()).
|
||||
WithRateGroup("utils.CheckLeaderAvailable", 1, 60).
|
||||
With(zap.Int64("leaderID", leader.ID))
|
||||
info := nodeMgr.Get(leader.ID)
|
||||
|
||||
// Check whether leader is online
|
||||
err := CheckNodeAvailable(leader.ID, info)
|
||||
if err != nil {
|
||||
log.Info("leader is not available", zap.Error(err))
|
||||
return fmt.Errorf("leader not available: %w", err)
|
||||
}
|
||||
|
||||
for id, version := range leader.Segments {
|
||||
info := nodeMgr.Get(version.GetNodeID())
|
||||
err = CheckNodeAvailable(version.GetNodeID(), info)
|
||||
if err != nil {
|
||||
log.Info("leader is not available due to QueryNode unavailable",
|
||||
zap.Int64("segmentID", id),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether segments are fully loaded
|
||||
for segmentID, info := range currentTargets {
|
||||
if info.GetInsertChannel() != leader.Channel {
|
||||
continue
|
||||
}
|
||||
|
||||
_, exist := leader.Segments[segmentID]
|
||||
if !exist {
|
||||
log.RatedInfo(10, "leader is not available due to lack of segment", zap.Int64("segmentID", segmentID))
|
||||
return merr.WrapErrSegmentLack(segmentID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetShardLeaders(m *meta.Meta, targetMgr *meta.TargetManager, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64) ([]*querypb.ShardLeadersList, error) {
|
||||
percentage := m.CollectionManager.CalculateLoadPercentage(collectionID)
|
||||
if percentage < 0 {
|
||||
err := merr.WrapErrCollectionNotLoaded(collectionID)
|
||||
log.Warn("failed to GetShardLeaders", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
collection := m.CollectionManager.GetCollection(collectionID)
|
||||
if collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded {
|
||||
// when collection is loaded, regard collection as readable, set percentage == 100
|
||||
percentage = 100
|
||||
}
|
||||
|
||||
if percentage < 100 {
|
||||
err := merr.WrapErrCollectionNotFullyLoaded(collectionID)
|
||||
msg := fmt.Sprintf("collection %v is not fully loaded", collectionID)
|
||||
log.Warn(msg)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channels := targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget)
|
||||
if len(channels) == 0 {
|
||||
msg := "loaded collection do not found any channel in target, may be in recovery"
|
||||
err := merr.WrapErrCollectionOnRecovering(collectionID, msg)
|
||||
log.Warn("failed to get channels", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := make([]*querypb.ShardLeadersList, 0)
|
||||
currentTargets := targetMgr.GetSealedSegmentsByCollection(collectionID, meta.CurrentTarget)
|
||||
for _, channel := range channels {
|
||||
log := log.With(zap.String("channel", channel.GetChannelName()))
|
||||
|
||||
var channelErr error
|
||||
leaders := dist.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(channel.GetChannelName()))
|
||||
if len(leaders) == 0 {
|
||||
channelErr = merr.WrapErrChannelLack(channel.GetChannelName(), "channel not subscribed")
|
||||
}
|
||||
|
||||
readableLeaders := make(map[int64]*meta.LeaderView)
|
||||
for _, leader := range leaders {
|
||||
if err := CheckLeaderAvailable(nodeMgr, leader, currentTargets); err != nil {
|
||||
multierr.AppendInto(&channelErr, err)
|
||||
continue
|
||||
}
|
||||
readableLeaders[leader.ID] = leader
|
||||
}
|
||||
|
||||
if len(readableLeaders) == 0 {
|
||||
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
|
||||
log.Warn(msg, zap.Error(channelErr))
|
||||
err := merr.WrapErrChannelNotAvailable(channel.GetChannelName(), channelErr.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
readableLeaders = filterDupLeaders(m.ReplicaManager, readableLeaders)
|
||||
ids := make([]int64, 0, len(leaders))
|
||||
addrs := make([]string, 0, len(leaders))
|
||||
for _, leader := range readableLeaders {
|
||||
info := nodeMgr.Get(leader.ID)
|
||||
if info != nil {
|
||||
ids = append(ids, info.ID())
|
||||
addrs = append(addrs, info.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
// to avoid node down during GetShardLeaders
|
||||
if len(ids) == 0 {
|
||||
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
|
||||
log.Warn(msg, zap.Error(channelErr))
|
||||
err := merr.WrapErrChannelNotAvailable(channel.GetChannelName(), channelErr.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = append(ret, &querypb.ShardLeadersList{
|
||||
ChannelName: channel.GetChannelName(),
|
||||
NodeIds: ids,
|
||||
NodeAddrs: addrs,
|
||||
})
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func filterDupLeaders(replicaManager *meta.ReplicaManager, leaders map[int64]*meta.LeaderView) map[int64]*meta.LeaderView {
|
||||
type leaderID struct {
|
||||
ReplicaID int64
|
||||
Shard string
|
||||
}
|
||||
|
||||
newLeaders := make(map[leaderID]*meta.LeaderView)
|
||||
for _, view := range leaders {
|
||||
replica := replicaManager.GetByCollectionAndNode(view.CollectionID, view.ID)
|
||||
if replica == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
id := leaderID{replica.GetID(), view.Channel}
|
||||
if old, ok := newLeaders[id]; ok && old.Version > view.Version {
|
||||
continue
|
||||
}
|
||||
|
||||
newLeaders[id] = view
|
||||
}
|
||||
|
||||
result := make(map[int64]*meta.LeaderView)
|
||||
for _, v := range newLeaders {
|
||||
result[v.ID] = v
|
||||
}
|
||||
return result
|
||||
}
|
@ -14,7 +14,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package checkers
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
@ -422,6 +422,49 @@ func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) RunAndReturn(ru
|
||||
return _c
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache provides a mock function with given fields: ctx, request
|
||||
func (_m *MockProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error {
|
||||
ret := _m.Called(ctx, request)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok {
|
||||
r0 = rf(ctx, request)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockProxyClientManager_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
|
||||
type MockProxyClientManager_InvalidateShardLeaderCache_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - request *proxypb.InvalidateShardLeaderCacheRequest
|
||||
func (_e *MockProxyClientManager_Expecter) InvalidateShardLeaderCache(ctx interface{}, request interface{}) *MockProxyClientManager_InvalidateShardLeaderCache_Call {
|
||||
return &MockProxyClientManager_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", ctx, request)}
|
||||
}
|
||||
|
||||
func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest)) *MockProxyClientManager_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*proxypb.InvalidateShardLeaderCacheRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) Return(_a0 error) *MockProxyClientManager_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error) *MockProxyClientManager_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RefreshPolicyInfoCache provides a mock function with given fields: ctx, req
|
||||
func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
@ -88,6 +88,7 @@ type ProxyClientManagerInterface interface {
|
||||
GetProxyCount() int
|
||||
|
||||
InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...ExpireCacheOpt) error
|
||||
InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error
|
||||
InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error
|
||||
UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error
|
||||
RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error
|
||||
@ -188,6 +189,11 @@ func (p *ProxyClientManager) InvalidateCollectionMetaCache(ctx context.Context,
|
||||
log.Warn("InvalidateCollectionMetaCache failed due to proxy service not found", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, merr.ErrServiceUnimplemented) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("InvalidateCollectionMetaCache failed, proxyID = %d, err = %s", k, err)
|
||||
}
|
||||
if sta.ErrorCode != commonpb.ErrorCode_Success {
|
||||
@ -363,3 +369,31 @@ func (p *ProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]
|
||||
|
||||
return states, nil
|
||||
}
|
||||
|
||||
func (p *ProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error {
|
||||
if p.proxyClient.Len() == 0 {
|
||||
log.Warn("proxy client is empty, InvalidateShardLeaderCache will not send to any client")
|
||||
return nil
|
||||
}
|
||||
|
||||
group := &errgroup.Group{}
|
||||
p.proxyClient.Range(func(key int64, value types.ProxyClient) bool {
|
||||
k, v := key, value
|
||||
group.Go(func() error {
|
||||
sta, err := v.InvalidateShardLeaderCache(ctx, request)
|
||||
if err != nil {
|
||||
if errors.Is(err, merr.ErrNodeNotFound) {
|
||||
log.Warn("InvalidateShardLeaderCache failed due to proxy service not found", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("InvalidateShardLeaderCache failed, proxyID = %d, err = %s", k, err)
|
||||
}
|
||||
if sta.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return fmt.Errorf("InvalidateShardLeaderCache failed, proxyID = %d, err = %s", k, sta.Reason)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return true
|
||||
})
|
||||
return group.Wait()
|
||||
}
|
||||
|
@ -313,6 +313,14 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyClientManager_TestGetProxyCount(t *testing.T) {
|
||||
p1 := mocks.NewMockProxyClient(t)
|
||||
pcm := NewProxyClientManager(DefaultProxyCreator)
|
||||
pcm.proxyClient.Insert(TestProxyID, p1)
|
||||
|
||||
assert.Equal(t, pcm.GetProxyCount(), 1)
|
||||
}
|
||||
|
||||
func TestProxyClientManager_GetProxyMetrics(t *testing.T) {
|
||||
TestProxyID := int64(1001)
|
||||
t.Run("empty proxy list", func(t *testing.T) {
|
||||
@ -424,3 +432,34 @@ func TestProxyClientManager_GetComponentStates(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyClientManager_InvalidateShardLeaderCache(t *testing.T) {
|
||||
TestProxyID := int64(1001)
|
||||
t.Run("empty proxy list", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pcm := NewProxyClientManager(DefaultProxyCreator)
|
||||
|
||||
err := pcm.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("mock rpc error", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p1 := mocks.NewMockProxyClient(t)
|
||||
p1.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(nil, errors.New("error mock InvalidateCredentialCache"))
|
||||
pcm := NewProxyClientManager(DefaultProxyCreator)
|
||||
pcm.proxyClient.Insert(TestProxyID, p1)
|
||||
err := pcm.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p1 := mocks.NewMockProxyClient(t)
|
||||
p1.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
pcm := NewProxyClientManager(DefaultProxyCreator)
|
||||
pcm.proxyClient.Insert(TestProxyID, p1)
|
||||
err := pcm.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user