enhance: update shard leader cache when leader location changed (#32470)

issue: #32466

this PR enhance that when shard location changed, update proxy's shard
leader cache. in case of query node failover case, proxy can find
replica recover

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2024-05-08 10:05:29 +08:00 committed by GitHub
parent 5038036ece
commit ba02d54a30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 987 additions and 212 deletions

View File

@ -216,3 +216,9 @@ func (c *Client) ListImports(ctx context.Context, req *internalpb.ListImportsReq
return client.ListImports(ctx, req) return client.ListImports(ctx, req)
}) })
} }
func (c *Client) InvalidateShardLeaderCache(ctx context.Context, req *proxypb.InvalidateShardLeaderCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return wrapGrpcCall(ctx, c, func(client proxypb.ProxyClient) (*commonpb.Status, error) {
return client.InvalidateShardLeaderCache(ctx, req)
})
}

View File

@ -462,3 +462,40 @@ func Test_ImportV2(t *testing.T) {
_, err = client.ListImports(ctx, &internalpb.ListImportsRequest{}) _, err = client.ListImports(ctx, &internalpb.ListImportsRequest{})
assert.Nil(t, err) assert.Nil(t, err)
} }
func Test_InvalidateShardLeaderCache(t *testing.T) {
paramtable.Init()
ctx := context.Background()
client, err := NewClient(ctx, "test", 1)
assert.NoError(t, err)
assert.NotNil(t, client)
defer client.Close()
mockProxy := mocks.NewMockProxyClient(t)
mockGrpcClient := mocks.NewMockGrpcClient[proxypb.ProxyClient](t)
mockGrpcClient.EXPECT().Close().Return(nil)
mockGrpcClient.EXPECT().ReCall(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, f func(proxypb.ProxyClient) (interface{}, error)) (interface{}, error) {
return f(mockProxy)
})
client.(*Client).grpcClient = mockGrpcClient
// test success
mockProxy.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(merr.Success(), nil)
_, err = client.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
assert.Nil(t, err)
// test return error code
mockProxy.ExpectedCalls = nil
mockProxy.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil)
_, err = client.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
assert.Nil(t, err)
// test ctx done
ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancel()
time.Sleep(20 * time.Millisecond)
_, err = client.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
assert.ErrorIs(t, err, context.DeadlineExceeded)
}

View File

@ -1229,3 +1229,7 @@ func (s *Server) ListImports(ctx context.Context, req *internalpb.ListImportsReq
func (s *Server) AlterDatabase(ctx context.Context, req *milvuspb.AlterDatabaseRequest) (*commonpb.Status, error) { func (s *Server) AlterDatabase(ctx context.Context, req *milvuspb.AlterDatabaseRequest) (*commonpb.Status, error) {
return s.proxy.AlterDatabase(ctx, req) return s.proxy.AlterDatabase(ctx, req)
} }
func (s *Server) InvalidateShardLeaderCache(ctx context.Context, req *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error) {
return s.proxy.InvalidateShardLeaderCache(ctx, req)
}

View File

@ -229,6 +229,12 @@ func Test_NewServer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("InvalidateShardLeaderCache", func(t *testing.T) {
mockProxy.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(nil, nil)
_, err := server.InvalidateShardLeaderCache(ctx, nil)
assert.NoError(t, err)
})
t.Run("CreateCollection", func(t *testing.T) { t.Run("CreateCollection", func(t *testing.T) {
mockProxy.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(nil, nil) mockProxy.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(nil, nil)
_, err := server.CreateCollection(ctx, nil) _, err := server.CreateCollection(ctx, nil)

View File

@ -3634,6 +3634,61 @@ func (_c *MockProxy_InvalidateCredentialCache_Call) RunAndReturn(run func(contex
return _c return _c
} }
// InvalidateShardLeaderCache provides a mock function with given fields: _a0, _a1
func (_m *MockProxy) InvalidateShardLeaderCache(_a0 context.Context, _a1 *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockProxy_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
type MockProxy_InvalidateShardLeaderCache_Call struct {
*mock.Call
}
// InvalidateShardLeaderCache is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *proxypb.InvalidateShardLeaderCacheRequest
func (_e *MockProxy_Expecter) InvalidateShardLeaderCache(_a0 interface{}, _a1 interface{}) *MockProxy_InvalidateShardLeaderCache_Call {
return &MockProxy_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", _a0, _a1)}
}
func (_c *MockProxy_InvalidateShardLeaderCache_Call) Run(run func(_a0 context.Context, _a1 *proxypb.InvalidateShardLeaderCacheRequest)) *MockProxy_InvalidateShardLeaderCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*proxypb.InvalidateShardLeaderCacheRequest))
})
return _c
}
func (_c *MockProxy_InvalidateShardLeaderCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxy_InvalidateShardLeaderCache_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockProxy_InvalidateShardLeaderCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error)) *MockProxy_InvalidateShardLeaderCache_Call {
_c.Call.Return(run)
return _c
}
// ListAliases provides a mock function with given fields: _a0, _a1 // ListAliases provides a mock function with given fields: _a0, _a1
func (_m *MockProxy) ListAliases(_a0 context.Context, _a1 *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error) { func (_m *MockProxy) ListAliases(_a0 context.Context, _a1 *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)

View File

@ -632,6 +632,76 @@ func (_c *MockProxyClient_InvalidateCredentialCache_Call) RunAndReturn(run func(
return _c return _c
} }
// InvalidateShardLeaderCache provides a mock function with given fields: ctx, in, opts
func (_m *MockProxyClient) InvalidateShardLeaderCache(ctx context.Context, in *proxypb.InvalidateShardLeaderCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockProxyClient_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
type MockProxyClient_InvalidateShardLeaderCache_Call struct {
*mock.Call
}
// InvalidateShardLeaderCache is a helper method to define mock.On call
// - ctx context.Context
// - in *proxypb.InvalidateShardLeaderCacheRequest
// - opts ...grpc.CallOption
func (_e *MockProxyClient_Expecter) InvalidateShardLeaderCache(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_InvalidateShardLeaderCache_Call {
return &MockProxyClient_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockProxyClient_InvalidateShardLeaderCache_Call) Run(run func(ctx context.Context, in *proxypb.InvalidateShardLeaderCacheRequest, opts ...grpc.CallOption)) *MockProxyClient_InvalidateShardLeaderCache_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*proxypb.InvalidateShardLeaderCacheRequest), variadicArgs...)
})
return _c
}
func (_c *MockProxyClient_InvalidateShardLeaderCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxyClient_InvalidateShardLeaderCache_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockProxyClient_InvalidateShardLeaderCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockProxyClient_InvalidateShardLeaderCache_Call {
_c.Call.Return(run)
return _c
}
// ListClientInfos provides a mock function with given fields: ctx, in, opts // ListClientInfos provides a mock function with given fields: ctx, in, opts
func (_m *MockProxyClient) ListClientInfos(ctx context.Context, in *proxypb.ListClientInfosRequest, opts ...grpc.CallOption) (*proxypb.ListClientInfosResponse, error) { func (_m *MockProxyClient) ListClientInfos(ctx context.Context, in *proxypb.ListClientInfosRequest, opts ...grpc.CallOption) (*proxypb.ListClientInfosResponse, error) {
_va := make([]interface{}, len(opts)) _va := make([]interface{}, len(opts))

View File

@ -27,6 +27,8 @@ service Proxy {
rpc ImportV2(internal.ImportRequest) returns(internal.ImportResponse){} rpc ImportV2(internal.ImportRequest) returns(internal.ImportResponse){}
rpc GetImportProgress(internal.GetImportProgressRequest) returns(internal.GetImportProgressResponse){} rpc GetImportProgress(internal.GetImportProgressRequest) returns(internal.GetImportProgressResponse){}
rpc ListImports(internal.ListImportsRequest) returns(internal.ListImportsResponse){} rpc ListImports(internal.ListImportsRequest) returns(internal.ListImportsResponse){}
rpc InvalidateShardLeaderCache(InvalidateShardLeaderCacheRequest) returns (common.Status) {}
} }
message InvalidateCollMetaCacheRequest { message InvalidateCollMetaCacheRequest {
@ -40,6 +42,11 @@ message InvalidateCollMetaCacheRequest {
string partition_name = 5; string partition_name = 5;
} }
message InvalidateShardLeaderCacheRequest {
common.MsgBase base = 1;
repeated int64 collectionIDs = 2;
}
message InvalidateCredCacheRequest { message InvalidateCredCacheRequest {
common.MsgBase base = 1; common.MsgBase base = 1;
string username = 2; string username = 2;

View File

@ -172,6 +172,29 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
return merr.Success(), nil return merr.Success(), nil
} }
// InvalidateCollectionMetaCache invalidate the meta cache of specific collection.
func (node *Proxy) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) (*commonpb.Status, error) {
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return merr.Status(err), nil
}
ctx = logutil.WithModule(ctx, moduleName)
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-InvalidateShardLeaderCache")
defer sp.End()
log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole),
)
log.Info("received request to invalidate shard leader cache", zap.Int64s("collectionIDs", request.GetCollectionIDs()))
if globalMetaCache != nil {
globalMetaCache.InvalidateShardLeaderCache(request.GetCollectionIDs())
}
log.Info("complete to invalidate shard leader cache", zap.Int64s("collectionIDs", request.GetCollectionIDs()))
return merr.Success(), nil
}
func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
if err := merr.CheckHealthy(node.GetStateCode()); err != nil { if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return merr.Status(err), nil return merr.Status(err), nil

View File

@ -1720,3 +1720,30 @@ func TestGetCollectionRateSubLabel(t *testing.T) {
} }
}) })
} }
func TestProxy_InvalidateShardLeaderCache(t *testing.T) {
t.Run("proxy unhealthy", func(t *testing.T) {
node := &Proxy{}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err := node.InvalidateShardLeaderCache(context.TODO(), nil)
assert.NoError(t, err)
assert.False(t, merr.Ok(resp))
})
t.Run("success", func(t *testing.T) {
node := &Proxy{}
node.UpdateStateCode(commonpb.StateCode_Healthy)
cacheBak := globalMetaCache
defer func() { globalMetaCache = cacheBak }()
// set expectations
cache := NewMockCache(t)
cache.EXPECT().InvalidateShardLeaderCache(mock.Anything)
globalMetaCache = cache
resp, err := node.InvalidateShardLeaderCache(context.TODO(), &proxypb.InvalidateShardLeaderCacheRequest{})
assert.NoError(t, err)
assert.True(t, merr.Ok(resp))
})
}

View File

@ -73,6 +73,7 @@ type Cache interface {
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error)
GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
DeprecateShardCache(database, collectionName string) DeprecateShardCache(database, collectionName string)
InvalidateShardLeaderCache(collections []int64)
RemoveCollection(ctx context.Context, database, collectionName string) RemoveCollection(ctx context.Context, database, collectionName string)
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string
RemovePartition(ctx context.Context, database, collectionName string, partitionName string) RemovePartition(ctx context.Context, database, collectionName string, partitionName string)
@ -201,6 +202,7 @@ type shardLeaders struct {
idx *atomic.Int64 idx *atomic.Int64
deprecated *atomic.Bool deprecated *atomic.Bool
collectionID int64
shardLeaders map[string][]nodeInfo shardLeaders map[string][]nodeInfo
} }
@ -944,6 +946,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
shards := parseShardLeaderList2QueryNode(resp.GetShards()) shards := parseShardLeaderList2QueryNode(resp.GetShards())
newShardLeaders := &shardLeaders{ newShardLeaders := &shardLeaders{
collectionID: info.collID,
shardLeaders: shards, shardLeaders: shards,
deprecated: atomic.NewBool(false), deprecated: atomic.NewBool(false),
idx: atomic.NewInt64(0), idx: atomic.NewInt64(0),
@ -997,6 +1000,21 @@ func (m *MetaCache) DeprecateShardCache(database, collectionName string) {
} }
} }
func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) {
log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections))
m.mu.RLock()
defer m.mu.RUnlock()
collectionSet := typeutil.NewUniqueSet(collections...)
for _, db := range m.collLeader {
for _, shardLeaders := range db {
if collectionSet.Contain(shardLeaders.collectionID) {
shardLeaders.deprecated.Store(true)
}
}
}
}
func (m *MetaCache) InitPolicyInfo(info []string, userRoles []string) { func (m *MetaCache) InitPolicyInfo(info []string, userRoles []string) {
defer func() { defer func() {
err := getEnforcer().LoadPolicy() err := getEnforcer().LoadPolicy()

View File

@ -42,6 +42,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
@ -1070,3 +1071,51 @@ func TestGlobalMetaCache_GetCollectionNamesByID(t *testing.T) {
assert.Equal(t, []string{"db1", "db1"}, dbNames) assert.Equal(t, []string{"db1", "db1"}, dbNames)
}) })
} }
func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) {
paramtable.Init()
paramtable.Get().Save(Params.ProxyCfg.ShardLeaderCacheInterval.Key, "1")
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &mocks.MockQueryCoordClient{}
shardMgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: merr.Success(),
CollectionIDs: []UniqueID{1},
InMemoryPercentages: []int64{100},
}, nil)
called := uatomic.NewInt32(0)
queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context,
gslr *querypb.GetShardLeadersRequest, co ...grpc.CallOption,
) (*querypb.GetShardLeadersResponse, error) {
called.Inc()
return &querypb.GetShardLeadersResponse{
Status: merr.Success(),
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil
})
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err)
assert.Len(t, nodeInfos["channel-1"], 3)
assert.Equal(t, called.Load(), int32(1))
globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.Equal(t, called.Load(), int32(1))
globalMetaCache.InvalidateShardLeaderCache([]int64{1})
nodeInfos, err = globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err)
assert.Len(t, nodeInfos["channel-1"], 3)
assert.Equal(t, called.Load(), int32(2))
}

View File

@ -952,6 +952,39 @@ func (_c *MockCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []strin
return _c return _c
} }
// InvalidateShardLeaderCache provides a mock function with given fields: collections
func (_m *MockCache) InvalidateShardLeaderCache(collections []int64) {
_m.Called(collections)
}
// MockCache_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
type MockCache_InvalidateShardLeaderCache_Call struct {
*mock.Call
}
// InvalidateShardLeaderCache is a helper method to define mock.On call
// - collections []int64
func (_e *MockCache_Expecter) InvalidateShardLeaderCache(collections interface{}) *MockCache_InvalidateShardLeaderCache_Call {
return &MockCache_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", collections)}
}
func (_c *MockCache_InvalidateShardLeaderCache_Call) Run(run func(collections []int64)) *MockCache_InvalidateShardLeaderCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]int64))
})
return _c
}
func (_c *MockCache_InvalidateShardLeaderCache_Call) Return() *MockCache_InvalidateShardLeaderCache_Call {
_c.Call.Return()
return _c
}
func (_c *MockCache_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int64)) *MockCache_InvalidateShardLeaderCache_Call {
_c.Call.Return(run)
return _c
}
// RefreshPolicyInfo provides a mock function with given fields: op // RefreshPolicyInfo provides a mock function with given fields: op
func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error { func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
ret := _m.Called(op) ret := _m.Called(op)

View File

@ -191,7 +191,7 @@ func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int
continue continue
} }
if err := CheckLeaderAvailable(c.nodeMgr, leaderView, targets); err != nil { if err := utils.CheckLeaderAvailable(c.nodeMgr, leaderView, targets); err != nil {
log.RatedInfo(10, "replica has unavailable shard leader", log.RatedInfo(10, "replica has unavailable shard leader",
zap.Int64("collectionID", replica.GetCollectionID()), zap.Int64("collectionID", replica.GetCollectionID()),
zap.Int64("replicaID", replicaID), zap.Int64("replicaID", replicaID),

View File

@ -1,81 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checkers
import (
"context"
"fmt"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)
func CheckNodeAvailable(nodeID int64, info *session.NodeInfo) error {
if info == nil {
return merr.WrapErrNodeOffline(nodeID)
}
return nil
}
// In a replica, a shard is available, if and only if:
// 1. The leader is online
// 2. All QueryNodes in the distribution are online
// 3. The last heartbeat response time is within HeartbeatAvailableInterval for all QueryNodes(include leader) in the distribution
// 4. All segments of the shard in target should be in the distribution
func CheckLeaderAvailable(nodeMgr *session.NodeManager, leader *meta.LeaderView, currentTargets map[int64]*datapb.SegmentInfo) error {
log := log.Ctx(context.TODO()).
WithRateGroup("checkers.CheckLeaderAvailable", 1, 60).
With(zap.Int64("leaderID", leader.ID))
info := nodeMgr.Get(leader.ID)
// Check whether leader is online
err := CheckNodeAvailable(leader.ID, info)
if err != nil {
log.Info("leader is not available", zap.Error(err))
return fmt.Errorf("leader not available: %w", err)
}
for id, version := range leader.Segments {
info := nodeMgr.Get(version.GetNodeID())
err = CheckNodeAvailable(version.GetNodeID(), info)
if err != nil {
log.Info("leader is not available due to QueryNode unavailable",
zap.Int64("segmentID", id),
zap.Error(err))
return err
}
}
// Check whether segments are fully loaded
for segmentID, info := range currentTargets {
if info.GetInsertChannel() != leader.Channel {
continue
}
_, exist := leader.Segments[segmentID]
if !exist {
log.RatedInfo(10, "leader is not available due to lack of segment", zap.Int64("segmentID", segmentID))
return merr.WrapErrSegmentLack(segmentID)
}
}
return nil
}

View File

@ -426,31 +426,3 @@ func (s *Server) fillReplicaInfo(replica *meta.Replica, withShardNodes bool) *mi
info.ShardReplicas = shardReplicas info.ShardReplicas = shardReplicas
return info return info
} }
func filterDupLeaders(replicaManager *meta.ReplicaManager, leaders map[int64]*meta.LeaderView) map[int64]*meta.LeaderView {
type leaderID struct {
ReplicaID int64
Shard string
}
newLeaders := make(map[leaderID]*meta.LeaderView)
for _, view := range leaders {
replica := replicaManager.GetByCollectionAndNode(view.CollectionID, view.ID)
if replica == nil {
continue
}
id := leaderID{replica.GetID(), view.Channel}
if old, ok := newLeaders[id]; ok && old.Version > view.Version {
continue
}
newLeaders[id] = view
}
result := make(map[int64]*meta.LeaderView)
for _, v := range newLeaders {
result[v.ID] = v
}
return result
}

View File

@ -22,6 +22,7 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
type lvCriterion struct { type lvCriterion struct {
@ -192,9 +193,12 @@ func composeNodeViews(views ...*LeaderView) nodeViews {
} }
} }
type NotifyDelegatorChanges = func(collectionID ...int64)
type LeaderViewManager struct { type LeaderViewManager struct {
rwmutex sync.RWMutex rwmutex sync.RWMutex
views map[int64]nodeViews // LeaderID -> Views (one per shard) views map[int64]nodeViews // LeaderID -> Views (one per shard)
notifyFunc NotifyDelegatorChanges
} }
func NewLeaderViewManager() *LeaderViewManager { func NewLeaderViewManager() *LeaderViewManager {
@ -203,11 +207,45 @@ func NewLeaderViewManager() *LeaderViewManager {
} }
} }
func (mgr *LeaderViewManager) SetNotifyFunc(notifyFunc NotifyDelegatorChanges) {
mgr.notifyFunc = notifyFunc
}
// Update updates the leader's views, all views have to be with the same leader ID // Update updates the leader's views, all views have to be with the same leader ID
func (mgr *LeaderViewManager) Update(leaderID int64, views ...*LeaderView) { func (mgr *LeaderViewManager) Update(leaderID int64, views ...*LeaderView) {
mgr.rwmutex.Lock() mgr.rwmutex.Lock()
defer mgr.rwmutex.Unlock() defer mgr.rwmutex.Unlock()
oldViews := make(map[string]*LeaderView, 0)
if _, ok := mgr.views[leaderID]; ok {
oldViews = mgr.views[leaderID].channelView
}
newViews := lo.SliceToMap(views, func(v *LeaderView) (string, *LeaderView) {
return v.Channel, v
})
// update leader views
mgr.views[leaderID] = composeNodeViews(views...) mgr.views[leaderID] = composeNodeViews(views...)
// compute leader location change, find it's correspond collection
if mgr.notifyFunc != nil {
viewChanges := typeutil.NewUniqueSet()
for channel, oldView := range oldViews {
// if channel released from current node
if _, ok := newViews[channel]; !ok {
viewChanges.Insert(oldView.CollectionID)
}
}
for channel, newView := range newViews {
// if channel loaded to current node
if _, ok := oldViews[channel]; !ok {
viewChanges.Insert(newView.CollectionID)
}
}
mgr.notifyFunc(viewChanges.Collect()...)
}
} }
func (mgr *LeaderViewManager) GetLeaderShardView(id int64, shard string) *LeaderView { func (mgr *LeaderViewManager) GetLeaderShardView(id int64, shard string) *LeaderView {

View File

@ -208,6 +208,58 @@ func (suite *LeaderViewManagerSuite) TestClone() {
} }
} }
func (suite *LeaderViewManagerSuite) TestNotifyDelegatorChanges() {
mgr := NewLeaderViewManager()
oldViews := []*LeaderView{
{
ID: 1,
CollectionID: 100,
Channel: "test-channel-1",
},
{
ID: 1,
CollectionID: 101,
Channel: "test-channel-2",
},
{
ID: 1,
CollectionID: 102,
Channel: "test-channel-3",
},
}
mgr.Update(1, oldViews...)
newViews := []*LeaderView{
{
ID: 1,
CollectionID: 101,
Channel: "test-channel-2",
},
{
ID: 1,
CollectionID: 102,
Channel: "test-channel-3",
},
{
ID: 1,
CollectionID: 103,
Channel: "test-channel-4",
},
}
updateCollections := make([]int64, 0)
mgr.SetNotifyFunc(func(collectionIDs ...int64) {
updateCollections = append(updateCollections, collectionIDs...)
})
mgr.Update(1, newViews...)
suite.Equal(2, len(updateCollections))
suite.Contains(updateCollections, int64(100))
suite.Contains(updateCollections, int64(103))
}
func TestLeaderViewManager(t *testing.T) { func TestLeaderViewManager(t *testing.T) {
suite.Run(t, new(LeaderViewManagerSuite)) suite.Run(t, new(LeaderViewManagerSuite))
} }

View File

@ -0,0 +1,114 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package observers
import (
"context"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type CollectionShardLeaderCache = map[string]*querypb.ShardLeadersList
// LeaderCacheObserver is to invalidate shard leader cache when leader location changes
type LeaderCacheObserver struct {
wg sync.WaitGroup
proxyManager proxyutil.ProxyClientManagerInterface
stopOnce sync.Once
closeCh chan struct{}
// collections which need to update event
eventCh chan int64
}
func (o *LeaderCacheObserver) Start(ctx context.Context) {
o.wg.Add(1)
go o.schedule(ctx)
}
func (o *LeaderCacheObserver) Stop() {
o.stopOnce.Do(func() {
close(o.closeCh)
o.wg.Wait()
})
}
func (o *LeaderCacheObserver) RegisterEvent(events ...int64) {
for _, event := range events {
o.eventCh <- event
}
}
func (o *LeaderCacheObserver) schedule(ctx context.Context) {
defer o.wg.Done()
for {
select {
case <-ctx.Done():
log.Info("stop leader cache observer due to context done")
return
case <-o.closeCh:
log.Info("stop leader cache observer")
return
case event := <-o.eventCh:
log.Info("receive event, trigger leader cache update", zap.Int64("event", event))
ret := make([]int64, 0)
ret = append(ret, event)
// try batch submit events
eventNum := len(o.eventCh)
if eventNum > 0 {
for eventNum > 0 {
event := <-o.eventCh
ret = append(ret, event)
eventNum--
}
}
o.HandleEvent(ctx, ret...)
}
}
}
func (o *LeaderCacheObserver) HandleEvent(ctx context.Context, collectionIDs ...int64) {
ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Second))
defer cancel()
err := o.proxyManager.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{
CollectionIDs: collectionIDs,
})
if err != nil {
log.Warn("failed to invalidate proxy's shard leader cache", zap.Error(err))
return
}
}
func NewLeaderCacheObserver(
proxyManager proxyutil.ProxyClientManagerInterface,
) *LeaderCacheObserver {
return &LeaderCacheObserver{
proxyManager: proxyManager,
closeCh: make(chan struct{}),
eventCh: make(chan int64, 1024),
}
}

View File

@ -0,0 +1,94 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package observers
import (
"context"
"testing"
"time"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type LeaderCacheObserverTestSuite struct {
suite.Suite
mockProxyManager *proxyutil.MockProxyClientManager
observer *LeaderCacheObserver
}
func (suite *LeaderCacheObserverTestSuite) SetupSuite() {
paramtable.Init()
suite.mockProxyManager = proxyutil.NewMockProxyClientManager(suite.T())
suite.observer = NewLeaderCacheObserver(suite.mockProxyManager)
}
func (suite *LeaderCacheObserverTestSuite) TestInvalidateShardLeaderCache() {
suite.observer.Start(context.TODO())
defer suite.observer.Stop()
ret := atomic.NewBool(false)
collectionIDs := typeutil.NewConcurrentSet[int64]()
suite.mockProxyManager.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, req *proxypb.InvalidateShardLeaderCacheRequest) error {
collectionIDs.Upsert(req.GetCollectionIDs()...)
collectionIDs := req.GetCollectionIDs()
if len(collectionIDs) == 1 && lo.Contains(collectionIDs, 1) {
ret.Store(true)
}
return nil
})
suite.observer.RegisterEvent(1)
suite.Eventually(func() bool {
return ret.Load()
}, 3*time.Second, 1*time.Second)
// test batch submit events
ret.Store(false)
suite.mockProxyManager.ExpectedCalls = nil
suite.mockProxyManager.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, req *proxypb.InvalidateShardLeaderCacheRequest) error {
collectionIDs.Upsert(req.GetCollectionIDs()...)
collectionIDs := req.GetCollectionIDs()
if len(collectionIDs) == 3 && lo.Contains(collectionIDs, 1) && lo.Contains(collectionIDs, 2) && lo.Contains(collectionIDs, 3) {
ret.Store(true)
}
return nil
})
suite.observer.RegisterEvent(1)
suite.observer.RegisterEvent(2)
suite.observer.RegisterEvent(3)
suite.Eventually(func() bool {
return ret.Load()
}, 3*time.Second, 1*time.Second)
}
func TestLeaderCacheObserverTestSuite(t *testing.T) {
suite.Run(t, new(LeaderCacheObserverTestSuite))
}

View File

@ -51,6 +51,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
@ -108,10 +109,11 @@ type Server struct {
checkerController *checkers.CheckerController checkerController *checkers.CheckerController
// Observers // Observers
collectionObserver *observers.CollectionObserver collectionObserver *observers.CollectionObserver
targetObserver *observers.TargetObserver targetObserver *observers.TargetObserver
replicaObserver *observers.ReplicaObserver replicaObserver *observers.ReplicaObserver
resourceObserver *observers.ResourceObserver resourceObserver *observers.ResourceObserver
leaderCacheObserver *observers.LeaderCacheObserver
balancer balance.Balance balancer balance.Balance
balancerMap map[string]balance.Balance balancerMap map[string]balance.Balance
@ -122,6 +124,11 @@ type Server struct {
nodeUpEventChan chan int64 nodeUpEventChan chan int64
notifyNodeUp chan struct{} notifyNodeUp chan struct{}
// proxy client manager
proxyCreator proxyutil.ProxyCreator
proxyWatcher proxyutil.ProxyWatcherInterface
proxyClientManager proxyutil.ProxyClientManagerInterface
} }
func NewQueryCoord(ctx context.Context) (*Server, error) { func NewQueryCoord(ctx context.Context) (*Server, error) {
@ -261,6 +268,16 @@ func (s *Server) initQueryCoord() error {
s.nodeMgr, s.nodeMgr,
) )
// init proxy client manager
s.proxyClientManager = proxyutil.NewProxyClientManager(proxyutil.DefaultProxyCreator)
s.proxyWatcher = proxyutil.NewProxyWatcher(
s.etcdCli,
s.proxyClientManager.AddProxyClients,
)
s.proxyWatcher.AddSessionFunc(s.proxyClientManager.AddProxyClient)
s.proxyWatcher.DelSessionFunc(s.proxyClientManager.DelProxyClient)
log.Info("init proxy manager done")
// Init heartbeat // Init heartbeat
log.Info("init dist controller") log.Info("init dist controller")
s.distController = dist.NewDistController( s.distController = dist.NewDistController(
@ -387,6 +404,11 @@ func (s *Server) initObserver() {
) )
s.resourceObserver = observers.NewResourceObserver(s.meta) s.resourceObserver = observers.NewResourceObserver(s.meta)
s.leaderCacheObserver = observers.NewLeaderCacheObserver(
s.proxyClientManager,
)
s.dist.LeaderViewManager.SetNotifyFunc(s.leaderCacheObserver.RegisterEvent)
} }
func (s *Server) afterStart() {} func (s *Server) afterStart() {}
@ -432,6 +454,10 @@ func (s *Server) startQueryCoord() error {
// check whether old node exist, if yes suspend auto balance until all old nodes down // check whether old node exist, if yes suspend auto balance until all old nodes down
s.updateBalanceConfigLoop(s.ctx) s.updateBalanceConfigLoop(s.ctx)
if err := s.proxyWatcher.WatchProxy(s.ctx); err != nil {
log.Warn("querycoord failed to watch proxy", zap.Error(err))
}
// Recover dist, to avoid generate too much task when dist not ready after restart // Recover dist, to avoid generate too much task when dist not ready after restart
s.distController.SyncAll(s.ctx) s.distController.SyncAll(s.ctx)
@ -453,6 +479,7 @@ func (s *Server) startServerLoop() {
s.targetObserver.Start() s.targetObserver.Start()
s.replicaObserver.Start() s.replicaObserver.Start()
s.resourceObserver.Start() s.resourceObserver.Start()
s.leaderCacheObserver.Start(s.ctx)
log.Info("start task scheduler...") log.Info("start task scheduler...")
s.taskScheduler.Start() s.taskScheduler.Start()
@ -504,6 +531,9 @@ func (s *Server) Stop() error {
if s.resourceObserver != nil { if s.resourceObserver != nil {
s.resourceObserver.Stop() s.resourceObserver.Stop()
} }
if s.leaderCacheObserver != nil {
s.leaderCacheObserver.Stop()
}
if s.distController != nil { if s.distController != nil {
log.Info("stop dist controller...") log.Info("stop dist controller...")

View File

@ -23,7 +23,6 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo" "github.com/samber/lo"
"go.uber.org/multierr"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -31,7 +30,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/checkers"
"github.com/milvus-io/milvus/internal/querycoordv2/job" "github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/internal/querycoordv2/utils"
@ -867,99 +865,11 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade
}, nil }, nil
} }
resp := &querypb.GetShardLeadersResponse{ leaders, err := utils.GetShardLeaders(s.meta, s.targetMgr, s.dist, s.nodeMgr, req.GetCollectionID())
Status: merr.Success(), return &querypb.GetShardLeadersResponse{
} Status: merr.Status(err),
Shards: leaders,
percentage := s.meta.CollectionManager.CalculateLoadPercentage(req.GetCollectionID()) }, nil
if percentage < 0 {
err := merr.WrapErrCollectionNotLoaded(req.GetCollectionID())
log.Warn("failed to GetShardLeaders", zap.Error(err))
resp.Status = merr.Status(err)
return resp, nil
}
collection := s.meta.CollectionManager.GetCollection(req.GetCollectionID())
if collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded {
// when collection is loaded, regard collection as readable, set percentage == 100
percentage = 100
}
if percentage < 100 {
err := merr.WrapErrCollectionNotFullyLoaded(req.GetCollectionID())
msg := fmt.Sprintf("collection %v is not fully loaded", req.GetCollectionID())
log.Warn(msg)
resp.Status = merr.Status(err)
return resp, nil
}
channels := s.targetMgr.GetDmChannelsByCollection(req.GetCollectionID(), meta.CurrentTarget)
if len(channels) == 0 {
err := merr.WrapErrCollectionOnRecovering(req.GetCollectionID(),
"loaded collection do not found any channel in target, may be in recovery")
log.Warn("failed to get channels", zap.Error(err))
resp.Status = merr.Status(err)
return resp, nil
}
currentTargets := s.targetMgr.GetSealedSegmentsByCollection(req.GetCollectionID(), meta.CurrentTarget)
for _, channel := range channels {
log := log.With(zap.String("channel", channel.GetChannelName()))
leaders := s.dist.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(channel.GetChannelName()))
readableLeaders := make(map[int64]*meta.LeaderView)
var channelErr error
if len(leaders) == 0 {
channelErr = merr.WrapErrChannelLack(channel.GetChannelName(), "channel not subscribed")
}
for _, leader := range leaders {
if err := checkers.CheckLeaderAvailable(s.nodeMgr, leader, currentTargets); err != nil {
multierr.AppendInto(&channelErr, err)
continue
}
readableLeaders[leader.ID] = leader
}
if len(readableLeaders) == 0 {
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
log.Warn(msg, zap.Error(channelErr))
resp.Status = merr.Status(
errors.Wrap(merr.WrapErrChannelNotAvailable(channel.GetChannelName()), channelErr.Error()))
resp.Shards = nil
return resp, nil
}
readableLeaders = filterDupLeaders(s.meta.ReplicaManager, readableLeaders)
ids := make([]int64, 0, len(leaders))
addrs := make([]string, 0, len(leaders))
for _, leader := range readableLeaders {
info := s.nodeMgr.Get(leader.ID)
if info != nil {
ids = append(ids, info.ID())
addrs = append(addrs, info.Addr())
}
}
// to avoid node down during GetShardLeaders
if len(ids) == 0 {
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
log.Warn(msg, zap.Error(channelErr))
resp.Status = merr.Status(
errors.Wrap(merr.WrapErrChannelNotAvailable(channel.GetChannelName()), channelErr.Error()))
resp.Shards = nil
return resp, nil
}
resp.Shards = append(resp.Shards, &querypb.ShardLeadersList{
ChannelName: channel.GetChannelName(),
NodeIds: ids,
NodeAddrs: addrs,
})
}
return resp, nil
} }
func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {

View File

@ -0,0 +1,195 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"context"
"fmt"
"go.uber.org/multierr"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)
func CheckNodeAvailable(nodeID int64, info *session.NodeInfo) error {
if info == nil {
return merr.WrapErrNodeOffline(nodeID)
}
return nil
}
// In a replica, a shard is available, if and only if:
// 1. The leader is online
// 2. All QueryNodes in the distribution are online
// 3. The last heartbeat response time is within HeartbeatAvailableInterval for all QueryNodes(include leader) in the distribution
// 4. All segments of the shard in target should be in the distribution
func CheckLeaderAvailable(nodeMgr *session.NodeManager, leader *meta.LeaderView, currentTargets map[int64]*datapb.SegmentInfo) error {
log := log.Ctx(context.TODO()).
WithRateGroup("utils.CheckLeaderAvailable", 1, 60).
With(zap.Int64("leaderID", leader.ID))
info := nodeMgr.Get(leader.ID)
// Check whether leader is online
err := CheckNodeAvailable(leader.ID, info)
if err != nil {
log.Info("leader is not available", zap.Error(err))
return fmt.Errorf("leader not available: %w", err)
}
for id, version := range leader.Segments {
info := nodeMgr.Get(version.GetNodeID())
err = CheckNodeAvailable(version.GetNodeID(), info)
if err != nil {
log.Info("leader is not available due to QueryNode unavailable",
zap.Int64("segmentID", id),
zap.Error(err))
return err
}
}
// Check whether segments are fully loaded
for segmentID, info := range currentTargets {
if info.GetInsertChannel() != leader.Channel {
continue
}
_, exist := leader.Segments[segmentID]
if !exist {
log.RatedInfo(10, "leader is not available due to lack of segment", zap.Int64("segmentID", segmentID))
return merr.WrapErrSegmentLack(segmentID)
}
}
return nil
}
func GetShardLeaders(m *meta.Meta, targetMgr *meta.TargetManager, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64) ([]*querypb.ShardLeadersList, error) {
percentage := m.CollectionManager.CalculateLoadPercentage(collectionID)
if percentage < 0 {
err := merr.WrapErrCollectionNotLoaded(collectionID)
log.Warn("failed to GetShardLeaders", zap.Error(err))
return nil, err
}
collection := m.CollectionManager.GetCollection(collectionID)
if collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded {
// when collection is loaded, regard collection as readable, set percentage == 100
percentage = 100
}
if percentage < 100 {
err := merr.WrapErrCollectionNotFullyLoaded(collectionID)
msg := fmt.Sprintf("collection %v is not fully loaded", collectionID)
log.Warn(msg)
return nil, err
}
channels := targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget)
if len(channels) == 0 {
msg := "loaded collection do not found any channel in target, may be in recovery"
err := merr.WrapErrCollectionOnRecovering(collectionID, msg)
log.Warn("failed to get channels", zap.Error(err))
return nil, err
}
ret := make([]*querypb.ShardLeadersList, 0)
currentTargets := targetMgr.GetSealedSegmentsByCollection(collectionID, meta.CurrentTarget)
for _, channel := range channels {
log := log.With(zap.String("channel", channel.GetChannelName()))
var channelErr error
leaders := dist.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(channel.GetChannelName()))
if len(leaders) == 0 {
channelErr = merr.WrapErrChannelLack(channel.GetChannelName(), "channel not subscribed")
}
readableLeaders := make(map[int64]*meta.LeaderView)
for _, leader := range leaders {
if err := CheckLeaderAvailable(nodeMgr, leader, currentTargets); err != nil {
multierr.AppendInto(&channelErr, err)
continue
}
readableLeaders[leader.ID] = leader
}
if len(readableLeaders) == 0 {
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
log.Warn(msg, zap.Error(channelErr))
err := merr.WrapErrChannelNotAvailable(channel.GetChannelName(), channelErr.Error())
return nil, err
}
readableLeaders = filterDupLeaders(m.ReplicaManager, readableLeaders)
ids := make([]int64, 0, len(leaders))
addrs := make([]string, 0, len(leaders))
for _, leader := range readableLeaders {
info := nodeMgr.Get(leader.ID)
if info != nil {
ids = append(ids, info.ID())
addrs = append(addrs, info.Addr())
}
}
// to avoid node down during GetShardLeaders
if len(ids) == 0 {
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
log.Warn(msg, zap.Error(channelErr))
err := merr.WrapErrChannelNotAvailable(channel.GetChannelName(), channelErr.Error())
return nil, err
}
ret = append(ret, &querypb.ShardLeadersList{
ChannelName: channel.GetChannelName(),
NodeIds: ids,
NodeAddrs: addrs,
})
}
return ret, nil
}
func filterDupLeaders(replicaManager *meta.ReplicaManager, leaders map[int64]*meta.LeaderView) map[int64]*meta.LeaderView {
type leaderID struct {
ReplicaID int64
Shard string
}
newLeaders := make(map[leaderID]*meta.LeaderView)
for _, view := range leaders {
replica := replicaManager.GetByCollectionAndNode(view.CollectionID, view.ID)
if replica == nil {
continue
}
id := leaderID{replica.GetID(), view.Channel}
if old, ok := newLeaders[id]; ok && old.Version > view.Version {
continue
}
newLeaders[id] = view
}
result := make(map[int64]*meta.LeaderView)
for _, v := range newLeaders {
result[v.ID] = v
}
return result
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package checkers package utils
import ( import (
"testing" "testing"

View File

@ -422,6 +422,49 @@ func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) RunAndReturn(ru
return _c return _c
} }
// InvalidateShardLeaderCache provides a mock function with given fields: ctx, request
func (_m *MockProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error {
ret := _m.Called(ctx, request)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok {
r0 = rf(ctx, request)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockProxyClientManager_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
type MockProxyClientManager_InvalidateShardLeaderCache_Call struct {
*mock.Call
}
// InvalidateShardLeaderCache is a helper method to define mock.On call
// - ctx context.Context
// - request *proxypb.InvalidateShardLeaderCacheRequest
func (_e *MockProxyClientManager_Expecter) InvalidateShardLeaderCache(ctx interface{}, request interface{}) *MockProxyClientManager_InvalidateShardLeaderCache_Call {
return &MockProxyClientManager_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", ctx, request)}
}
func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest)) *MockProxyClientManager_InvalidateShardLeaderCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*proxypb.InvalidateShardLeaderCacheRequest))
})
return _c
}
func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) Return(_a0 error) *MockProxyClientManager_InvalidateShardLeaderCache_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error) *MockProxyClientManager_InvalidateShardLeaderCache_Call {
_c.Call.Return(run)
return _c
}
// RefreshPolicyInfoCache provides a mock function with given fields: ctx, req // RefreshPolicyInfoCache provides a mock function with given fields: ctx, req
func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)

View File

@ -88,6 +88,7 @@ type ProxyClientManagerInterface interface {
GetProxyCount() int GetProxyCount() int
InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...ExpireCacheOpt) error InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...ExpireCacheOpt) error
InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error
InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error
UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error
RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error
@ -188,6 +189,11 @@ func (p *ProxyClientManager) InvalidateCollectionMetaCache(ctx context.Context,
log.Warn("InvalidateCollectionMetaCache failed due to proxy service not found", zap.Error(err)) log.Warn("InvalidateCollectionMetaCache failed due to proxy service not found", zap.Error(err))
return nil return nil
} }
if errors.Is(err, merr.ErrServiceUnimplemented) {
return nil
}
return fmt.Errorf("InvalidateCollectionMetaCache failed, proxyID = %d, err = %s", k, err) return fmt.Errorf("InvalidateCollectionMetaCache failed, proxyID = %d, err = %s", k, err)
} }
if sta.ErrorCode != commonpb.ErrorCode_Success { if sta.ErrorCode != commonpb.ErrorCode_Success {
@ -363,3 +369,31 @@ func (p *ProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]
return states, nil return states, nil
} }
func (p *ProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error {
if p.proxyClient.Len() == 0 {
log.Warn("proxy client is empty, InvalidateShardLeaderCache will not send to any client")
return nil
}
group := &errgroup.Group{}
p.proxyClient.Range(func(key int64, value types.ProxyClient) bool {
k, v := key, value
group.Go(func() error {
sta, err := v.InvalidateShardLeaderCache(ctx, request)
if err != nil {
if errors.Is(err, merr.ErrNodeNotFound) {
log.Warn("InvalidateShardLeaderCache failed due to proxy service not found", zap.Error(err))
return nil
}
return fmt.Errorf("InvalidateShardLeaderCache failed, proxyID = %d, err = %s", k, err)
}
if sta.ErrorCode != commonpb.ErrorCode_Success {
return fmt.Errorf("InvalidateShardLeaderCache failed, proxyID = %d, err = %s", k, sta.Reason)
}
return nil
})
return true
})
return group.Wait()
}

View File

@ -313,6 +313,14 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) {
}) })
} }
func TestProxyClientManager_TestGetProxyCount(t *testing.T) {
p1 := mocks.NewMockProxyClient(t)
pcm := NewProxyClientManager(DefaultProxyCreator)
pcm.proxyClient.Insert(TestProxyID, p1)
assert.Equal(t, pcm.GetProxyCount(), 1)
}
func TestProxyClientManager_GetProxyMetrics(t *testing.T) { func TestProxyClientManager_GetProxyMetrics(t *testing.T) {
TestProxyID := int64(1001) TestProxyID := int64(1001)
t.Run("empty proxy list", func(t *testing.T) { t.Run("empty proxy list", func(t *testing.T) {
@ -424,3 +432,34 @@ func TestProxyClientManager_GetComponentStates(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
func TestProxyClientManager_InvalidateShardLeaderCache(t *testing.T) {
TestProxyID := int64(1001)
t.Run("empty proxy list", func(t *testing.T) {
ctx := context.Background()
pcm := NewProxyClientManager(DefaultProxyCreator)
err := pcm.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
assert.NoError(t, err)
})
t.Run("mock rpc error", func(t *testing.T) {
ctx := context.Background()
p1 := mocks.NewMockProxyClient(t)
p1.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(nil, errors.New("error mock InvalidateCredentialCache"))
pcm := NewProxyClientManager(DefaultProxyCreator)
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
ctx := context.Background()
p1 := mocks.NewMockProxyClient(t)
p1.EXPECT().InvalidateShardLeaderCache(mock.Anything, mock.Anything).Return(merr.Success(), nil)
pcm := NewProxyClientManager(DefaultProxyCreator)
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateShardLeaderCache(ctx, &proxypb.InvalidateShardLeaderCacheRequest{})
assert.NoError(t, err)
})
}