enhance: refine querycoord meta/catalog related interfaces to ensure that each method includes a ctx parameter (#37916)

issue: #35917 
This PR refine the querycoord meta related interfaces to ensure that
each method includes a ctx parameter.

Signed-off-by: tinswzy <zhenyuan.wei@zilliz.com>
This commit is contained in:
tinswzy 2024-11-25 11:14:34 +08:00 committed by GitHub
parent 0b9edb62a9
commit e76802f910
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 2658 additions and 2344 deletions

View File

@ -182,23 +182,23 @@ type DataCoordCatalog interface {
}
type QueryCoordCatalog interface {
SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error
SavePartition(info ...*querypb.PartitionLoadInfo) error
SaveReplica(replicas ...*querypb.Replica) error
GetCollections() ([]*querypb.CollectionLoadInfo, error)
GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error)
GetReplicas() ([]*querypb.Replica, error)
ReleaseCollection(collection int64) error
ReleasePartition(collection int64, partitions ...int64) error
ReleaseReplicas(collectionID int64) error
ReleaseReplica(collection int64, replicas ...int64) error
SaveResourceGroup(rgs ...*querypb.ResourceGroup) error
RemoveResourceGroup(rgName string) error
GetResourceGroups() ([]*querypb.ResourceGroup, error)
SaveCollection(ctx context.Context, collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error
SavePartition(ctx context.Context, info ...*querypb.PartitionLoadInfo) error
SaveReplica(ctx context.Context, replicas ...*querypb.Replica) error
GetCollections(ctx context.Context) ([]*querypb.CollectionLoadInfo, error)
GetPartitions(ctx context.Context) (map[int64][]*querypb.PartitionLoadInfo, error)
GetReplicas(ctx context.Context) ([]*querypb.Replica, error)
ReleaseCollection(ctx context.Context, collection int64) error
ReleasePartition(ctx context.Context, collection int64, partitions ...int64) error
ReleaseReplicas(ctx context.Context, collectionID int64) error
ReleaseReplica(ctx context.Context, collection int64, replicas ...int64) error
SaveResourceGroup(ctx context.Context, rgs ...*querypb.ResourceGroup) error
RemoveResourceGroup(ctx context.Context, rgName string) error
GetResourceGroups(ctx context.Context) ([]*querypb.ResourceGroup, error)
SaveCollectionTargets(target ...*querypb.CollectionTarget) error
RemoveCollectionTarget(collectionID int64) error
GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error)
SaveCollectionTargets(ctx context.Context, target ...*querypb.CollectionTarget) error
RemoveCollectionTarget(ctx context.Context, collectionID int64) error
GetCollectionTargets(ctx context.Context) (map[int64]*querypb.CollectionTarget, error)
}
// StreamingCoordCataLog is the interface for streamingcoord catalog

View File

@ -2,6 +2,7 @@ package querycoord
import (
"bytes"
"context"
"fmt"
"io"
@ -42,7 +43,7 @@ func NewCatalog(cli kv.MetaKv) Catalog {
}
}
func (s Catalog) SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error {
func (s Catalog) SaveCollection(ctx context.Context, collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error {
k := EncodeCollectionLoadInfoKey(collection.GetCollectionID())
v, err := proto.Marshal(collection)
if err != nil {
@ -52,10 +53,10 @@ func (s Catalog) SaveCollection(collection *querypb.CollectionLoadInfo, partitio
if err != nil {
return err
}
return s.SavePartition(partitions...)
return s.SavePartition(ctx, partitions...)
}
func (s Catalog) SavePartition(info ...*querypb.PartitionLoadInfo) error {
func (s Catalog) SavePartition(ctx context.Context, info ...*querypb.PartitionLoadInfo) error {
for _, partition := range info {
k := EncodePartitionLoadInfoKey(partition.GetCollectionID(), partition.GetPartitionID())
v, err := proto.Marshal(partition)
@ -70,7 +71,7 @@ func (s Catalog) SavePartition(info ...*querypb.PartitionLoadInfo) error {
return nil
}
func (s Catalog) SaveReplica(replicas ...*querypb.Replica) error {
func (s Catalog) SaveReplica(ctx context.Context, replicas ...*querypb.Replica) error {
kvs := make(map[string]string)
for _, replica := range replicas {
key := encodeReplicaKey(replica.GetCollectionID(), replica.GetID())
@ -83,7 +84,7 @@ func (s Catalog) SaveReplica(replicas ...*querypb.Replica) error {
return s.cli.MultiSave(kvs)
}
func (s Catalog) SaveResourceGroup(rgs ...*querypb.ResourceGroup) error {
func (s Catalog) SaveResourceGroup(ctx context.Context, rgs ...*querypb.ResourceGroup) error {
ret := make(map[string]string)
for _, rg := range rgs {
key := encodeResourceGroupKey(rg.GetName())
@ -98,12 +99,12 @@ func (s Catalog) SaveResourceGroup(rgs ...*querypb.ResourceGroup) error {
return s.cli.MultiSave(ret)
}
func (s Catalog) RemoveResourceGroup(rgName string) error {
func (s Catalog) RemoveResourceGroup(ctx context.Context, rgName string) error {
key := encodeResourceGroupKey(rgName)
return s.cli.Remove(key)
}
func (s Catalog) GetCollections() ([]*querypb.CollectionLoadInfo, error) {
func (s Catalog) GetCollections(ctx context.Context) ([]*querypb.CollectionLoadInfo, error) {
_, values, err := s.cli.LoadWithPrefix(CollectionLoadInfoPrefix)
if err != nil {
return nil, err
@ -120,7 +121,7 @@ func (s Catalog) GetCollections() ([]*querypb.CollectionLoadInfo, error) {
return ret, nil
}
func (s Catalog) GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error) {
func (s Catalog) GetPartitions(ctx context.Context) (map[int64][]*querypb.PartitionLoadInfo, error) {
_, values, err := s.cli.LoadWithPrefix(PartitionLoadInfoPrefix)
if err != nil {
return nil, err
@ -137,7 +138,7 @@ func (s Catalog) GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error)
return ret, nil
}
func (s Catalog) GetReplicas() ([]*querypb.Replica, error) {
func (s Catalog) GetReplicas(ctx context.Context) ([]*querypb.Replica, error) {
_, values, err := s.cli.LoadWithPrefix(ReplicaPrefix)
if err != nil {
return nil, err
@ -151,7 +152,7 @@ func (s Catalog) GetReplicas() ([]*querypb.Replica, error) {
ret = append(ret, &info)
}
replicasV1, err := s.getReplicasFromV1()
replicasV1, err := s.getReplicasFromV1(ctx)
if err != nil {
return nil, err
}
@ -160,7 +161,7 @@ func (s Catalog) GetReplicas() ([]*querypb.Replica, error) {
return ret, nil
}
func (s Catalog) getReplicasFromV1() ([]*querypb.Replica, error) {
func (s Catalog) getReplicasFromV1(ctx context.Context) ([]*querypb.Replica, error) {
_, replicaValues, err := s.cli.LoadWithPrefix(ReplicaMetaPrefixV1)
if err != nil {
return nil, err
@ -183,7 +184,7 @@ func (s Catalog) getReplicasFromV1() ([]*querypb.Replica, error) {
return ret, nil
}
func (s Catalog) GetResourceGroups() ([]*querypb.ResourceGroup, error) {
func (s Catalog) GetResourceGroups(ctx context.Context) ([]*querypb.ResourceGroup, error) {
_, rgs, err := s.cli.LoadWithPrefix(ResourceGroupPrefix)
if err != nil {
return nil, err
@ -202,7 +203,7 @@ func (s Catalog) GetResourceGroups() ([]*querypb.ResourceGroup, error) {
return ret, nil
}
func (s Catalog) ReleaseCollection(collection int64) error {
func (s Catalog) ReleaseCollection(ctx context.Context, collection int64) error {
// remove collection and obtained partitions
collectionKey := EncodeCollectionLoadInfoKey(collection)
err := s.cli.Remove(collectionKey)
@ -213,7 +214,7 @@ func (s Catalog) ReleaseCollection(collection int64) error {
return s.cli.RemoveWithPrefix(partitionsPrefix)
}
func (s Catalog) ReleasePartition(collection int64, partitions ...int64) error {
func (s Catalog) ReleasePartition(ctx context.Context, collection int64, partitions ...int64) error {
keys := lo.Map(partitions, func(partition int64, _ int) string {
return EncodePartitionLoadInfoKey(collection, partition)
})
@ -235,12 +236,12 @@ func (s Catalog) ReleasePartition(collection int64, partitions ...int64) error {
return s.cli.MultiRemove(keys)
}
func (s Catalog) ReleaseReplicas(collectionID int64) error {
func (s Catalog) ReleaseReplicas(ctx context.Context, collectionID int64) error {
key := encodeCollectionReplicaKey(collectionID)
return s.cli.RemoveWithPrefix(key)
}
func (s Catalog) ReleaseReplica(collection int64, replicas ...int64) error {
func (s Catalog) ReleaseReplica(ctx context.Context, collection int64, replicas ...int64) error {
keys := lo.Map(replicas, func(replica int64, _ int) string {
return encodeReplicaKey(collection, replica)
})
@ -262,7 +263,7 @@ func (s Catalog) ReleaseReplica(collection int64, replicas ...int64) error {
return s.cli.MultiRemove(keys)
}
func (s Catalog) SaveCollectionTargets(targets ...*querypb.CollectionTarget) error {
func (s Catalog) SaveCollectionTargets(ctx context.Context, targets ...*querypb.CollectionTarget) error {
kvs := make(map[string]string)
for _, target := range targets {
k := encodeCollectionTargetKey(target.GetCollectionID())
@ -283,12 +284,12 @@ func (s Catalog) SaveCollectionTargets(targets ...*querypb.CollectionTarget) err
return nil
}
func (s Catalog) RemoveCollectionTarget(collectionID int64) error {
func (s Catalog) RemoveCollectionTarget(ctx context.Context, collectionID int64) error {
k := encodeCollectionTargetKey(collectionID)
return s.cli.Remove(k)
}
func (s Catalog) GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) {
func (s Catalog) GetCollectionTargets(ctx context.Context) (map[int64]*querypb.CollectionTarget, error) {
keys, values, err := s.cli.LoadWithPrefix(CollectionTargetPrefix)
if err != nil {
return nil, err

View File

@ -1,6 +1,7 @@
package querycoord
import (
"context"
"sort"
"testing"
@ -50,53 +51,55 @@ func (suite *CatalogTestSuite) TearDownTest() {
}
func (suite *CatalogTestSuite) TestCollection() {
suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{
ctx := context.Background()
suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{
CollectionID: 1,
})
suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{
suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{
CollectionID: 2,
})
suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{
suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{
CollectionID: 3,
})
suite.catalog.ReleaseCollection(1)
suite.catalog.ReleaseCollection(2)
suite.catalog.ReleaseCollection(ctx, 1)
suite.catalog.ReleaseCollection(ctx, 2)
collections, err := suite.catalog.GetCollections()
collections, err := suite.catalog.GetCollections(ctx)
suite.NoError(err)
suite.Len(collections, 1)
}
func (suite *CatalogTestSuite) TestCollectionWithPartition() {
suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{
ctx := context.Background()
suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{
CollectionID: 1,
})
suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{
suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{
CollectionID: 2,
}, &querypb.PartitionLoadInfo{
CollectionID: 2,
PartitionID: 102,
})
suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{
suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{
CollectionID: 3,
}, &querypb.PartitionLoadInfo{
CollectionID: 3,
PartitionID: 103,
})
suite.catalog.ReleaseCollection(1)
suite.catalog.ReleaseCollection(2)
suite.catalog.ReleaseCollection(ctx, 1)
suite.catalog.ReleaseCollection(ctx, 2)
collections, err := suite.catalog.GetCollections()
collections, err := suite.catalog.GetCollections(ctx)
suite.NoError(err)
suite.Len(collections, 1)
suite.Equal(int64(3), collections[0].GetCollectionID())
partitions, err := suite.catalog.GetPartitions()
partitions, err := suite.catalog.GetPartitions(ctx)
suite.NoError(err)
suite.Len(partitions, 1)
suite.Len(partitions[int64(3)], 1)
@ -104,88 +107,92 @@ func (suite *CatalogTestSuite) TestCollectionWithPartition() {
}
func (suite *CatalogTestSuite) TestPartition() {
suite.catalog.SavePartition(&querypb.PartitionLoadInfo{
ctx := context.Background()
suite.catalog.SavePartition(ctx, &querypb.PartitionLoadInfo{
PartitionID: 1,
})
suite.catalog.SavePartition(&querypb.PartitionLoadInfo{
suite.catalog.SavePartition(ctx, &querypb.PartitionLoadInfo{
PartitionID: 2,
})
suite.catalog.SavePartition(&querypb.PartitionLoadInfo{
suite.catalog.SavePartition(ctx, &querypb.PartitionLoadInfo{
PartitionID: 3,
})
suite.catalog.ReleasePartition(1)
suite.catalog.ReleasePartition(2)
suite.catalog.ReleasePartition(ctx, 1)
suite.catalog.ReleasePartition(ctx, 2)
partitions, err := suite.catalog.GetPartitions()
partitions, err := suite.catalog.GetPartitions(ctx)
suite.NoError(err)
suite.Len(partitions, 1)
}
func (suite *CatalogTestSuite) TestReleaseManyPartitions() {
ctx := context.Background()
partitionIDs := make([]int64, 0)
for i := 1; i <= 150; i++ {
suite.catalog.SavePartition(&querypb.PartitionLoadInfo{
suite.catalog.SavePartition(ctx, &querypb.PartitionLoadInfo{
CollectionID: 1,
PartitionID: int64(i),
})
partitionIDs = append(partitionIDs, int64(i))
}
err := suite.catalog.ReleasePartition(1, partitionIDs...)
err := suite.catalog.ReleasePartition(ctx, 1, partitionIDs...)
suite.NoError(err)
partitions, err := suite.catalog.GetPartitions()
partitions, err := suite.catalog.GetPartitions(ctx)
suite.NoError(err)
suite.Len(partitions, 0)
}
func (suite *CatalogTestSuite) TestReplica() {
suite.catalog.SaveReplica(&querypb.Replica{
ctx := context.Background()
suite.catalog.SaveReplica(ctx, &querypb.Replica{
CollectionID: 1,
ID: 1,
})
suite.catalog.SaveReplica(&querypb.Replica{
suite.catalog.SaveReplica(ctx, &querypb.Replica{
CollectionID: 1,
ID: 2,
})
suite.catalog.SaveReplica(&querypb.Replica{
suite.catalog.SaveReplica(ctx, &querypb.Replica{
CollectionID: 1,
ID: 3,
})
suite.catalog.ReleaseReplica(1, 1)
suite.catalog.ReleaseReplica(1, 2)
suite.catalog.ReleaseReplica(ctx, 1, 1)
suite.catalog.ReleaseReplica(ctx, 1, 2)
replicas, err := suite.catalog.GetReplicas()
replicas, err := suite.catalog.GetReplicas(ctx)
suite.NoError(err)
suite.Len(replicas, 1)
}
func (suite *CatalogTestSuite) TestResourceGroup() {
suite.catalog.SaveResourceGroup(&querypb.ResourceGroup{
ctx := context.Background()
suite.catalog.SaveResourceGroup(ctx, &querypb.ResourceGroup{
Name: "rg1",
Capacity: 3,
Nodes: []int64{1, 2, 3},
})
suite.catalog.SaveResourceGroup(&querypb.ResourceGroup{
suite.catalog.SaveResourceGroup(ctx, &querypb.ResourceGroup{
Name: "rg2",
Capacity: 3,
Nodes: []int64{4, 5},
})
suite.catalog.SaveResourceGroup(&querypb.ResourceGroup{
suite.catalog.SaveResourceGroup(ctx, &querypb.ResourceGroup{
Name: "rg3",
Capacity: 0,
Nodes: []int64{},
})
suite.catalog.RemoveResourceGroup("rg3")
suite.catalog.RemoveResourceGroup(ctx, "rg3")
groups, err := suite.catalog.GetResourceGroups()
groups, err := suite.catalog.GetResourceGroups(ctx)
suite.NoError(err)
suite.Len(groups, 2)
@ -203,7 +210,8 @@ func (suite *CatalogTestSuite) TestResourceGroup() {
}
func (suite *CatalogTestSuite) TestCollectionTarget() {
suite.catalog.SaveCollectionTargets(&querypb.CollectionTarget{
ctx := context.Background()
suite.catalog.SaveCollectionTargets(ctx, &querypb.CollectionTarget{
CollectionID: 1,
Version: 1,
},
@ -219,9 +227,9 @@ func (suite *CatalogTestSuite) TestCollectionTarget() {
CollectionID: 1,
Version: 4,
})
suite.catalog.RemoveCollectionTarget(2)
suite.catalog.RemoveCollectionTarget(ctx, 2)
targets, err := suite.catalog.GetCollectionTargets()
targets, err := suite.catalog.GetCollectionTargets(ctx)
suite.NoError(err)
suite.Len(targets, 2)
suite.Equal(int64(4), targets[1].Version)
@ -234,14 +242,14 @@ func (suite *CatalogTestSuite) TestCollectionTarget() {
mockStore.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, mockErr)
suite.catalog.cli = mockStore
err = suite.catalog.SaveCollectionTargets(&querypb.CollectionTarget{})
err = suite.catalog.SaveCollectionTargets(ctx, &querypb.CollectionTarget{})
suite.ErrorIs(err, mockErr)
_, err = suite.catalog.GetCollectionTargets()
_, err = suite.catalog.GetCollectionTargets(ctx)
suite.ErrorIs(err, mockErr)
// test invalid message
err = suite.catalog.SaveCollectionTargets(nil)
err = suite.catalog.SaveCollectionTargets(ctx)
suite.Error(err)
}

View File

@ -1,10 +1,13 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
// Code generated by mockery v2.46.0. DO NOT EDIT.
package mocks
import (
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
context "context"
mock "github.com/stretchr/testify/mock"
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
)
// QueryCoordCatalog is an autogenerated mock type for the QueryCoordCatalog type
@ -20,25 +23,29 @@ func (_m *QueryCoordCatalog) EXPECT() *QueryCoordCatalog_Expecter {
return &QueryCoordCatalog_Expecter{mock: &_m.Mock}
}
// GetCollectionTargets provides a mock function with given fields:
func (_m *QueryCoordCatalog) GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) {
ret := _m.Called()
// GetCollectionTargets provides a mock function with given fields: ctx
func (_m *QueryCoordCatalog) GetCollectionTargets(ctx context.Context) (map[int64]*querypb.CollectionTarget, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetCollectionTargets")
}
var r0 map[int64]*querypb.CollectionTarget
var r1 error
if rf, ok := ret.Get(0).(func() (map[int64]*querypb.CollectionTarget, error)); ok {
return rf()
if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*querypb.CollectionTarget, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func() map[int64]*querypb.CollectionTarget); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context) map[int64]*querypb.CollectionTarget); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]*querypb.CollectionTarget)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@ -52,13 +59,14 @@ type QueryCoordCatalog_GetCollectionTargets_Call struct {
}
// GetCollectionTargets is a helper method to define mock.On call
func (_e *QueryCoordCatalog_Expecter) GetCollectionTargets() *QueryCoordCatalog_GetCollectionTargets_Call {
return &QueryCoordCatalog_GetCollectionTargets_Call{Call: _e.mock.On("GetCollectionTargets")}
// - ctx context.Context
func (_e *QueryCoordCatalog_Expecter) GetCollectionTargets(ctx interface{}) *QueryCoordCatalog_GetCollectionTargets_Call {
return &QueryCoordCatalog_GetCollectionTargets_Call{Call: _e.mock.On("GetCollectionTargets", ctx)}
}
func (_c *QueryCoordCatalog_GetCollectionTargets_Call) Run(run func()) *QueryCoordCatalog_GetCollectionTargets_Call {
func (_c *QueryCoordCatalog_GetCollectionTargets_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetCollectionTargets_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
run(args[0].(context.Context))
})
return _c
}
@ -68,30 +76,34 @@ func (_c *QueryCoordCatalog_GetCollectionTargets_Call) Return(_a0 map[int64]*que
return _c
}
func (_c *QueryCoordCatalog_GetCollectionTargets_Call) RunAndReturn(run func() (map[int64]*querypb.CollectionTarget, error)) *QueryCoordCatalog_GetCollectionTargets_Call {
func (_c *QueryCoordCatalog_GetCollectionTargets_Call) RunAndReturn(run func(context.Context) (map[int64]*querypb.CollectionTarget, error)) *QueryCoordCatalog_GetCollectionTargets_Call {
_c.Call.Return(run)
return _c
}
// GetCollections provides a mock function with given fields:
func (_m *QueryCoordCatalog) GetCollections() ([]*querypb.CollectionLoadInfo, error) {
ret := _m.Called()
// GetCollections provides a mock function with given fields: ctx
func (_m *QueryCoordCatalog) GetCollections(ctx context.Context) ([]*querypb.CollectionLoadInfo, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetCollections")
}
var r0 []*querypb.CollectionLoadInfo
var r1 error
if rf, ok := ret.Get(0).(func() ([]*querypb.CollectionLoadInfo, error)); ok {
return rf()
if rf, ok := ret.Get(0).(func(context.Context) ([]*querypb.CollectionLoadInfo, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func() []*querypb.CollectionLoadInfo); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context) []*querypb.CollectionLoadInfo); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*querypb.CollectionLoadInfo)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@ -105,13 +117,14 @@ type QueryCoordCatalog_GetCollections_Call struct {
}
// GetCollections is a helper method to define mock.On call
func (_e *QueryCoordCatalog_Expecter) GetCollections() *QueryCoordCatalog_GetCollections_Call {
return &QueryCoordCatalog_GetCollections_Call{Call: _e.mock.On("GetCollections")}
// - ctx context.Context
func (_e *QueryCoordCatalog_Expecter) GetCollections(ctx interface{}) *QueryCoordCatalog_GetCollections_Call {
return &QueryCoordCatalog_GetCollections_Call{Call: _e.mock.On("GetCollections", ctx)}
}
func (_c *QueryCoordCatalog_GetCollections_Call) Run(run func()) *QueryCoordCatalog_GetCollections_Call {
func (_c *QueryCoordCatalog_GetCollections_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetCollections_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
run(args[0].(context.Context))
})
return _c
}
@ -121,30 +134,34 @@ func (_c *QueryCoordCatalog_GetCollections_Call) Return(_a0 []*querypb.Collectio
return _c
}
func (_c *QueryCoordCatalog_GetCollections_Call) RunAndReturn(run func() ([]*querypb.CollectionLoadInfo, error)) *QueryCoordCatalog_GetCollections_Call {
func (_c *QueryCoordCatalog_GetCollections_Call) RunAndReturn(run func(context.Context) ([]*querypb.CollectionLoadInfo, error)) *QueryCoordCatalog_GetCollections_Call {
_c.Call.Return(run)
return _c
}
// GetPartitions provides a mock function with given fields:
func (_m *QueryCoordCatalog) GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error) {
ret := _m.Called()
// GetPartitions provides a mock function with given fields: ctx
func (_m *QueryCoordCatalog) GetPartitions(ctx context.Context) (map[int64][]*querypb.PartitionLoadInfo, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetPartitions")
}
var r0 map[int64][]*querypb.PartitionLoadInfo
var r1 error
if rf, ok := ret.Get(0).(func() (map[int64][]*querypb.PartitionLoadInfo, error)); ok {
return rf()
if rf, ok := ret.Get(0).(func(context.Context) (map[int64][]*querypb.PartitionLoadInfo, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func() map[int64][]*querypb.PartitionLoadInfo); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context) map[int64][]*querypb.PartitionLoadInfo); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64][]*querypb.PartitionLoadInfo)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@ -158,13 +175,14 @@ type QueryCoordCatalog_GetPartitions_Call struct {
}
// GetPartitions is a helper method to define mock.On call
func (_e *QueryCoordCatalog_Expecter) GetPartitions() *QueryCoordCatalog_GetPartitions_Call {
return &QueryCoordCatalog_GetPartitions_Call{Call: _e.mock.On("GetPartitions")}
// - ctx context.Context
func (_e *QueryCoordCatalog_Expecter) GetPartitions(ctx interface{}) *QueryCoordCatalog_GetPartitions_Call {
return &QueryCoordCatalog_GetPartitions_Call{Call: _e.mock.On("GetPartitions", ctx)}
}
func (_c *QueryCoordCatalog_GetPartitions_Call) Run(run func()) *QueryCoordCatalog_GetPartitions_Call {
func (_c *QueryCoordCatalog_GetPartitions_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetPartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
run(args[0].(context.Context))
})
return _c
}
@ -174,30 +192,34 @@ func (_c *QueryCoordCatalog_GetPartitions_Call) Return(_a0 map[int64][]*querypb.
return _c
}
func (_c *QueryCoordCatalog_GetPartitions_Call) RunAndReturn(run func() (map[int64][]*querypb.PartitionLoadInfo, error)) *QueryCoordCatalog_GetPartitions_Call {
func (_c *QueryCoordCatalog_GetPartitions_Call) RunAndReturn(run func(context.Context) (map[int64][]*querypb.PartitionLoadInfo, error)) *QueryCoordCatalog_GetPartitions_Call {
_c.Call.Return(run)
return _c
}
// GetReplicas provides a mock function with given fields:
func (_m *QueryCoordCatalog) GetReplicas() ([]*querypb.Replica, error) {
ret := _m.Called()
// GetReplicas provides a mock function with given fields: ctx
func (_m *QueryCoordCatalog) GetReplicas(ctx context.Context) ([]*querypb.Replica, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetReplicas")
}
var r0 []*querypb.Replica
var r1 error
if rf, ok := ret.Get(0).(func() ([]*querypb.Replica, error)); ok {
return rf()
if rf, ok := ret.Get(0).(func(context.Context) ([]*querypb.Replica, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func() []*querypb.Replica); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context) []*querypb.Replica); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*querypb.Replica)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@ -211,13 +233,14 @@ type QueryCoordCatalog_GetReplicas_Call struct {
}
// GetReplicas is a helper method to define mock.On call
func (_e *QueryCoordCatalog_Expecter) GetReplicas() *QueryCoordCatalog_GetReplicas_Call {
return &QueryCoordCatalog_GetReplicas_Call{Call: _e.mock.On("GetReplicas")}
// - ctx context.Context
func (_e *QueryCoordCatalog_Expecter) GetReplicas(ctx interface{}) *QueryCoordCatalog_GetReplicas_Call {
return &QueryCoordCatalog_GetReplicas_Call{Call: _e.mock.On("GetReplicas", ctx)}
}
func (_c *QueryCoordCatalog_GetReplicas_Call) Run(run func()) *QueryCoordCatalog_GetReplicas_Call {
func (_c *QueryCoordCatalog_GetReplicas_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetReplicas_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
run(args[0].(context.Context))
})
return _c
}
@ -227,30 +250,34 @@ func (_c *QueryCoordCatalog_GetReplicas_Call) Return(_a0 []*querypb.Replica, _a1
return _c
}
func (_c *QueryCoordCatalog_GetReplicas_Call) RunAndReturn(run func() ([]*querypb.Replica, error)) *QueryCoordCatalog_GetReplicas_Call {
func (_c *QueryCoordCatalog_GetReplicas_Call) RunAndReturn(run func(context.Context) ([]*querypb.Replica, error)) *QueryCoordCatalog_GetReplicas_Call {
_c.Call.Return(run)
return _c
}
// GetResourceGroups provides a mock function with given fields:
func (_m *QueryCoordCatalog) GetResourceGroups() ([]*querypb.ResourceGroup, error) {
ret := _m.Called()
// GetResourceGroups provides a mock function with given fields: ctx
func (_m *QueryCoordCatalog) GetResourceGroups(ctx context.Context) ([]*querypb.ResourceGroup, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetResourceGroups")
}
var r0 []*querypb.ResourceGroup
var r1 error
if rf, ok := ret.Get(0).(func() ([]*querypb.ResourceGroup, error)); ok {
return rf()
if rf, ok := ret.Get(0).(func(context.Context) ([]*querypb.ResourceGroup, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func() []*querypb.ResourceGroup); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context) []*querypb.ResourceGroup); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*querypb.ResourceGroup)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@ -264,13 +291,14 @@ type QueryCoordCatalog_GetResourceGroups_Call struct {
}
// GetResourceGroups is a helper method to define mock.On call
func (_e *QueryCoordCatalog_Expecter) GetResourceGroups() *QueryCoordCatalog_GetResourceGroups_Call {
return &QueryCoordCatalog_GetResourceGroups_Call{Call: _e.mock.On("GetResourceGroups")}
// - ctx context.Context
func (_e *QueryCoordCatalog_Expecter) GetResourceGroups(ctx interface{}) *QueryCoordCatalog_GetResourceGroups_Call {
return &QueryCoordCatalog_GetResourceGroups_Call{Call: _e.mock.On("GetResourceGroups", ctx)}
}
func (_c *QueryCoordCatalog_GetResourceGroups_Call) Run(run func()) *QueryCoordCatalog_GetResourceGroups_Call {
func (_c *QueryCoordCatalog_GetResourceGroups_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetResourceGroups_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
run(args[0].(context.Context))
})
return _c
}
@ -280,18 +308,22 @@ func (_c *QueryCoordCatalog_GetResourceGroups_Call) Return(_a0 []*querypb.Resour
return _c
}
func (_c *QueryCoordCatalog_GetResourceGroups_Call) RunAndReturn(run func() ([]*querypb.ResourceGroup, error)) *QueryCoordCatalog_GetResourceGroups_Call {
func (_c *QueryCoordCatalog_GetResourceGroups_Call) RunAndReturn(run func(context.Context) ([]*querypb.ResourceGroup, error)) *QueryCoordCatalog_GetResourceGroups_Call {
_c.Call.Return(run)
return _c
}
// ReleaseCollection provides a mock function with given fields: collection
func (_m *QueryCoordCatalog) ReleaseCollection(collection int64) error {
ret := _m.Called(collection)
// ReleaseCollection provides a mock function with given fields: ctx, collection
func (_m *QueryCoordCatalog) ReleaseCollection(ctx context.Context, collection int64) error {
ret := _m.Called(ctx, collection)
if len(ret) == 0 {
panic("no return value specified for ReleaseCollection")
}
var r0 error
if rf, ok := ret.Get(0).(func(int64) error); ok {
r0 = rf(collection)
if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = rf(ctx, collection)
} else {
r0 = ret.Error(0)
}
@ -305,14 +337,15 @@ type QueryCoordCatalog_ReleaseCollection_Call struct {
}
// ReleaseCollection is a helper method to define mock.On call
// - ctx context.Context
// - collection int64
func (_e *QueryCoordCatalog_Expecter) ReleaseCollection(collection interface{}) *QueryCoordCatalog_ReleaseCollection_Call {
return &QueryCoordCatalog_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", collection)}
func (_e *QueryCoordCatalog_Expecter) ReleaseCollection(ctx interface{}, collection interface{}) *QueryCoordCatalog_ReleaseCollection_Call {
return &QueryCoordCatalog_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", ctx, collection)}
}
func (_c *QueryCoordCatalog_ReleaseCollection_Call) Run(run func(collection int64)) *QueryCoordCatalog_ReleaseCollection_Call {
func (_c *QueryCoordCatalog_ReleaseCollection_Call) Run(run func(ctx context.Context, collection int64)) *QueryCoordCatalog_ReleaseCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -322,25 +355,29 @@ func (_c *QueryCoordCatalog_ReleaseCollection_Call) Return(_a0 error) *QueryCoor
return _c
}
func (_c *QueryCoordCatalog_ReleaseCollection_Call) RunAndReturn(run func(int64) error) *QueryCoordCatalog_ReleaseCollection_Call {
func (_c *QueryCoordCatalog_ReleaseCollection_Call) RunAndReturn(run func(context.Context, int64) error) *QueryCoordCatalog_ReleaseCollection_Call {
_c.Call.Return(run)
return _c
}
// ReleasePartition provides a mock function with given fields: collection, partitions
func (_m *QueryCoordCatalog) ReleasePartition(collection int64, partitions ...int64) error {
// ReleasePartition provides a mock function with given fields: ctx, collection, partitions
func (_m *QueryCoordCatalog) ReleasePartition(ctx context.Context, collection int64, partitions ...int64) error {
_va := make([]interface{}, len(partitions))
for _i := range partitions {
_va[_i] = partitions[_i]
}
var _ca []interface{}
_ca = append(_ca, collection)
_ca = append(_ca, ctx, collection)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for ReleasePartition")
}
var r0 error
if rf, ok := ret.Get(0).(func(int64, ...int64) error); ok {
r0 = rf(collection, partitions...)
if rf, ok := ret.Get(0).(func(context.Context, int64, ...int64) error); ok {
r0 = rf(ctx, collection, partitions...)
} else {
r0 = ret.Error(0)
}
@ -354,22 +391,23 @@ type QueryCoordCatalog_ReleasePartition_Call struct {
}
// ReleasePartition is a helper method to define mock.On call
// - ctx context.Context
// - collection int64
// - partitions ...int64
func (_e *QueryCoordCatalog_Expecter) ReleasePartition(collection interface{}, partitions ...interface{}) *QueryCoordCatalog_ReleasePartition_Call {
func (_e *QueryCoordCatalog_Expecter) ReleasePartition(ctx interface{}, collection interface{}, partitions ...interface{}) *QueryCoordCatalog_ReleasePartition_Call {
return &QueryCoordCatalog_ReleasePartition_Call{Call: _e.mock.On("ReleasePartition",
append([]interface{}{collection}, partitions...)...)}
append([]interface{}{ctx, collection}, partitions...)...)}
}
func (_c *QueryCoordCatalog_ReleasePartition_Call) Run(run func(collection int64, partitions ...int64)) *QueryCoordCatalog_ReleasePartition_Call {
func (_c *QueryCoordCatalog_ReleasePartition_Call) Run(run func(ctx context.Context, collection int64, partitions ...int64)) *QueryCoordCatalog_ReleasePartition_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]int64, len(args)-1)
for i, a := range args[1:] {
variadicArgs := make([]int64, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(int64)
}
}
run(args[0].(int64), variadicArgs...)
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
})
return _c
}
@ -379,25 +417,29 @@ func (_c *QueryCoordCatalog_ReleasePartition_Call) Return(_a0 error) *QueryCoord
return _c
}
func (_c *QueryCoordCatalog_ReleasePartition_Call) RunAndReturn(run func(int64, ...int64) error) *QueryCoordCatalog_ReleasePartition_Call {
func (_c *QueryCoordCatalog_ReleasePartition_Call) RunAndReturn(run func(context.Context, int64, ...int64) error) *QueryCoordCatalog_ReleasePartition_Call {
_c.Call.Return(run)
return _c
}
// ReleaseReplica provides a mock function with given fields: collection, replicas
func (_m *QueryCoordCatalog) ReleaseReplica(collection int64, replicas ...int64) error {
// ReleaseReplica provides a mock function with given fields: ctx, collection, replicas
func (_m *QueryCoordCatalog) ReleaseReplica(ctx context.Context, collection int64, replicas ...int64) error {
_va := make([]interface{}, len(replicas))
for _i := range replicas {
_va[_i] = replicas[_i]
}
var _ca []interface{}
_ca = append(_ca, collection)
_ca = append(_ca, ctx, collection)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for ReleaseReplica")
}
var r0 error
if rf, ok := ret.Get(0).(func(int64, ...int64) error); ok {
r0 = rf(collection, replicas...)
if rf, ok := ret.Get(0).(func(context.Context, int64, ...int64) error); ok {
r0 = rf(ctx, collection, replicas...)
} else {
r0 = ret.Error(0)
}
@ -411,22 +453,23 @@ type QueryCoordCatalog_ReleaseReplica_Call struct {
}
// ReleaseReplica is a helper method to define mock.On call
// - ctx context.Context
// - collection int64
// - replicas ...int64
func (_e *QueryCoordCatalog_Expecter) ReleaseReplica(collection interface{}, replicas ...interface{}) *QueryCoordCatalog_ReleaseReplica_Call {
func (_e *QueryCoordCatalog_Expecter) ReleaseReplica(ctx interface{}, collection interface{}, replicas ...interface{}) *QueryCoordCatalog_ReleaseReplica_Call {
return &QueryCoordCatalog_ReleaseReplica_Call{Call: _e.mock.On("ReleaseReplica",
append([]interface{}{collection}, replicas...)...)}
append([]interface{}{ctx, collection}, replicas...)...)}
}
func (_c *QueryCoordCatalog_ReleaseReplica_Call) Run(run func(collection int64, replicas ...int64)) *QueryCoordCatalog_ReleaseReplica_Call {
func (_c *QueryCoordCatalog_ReleaseReplica_Call) Run(run func(ctx context.Context, collection int64, replicas ...int64)) *QueryCoordCatalog_ReleaseReplica_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]int64, len(args)-1)
for i, a := range args[1:] {
variadicArgs := make([]int64, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(int64)
}
}
run(args[0].(int64), variadicArgs...)
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
})
return _c
}
@ -436,18 +479,22 @@ func (_c *QueryCoordCatalog_ReleaseReplica_Call) Return(_a0 error) *QueryCoordCa
return _c
}
func (_c *QueryCoordCatalog_ReleaseReplica_Call) RunAndReturn(run func(int64, ...int64) error) *QueryCoordCatalog_ReleaseReplica_Call {
func (_c *QueryCoordCatalog_ReleaseReplica_Call) RunAndReturn(run func(context.Context, int64, ...int64) error) *QueryCoordCatalog_ReleaseReplica_Call {
_c.Call.Return(run)
return _c
}
// ReleaseReplicas provides a mock function with given fields: collectionID
func (_m *QueryCoordCatalog) ReleaseReplicas(collectionID int64) error {
ret := _m.Called(collectionID)
// ReleaseReplicas provides a mock function with given fields: ctx, collectionID
func (_m *QueryCoordCatalog) ReleaseReplicas(ctx context.Context, collectionID int64) error {
ret := _m.Called(ctx, collectionID)
if len(ret) == 0 {
panic("no return value specified for ReleaseReplicas")
}
var r0 error
if rf, ok := ret.Get(0).(func(int64) error); ok {
r0 = rf(collectionID)
if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Error(0)
}
@ -461,14 +508,15 @@ type QueryCoordCatalog_ReleaseReplicas_Call struct {
}
// ReleaseReplicas is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *QueryCoordCatalog_Expecter) ReleaseReplicas(collectionID interface{}) *QueryCoordCatalog_ReleaseReplicas_Call {
return &QueryCoordCatalog_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", collectionID)}
func (_e *QueryCoordCatalog_Expecter) ReleaseReplicas(ctx interface{}, collectionID interface{}) *QueryCoordCatalog_ReleaseReplicas_Call {
return &QueryCoordCatalog_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", ctx, collectionID)}
}
func (_c *QueryCoordCatalog_ReleaseReplicas_Call) Run(run func(collectionID int64)) *QueryCoordCatalog_ReleaseReplicas_Call {
func (_c *QueryCoordCatalog_ReleaseReplicas_Call) Run(run func(ctx context.Context, collectionID int64)) *QueryCoordCatalog_ReleaseReplicas_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -478,18 +526,22 @@ func (_c *QueryCoordCatalog_ReleaseReplicas_Call) Return(_a0 error) *QueryCoordC
return _c
}
func (_c *QueryCoordCatalog_ReleaseReplicas_Call) RunAndReturn(run func(int64) error) *QueryCoordCatalog_ReleaseReplicas_Call {
func (_c *QueryCoordCatalog_ReleaseReplicas_Call) RunAndReturn(run func(context.Context, int64) error) *QueryCoordCatalog_ReleaseReplicas_Call {
_c.Call.Return(run)
return _c
}
// RemoveCollectionTarget provides a mock function with given fields: collectionID
func (_m *QueryCoordCatalog) RemoveCollectionTarget(collectionID int64) error {
ret := _m.Called(collectionID)
// RemoveCollectionTarget provides a mock function with given fields: ctx, collectionID
func (_m *QueryCoordCatalog) RemoveCollectionTarget(ctx context.Context, collectionID int64) error {
ret := _m.Called(ctx, collectionID)
if len(ret) == 0 {
panic("no return value specified for RemoveCollectionTarget")
}
var r0 error
if rf, ok := ret.Get(0).(func(int64) error); ok {
r0 = rf(collectionID)
if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Error(0)
}
@ -503,14 +555,15 @@ type QueryCoordCatalog_RemoveCollectionTarget_Call struct {
}
// RemoveCollectionTarget is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *QueryCoordCatalog_Expecter) RemoveCollectionTarget(collectionID interface{}) *QueryCoordCatalog_RemoveCollectionTarget_Call {
return &QueryCoordCatalog_RemoveCollectionTarget_Call{Call: _e.mock.On("RemoveCollectionTarget", collectionID)}
func (_e *QueryCoordCatalog_Expecter) RemoveCollectionTarget(ctx interface{}, collectionID interface{}) *QueryCoordCatalog_RemoveCollectionTarget_Call {
return &QueryCoordCatalog_RemoveCollectionTarget_Call{Call: _e.mock.On("RemoveCollectionTarget", ctx, collectionID)}
}
func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) Run(run func(collectionID int64)) *QueryCoordCatalog_RemoveCollectionTarget_Call {
func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) Run(run func(ctx context.Context, collectionID int64)) *QueryCoordCatalog_RemoveCollectionTarget_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -520,18 +573,22 @@ func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) Return(_a0 error) *Quer
return _c
}
func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) RunAndReturn(run func(int64) error) *QueryCoordCatalog_RemoveCollectionTarget_Call {
func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) RunAndReturn(run func(context.Context, int64) error) *QueryCoordCatalog_RemoveCollectionTarget_Call {
_c.Call.Return(run)
return _c
}
// RemoveResourceGroup provides a mock function with given fields: rgName
func (_m *QueryCoordCatalog) RemoveResourceGroup(rgName string) error {
ret := _m.Called(rgName)
// RemoveResourceGroup provides a mock function with given fields: ctx, rgName
func (_m *QueryCoordCatalog) RemoveResourceGroup(ctx context.Context, rgName string) error {
ret := _m.Called(ctx, rgName)
if len(ret) == 0 {
panic("no return value specified for RemoveResourceGroup")
}
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(rgName)
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, rgName)
} else {
r0 = ret.Error(0)
}
@ -545,14 +602,15 @@ type QueryCoordCatalog_RemoveResourceGroup_Call struct {
}
// RemoveResourceGroup is a helper method to define mock.On call
// - ctx context.Context
// - rgName string
func (_e *QueryCoordCatalog_Expecter) RemoveResourceGroup(rgName interface{}) *QueryCoordCatalog_RemoveResourceGroup_Call {
return &QueryCoordCatalog_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", rgName)}
func (_e *QueryCoordCatalog_Expecter) RemoveResourceGroup(ctx interface{}, rgName interface{}) *QueryCoordCatalog_RemoveResourceGroup_Call {
return &QueryCoordCatalog_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", ctx, rgName)}
}
func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) Run(run func(rgName string)) *QueryCoordCatalog_RemoveResourceGroup_Call {
func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) Run(run func(ctx context.Context, rgName string)) *QueryCoordCatalog_RemoveResourceGroup_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
run(args[0].(context.Context), args[1].(string))
})
return _c
}
@ -562,25 +620,29 @@ func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) Return(_a0 error) *QueryCo
return _c
}
func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) RunAndReturn(run func(string) error) *QueryCoordCatalog_RemoveResourceGroup_Call {
func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) RunAndReturn(run func(context.Context, string) error) *QueryCoordCatalog_RemoveResourceGroup_Call {
_c.Call.Return(run)
return _c
}
// SaveCollection provides a mock function with given fields: collection, partitions
func (_m *QueryCoordCatalog) SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error {
// SaveCollection provides a mock function with given fields: ctx, collection, partitions
func (_m *QueryCoordCatalog) SaveCollection(ctx context.Context, collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error {
_va := make([]interface{}, len(partitions))
for _i := range partitions {
_va[_i] = partitions[_i]
}
var _ca []interface{}
_ca = append(_ca, collection)
_ca = append(_ca, ctx, collection)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for SaveCollection")
}
var r0 error
if rf, ok := ret.Get(0).(func(*querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error); ok {
r0 = rf(collection, partitions...)
if rf, ok := ret.Get(0).(func(context.Context, *querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error); ok {
r0 = rf(ctx, collection, partitions...)
} else {
r0 = ret.Error(0)
}
@ -594,22 +656,23 @@ type QueryCoordCatalog_SaveCollection_Call struct {
}
// SaveCollection is a helper method to define mock.On call
// - ctx context.Context
// - collection *querypb.CollectionLoadInfo
// - partitions ...*querypb.PartitionLoadInfo
func (_e *QueryCoordCatalog_Expecter) SaveCollection(collection interface{}, partitions ...interface{}) *QueryCoordCatalog_SaveCollection_Call {
func (_e *QueryCoordCatalog_Expecter) SaveCollection(ctx interface{}, collection interface{}, partitions ...interface{}) *QueryCoordCatalog_SaveCollection_Call {
return &QueryCoordCatalog_SaveCollection_Call{Call: _e.mock.On("SaveCollection",
append([]interface{}{collection}, partitions...)...)}
append([]interface{}{ctx, collection}, partitions...)...)}
}
func (_c *QueryCoordCatalog_SaveCollection_Call) Run(run func(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo)) *QueryCoordCatalog_SaveCollection_Call {
func (_c *QueryCoordCatalog_SaveCollection_Call) Run(run func(ctx context.Context, collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo)) *QueryCoordCatalog_SaveCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-1)
for i, a := range args[1:] {
variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(*querypb.PartitionLoadInfo)
}
}
run(args[0].(*querypb.CollectionLoadInfo), variadicArgs...)
run(args[0].(context.Context), args[1].(*querypb.CollectionLoadInfo), variadicArgs...)
})
return _c
}
@ -619,24 +682,29 @@ func (_c *QueryCoordCatalog_SaveCollection_Call) Return(_a0 error) *QueryCoordCa
return _c
}
func (_c *QueryCoordCatalog_SaveCollection_Call) RunAndReturn(run func(*querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error) *QueryCoordCatalog_SaveCollection_Call {
func (_c *QueryCoordCatalog_SaveCollection_Call) RunAndReturn(run func(context.Context, *querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error) *QueryCoordCatalog_SaveCollection_Call {
_c.Call.Return(run)
return _c
}
// SaveCollectionTargets provides a mock function with given fields: target
func (_m *QueryCoordCatalog) SaveCollectionTargets(target ...*querypb.CollectionTarget) error {
// SaveCollectionTargets provides a mock function with given fields: ctx, target
func (_m *QueryCoordCatalog) SaveCollectionTargets(ctx context.Context, target ...*querypb.CollectionTarget) error {
_va := make([]interface{}, len(target))
for _i := range target {
_va[_i] = target[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for SaveCollectionTargets")
}
var r0 error
if rf, ok := ret.Get(0).(func(...*querypb.CollectionTarget) error); ok {
r0 = rf(target...)
if rf, ok := ret.Get(0).(func(context.Context, ...*querypb.CollectionTarget) error); ok {
r0 = rf(ctx, target...)
} else {
r0 = ret.Error(0)
}
@ -650,21 +718,22 @@ type QueryCoordCatalog_SaveCollectionTargets_Call struct {
}
// SaveCollectionTargets is a helper method to define mock.On call
// - ctx context.Context
// - target ...*querypb.CollectionTarget
func (_e *QueryCoordCatalog_Expecter) SaveCollectionTargets(target ...interface{}) *QueryCoordCatalog_SaveCollectionTargets_Call {
func (_e *QueryCoordCatalog_Expecter) SaveCollectionTargets(ctx interface{}, target ...interface{}) *QueryCoordCatalog_SaveCollectionTargets_Call {
return &QueryCoordCatalog_SaveCollectionTargets_Call{Call: _e.mock.On("SaveCollectionTargets",
append([]interface{}{}, target...)...)}
append([]interface{}{ctx}, target...)...)}
}
func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Run(run func(target ...*querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTargets_Call {
func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Run(run func(ctx context.Context, target ...*querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTargets_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]*querypb.CollectionTarget, len(args)-0)
for i, a := range args[0:] {
variadicArgs := make([]*querypb.CollectionTarget, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(*querypb.CollectionTarget)
}
}
run(variadicArgs...)
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
@ -674,24 +743,29 @@ func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Return(_a0 error) *Query
return _c
}
func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) RunAndReturn(run func(...*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTargets_Call {
func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) RunAndReturn(run func(context.Context, ...*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTargets_Call {
_c.Call.Return(run)
return _c
}
// SavePartition provides a mock function with given fields: info
func (_m *QueryCoordCatalog) SavePartition(info ...*querypb.PartitionLoadInfo) error {
// SavePartition provides a mock function with given fields: ctx, info
func (_m *QueryCoordCatalog) SavePartition(ctx context.Context, info ...*querypb.PartitionLoadInfo) error {
_va := make([]interface{}, len(info))
for _i := range info {
_va[_i] = info[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for SavePartition")
}
var r0 error
if rf, ok := ret.Get(0).(func(...*querypb.PartitionLoadInfo) error); ok {
r0 = rf(info...)
if rf, ok := ret.Get(0).(func(context.Context, ...*querypb.PartitionLoadInfo) error); ok {
r0 = rf(ctx, info...)
} else {
r0 = ret.Error(0)
}
@ -705,21 +779,22 @@ type QueryCoordCatalog_SavePartition_Call struct {
}
// SavePartition is a helper method to define mock.On call
// - ctx context.Context
// - info ...*querypb.PartitionLoadInfo
func (_e *QueryCoordCatalog_Expecter) SavePartition(info ...interface{}) *QueryCoordCatalog_SavePartition_Call {
func (_e *QueryCoordCatalog_Expecter) SavePartition(ctx interface{}, info ...interface{}) *QueryCoordCatalog_SavePartition_Call {
return &QueryCoordCatalog_SavePartition_Call{Call: _e.mock.On("SavePartition",
append([]interface{}{}, info...)...)}
append([]interface{}{ctx}, info...)...)}
}
func (_c *QueryCoordCatalog_SavePartition_Call) Run(run func(info ...*querypb.PartitionLoadInfo)) *QueryCoordCatalog_SavePartition_Call {
func (_c *QueryCoordCatalog_SavePartition_Call) Run(run func(ctx context.Context, info ...*querypb.PartitionLoadInfo)) *QueryCoordCatalog_SavePartition_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-0)
for i, a := range args[0:] {
variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(*querypb.PartitionLoadInfo)
}
}
run(variadicArgs...)
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
@ -729,24 +804,29 @@ func (_c *QueryCoordCatalog_SavePartition_Call) Return(_a0 error) *QueryCoordCat
return _c
}
func (_c *QueryCoordCatalog_SavePartition_Call) RunAndReturn(run func(...*querypb.PartitionLoadInfo) error) *QueryCoordCatalog_SavePartition_Call {
func (_c *QueryCoordCatalog_SavePartition_Call) RunAndReturn(run func(context.Context, ...*querypb.PartitionLoadInfo) error) *QueryCoordCatalog_SavePartition_Call {
_c.Call.Return(run)
return _c
}
// SaveReplica provides a mock function with given fields: replicas
func (_m *QueryCoordCatalog) SaveReplica(replicas ...*querypb.Replica) error {
// SaveReplica provides a mock function with given fields: ctx, replicas
func (_m *QueryCoordCatalog) SaveReplica(ctx context.Context, replicas ...*querypb.Replica) error {
_va := make([]interface{}, len(replicas))
for _i := range replicas {
_va[_i] = replicas[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for SaveReplica")
}
var r0 error
if rf, ok := ret.Get(0).(func(...*querypb.Replica) error); ok {
r0 = rf(replicas...)
if rf, ok := ret.Get(0).(func(context.Context, ...*querypb.Replica) error); ok {
r0 = rf(ctx, replicas...)
} else {
r0 = ret.Error(0)
}
@ -760,21 +840,22 @@ type QueryCoordCatalog_SaveReplica_Call struct {
}
// SaveReplica is a helper method to define mock.On call
// - ctx context.Context
// - replicas ...*querypb.Replica
func (_e *QueryCoordCatalog_Expecter) SaveReplica(replicas ...interface{}) *QueryCoordCatalog_SaveReplica_Call {
func (_e *QueryCoordCatalog_Expecter) SaveReplica(ctx interface{}, replicas ...interface{}) *QueryCoordCatalog_SaveReplica_Call {
return &QueryCoordCatalog_SaveReplica_Call{Call: _e.mock.On("SaveReplica",
append([]interface{}{}, replicas...)...)}
append([]interface{}{ctx}, replicas...)...)}
}
func (_c *QueryCoordCatalog_SaveReplica_Call) Run(run func(replicas ...*querypb.Replica)) *QueryCoordCatalog_SaveReplica_Call {
func (_c *QueryCoordCatalog_SaveReplica_Call) Run(run func(ctx context.Context, replicas ...*querypb.Replica)) *QueryCoordCatalog_SaveReplica_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]*querypb.Replica, len(args)-0)
for i, a := range args[0:] {
variadicArgs := make([]*querypb.Replica, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(*querypb.Replica)
}
}
run(variadicArgs...)
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
@ -784,24 +865,29 @@ func (_c *QueryCoordCatalog_SaveReplica_Call) Return(_a0 error) *QueryCoordCatal
return _c
}
func (_c *QueryCoordCatalog_SaveReplica_Call) RunAndReturn(run func(...*querypb.Replica) error) *QueryCoordCatalog_SaveReplica_Call {
func (_c *QueryCoordCatalog_SaveReplica_Call) RunAndReturn(run func(context.Context, ...*querypb.Replica) error) *QueryCoordCatalog_SaveReplica_Call {
_c.Call.Return(run)
return _c
}
// SaveResourceGroup provides a mock function with given fields: rgs
func (_m *QueryCoordCatalog) SaveResourceGroup(rgs ...*querypb.ResourceGroup) error {
// SaveResourceGroup provides a mock function with given fields: ctx, rgs
func (_m *QueryCoordCatalog) SaveResourceGroup(ctx context.Context, rgs ...*querypb.ResourceGroup) error {
_va := make([]interface{}, len(rgs))
for _i := range rgs {
_va[_i] = rgs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for SaveResourceGroup")
}
var r0 error
if rf, ok := ret.Get(0).(func(...*querypb.ResourceGroup) error); ok {
r0 = rf(rgs...)
if rf, ok := ret.Get(0).(func(context.Context, ...*querypb.ResourceGroup) error); ok {
r0 = rf(ctx, rgs...)
} else {
r0 = ret.Error(0)
}
@ -815,21 +901,22 @@ type QueryCoordCatalog_SaveResourceGroup_Call struct {
}
// SaveResourceGroup is a helper method to define mock.On call
// - ctx context.Context
// - rgs ...*querypb.ResourceGroup
func (_e *QueryCoordCatalog_Expecter) SaveResourceGroup(rgs ...interface{}) *QueryCoordCatalog_SaveResourceGroup_Call {
func (_e *QueryCoordCatalog_Expecter) SaveResourceGroup(ctx interface{}, rgs ...interface{}) *QueryCoordCatalog_SaveResourceGroup_Call {
return &QueryCoordCatalog_SaveResourceGroup_Call{Call: _e.mock.On("SaveResourceGroup",
append([]interface{}{}, rgs...)...)}
append([]interface{}{ctx}, rgs...)...)}
}
func (_c *QueryCoordCatalog_SaveResourceGroup_Call) Run(run func(rgs ...*querypb.ResourceGroup)) *QueryCoordCatalog_SaveResourceGroup_Call {
func (_c *QueryCoordCatalog_SaveResourceGroup_Call) Run(run func(ctx context.Context, rgs ...*querypb.ResourceGroup)) *QueryCoordCatalog_SaveResourceGroup_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]*querypb.ResourceGroup, len(args)-0)
for i, a := range args[0:] {
variadicArgs := make([]*querypb.ResourceGroup, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(*querypb.ResourceGroup)
}
}
run(variadicArgs...)
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
@ -839,7 +926,7 @@ func (_c *QueryCoordCatalog_SaveResourceGroup_Call) Return(_a0 error) *QueryCoor
return _c
}
func (_c *QueryCoordCatalog_SaveResourceGroup_Call) RunAndReturn(run func(...*querypb.ResourceGroup) error) *QueryCoordCatalog_SaveResourceGroup_Call {
func (_c *QueryCoordCatalog_SaveResourceGroup_Call) RunAndReturn(run func(context.Context, ...*querypb.ResourceGroup) error) *QueryCoordCatalog_SaveResourceGroup_Call {
_c.Call.Return(run)
return _c
}

View File

@ -17,6 +17,7 @@
package balance
import (
"context"
"fmt"
"sort"
@ -57,9 +58,9 @@ func (chanPlan *ChannelAssignPlan) String() string {
}
type Balance interface {
AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan
AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan
BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)
AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan
AssignChannel(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan
BalanceReplica(ctx context.Context, replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)
}
type RoundRobinBalancer struct {
@ -67,7 +68,7 @@ type RoundRobinBalancer struct {
nodeManager *session.NodeManager
}
func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
func (b *RoundRobinBalancer) AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
@ -103,7 +104,7 @@ func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.
return ret
}
func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
func (b *RoundRobinBalancer) AssignChannel(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
versionRangeFilter := semver.MustParseRange(">2.3.x")
@ -136,7 +137,7 @@ func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []i
return ret
}
func (b *RoundRobinBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) {
func (b *RoundRobinBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) {
// TODO by chun.han
return nil, nil
}

View File

@ -17,6 +17,7 @@
package balance
import (
"context"
"testing"
"github.com/stretchr/testify/mock"
@ -50,6 +51,7 @@ func (suite *BalanceTestSuite) SetupTest() {
}
func (suite *BalanceTestSuite) TestAssignBalance() {
ctx := context.Background()
cases := []struct {
name string
nodeIDs []int64
@ -108,13 +110,14 @@ func (suite *BalanceTestSuite) TestAssignBalance() {
suite.mockScheduler.EXPECT().GetSegmentTaskDelta(c.nodeIDs[i], int64(-1)).Return(c.deltaCnts[i])
}
}
plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs, false)
plans := suite.roundRobinBalancer.AssignSegment(ctx, 0, c.assignments, c.nodeIDs, false)
suite.ElementsMatch(c.expectPlans, plans)
})
}
}
func (suite *BalanceTestSuite) TestAssignChannel() {
ctx := context.Background()
cases := []struct {
name string
nodeIDs []int64
@ -174,7 +177,7 @@ func (suite *BalanceTestSuite) TestAssignChannel() {
suite.mockScheduler.EXPECT().GetChannelTaskDelta(c.nodeIDs[i], int64(-1)).Return(c.deltaCnts[i])
}
}
plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs, false)
plans := suite.roundRobinBalancer.AssignChannel(ctx, c.assignments, c.nodeIDs, false)
suite.ElementsMatch(c.expectPlans, plans)
})
}

View File

@ -17,6 +17,7 @@
package balance
import (
"context"
"fmt"
"math"
"sort"
@ -49,7 +50,7 @@ func NewChannelLevelScoreBalancer(scheduler task.Scheduler,
}
}
func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) {
func (b *ChannelLevelScoreBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) {
log := log.With(
zap.Int64("collection", replica.GetCollectionID()),
zap.Int64("replica id", replica.GetID()),
@ -67,7 +68,7 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme
}()
exclusiveMode := true
channels := b.targetMgr.GetDmChannelsByCollection(replica.GetCollectionID(), meta.CurrentTarget)
channels := b.targetMgr.GetDmChannelsByCollection(ctx, replica.GetCollectionID(), meta.CurrentTarget)
for channelName := range channels {
if len(replica.GetChannelRWNodes(channelName)) == 0 {
exclusiveMode = false
@ -77,7 +78,7 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme
// if some channel doesn't own nodes, exit exclusive mode
if !exclusiveMode {
return b.ScoreBasedBalancer.BalanceReplica(replica)
return b.ScoreBasedBalancer.BalanceReplica(ctx, replica)
}
channelPlans = make([]ChannelAssignPlan, 0)
@ -122,19 +123,19 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme
)
// handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score
if b.permitBalanceChannel(replica.GetCollectionID()) {
channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, channelName, rwNodes, roNodes)...)
channelPlans = append(channelPlans, b.genStoppingChannelPlan(ctx, replica, channelName, rwNodes, roNodes)...)
}
if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) {
segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, channelName, rwNodes, roNodes)...)
segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(ctx, replica, channelName, rwNodes, roNodes)...)
}
} else {
if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) {
channelPlans = append(channelPlans, b.genChannelPlan(replica, channelName, rwNodes)...)
channelPlans = append(channelPlans, b.genChannelPlan(ctx, replica, channelName, rwNodes)...)
}
if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) {
segmentPlans = append(segmentPlans, b.genSegmentPlan(br, replica, channelName, rwNodes)...)
segmentPlans = append(segmentPlans, b.genSegmentPlan(ctx, br, replica, channelName, rwNodes)...)
}
}
}
@ -142,11 +143,11 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme
return segmentPlans, channelPlans
}
func (b *ChannelLevelScoreBalancer) genStoppingChannelPlan(replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan {
func (b *ChannelLevelScoreBalancer) genStoppingChannelPlan(ctx context.Context, replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
for _, nodeID := range offlineNodes {
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID), meta.WithChannelName2Channel(channelName))
plans := b.AssignChannel(dmChannels, onlineNodes, false)
plans := b.AssignChannel(ctx, dmChannels, onlineNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
@ -156,14 +157,14 @@ func (b *ChannelLevelScoreBalancer) genStoppingChannelPlan(replica *meta.Replica
return channelPlans
}
func (b *ChannelLevelScoreBalancer) genStoppingSegmentPlan(replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
func (b *ChannelLevelScoreBalancer) genStoppingSegmentPlan(ctx context.Context, replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
for _, nodeID := range offlineNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID), meta.WithChannel(channelName))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID())
return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID())
})
plans := b.AssignSegment(replica.GetCollectionID(), segments, onlineNodes, false)
plans := b.AssignSegment(ctx, replica.GetCollectionID(), segments, onlineNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
@ -173,7 +174,7 @@ func (b *ChannelLevelScoreBalancer) genStoppingSegmentPlan(replica *meta.Replica
return segmentPlans
}
func (b *ChannelLevelScoreBalancer) genSegmentPlan(br *balanceReport, replica *meta.Replica, channelName string, onlineNodes []int64) []SegmentAssignPlan {
func (b *ChannelLevelScoreBalancer) genSegmentPlan(ctx context.Context, br *balanceReport, replica *meta.Replica, channelName string, onlineNodes []int64) []SegmentAssignPlan {
segmentDist := make(map[int64][]*meta.Segment)
nodeItemsMap := b.convertToNodeItems(br, replica.GetCollectionID(), onlineNodes)
if len(nodeItemsMap) == 0 {
@ -189,7 +190,7 @@ func (b *ChannelLevelScoreBalancer) genSegmentPlan(br *balanceReport, replica *m
for _, node := range onlineNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node), meta.WithChannel(channelName))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID())
return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID())
})
segmentDist[node] = segments
}
@ -224,7 +225,7 @@ func (b *ChannelLevelScoreBalancer) genSegmentPlan(br *balanceReport, replica *m
return nil
}
segmentPlans := b.AssignSegment(replica.GetCollectionID(), segmentsToMove, onlineNodes, false)
segmentPlans := b.AssignSegment(ctx, replica.GetCollectionID(), segmentsToMove, onlineNodes, false)
for i := range segmentPlans {
segmentPlans[i].From = segmentPlans[i].Segment.Node
segmentPlans[i].Replica = replica
@ -233,7 +234,7 @@ func (b *ChannelLevelScoreBalancer) genSegmentPlan(br *balanceReport, replica *m
return segmentPlans
}
func (b *ChannelLevelScoreBalancer) genChannelPlan(replica *meta.Replica, channelName string, onlineNodes []int64) []ChannelAssignPlan {
func (b *ChannelLevelScoreBalancer) genChannelPlan(ctx context.Context, replica *meta.Replica, channelName string, onlineNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
if len(onlineNodes) > 1 {
// start to balance channels on all available nodes
@ -261,7 +262,7 @@ func (b *ChannelLevelScoreBalancer) genChannelPlan(replica *meta.Replica, channe
return nil
}
channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel, false)
channelPlans := b.AssignChannel(ctx, channelsToMove, nodeWithLessChannel, false)
for i := range channelPlans {
channelPlans[i].From = channelPlans[i].Channel.Node
channelPlans[i].Replica = replica

View File

@ -16,6 +16,7 @@
package balance
import (
"context"
"testing"
"github.com/samber/lo"
@ -85,6 +86,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TearDownTest() {
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegment() {
ctx := context.Background()
cases := []struct {
name string
comment string
@ -240,7 +242,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegment() {
suite.balancer.nodeManager.Add(nodeInfo)
}
for i := range c.collectionIDs {
plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes, false)
plans := balancer.AssignSegment(ctx, c.collectionIDs[i], c.assignments[i], c.nodes, false)
if c.unstableAssignment {
suite.Equal(len(plans), len(c.expectPlans[i]))
} else {
@ -252,6 +254,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegment() {
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegmentWithGrowing() {
ctx := context.Background()
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
@ -293,13 +296,14 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegmentWithGrowing()
CollectionID: 1,
}
suite.balancer.dist.LeaderViewManager.Update(1, leaderView)
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false)
plans := balancer.AssignSegment(ctx, 1, toAssign, lo.Keys(distributions), false)
for _, p := range plans {
suite.Equal(int64(2), p.To)
}
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -376,11 +380,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.distributions {
@ -400,7 +404,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
// 4. balance and verify result
@ -412,6 +416,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() {
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() {
ctx := context.Background()
balanceCase := struct {
name string
nodes []int64
@ -495,12 +500,12 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() {
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i]))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i],
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i]))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i],
append(balanceCase.nodes, balanceCase.notExistedNodes...)))
balancer.targetMgr.UpdateCollectionNextTarget(balanceCase.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i])
balancer.targetMgr.UpdateCollectionNextTarget(ctx, balanceCase.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, balanceCase.collectionIDs[i])
}
// 2. set up target for distribution for multi collections
@ -517,7 +522,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() {
})
nodeInfo.SetState(balanceCase.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(balanceCase.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, balanceCase.nodes[i])
}
// 4. first round balance
@ -535,6 +540,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() {
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -654,11 +660,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.distributions {
@ -678,11 +684,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
for i := range c.outBoundNodes {
suite.balancer.meta.ResourceManager.HandleNodeDown(c.outBoundNodes[i])
suite.balancer.meta.ResourceManager.HandleNodeDown(ctx, c.outBoundNodes[i])
}
utils.RecoverAllCollection(balancer.meta)
@ -695,6 +701,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() {
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() {
ctx := context.Background()
cases := []struct {
name string
collectionID int64
@ -771,13 +778,13 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
for replicaID, nodes := range c.replicaWithNodes {
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, c.collectionID, nodes))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(replicaID, c.collectionID, nodes))
}
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.segmentDist {
@ -798,7 +805,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.channelDist[nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodes[i])
}
}
@ -824,10 +831,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() {
func (suite *ChannelLevelScoreBalancerTestSuite) getCollectionBalancePlans(balancer *ChannelLevelScoreBalancer,
collectionID int64,
) ([]SegmentAssignPlan, []ChannelAssignPlan) {
replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID)
ctx := context.Background()
replicas := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)
segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0)
for _, replica := range replicas {
sPlans, cPlans := balancer.BalanceReplica(replica)
sPlans, cPlans := balancer.BalanceReplica(ctx, replica)
segmentPlans = append(segmentPlans, sPlans...)
channelPlans = append(channelPlans, cPlans...)
}
@ -835,6 +843,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) getCollectionBalancePlans(balan
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_ChannelOutBound() {
ctx := context.Background()
Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName)
defer Params.Reset(Params.QueryCoordCfg.Balancer.Key)
Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2")
@ -865,11 +874,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Cha
collection := utils.CreateTestCollection(collectionID, int32(1))
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID))
balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
balancer.targetMgr.UpdateCollectionNextTarget(collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID))
balancer.meta.ReplicaManager.Spawn(ctx, 1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID)
// 3. set up nodes info and resourceManager for balancer
nodeCount := 4
@ -883,11 +892,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Cha
// nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(session.NodeStateNormal)
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID())
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodeInfo.ID())
}
utils.RecoverAllCollection(balancer.meta)
replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0]
replica := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)[0]
ch1Nodes := replica.GetChannelRWNodes("channel1")
ch2Nodes := replica.GetChannelRWNodes("channel2")
suite.Len(ch1Nodes, 2)
@ -903,12 +912,13 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Cha
},
}...)
sPlans, cPlans := balancer.BalanceReplica(replica)
sPlans, cPlans := balancer.BalanceReplica(ctx, replica)
suite.Len(sPlans, 0)
suite.Len(cPlans, 1)
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_SegmentOutbound() {
ctx := context.Background()
Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName)
defer Params.Reset(Params.QueryCoordCfg.Balancer.Key)
Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2")
@ -939,11 +949,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg
collection := utils.CreateTestCollection(collectionID, int32(1))
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID))
balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
balancer.targetMgr.UpdateCollectionNextTarget(collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID))
balancer.meta.ReplicaManager.Spawn(ctx, 1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID)
// 3. set up nodes info and resourceManager for balancer
nodeCount := 4
@ -957,11 +967,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg
// nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(session.NodeStateNormal)
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID())
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodeInfo.ID())
}
utils.RecoverAllCollection(balancer.meta)
replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0]
replica := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)[0]
ch1Nodes := replica.GetChannelRWNodes("channel1")
ch2Nodes := replica.GetChannelRWNodes("channel2")
suite.Len(ch1Nodes, 2)
@ -1000,12 +1010,13 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg
},
}...)
sPlans, cPlans := balancer.BalanceReplica(replica)
sPlans, cPlans := balancer.BalanceReplica(ctx, replica)
suite.Len(sPlans, 1)
suite.Len(cPlans, 0)
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_NodeStopping() {
ctx := context.Background()
Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName)
defer Params.Reset(Params.QueryCoordCfg.Balancer.Key)
Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2")
@ -1036,11 +1047,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Nod
collection := utils.CreateTestCollection(collectionID, int32(1))
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID))
balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
balancer.targetMgr.UpdateCollectionNextTarget(collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID))
balancer.meta.ReplicaManager.Spawn(ctx, 1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID)
// 3. set up nodes info and resourceManager for balancer
nodeCount := 4
@ -1054,11 +1065,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Nod
// nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(session.NodeStateNormal)
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID())
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodeInfo.ID())
}
utils.RecoverAllCollection(balancer.meta)
replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0]
replica := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)[0]
ch1Nodes := replica.GetChannelRWNodes("channel1")
ch2Nodes := replica.GetChannelRWNodes("channel2")
suite.Len(ch1Nodes, 2)
@ -1112,24 +1123,25 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Nod
balancer.nodeManager.Stopping(ch1Nodes[0])
balancer.nodeManager.Stopping(ch2Nodes[0])
suite.balancer.meta.ResourceManager.HandleNodeStopping(ch1Nodes[0])
suite.balancer.meta.ResourceManager.HandleNodeStopping(ch2Nodes[0])
suite.balancer.meta.ResourceManager.HandleNodeStopping(ctx, ch1Nodes[0])
suite.balancer.meta.ResourceManager.HandleNodeStopping(ctx, ch2Nodes[0])
utils.RecoverAllCollection(balancer.meta)
replica = balancer.meta.ReplicaManager.Get(replica.GetID())
sPlans, cPlans := balancer.BalanceReplica(replica)
replica = balancer.meta.ReplicaManager.Get(ctx, replica.GetID())
sPlans, cPlans := balancer.BalanceReplica(ctx, replica)
suite.Len(sPlans, 0)
suite.Len(cPlans, 2)
balancer.dist.ChannelDistManager.Update(ch1Nodes[0])
balancer.dist.ChannelDistManager.Update(ch2Nodes[0])
sPlans, cPlans = balancer.BalanceReplica(replica)
sPlans, cPlans = balancer.BalanceReplica(ctx, replica)
suite.Len(sPlans, 2)
suite.Len(cPlans, 0)
}
func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_SegmentUnbalance() {
ctx := context.Background()
Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName)
defer Params.Reset(Params.QueryCoordCfg.Balancer.Key)
Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2")
@ -1160,11 +1172,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg
collection := utils.CreateTestCollection(collectionID, int32(1))
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID))
balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
balancer.targetMgr.UpdateCollectionNextTarget(collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID))
balancer.meta.ReplicaManager.Spawn(ctx, 1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID)
// 3. set up nodes info and resourceManager for balancer
nodeCount := 4
@ -1178,11 +1190,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg
// nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(session.NodeStateNormal)
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID())
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodeInfo.ID())
}
utils.RecoverAllCollection(balancer.meta)
replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0]
replica := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)[0]
ch1Nodes := replica.GetChannelRWNodes("channel1")
ch2Nodes := replica.GetChannelRWNodes("channel2")
suite.Len(ch1Nodes, 2)
@ -1254,7 +1266,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg
},
}...)
sPlans, cPlans := balancer.BalanceReplica(replica)
sPlans, cPlans := balancer.BalanceReplica(ctx, replica)
suite.Len(sPlans, 2)
suite.Len(cPlans, 0)
}

View File

@ -3,6 +3,8 @@
package balance
import (
context "context"
meta "github.com/milvus-io/milvus/internal/querycoordv2/meta"
mock "github.com/stretchr/testify/mock"
)
@ -20,17 +22,17 @@ func (_m *MockBalancer) EXPECT() *MockBalancer_Expecter {
return &MockBalancer_Expecter{mock: &_m.Mock}
}
// AssignChannel provides a mock function with given fields: channels, nodes, manualBalance
func (_m *MockBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
ret := _m.Called(channels, nodes, manualBalance)
// AssignChannel provides a mock function with given fields: ctx, channels, nodes, manualBalance
func (_m *MockBalancer) AssignChannel(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
ret := _m.Called(ctx, channels, nodes, manualBalance)
if len(ret) == 0 {
panic("no return value specified for AssignChannel")
}
var r0 []ChannelAssignPlan
if rf, ok := ret.Get(0).(func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan); ok {
r0 = rf(channels, nodes, manualBalance)
if rf, ok := ret.Get(0).(func(context.Context, []*meta.DmChannel, []int64, bool) []ChannelAssignPlan); ok {
r0 = rf(ctx, channels, nodes, manualBalance)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]ChannelAssignPlan)
@ -46,16 +48,17 @@ type MockBalancer_AssignChannel_Call struct {
}
// AssignChannel is a helper method to define mock.On call
// - ctx context.Context
// - channels []*meta.DmChannel
// - nodes []int64
// - manualBalance bool
func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignChannel_Call {
return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes, manualBalance)}
func (_e *MockBalancer_Expecter) AssignChannel(ctx interface{}, channels interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignChannel_Call {
return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", ctx, channels, nodes, manualBalance)}
}
func (_c *MockBalancer_AssignChannel_Call) Run(run func(channels []*meta.DmChannel, nodes []int64, manualBalance bool)) *MockBalancer_AssignChannel_Call {
func (_c *MockBalancer_AssignChannel_Call) Run(run func(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool)) *MockBalancer_AssignChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]*meta.DmChannel), args[1].([]int64), args[2].(bool))
run(args[0].(context.Context), args[1].([]*meta.DmChannel), args[2].([]int64), args[3].(bool))
})
return _c
}
@ -65,22 +68,22 @@ func (_c *MockBalancer_AssignChannel_Call) Return(_a0 []ChannelAssignPlan) *Mock
return _c
}
func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call {
func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func(context.Context, []*meta.DmChannel, []int64, bool) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call {
_c.Call.Return(run)
return _c
}
// AssignSegment provides a mock function with given fields: collectionID, segments, nodes, manualBalance
func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
ret := _m.Called(collectionID, segments, nodes, manualBalance)
// AssignSegment provides a mock function with given fields: ctx, collectionID, segments, nodes, manualBalance
func (_m *MockBalancer) AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
ret := _m.Called(ctx, collectionID, segments, nodes, manualBalance)
if len(ret) == 0 {
panic("no return value specified for AssignSegment")
}
var r0 []SegmentAssignPlan
if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan); ok {
r0 = rf(collectionID, segments, nodes, manualBalance)
if rf, ok := ret.Get(0).(func(context.Context, int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan); ok {
r0 = rf(ctx, collectionID, segments, nodes, manualBalance)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]SegmentAssignPlan)
@ -96,17 +99,18 @@ type MockBalancer_AssignSegment_Call struct {
}
// AssignSegment is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - segments []*meta.Segment
// - nodes []int64
// - manualBalance bool
func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignSegment_Call {
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes, manualBalance)}
func (_e *MockBalancer_Expecter) AssignSegment(ctx interface{}, collectionID interface{}, segments interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignSegment_Call {
return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", ctx, collectionID, segments, nodes, manualBalance)}
}
func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool)) *MockBalancer_AssignSegment_Call {
func (_c *MockBalancer_AssignSegment_Call) Run(run func(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool)) *MockBalancer_AssignSegment_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64), args[3].(bool))
run(args[0].(context.Context), args[1].(int64), args[2].([]*meta.Segment), args[3].([]int64), args[4].(bool))
})
return _c
}
@ -116,14 +120,14 @@ func (_c *MockBalancer_AssignSegment_Call) Return(_a0 []SegmentAssignPlan) *Mock
return _c
}
func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call {
func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(context.Context, int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call {
_c.Call.Return(run)
return _c
}
// BalanceReplica provides a mock function with given fields: replica
func (_m *MockBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) {
ret := _m.Called(replica)
// BalanceReplica provides a mock function with given fields: ctx, replica
func (_m *MockBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) {
ret := _m.Called(ctx, replica)
if len(ret) == 0 {
panic("no return value specified for BalanceReplica")
@ -131,19 +135,19 @@ func (_m *MockBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAssignPl
var r0 []SegmentAssignPlan
var r1 []ChannelAssignPlan
if rf, ok := ret.Get(0).(func(*meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)); ok {
return rf(replica)
if rf, ok := ret.Get(0).(func(context.Context, *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)); ok {
return rf(ctx, replica)
}
if rf, ok := ret.Get(0).(func(*meta.Replica) []SegmentAssignPlan); ok {
r0 = rf(replica)
if rf, ok := ret.Get(0).(func(context.Context, *meta.Replica) []SegmentAssignPlan); ok {
r0 = rf(ctx, replica)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]SegmentAssignPlan)
}
}
if rf, ok := ret.Get(1).(func(*meta.Replica) []ChannelAssignPlan); ok {
r1 = rf(replica)
if rf, ok := ret.Get(1).(func(context.Context, *meta.Replica) []ChannelAssignPlan); ok {
r1 = rf(ctx, replica)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).([]ChannelAssignPlan)
@ -159,14 +163,15 @@ type MockBalancer_BalanceReplica_Call struct {
}
// BalanceReplica is a helper method to define mock.On call
// - ctx context.Context
// - replica *meta.Replica
func (_e *MockBalancer_Expecter) BalanceReplica(replica interface{}) *MockBalancer_BalanceReplica_Call {
return &MockBalancer_BalanceReplica_Call{Call: _e.mock.On("BalanceReplica", replica)}
func (_e *MockBalancer_Expecter) BalanceReplica(ctx interface{}, replica interface{}) *MockBalancer_BalanceReplica_Call {
return &MockBalancer_BalanceReplica_Call{Call: _e.mock.On("BalanceReplica", ctx, replica)}
}
func (_c *MockBalancer_BalanceReplica_Call) Run(run func(replica *meta.Replica)) *MockBalancer_BalanceReplica_Call {
func (_c *MockBalancer_BalanceReplica_Call) Run(run func(ctx context.Context, replica *meta.Replica)) *MockBalancer_BalanceReplica_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*meta.Replica))
run(args[0].(context.Context), args[1].(*meta.Replica))
})
return _c
}
@ -176,7 +181,7 @@ func (_c *MockBalancer_BalanceReplica_Call) Return(_a0 []SegmentAssignPlan, _a1
return _c
}
func (_c *MockBalancer_BalanceReplica_Call) RunAndReturn(run func(*meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)) *MockBalancer_BalanceReplica_Call {
func (_c *MockBalancer_BalanceReplica_Call) RunAndReturn(run func(context.Context, *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)) *MockBalancer_BalanceReplica_Call {
_c.Call.Return(run)
return _c
}

View File

@ -1,6 +1,7 @@
package balance
import (
"context"
"fmt"
"math"
"math/rand"
@ -468,7 +469,7 @@ type MultiTargetBalancer struct {
targetMgr meta.TargetManagerInterface
}
func (b *MultiTargetBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) {
func (b *MultiTargetBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) {
log := log.With(
zap.Int64("collection", replica.GetCollectionID()),
zap.Int64("replica id", replica.GetID()),
@ -510,32 +511,32 @@ func (b *MultiTargetBalancer) BalanceReplica(replica *meta.Replica) (segmentPlan
)
// handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score
if b.permitBalanceChannel(replica.GetCollectionID()) {
channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...)
channelPlans = append(channelPlans, b.genStoppingChannelPlan(ctx, replica, rwNodes, roNodes)...)
}
if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) {
segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...)
segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(ctx, replica, rwNodes, roNodes)...)
}
} else {
if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) {
channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...)
channelPlans = append(channelPlans, b.genChannelPlan(ctx, br, replica, rwNodes)...)
}
if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) {
segmentPlans = b.genSegmentPlan(replica, rwNodes)
segmentPlans = b.genSegmentPlan(ctx, replica, rwNodes)
}
}
return segmentPlans, channelPlans
}
func (b *MultiTargetBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan {
func (b *MultiTargetBalancer) genSegmentPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan {
// get segments distribution on replica level and global level
nodeSegments := make(map[int64][]*meta.Segment)
globalNodeSegments := make(map[int64][]*meta.Segment)
for _, node := range rwNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID())
return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID())
})
nodeSegments[node] = segments
globalNodeSegments[node] = b.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(node))

View File

@ -42,7 +42,7 @@ type RowCountBasedBalancer struct {
// AssignSegment, when row count based balancer assign segments, it will assign segment to node with least global row count.
// try to make every query node has same row count.
func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
nodes = lo.Filter(nodes, func(node int64, _ int) bool {
@ -87,7 +87,7 @@ func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*me
// AssignSegment, when row count based balancer assign segments, it will assign channel to node with least global channel count.
// try to make every query node has channel count
func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan {
// skip out suspend node and stopping node during assignment, but skip this check for manual balance
if !manualBalance {
versionRangeFilter := semver.MustParseRange(">2.3.x")
@ -167,7 +167,7 @@ func (b *RowCountBasedBalancer) convertToNodeItemsByChannel(nodeIDs []int64) []*
return ret
}
func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) {
func (b *RowCountBasedBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) {
log := log.Ctx(context.TODO()).WithRateGroup("qcv2.RowCountBasedBalancer", 1, 60).With(
zap.Int64("collectionID", replica.GetCollectionID()),
zap.Int64("replicaID", replica.GetCollectionID()),
@ -206,33 +206,33 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPl
)
// handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score
if b.permitBalanceChannel(replica.GetCollectionID()) {
channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...)
channelPlans = append(channelPlans, b.genStoppingChannelPlan(ctx, replica, rwNodes, roNodes)...)
}
if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) {
segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...)
segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(ctx, replica, rwNodes, roNodes)...)
}
} else {
if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) {
channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...)
channelPlans = append(channelPlans, b.genChannelPlan(ctx, br, replica, rwNodes)...)
}
if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) {
segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, rwNodes)...)
segmentPlans = append(segmentPlans, b.genSegmentPlan(ctx, replica, rwNodes)...)
}
}
return segmentPlans, channelPlans
}
func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, rwNodes []int64, roNodes []int64) []SegmentAssignPlan {
func (b *RowCountBasedBalancer) genStoppingSegmentPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64, roNodes []int64) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
for _, nodeID := range roNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID())
return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID())
})
plans := b.AssignSegment(replica.GetCollectionID(), segments, rwNodes, false)
plans := b.AssignSegment(ctx, replica.GetCollectionID(), segments, rwNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
@ -242,7 +242,7 @@ func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, rw
return segmentPlans
}
func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan {
func (b *RowCountBasedBalancer) genSegmentPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan {
segmentsToMove := make([]*meta.Segment, 0)
nodeRowCount := make(map[int64]int, 0)
@ -251,7 +251,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []
for _, node := range rwNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID())
return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID())
})
rowCount := 0
for _, s := range segments {
@ -298,7 +298,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []
return nil
}
segmentPlans := b.AssignSegment(replica.GetCollectionID(), segmentsToMove, nodesWithLessRow, false)
segmentPlans := b.AssignSegment(ctx, replica.GetCollectionID(), segmentsToMove, nodesWithLessRow, false)
for i := range segmentPlans {
segmentPlans[i].From = segmentPlans[i].Segment.Node
segmentPlans[i].Replica = replica
@ -307,11 +307,11 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []
return segmentPlans
}
func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, rwNodes []int64, roNodes []int64) []ChannelAssignPlan {
func (b *RowCountBasedBalancer) genStoppingChannelPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64, roNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
for _, nodeID := range roNodes {
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID))
plans := b.AssignChannel(dmChannels, rwNodes, false)
plans := b.AssignChannel(ctx, dmChannels, rwNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
@ -321,7 +321,7 @@ func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, rw
return channelPlans
}
func (b *RowCountBasedBalancer) genChannelPlan(br *balanceReport, replica *meta.Replica, rwNodes []int64) []ChannelAssignPlan {
func (b *RowCountBasedBalancer) genChannelPlan(ctx context.Context, br *balanceReport, replica *meta.Replica, rwNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
if len(rwNodes) > 1 {
// start to balance channels on all available nodes
@ -349,7 +349,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(br *balanceReport, replica *meta.
return nil
}
channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel, false)
channelPlans := b.AssignChannel(ctx, channelsToMove, nodeWithLessChannel, false)
for i := range channelPlans {
channelPlans[i].From = channelPlans[i].Channel.Node
channelPlans[i].Replica = replica

View File

@ -17,6 +17,7 @@
package balance
import (
"context"
"fmt"
"testing"
@ -90,6 +91,7 @@ func (suite *RowCountBasedBalancerTestSuite) TearDownTest() {
}
func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() {
ctx := context.Background()
cases := []struct {
name string
distributions map[int64][]*meta.Segment
@ -142,13 +144,14 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() {
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
}
plans := balancer.AssignSegment(0, c.assignments, c.nodes, false)
plans := balancer.AssignSegment(ctx, 0, c.assignments, c.nodes, false)
assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans, plans)
})
}
}
func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -403,13 +406,13 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes))
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, c.nodes))
suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil)
balancer.targetMgr.UpdateCollectionNextTarget(int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(1)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, 1)
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
@ -427,7 +430,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
utils.RecoverAllCollection(balancer.meta)
@ -443,7 +446,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
// clear distribution
for _, node := range c.nodes {
balancer.meta.ResourceManager.HandleNodeDown(node)
balancer.meta.ResourceManager.HandleNodeDown(ctx, node)
balancer.nodeManager.Remove(node)
balancer.dist.SegmentDistManager.Update(node)
balancer.dist.ChannelDistManager.Update(node)
@ -453,6 +456,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
}
func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -614,15 +618,15 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
collection.LoadPercentage = 100
collection.LoadType = querypb.LoadType_LoadCollection
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, c.segmentInCurrent, nil)
balancer.targetMgr.UpdateCollectionNextTarget(int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(1)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, 1)
suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, c.segmentInNext, nil)
balancer.targetMgr.UpdateCollectionNextTarget(int64(1))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
@ -640,7 +644,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
utils.RecoverAllCollection(balancer.meta)
@ -652,6 +656,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
}
func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -759,12 +764,12 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() {
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil)
balancer.targetMgr.UpdateCollectionNextTarget(int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(1)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, 1)
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
@ -784,8 +789,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() {
suite.balancer.nodeManager.Add(nodeInfo)
}
// make node-3 outbound
balancer.meta.ResourceManager.HandleNodeUp(1)
balancer.meta.ResourceManager.HandleNodeUp(2)
balancer.meta.ResourceManager.HandleNodeUp(ctx, 1)
balancer.meta.ResourceManager.HandleNodeUp(ctx, 2)
utils.RecoverAllCollection(balancer.meta)
segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, 1)
assertChannelAssignPlanElementMatch(&suite.Suite, c.expectChannelPlans, channelPlans)
@ -801,6 +806,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() {
}
func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -830,8 +836,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() {
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loading
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes))
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, c.nodes))
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
@ -845,10 +851,11 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() {
func (suite *RowCountBasedBalancerTestSuite) getCollectionBalancePlans(balancer *RowCountBasedBalancer,
collectionID int64,
) ([]SegmentAssignPlan, []ChannelAssignPlan) {
replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID)
ctx := context.Background()
replicas := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)
segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0)
for _, replica := range replicas {
sPlans, cPlans := balancer.BalanceReplica(replica)
sPlans, cPlans := balancer.BalanceReplica(ctx, replica)
segmentPlans = append(segmentPlans, sPlans...)
channelPlans = append(channelPlans, cPlans...)
}
@ -859,6 +866,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
ctx := context.Background()
distributions := map[int64][]*meta.Segment{
1: {
@ -895,13 +903,14 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
NumOfGrowingRows: 50,
}
suite.balancer.dist.LeaderViewManager.Update(1, leaderView)
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false)
plans := balancer.AssignSegment(ctx, 1, toAssign, lo.Keys(distributions), false)
for _, p := range plans {
suite.Equal(int64(2), p.To)
}
}
func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -989,13 +998,13 @@ func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() {
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil)
balancer.targetMgr.UpdateCollectionNextTarget(int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(1)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, 1)
for node, s := range c.distributions {
balancer.dist.SegmentDistManager.Update(node, s...)
}
@ -1013,7 +1022,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
Params.Save(Params.QueryCoordCfg.AutoBalanceChannel.Key, fmt.Sprint(c.enableBalanceChannel))
@ -1039,6 +1048,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() {
}
func (suite *RowCountBasedBalancerTestSuite) TestMultiReplicaBalance() {
ctx := context.Background()
cases := []struct {
name string
collectionID int64
@ -1115,13 +1125,13 @@ func (suite *RowCountBasedBalancerTestSuite) TestMultiReplicaBalance() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
for replicaID, nodes := range c.replicaWithNodes {
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, c.collectionID, nodes))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(replicaID, c.collectionID, nodes))
}
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.segmentDist {
@ -1142,7 +1152,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestMultiReplicaBalance() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.channelDist[nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodes[i])
}
}

View File

@ -17,6 +17,7 @@
package balance
import (
"context"
"fmt"
"math"
"sort"
@ -50,7 +51,7 @@ func NewScoreBasedBalancer(scheduler task.Scheduler,
}
// AssignSegment got a segment list, and try to assign each segment to node's with lowest score
func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
func (b *ScoreBasedBalancer) AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan {
br := NewBalanceReport()
return b.assignSegment(br, collectionID, segments, nodes, manualBalance)
}
@ -263,7 +264,7 @@ func (b *ScoreBasedBalancer) calculateSegmentScore(s *meta.Segment) float64 {
return float64(s.GetNumOfRows()) * (1 + params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat())
}
func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) {
func (b *ScoreBasedBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) {
log := log.With(
zap.Int64("collection", replica.GetCollectionID()),
zap.Int64("replica id", replica.GetID()),
@ -308,32 +309,32 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans
br.AddRecord(StrRecordf("executing stopping balance: %v", roNodes))
// handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score
if b.permitBalanceChannel(replica.GetCollectionID()) {
channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...)
channelPlans = append(channelPlans, b.genStoppingChannelPlan(ctx, replica, rwNodes, roNodes)...)
}
if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) {
segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...)
segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(ctx, replica, rwNodes, roNodes)...)
}
} else {
if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) {
channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...)
channelPlans = append(channelPlans, b.genChannelPlan(ctx, br, replica, rwNodes)...)
}
if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) {
segmentPlans = append(segmentPlans, b.genSegmentPlan(br, replica, rwNodes)...)
segmentPlans = append(segmentPlans, b.genSegmentPlan(ctx, br, replica, rwNodes)...)
}
}
return segmentPlans, channelPlans
}
func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
func (b *ScoreBasedBalancer) genStoppingSegmentPlan(ctx context.Context, replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan {
segmentPlans := make([]SegmentAssignPlan, 0)
for _, nodeID := range offlineNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID())
return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID())
})
plans := b.AssignSegment(replica.GetCollectionID(), segments, onlineNodes, false)
plans := b.AssignSegment(ctx, replica.GetCollectionID(), segments, onlineNodes, false)
for i := range plans {
plans[i].From = nodeID
plans[i].Replica = replica
@ -343,7 +344,7 @@ func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlin
return segmentPlans
}
func (b *ScoreBasedBalancer) genSegmentPlan(br *balanceReport, replica *meta.Replica, onlineNodes []int64) []SegmentAssignPlan {
func (b *ScoreBasedBalancer) genSegmentPlan(ctx context.Context, br *balanceReport, replica *meta.Replica, onlineNodes []int64) []SegmentAssignPlan {
segmentDist := make(map[int64][]*meta.Segment)
nodeItemsMap := b.convertToNodeItems(br, replica.GetCollectionID(), onlineNodes)
if len(nodeItemsMap) == 0 {
@ -359,7 +360,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(br *balanceReport, replica *meta.Rep
for _, node := range onlineNodes {
dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))
segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool {
return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID())
return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID())
})
segmentDist[node] = segments
}

View File

@ -16,6 +16,7 @@
package balance
import (
"context"
"testing"
"github.com/samber/lo"
@ -85,6 +86,7 @@ func (suite *ScoreBasedBalancerTestSuite) TearDownTest() {
}
func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() {
ctx := context.Background()
cases := []struct {
name string
comment string
@ -240,7 +242,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() {
suite.balancer.nodeManager.Add(nodeInfo)
}
for i := range c.collectionIDs {
plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes, false)
plans := balancer.AssignSegment(ctx, c.collectionIDs[i], c.assignments[i], c.nodes, false)
if c.unstableAssignment {
suite.Len(plans, len(c.expectPlans[i]))
} else {
@ -255,9 +257,10 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
suite.SetupSuite()
defer suite.TearDownTest()
balancer := suite.balancer
ctx := context.Background()
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.DelegatorMemoryOverloadFactor.Key, "0.3")
suite.balancer.meta.PutCollection(&meta.Collection{
suite.balancer.meta.PutCollection(ctx, &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: 1,
},
@ -300,13 +303,14 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() {
CollectionID: 1,
}
suite.balancer.dist.LeaderViewManager.Update(1, leaderView)
plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false)
plans := balancer.AssignSegment(ctx, 1, toAssign, lo.Keys(distributions), false)
for _, p := range plans {
suite.Equal(int64(2), p.To)
}
}
func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -377,11 +381,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.distributions {
@ -401,7 +405,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
utils.RecoverAllCollection(balancer.meta)
@ -414,6 +418,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
}
func (suite *ScoreBasedBalancerTestSuite) TestDelegatorPreserveMemory() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -463,12 +468,12 @@ func (suite *ScoreBasedBalancerTestSuite) TestDelegatorPreserveMemory() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.distributions {
@ -494,7 +499,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestDelegatorPreserveMemory() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
utils.RecoverAllCollection(balancer.meta)
@ -520,6 +525,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestDelegatorPreserveMemory() {
}
func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -572,11 +578,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.distributions {
@ -596,7 +602,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
utils.RecoverAllCollection(balancer.meta)
@ -618,6 +624,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() {
}
func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() {
ctx := context.Background()
balanceCase := struct {
name string
nodes []int64
@ -695,12 +702,12 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() {
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i]))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i],
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i]))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i],
append(balanceCase.nodes, balanceCase.notExistedNodes...)))
balancer.targetMgr.UpdateCollectionNextTarget(balanceCase.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i])
balancer.targetMgr.UpdateCollectionNextTarget(ctx, balanceCase.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, balanceCase.collectionIDs[i])
}
// 2. set up target for distribution for multi collections
@ -717,7 +724,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() {
})
nodeInfo.SetState(balanceCase.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(balanceCase.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, balanceCase.nodes[i])
}
// 4. first round balance
@ -735,6 +742,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() {
}
func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -838,11 +846,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.distributions {
@ -862,11 +870,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
for i := range c.outBoundNodes {
suite.balancer.meta.ResourceManager.HandleNodeDown(c.outBoundNodes[i])
suite.balancer.meta.ResourceManager.HandleNodeDown(ctx, c.outBoundNodes[i])
}
utils.RecoverAllCollection(balancer.meta)
@ -879,6 +887,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
}
func (suite *ScoreBasedBalancerTestSuite) TestMultiReplicaBalance() {
ctx := context.Background()
cases := []struct {
name string
collectionID int64
@ -955,13 +964,13 @@ func (suite *ScoreBasedBalancerTestSuite) TestMultiReplicaBalance() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
for replicaID, nodes := range c.replicaWithNodes {
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, c.collectionID, nodes))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(replicaID, c.collectionID, nodes))
}
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.segmentDist {
@ -982,7 +991,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestMultiReplicaBalance() {
nodeInfo.UpdateStats(session.WithChannelCnt(len(c.channelDist[nodes[i]])))
nodeInfo.SetState(c.states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodes[i])
}
}
@ -1006,6 +1015,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestMultiReplicaBalance() {
}
func (suite *ScoreBasedBalancerTestSuite) TestQNMemoryCapacity() {
ctx := context.Background()
cases := []struct {
name string
nodes []int64
@ -1054,12 +1064,12 @@ func (suite *ScoreBasedBalancerTestSuite) TestQNMemoryCapacity() {
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID)
balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID)
balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID)
// 2. set up target for distribution for multi collections
for node, s := range c.distributions {
@ -1081,7 +1091,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestQNMemoryCapacity() {
nodeInfo.SetState(c.states[i])
nodeInfoMap[c.nodes[i]] = nodeInfo
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i])
}
utils.RecoverAllCollection(balancer.meta)
@ -1113,10 +1123,11 @@ func TestScoreBasedBalancerSuite(t *testing.T) {
func (suite *ScoreBasedBalancerTestSuite) getCollectionBalancePlans(balancer *ScoreBasedBalancer,
collectionID int64,
) ([]SegmentAssignPlan, []ChannelAssignPlan) {
replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID)
ctx := context.Background()
replicas := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)
segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0)
for _, replica := range replicas {
sPlans, cPlans := balancer.BalanceReplica(replica)
sPlans, cPlans := balancer.BalanceReplica(ctx, replica)
segmentPlans = append(segmentPlans, sPlans...)
channelPlans = append(channelPlans, cPlans...)
}
@ -1124,6 +1135,7 @@ func (suite *ScoreBasedBalancerTestSuite) getCollectionBalancePlans(balancer *Sc
}
func (suite *ScoreBasedBalancerTestSuite) TestBalanceSegmentAndChannel() {
ctx := context.Background()
nodes := []int64{1, 2, 3}
collectionID := int64(1)
replicaID := int64(1)
@ -1140,11 +1152,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceSegmentAndChannel() {
suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe()
collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded
balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, collectionID))
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, collectionID, nodes))
balancer.targetMgr.UpdateCollectionNextTarget(collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID)
balancer.meta.CollectionManager.PutCollection(ctx, collection)
balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, collectionID))
balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(replicaID, collectionID, nodes))
balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID)
for i := range nodes {
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
@ -1155,7 +1167,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceSegmentAndChannel() {
})
nodeInfo.SetState(states[i])
suite.balancer.nodeManager.Add(nodeInfo)
suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i])
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodes[i])
}
utils.RecoverAllCollection(balancer.meta)

View File

@ -73,19 +73,19 @@ func (b *BalanceChecker) Description() string {
return "BalanceChecker checks the cluster distribution and generates balance tasks"
}
func (b *BalanceChecker) readyToCheck(collectionID int64) bool {
metaExist := (b.meta.GetCollection(collectionID) != nil)
targetExist := b.targetMgr.IsNextTargetExist(collectionID) || b.targetMgr.IsCurrentTargetExist(collectionID, common.AllPartitionsID)
func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) bool {
metaExist := (b.meta.GetCollection(ctx, collectionID) != nil)
targetExist := b.targetMgr.IsNextTargetExist(ctx, collectionID) || b.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID)
return metaExist && targetExist
}
func (b *BalanceChecker) replicasToBalance() []int64 {
ids := b.meta.GetAll()
func (b *BalanceChecker) replicasToBalance(ctx context.Context) []int64 {
ids := b.meta.GetAll(ctx)
// all replicas belonging to loading collection will be skipped
loadedCollections := lo.Filter(ids, func(cid int64, _ int) bool {
collection := b.meta.GetCollection(cid)
collection := b.meta.GetCollection(ctx, cid)
return collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded
})
sort.Slice(loadedCollections, func(i, j int) bool {
@ -97,10 +97,10 @@ func (b *BalanceChecker) replicasToBalance() []int64 {
stoppingReplicas := make([]int64, 0)
for _, cid := range loadedCollections {
// if target and meta isn't ready, skip balance this collection
if !b.readyToCheck(cid) {
if !b.readyToCheck(ctx, cid) {
continue
}
replicas := b.meta.ReplicaManager.GetByCollection(cid)
replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid)
for _, replica := range replicas {
if replica.RONodesCount() > 0 {
stoppingReplicas = append(stoppingReplicas, replica.GetID())
@ -130,7 +130,7 @@ func (b *BalanceChecker) replicasToBalance() []int64 {
}
hasUnbalancedCollection = true
b.normalBalanceCollectionsCurrentRound.Insert(cid)
for _, replica := range b.meta.ReplicaManager.GetByCollection(cid) {
for _, replica := range b.meta.ReplicaManager.GetByCollection(ctx, cid) {
normalReplicasToBalance = append(normalReplicasToBalance, replica.GetID())
}
break
@ -144,14 +144,14 @@ func (b *BalanceChecker) replicasToBalance() []int64 {
return normalReplicasToBalance
}
func (b *BalanceChecker) balanceReplicas(replicaIDs []int64) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) {
func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) {
segmentPlans, channelPlans := make([]balance.SegmentAssignPlan, 0), make([]balance.ChannelAssignPlan, 0)
for _, rid := range replicaIDs {
replica := b.meta.ReplicaManager.Get(rid)
replica := b.meta.ReplicaManager.Get(ctx, rid)
if replica == nil {
continue
}
sPlans, cPlans := b.getBalancerFunc().BalanceReplica(replica)
sPlans, cPlans := b.getBalancerFunc().BalanceReplica(ctx, replica)
segmentPlans = append(segmentPlans, sPlans...)
channelPlans = append(channelPlans, cPlans...)
if len(segmentPlans) != 0 || len(channelPlans) != 0 {
@ -164,12 +164,12 @@ func (b *BalanceChecker) balanceReplicas(replicaIDs []int64) ([]balance.SegmentA
func (b *BalanceChecker) Check(ctx context.Context) []task.Task {
ret := make([]task.Task, 0)
replicasToBalance := b.replicasToBalance()
segmentPlans, channelPlans := b.balanceReplicas(replicasToBalance)
replicasToBalance := b.replicasToBalance(ctx)
segmentPlans, channelPlans := b.balanceReplicas(ctx, replicasToBalance)
// iterate all collection to find a collection to balance
for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.normalBalanceCollectionsCurrentRound.Len() > 0 {
replicasToBalance := b.replicasToBalance()
segmentPlans, channelPlans = b.balanceReplicas(replicasToBalance)
replicasToBalance := b.replicasToBalance(ctx)
segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance)
}
tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans)

View File

@ -86,6 +86,7 @@ func (suite *BalanceCheckerTestSuite) TearDownTest() {
}
func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() {
ctx := context.Background()
// set up nodes info
nodeID1, nodeID2 := 1, 2
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
@ -98,8 +99,8 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() {
Address: "localhost",
Hostname: "localhost",
}))
suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID1))
suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID2))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2))
// set collections meta
segments := []*datapb.SegmentInfo{
@ -123,46 +124,47 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() {
collection1.Status = querypb.LoadStatus_Loaded
replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)})
partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1))
suite.checker.meta.CollectionManager.PutCollection(collection1, partition1)
suite.checker.meta.ReplicaManager.Put(replica1)
suite.targetMgr.UpdateCollectionNextTarget(int64(cid1))
suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid1))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1)
suite.checker.meta.ReplicaManager.Put(ctx, replica1)
suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1))
cid2, replicaID2, partitionID2 := 2, 2, 2
collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2))
collection2.Status = querypb.LoadStatus_Loaded
replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)})
partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2))
suite.checker.meta.CollectionManager.PutCollection(collection2, partition2)
suite.checker.meta.ReplicaManager.Put(replica2)
suite.targetMgr.UpdateCollectionNextTarget(int64(cid2))
suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid2))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2)
suite.checker.meta.ReplicaManager.Put(ctx, replica2)
suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2))
// test disable auto balance
paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "false")
suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int {
return 0
})
replicasToBalance := suite.checker.replicasToBalance()
replicasToBalance := suite.checker.replicasToBalance(ctx)
suite.Empty(replicasToBalance)
segPlans, _ := suite.checker.balanceReplicas(replicasToBalance)
segPlans, _ := suite.checker.balanceReplicas(ctx, replicasToBalance)
suite.Empty(segPlans)
// test enable auto balance
paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true")
idsToBalance := []int64{int64(replicaID1)}
replicasToBalance = suite.checker.replicasToBalance()
replicasToBalance = suite.checker.replicasToBalance(ctx)
suite.ElementsMatch(idsToBalance, replicasToBalance)
// next round
idsToBalance = []int64{int64(replicaID2)}
replicasToBalance = suite.checker.replicasToBalance()
replicasToBalance = suite.checker.replicasToBalance(ctx)
suite.ElementsMatch(idsToBalance, replicasToBalance)
// final round
replicasToBalance = suite.checker.replicasToBalance()
replicasToBalance = suite.checker.replicasToBalance(ctx)
suite.Empty(replicasToBalance)
}
func (suite *BalanceCheckerTestSuite) TestBusyScheduler() {
ctx := context.Background()
// set up nodes info
nodeID1, nodeID2 := 1, 2
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
@ -175,8 +177,8 @@ func (suite *BalanceCheckerTestSuite) TestBusyScheduler() {
Address: "localhost",
Hostname: "localhost",
}))
suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID1))
suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID2))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2))
segments := []*datapb.SegmentInfo{
{
@ -199,31 +201,32 @@ func (suite *BalanceCheckerTestSuite) TestBusyScheduler() {
collection1.Status = querypb.LoadStatus_Loaded
replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)})
partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1))
suite.checker.meta.CollectionManager.PutCollection(collection1, partition1)
suite.checker.meta.ReplicaManager.Put(replica1)
suite.targetMgr.UpdateCollectionNextTarget(int64(cid1))
suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid1))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1)
suite.checker.meta.ReplicaManager.Put(ctx, replica1)
suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1))
cid2, replicaID2, partitionID2 := 2, 2, 2
collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2))
collection2.Status = querypb.LoadStatus_Loaded
replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)})
partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2))
suite.checker.meta.CollectionManager.PutCollection(collection2, partition2)
suite.checker.meta.ReplicaManager.Put(replica2)
suite.targetMgr.UpdateCollectionNextTarget(int64(cid2))
suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid2))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2)
suite.checker.meta.ReplicaManager.Put(ctx, replica2)
suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2))
// test scheduler busy
paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true")
suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int {
return 1
})
replicasToBalance := suite.checker.replicasToBalance()
replicasToBalance := suite.checker.replicasToBalance(ctx)
suite.Len(replicasToBalance, 1)
}
func (suite *BalanceCheckerTestSuite) TestStoppingBalance() {
ctx := context.Background()
// set up nodes info, stopping node1
nodeID1, nodeID2 := 1, 2
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
@ -237,8 +240,8 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() {
Hostname: "localhost",
}))
suite.nodeMgr.Stopping(int64(nodeID1))
suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID1))
suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID2))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2))
segments := []*datapb.SegmentInfo{
{
@ -261,32 +264,32 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() {
collection1.Status = querypb.LoadStatus_Loaded
replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)})
partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1))
suite.checker.meta.CollectionManager.PutCollection(collection1, partition1)
suite.checker.meta.ReplicaManager.Put(replica1)
suite.targetMgr.UpdateCollectionNextTarget(int64(cid1))
suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid1))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1)
suite.checker.meta.ReplicaManager.Put(ctx, replica1)
suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1))
cid2, replicaID2, partitionID2 := 2, 2, 2
collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2))
collection2.Status = querypb.LoadStatus_Loaded
replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)})
partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2))
suite.checker.meta.CollectionManager.PutCollection(collection2, partition2)
suite.checker.meta.ReplicaManager.Put(replica2)
suite.targetMgr.UpdateCollectionNextTarget(int64(cid2))
suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid2))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2)
suite.checker.meta.ReplicaManager.Put(ctx, replica2)
suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2))
mr1 := replica1.CopyForWrite()
mr1.AddRONode(1)
suite.checker.meta.ReplicaManager.Put(mr1.IntoReplica())
suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica())
mr2 := replica2.CopyForWrite()
mr2.AddRONode(1)
suite.checker.meta.ReplicaManager.Put(mr2.IntoReplica())
suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica())
// test stopping balance
idsToBalance := []int64{int64(replicaID1), int64(replicaID2)}
replicasToBalance := suite.checker.replicasToBalance()
replicasToBalance := suite.checker.replicasToBalance(ctx)
suite.ElementsMatch(idsToBalance, replicasToBalance)
// checker check
@ -298,12 +301,13 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() {
To: 2,
}
segPlans = append(segPlans, mockPlan)
suite.balancer.EXPECT().BalanceReplica(mock.Anything).Return(segPlans, chanPlans)
suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans)
tasks := suite.checker.Check(context.TODO())
suite.Len(tasks, 2)
}
func (suite *BalanceCheckerTestSuite) TestTargetNotReady() {
ctx := context.Background()
// set up nodes info, stopping node1
nodeID1, nodeID2 := int64(1), int64(2)
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
@ -317,8 +321,8 @@ func (suite *BalanceCheckerTestSuite) TestTargetNotReady() {
Hostname: "localhost",
}))
suite.nodeMgr.Stopping(nodeID1)
suite.checker.meta.ResourceManager.HandleNodeUp(nodeID1)
suite.checker.meta.ResourceManager.HandleNodeUp(nodeID2)
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1)
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2)
segments := []*datapb.SegmentInfo{
{
@ -341,30 +345,30 @@ func (suite *BalanceCheckerTestSuite) TestTargetNotReady() {
collection1.Status = querypb.LoadStatus_Loaded
replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2})
partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1))
suite.checker.meta.CollectionManager.PutCollection(collection1, partition1)
suite.checker.meta.ReplicaManager.Put(replica1)
suite.targetMgr.UpdateCollectionNextTarget(int64(cid1))
suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid1))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1)
suite.checker.meta.ReplicaManager.Put(ctx, replica1)
suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1))
cid2, replicaID2, partitionID2 := 2, 2, 2
collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2))
collection2.Status = querypb.LoadStatus_Loaded
replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2})
partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2))
suite.checker.meta.CollectionManager.PutCollection(collection2, partition2)
suite.checker.meta.ReplicaManager.Put(replica2)
suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2)
suite.checker.meta.ReplicaManager.Put(ctx, replica2)
mr1 := replica1.CopyForWrite()
mr1.AddRONode(1)
suite.checker.meta.ReplicaManager.Put(mr1.IntoReplica())
suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica())
mr2 := replica2.CopyForWrite()
mr2.AddRONode(1)
suite.checker.meta.ReplicaManager.Put(mr2.IntoReplica())
suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica())
// test stopping balance
idsToBalance := []int64{int64(replicaID1)}
replicasToBalance := suite.checker.replicasToBalance()
replicasToBalance := suite.checker.replicasToBalance(ctx)
suite.ElementsMatch(idsToBalance, replicasToBalance)
}

View File

@ -70,9 +70,9 @@ func (c *ChannelChecker) Description() string {
return "DmChannelChecker checks the lack of DmChannels, or some DmChannels are redundant"
}
func (c *ChannelChecker) readyToCheck(collectionID int64) bool {
metaExist := (c.meta.GetCollection(collectionID) != nil)
targetExist := c.targetMgr.IsNextTargetExist(collectionID) || c.targetMgr.IsCurrentTargetExist(collectionID, common.AllPartitionsID)
func (c *ChannelChecker) readyToCheck(ctx context.Context, collectionID int64) bool {
metaExist := (c.meta.GetCollection(ctx, collectionID) != nil)
targetExist := c.targetMgr.IsNextTargetExist(ctx, collectionID) || c.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID)
return metaExist && targetExist
}
@ -81,11 +81,11 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task {
if !c.IsActive() {
return nil
}
collectionIDs := c.meta.CollectionManager.GetAll()
collectionIDs := c.meta.CollectionManager.GetAll(ctx)
tasks := make([]task.Task, 0)
for _, cid := range collectionIDs {
if c.readyToCheck(cid) {
replicas := c.meta.ReplicaManager.GetByCollection(cid)
if c.readyToCheck(ctx, cid) {
replicas := c.meta.ReplicaManager.GetByCollection(ctx, cid)
for _, r := range replicas {
tasks = append(tasks, c.checkReplica(ctx, r)...)
}
@ -105,7 +105,7 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task {
channelOnQN := c.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(nodeID))
collectionChannels := lo.GroupBy(channelOnQN, func(ch *meta.DmChannel) int64 { return ch.CollectionID })
for collectionID, channels := range collectionChannels {
replica := c.meta.ReplicaManager.GetByCollectionAndNode(collectionID, nodeID)
replica := c.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, nodeID)
if replica == nil {
reduceTasks := c.createChannelReduceTasks(ctx, channels, meta.NilReplica)
task.SetReason("dirty channel exists", reduceTasks...)
@ -119,7 +119,7 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task {
func (c *ChannelChecker) checkReplica(ctx context.Context, replica *meta.Replica) []task.Task {
ret := make([]task.Task, 0)
lacks, redundancies := c.getDmChannelDiff(replica.GetCollectionID(), replica.GetID())
lacks, redundancies := c.getDmChannelDiff(ctx, replica.GetCollectionID(), replica.GetID())
tasks := c.createChannelLoadTask(c.getTraceCtx(ctx, replica.GetCollectionID()), lacks, replica)
task.SetReason("lacks of channel", tasks...)
ret = append(ret, tasks...)
@ -139,10 +139,10 @@ func (c *ChannelChecker) checkReplica(ctx context.Context, replica *meta.Replica
}
// GetDmChannelDiff get channel diff between target and dist
func (c *ChannelChecker) getDmChannelDiff(collectionID int64,
func (c *ChannelChecker) getDmChannelDiff(ctx context.Context, collectionID int64,
replicaID int64,
) (toLoad, toRelease []*meta.DmChannel) {
replica := c.meta.Get(replicaID)
replica := c.meta.Get(ctx, replicaID)
if replica == nil {
log.Info("replica does not exist, skip it")
return
@ -154,8 +154,8 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64,
distMap.Insert(ch.GetChannelName())
}
nextTargetMap := c.targetMgr.GetDmChannelsByCollection(collectionID, meta.NextTarget)
currentTargetMap := c.targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget)
nextTargetMap := c.targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.NextTarget)
currentTargetMap := c.targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.CurrentTarget)
// get channels which exists on dist, but not exist on current and next
for _, ch := range dist {
@ -179,7 +179,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64,
func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int64) []*meta.DmChannel {
log := log.Ctx(ctx).WithRateGroup("ChannelChecker.findRepeatedChannels", 1, 60)
replica := c.meta.Get(replicaID)
replica := c.meta.Get(ctx, replicaID)
ret := make([]*meta.DmChannel, 0)
if replica == nil {
@ -232,7 +232,7 @@ func (c *ChannelChecker) createChannelLoadTask(ctx context.Context, channels []*
if len(rwNodes) == 0 {
rwNodes = replica.GetRWNodes()
}
plan := c.getBalancerFunc().AssignChannel([]*meta.DmChannel{ch}, rwNodes, false)
plan := c.getBalancerFunc().AssignChannel(ctx, []*meta.DmChannel{ch}, rwNodes, false)
plans = append(plans, plan...)
}
@ -264,7 +264,7 @@ func (c *ChannelChecker) createChannelReduceTasks(ctx context.Context, channels
}
func (c *ChannelChecker) getTraceCtx(ctx context.Context, collectionID int64) context.Context {
coll := c.meta.GetCollection(collectionID)
coll := c.meta.GetCollection(ctx, collectionID)
if coll == nil || coll.LoadSpan == nil {
return ctx
}

View File

@ -100,7 +100,7 @@ func (suite *ChannelCheckerTestSuite) setNodeAvailable(nodes ...int64) {
func (suite *ChannelCheckerTestSuite) createMockBalancer() balance.Balance {
balancer := balance.NewMockBalancer(suite.T())
balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(channels []*meta.DmChannel, nodes []int64, _ bool) []balance.ChannelAssignPlan {
balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(ctx context.Context, channels []*meta.DmChannel, nodes []int64, _ bool) []balance.ChannelAssignPlan {
plans := make([]balance.ChannelAssignPlan, 0, len(channels))
for i, c := range channels {
plan := balance.ChannelAssignPlan{
@ -117,16 +117,17 @@ func (suite *ChannelCheckerTestSuite) createMockBalancer() balance.Balance {
}
func (suite *ChannelCheckerTestSuite) TestLoadChannel() {
ctx := context.Background()
checker := suite.checker
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
channels := []*datapb.VchannelInfo{
{
@ -137,7 +138,7 @@ func (suite *ChannelCheckerTestSuite) TestLoadChannel() {
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, nil, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
tasks := checker.Check(context.TODO())
suite.Len(tasks, 1)
@ -151,10 +152,11 @@ func (suite *ChannelCheckerTestSuite) TestLoadChannel() {
}
func (suite *ChannelCheckerTestSuite) TestReduceChannel() {
ctx := context.Background()
checker := suite.checker
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1}))
channels := []*datapb.VchannelInfo{
{
@ -164,8 +166,8 @@ func (suite *ChannelCheckerTestSuite) TestReduceChannel() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, nil, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1))
checker.dist.ChannelDistManager.Update(1, utils.CreateTestChannel(1, 1, 1, "test-insert-channel1"))
checker.dist.LeaderViewManager.Update(1, &meta.LeaderView{ID: 1, Channel: "test-insert-channel1"})
@ -184,11 +186,12 @@ func (suite *ChannelCheckerTestSuite) TestReduceChannel() {
}
func (suite *ChannelCheckerTestSuite) TestRepeatedChannels() {
ctx := context.Background()
checker := suite.checker
err := checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
err := checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
suite.NoError(err)
err = checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
err = checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
suite.NoError(err)
segments := []*datapb.SegmentInfo{
@ -206,7 +209,7 @@ func (suite *ChannelCheckerTestSuite) TestRepeatedChannels() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.dist.ChannelDistManager.Update(1, utils.CreateTestChannel(1, 1, 1, "test-insert-channel"))
checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 2, "test-insert-channel"))
@ -228,11 +231,12 @@ func (suite *ChannelCheckerTestSuite) TestRepeatedChannels() {
}
func (suite *ChannelCheckerTestSuite) TestReleaseDirtyChannels() {
ctx := context.Background()
checker := suite.checker
err := checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
err := checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
suite.NoError(err)
err = checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1}))
err = checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1}))
suite.NoError(err)
segments := []*datapb.SegmentInfo{
@ -261,7 +265,7 @@ func (suite *ChannelCheckerTestSuite) TestReleaseDirtyChannels() {
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.dist.ChannelDistManager.Update(1, utils.CreateTestChannel(1, 1, 2, "test-insert-channel"))
checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 2, "test-insert-channel"))
checker.dist.LeaderViewManager.Update(1, &meta.LeaderView{ID: 1, Channel: "test-insert-channel"})

View File

@ -17,6 +17,7 @@
package checkers
import (
"context"
"testing"
"time"
@ -85,10 +86,11 @@ func (suite *CheckerControllerSuite) SetupTest() {
}
func (suite *CheckerControllerSuite) TestBasic() {
ctx := context.Background()
// set meta
suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
suite.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
suite.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -99,8 +101,8 @@ func (suite *CheckerControllerSuite) TestBasic() {
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.HandleNodeUp(1)
suite.meta.ResourceManager.HandleNodeUp(2)
suite.meta.ResourceManager.HandleNodeUp(ctx, 1)
suite.meta.ResourceManager.HandleNodeUp(ctx, 2)
// set target
channels := []*datapb.VchannelInfo{
@ -119,7 +121,7 @@ func (suite *CheckerControllerSuite) TestBasic() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
suite.targetManager.UpdateCollectionNextTarget(int64(1))
suite.targetManager.UpdateCollectionNextTarget(ctx, int64(1))
// set dist
suite.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
@ -134,11 +136,11 @@ func (suite *CheckerControllerSuite) TestBasic() {
assignSegCounter := atomic.NewInt32(0)
assingChanCounter := atomic.NewInt32(0)
suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64, i4 bool) []balance.SegmentAssignPlan {
suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, i1 int64, s []*meta.Segment, i2 []int64, i4 bool) []balance.SegmentAssignPlan {
assignSegCounter.Inc()
return nil
})
suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64, _ bool) []balance.ChannelAssignPlan {
suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dc []*meta.DmChannel, i []int64, _ bool) []balance.ChannelAssignPlan {
assingChanCounter.Inc()
return nil
})

View File

@ -79,7 +79,7 @@ func (c *IndexChecker) Check(ctx context.Context) []task.Task {
if !c.IsActive() {
return nil
}
collectionIDs := c.meta.CollectionManager.GetAll()
collectionIDs := c.meta.CollectionManager.GetAll(ctx)
var tasks []task.Task
for _, collectionID := range collectionIDs {
@ -89,12 +89,12 @@ func (c *IndexChecker) Check(ctx context.Context) []task.Task {
continue
}
collection := c.meta.CollectionManager.GetCollection(collectionID)
collection := c.meta.CollectionManager.GetCollection(ctx, collectionID)
if collection == nil {
log.Warn("collection released during check index", zap.Int64("collection", collectionID))
continue
}
replicas := c.meta.ReplicaManager.GetByCollection(collectionID)
replicas := c.meta.ReplicaManager.GetByCollection(ctx, collectionID)
for _, replica := range replicas {
tasks = append(tasks, c.checkReplica(ctx, collection, replica, indexInfos)...)
}
@ -121,7 +121,7 @@ func (c *IndexChecker) checkReplica(ctx context.Context, collection *meta.Collec
}
// skip update index for l0 segment
segmentInTarget := c.targetMgr.GetSealedSegment(collection.GetCollectionID(), segment.GetID(), meta.CurrentTargetFirst)
segmentInTarget := c.targetMgr.GetSealedSegment(ctx, collection.GetCollectionID(), segment.GetID(), meta.CurrentTargetFirst)
if segmentInTarget == nil || segmentInTarget.GetLevel() == datapb.SegmentLevel_L0 {
continue
}

View File

@ -78,7 +78,7 @@ func (suite *IndexCheckerSuite) SetupTest() {
suite.targetMgr = meta.NewMockTargetManager(suite.T())
suite.checker = NewIndexChecker(suite.meta, distManager, suite.broker, suite.nodeMgr, suite.targetMgr)
suite.targetMgr.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(cid, sid int64, i3 int32) *datapb.SegmentInfo {
suite.targetMgr.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cid, sid int64, i3 int32) *datapb.SegmentInfo {
return &datapb.SegmentInfo{
ID: sid,
Level: datapb.SegmentLevel_L1,
@ -92,12 +92,13 @@ func (suite *IndexCheckerSuite) TearDownTest() {
func (suite *IndexCheckerSuite) TestLoadIndex() {
checker := suite.checker
ctx := context.Background()
// meta
coll := utils.CreateTestCollection(1, 1)
coll.FieldIndexID = map[int64]int64{101: 1000}
checker.meta.CollectionManager.PutCollection(coll)
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(200, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, coll)
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(200, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -108,8 +109,8 @@ func (suite *IndexCheckerSuite) TestLoadIndex() {
Address: "localhost",
Hostname: "localhost",
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(2)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 2)
// dist
checker.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 2, 1, 1, "test-insert-channel"))
@ -147,8 +148,8 @@ func (suite *IndexCheckerSuite) TestLoadIndex() {
// test skip load index for read only node
suite.nodeMgr.Stopping(1)
suite.nodeMgr.Stopping(2)
suite.meta.ResourceManager.HandleNodeStopping(1)
suite.meta.ResourceManager.HandleNodeStopping(2)
suite.meta.ResourceManager.HandleNodeStopping(ctx, 1)
suite.meta.ResourceManager.HandleNodeStopping(ctx, 2)
utils.RecoverAllCollection(suite.meta)
tasks = checker.Check(context.Background())
suite.Require().Len(tasks, 0)
@ -156,12 +157,13 @@ func (suite *IndexCheckerSuite) TestLoadIndex() {
func (suite *IndexCheckerSuite) TestIndexInfoNotMatch() {
checker := suite.checker
ctx := context.Background()
// meta
coll := utils.CreateTestCollection(1, 1)
coll.FieldIndexID = map[int64]int64{101: 1000}
checker.meta.CollectionManager.PutCollection(coll)
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(200, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, coll)
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(200, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -172,8 +174,8 @@ func (suite *IndexCheckerSuite) TestIndexInfoNotMatch() {
Address: "localhost",
Hostname: "localhost",
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(2)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 2)
// dist
checker.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 2, 1, 1, "test-insert-channel"))
@ -216,12 +218,13 @@ func (suite *IndexCheckerSuite) TestIndexInfoNotMatch() {
func (suite *IndexCheckerSuite) TestGetIndexInfoFailed() {
checker := suite.checker
ctx := context.Background()
// meta
coll := utils.CreateTestCollection(1, 1)
coll.FieldIndexID = map[int64]int64{101: 1000}
checker.meta.CollectionManager.PutCollection(coll)
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(200, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, coll)
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(200, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -232,8 +235,8 @@ func (suite *IndexCheckerSuite) TestGetIndexInfoFailed() {
Address: "localhost",
Hostname: "localhost",
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(2)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 2)
// dist
checker.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 2, 1, 1, "test-insert-channel"))
@ -255,12 +258,13 @@ func (suite *IndexCheckerSuite) TestGetIndexInfoFailed() {
func (suite *IndexCheckerSuite) TestCreateNewIndex() {
checker := suite.checker
ctx := context.Background()
// meta
coll := utils.CreateTestCollection(1, 1)
coll.FieldIndexID = map[int64]int64{101: 1000}
checker.meta.CollectionManager.PutCollection(coll)
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(200, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, coll)
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(200, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -271,8 +275,8 @@ func (suite *IndexCheckerSuite) TestCreateNewIndex() {
Address: "localhost",
Hostname: "localhost",
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(2)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 2)
// dist
segment := utils.CreateTestSegment(1, 1, 2, 1, 1, "test-insert-channel")

View File

@ -65,9 +65,9 @@ func (c *LeaderChecker) Description() string {
return "LeaderChecker checks the difference of leader view between dist, and try to correct it"
}
func (c *LeaderChecker) readyToCheck(collectionID int64) bool {
metaExist := (c.meta.GetCollection(collectionID) != nil)
targetExist := c.target.IsNextTargetExist(collectionID) || c.target.IsCurrentTargetExist(collectionID, common.AllPartitionsID)
func (c *LeaderChecker) readyToCheck(ctx context.Context, collectionID int64) bool {
metaExist := (c.meta.GetCollection(ctx, collectionID) != nil)
targetExist := c.target.IsNextTargetExist(ctx, collectionID) || c.target.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID)
return metaExist && targetExist
}
@ -77,20 +77,20 @@ func (c *LeaderChecker) Check(ctx context.Context) []task.Task {
return nil
}
collectionIDs := c.meta.CollectionManager.GetAll()
collectionIDs := c.meta.CollectionManager.GetAll(ctx)
tasks := make([]task.Task, 0)
for _, collectionID := range collectionIDs {
if !c.readyToCheck(collectionID) {
if !c.readyToCheck(ctx, collectionID) {
continue
}
collection := c.meta.CollectionManager.GetCollection(collectionID)
collection := c.meta.CollectionManager.GetCollection(ctx, collectionID)
if collection == nil {
log.Warn("collection released during check leader", zap.Int64("collection", collectionID))
continue
}
replicas := c.meta.ReplicaManager.GetByCollection(collectionID)
replicas := c.meta.ReplicaManager.GetByCollection(ctx, collectionID)
for _, replica := range replicas {
for _, node := range replica.GetRWNodes() {
leaderViews := c.dist.LeaderViewManager.GetByFilter(meta.WithCollectionID2LeaderView(replica.GetCollectionID()), meta.WithNodeID2LeaderView(node))
@ -109,7 +109,7 @@ func (c *LeaderChecker) Check(ctx context.Context) []task.Task {
func (c *LeaderChecker) findNeedSyncPartitionStats(ctx context.Context, replica *meta.Replica, leaderView *meta.LeaderView, nodeID int64) []task.Task {
ret := make([]task.Task, 0)
curDmlChannel := c.target.GetDmChannel(leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget)
curDmlChannel := c.target.GetDmChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget)
if curDmlChannel == nil {
return ret
}
@ -163,7 +163,7 @@ func (c *LeaderChecker) findNeedLoadedSegments(ctx context.Context, replica *met
latestNodeDist := utils.FindMaxVersionSegments(dist)
for _, s := range latestNodeDist {
segment := c.target.GetSealedSegment(leaderView.CollectionID, s.GetID(), meta.CurrentTargetFirst)
segment := c.target.GetSealedSegment(ctx, leaderView.CollectionID, s.GetID(), meta.CurrentTargetFirst)
existInTarget := segment != nil
isL0Segment := existInTarget && segment.GetLevel() == datapb.SegmentLevel_L0
// shouldn't set l0 segment location to delegator. l0 segment should be reload in delegator
@ -213,7 +213,7 @@ func (c *LeaderChecker) findNeedRemovedSegments(ctx context.Context, replica *me
for sid, s := range leaderView.Segments {
_, ok := distMap[sid]
segment := c.target.GetSealedSegment(leaderView.CollectionID, sid, meta.CurrentTargetFirst)
segment := c.target.GetSealedSegment(ctx, leaderView.CollectionID, sid, meta.CurrentTargetFirst)
existInTarget := segment != nil
isL0Segment := existInTarget && segment.GetLevel() == datapb.SegmentLevel_L0
if ok || existInTarget || isL0Segment {

View File

@ -83,10 +83,11 @@ func (suite *LeaderCheckerTestSuite) TearDownTest() {
}
func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() {
ctx := context.Background()
observer := suite.checker
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentInfo{
{
ID: 1,
@ -119,13 +120,13 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() {
}))
// test leader view lack of segments
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
loadVersion := time.Now().UnixMilli()
observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, loadVersion, "test-insert-channel"))
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(2, view)
tasks = suite.checker.Check(context.TODO())
@ -140,7 +141,7 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() {
// test segment's version in leader view doesn't match segment's version in dist
observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, 1, "test-insert-channel"))
view = utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
view.Segments[1] = &querypb.SegmentDist{
NodeID: 0,
Version: time.Now().UnixMilli() - 1,
@ -168,23 +169,24 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() {
suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
// mock l0 segment exist on non delegator node, doesn't set to leader view
observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 1, loadVersion, "test-insert-channel"))
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
view = utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(2, view)
tasks = suite.checker.Check(context.TODO())
suite.Len(tasks, 0)
}
func (suite *LeaderCheckerTestSuite) TestActivation() {
ctx := context.Background()
observer := suite.checker
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentInfo{
{
ID: 1,
@ -211,12 +213,12 @@ func (suite *LeaderCheckerTestSuite) TestActivation() {
Address: "localhost",
Hostname: "localhost",
}))
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, 1, "test-insert-channel"))
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(2, view)
suite.checker.Deactivate()
@ -234,11 +236,12 @@ func (suite *LeaderCheckerTestSuite) TestActivation() {
}
func (suite *LeaderCheckerTestSuite) TestStoppingNode() {
ctx := context.Background()
observer := suite.checker
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
replica := utils.CreateTestReplica(1, 1, []int64{1, 2})
observer.meta.ReplicaManager.Put(replica)
observer.meta.ReplicaManager.Put(ctx, replica)
segments := []*datapb.SegmentInfo{
{
ID: 1,
@ -254,27 +257,28 @@ func (suite *LeaderCheckerTestSuite) TestStoppingNode() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, 1, "test-insert-channel"))
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(2, view)
mutableReplica := replica.CopyForWrite()
mutableReplica.AddRONode(2)
observer.meta.ReplicaManager.Put(mutableReplica.IntoReplica())
observer.meta.ReplicaManager.Put(ctx, mutableReplica.IntoReplica())
tasks := suite.checker.Check(context.TODO())
suite.Len(tasks, 0)
}
func (suite *LeaderCheckerTestSuite) TestIgnoreSyncLoadedSegments() {
ctx := context.Background()
observer := suite.checker
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentInfo{
{
ID: 1,
@ -301,14 +305,14 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncLoadedSegments() {
Address: "localhost",
Hostname: "localhost",
}))
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, 1, "test-insert-channel"),
utils.CreateTestSegment(1, 1, 2, 2, 1, "test-insert-channel"))
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(2, view)
tasks := suite.checker.Check(context.TODO())
@ -322,11 +326,12 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncLoadedSegments() {
}
func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegmentsWithReplicas() {
ctx := context.Background()
observer := suite.checker
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 2))
observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(2, 1, []int64{3, 4}))
observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 2))
observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(2, 1, []int64{3, 4}))
segments := []*datapb.SegmentInfo{
{
ID: 1,
@ -354,17 +359,17 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegmentsWithReplicas() {
Hostname: "localhost",
}))
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 1, 0, "test-insert-channel"))
observer.dist.SegmentDistManager.Update(4, utils.CreateTestSegment(1, 1, 1, 4, 0, "test-insert-channel"))
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
observer.dist.ChannelDistManager.Update(4, utils.CreateTestChannel(1, 4, 2, "test-insert-channel"))
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(2, view)
view2 := utils.CreateTestLeaderView(4, 1, "test-insert-channel", map[int64]int64{1: 4}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(4, view2)
tasks := suite.checker.Check(context.TODO())
@ -379,10 +384,11 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegmentsWithReplicas() {
}
func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() {
ctx := context.Background()
observer := suite.checker
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
channels := []*datapb.VchannelInfo{
{
@ -393,12 +399,12 @@ func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() {
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, nil, nil)
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 1}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(2, view)
tasks := suite.checker.Check(context.TODO())
@ -425,12 +431,12 @@ func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() {
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
view = utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 1}, map[int64]*meta.Segment{})
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(2, view)
tasks = suite.checker.Check(context.TODO())
@ -438,10 +444,11 @@ func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() {
}
func (suite *LeaderCheckerTestSuite) TestIgnoreSyncRemovedSegments() {
ctx := context.Background()
observer := suite.checker
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentInfo{
{
@ -458,7 +465,7 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncRemovedSegments() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
observer.dist.LeaderViewManager.Update(2, utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 2, 2: 2}, map[int64]*meta.Segment{}))
@ -475,12 +482,13 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncRemovedSegments() {
}
func (suite *LeaderCheckerTestSuite) TestUpdatePartitionStats() {
ctx := context.Background()
testChannel := "test-insert-channel"
leaderID := int64(2)
observer := suite.checker
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentInfo{
{
ID: 1,
@ -506,8 +514,8 @@ func (suite *LeaderCheckerTestSuite) TestUpdatePartitionStats() {
suite.Len(tasks, 0)
// try to update cur/next target
observer.target.UpdateCollectionNextTarget(int64(1))
observer.target.UpdateCollectionCurrentTarget(1)
observer.target.UpdateCollectionNextTarget(ctx, int64(1))
observer.target.UpdateCollectionCurrentTarget(ctx, 1)
loadVersion := time.Now().UnixMilli()
observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 2, 1, loadVersion, testChannel))
observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, testChannel))
@ -516,7 +524,7 @@ func (suite *LeaderCheckerTestSuite) TestUpdatePartitionStats() {
1: 100,
}
// current partition stat version in leader view is version100 for partition1
view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget)
view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget)
observer.dist.LeaderViewManager.Update(leaderID, view)
tasks = suite.checker.Check(context.TODO())

View File

@ -75,9 +75,9 @@ func (c *SegmentChecker) Description() string {
return "SegmentChecker checks the lack of segments, or some segments are redundant"
}
func (c *SegmentChecker) readyToCheck(collectionID int64) bool {
metaExist := (c.meta.GetCollection(collectionID) != nil)
targetExist := c.targetMgr.IsNextTargetExist(collectionID) || c.targetMgr.IsCurrentTargetExist(collectionID, common.AllPartitionsID)
func (c *SegmentChecker) readyToCheck(ctx context.Context, collectionID int64) bool {
metaExist := (c.meta.GetCollection(ctx, collectionID) != nil)
targetExist := c.targetMgr.IsNextTargetExist(ctx, collectionID) || c.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID)
return metaExist && targetExist
}
@ -86,11 +86,11 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task {
if !c.IsActive() {
return nil
}
collectionIDs := c.meta.CollectionManager.GetAll()
collectionIDs := c.meta.CollectionManager.GetAll(ctx)
results := make([]task.Task, 0)
for _, cid := range collectionIDs {
if c.readyToCheck(cid) {
replicas := c.meta.ReplicaManager.GetByCollection(cid)
if c.readyToCheck(ctx, cid) {
replicas := c.meta.ReplicaManager.GetByCollection(ctx, cid)
for _, r := range replicas {
results = append(results, c.checkReplica(ctx, r)...)
}
@ -111,7 +111,7 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task {
segmentsOnQN := c.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(nodeID))
collectionSegments := lo.GroupBy(segmentsOnQN, func(segment *meta.Segment) int64 { return segment.GetCollectionID() })
for collectionID, segments := range collectionSegments {
replica := c.meta.ReplicaManager.GetByCollectionAndNode(collectionID, nodeID)
replica := c.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, nodeID)
if replica == nil {
reduceTasks := c.createSegmentReduceTasks(ctx, segments, meta.NilReplica, querypb.DataScope_Historical)
task.SetReason("dirty segment exists", reduceTasks...)
@ -128,21 +128,21 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica
ret := make([]task.Task, 0)
// compare with targets to find the lack and redundancy of segments
lacks, redundancies := c.getSealedSegmentDiff(replica.GetCollectionID(), replica.GetID())
lacks, redundancies := c.getSealedSegmentDiff(ctx, replica.GetCollectionID(), replica.GetID())
// loadCtx := trace.ContextWithSpan(context.Background(), c.meta.GetCollection(replica.CollectionID).LoadSpan)
tasks := c.createSegmentLoadTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), lacks, replica)
task.SetReason("lacks of segment", tasks...)
task.SetPriority(task.TaskPriorityNormal, tasks...)
ret = append(ret, tasks...)
redundancies = c.filterSegmentInUse(replica, redundancies)
redundancies = c.filterSegmentInUse(ctx, replica, redundancies)
tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Historical)
task.SetReason("segment not exists in target", tasks...)
task.SetPriority(task.TaskPriorityNormal, tasks...)
ret = append(ret, tasks...)
// compare inner dists to find repeated loaded segments
redundancies = c.findRepeatedSealedSegments(replica.GetID())
redundancies = c.findRepeatedSealedSegments(ctx, replica.GetID())
redundancies = c.filterExistedOnLeader(replica, redundancies)
tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Historical)
task.SetReason("redundancies of segment", tasks...)
@ -151,7 +151,7 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica
ret = append(ret, tasks...)
// compare with target to find the lack and redundancy of segments
_, redundancies = c.getGrowingSegmentDiff(replica.GetCollectionID(), replica.GetID())
_, redundancies = c.getGrowingSegmentDiff(ctx, replica.GetCollectionID(), replica.GetID())
tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Streaming)
task.SetReason("streaming segment not exists in target", tasks...)
task.SetPriority(task.TaskPriorityNormal, tasks...)
@ -161,10 +161,10 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica
}
// GetGrowingSegmentDiff get streaming segment diff between leader view and target
func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64,
func (c *SegmentChecker) getGrowingSegmentDiff(ctx context.Context, collectionID int64,
replicaID int64,
) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) {
replica := c.meta.Get(replicaID)
replica := c.meta.Get(ctx, replicaID)
if replica == nil {
log.Info("replica does not exist, skip it")
return
@ -181,7 +181,7 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64,
log.Info("leaderView is not ready, skip", zap.String("channelName", channelName), zap.Int64("node", node))
continue
}
targetVersion := c.targetMgr.GetCollectionTargetVersion(collectionID, meta.CurrentTarget)
targetVersion := c.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.CurrentTarget)
if view.TargetVersion != targetVersion {
// before shard delegator update it's readable version, skip release segment
log.RatedInfo(20, "before shard delegator update it's readable version, skip release segment",
@ -193,10 +193,10 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64,
continue
}
nextTargetExist := c.targetMgr.IsNextTargetExist(collectionID)
nextTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(collectionID, meta.NextTarget)
currentTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(collectionID, meta.CurrentTarget)
currentTargetChannelMap := c.targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget)
nextTargetExist := c.targetMgr.IsNextTargetExist(ctx, collectionID)
nextTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(ctx, collectionID, meta.NextTarget)
currentTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(ctx, collectionID, meta.CurrentTarget)
currentTargetChannelMap := c.targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.CurrentTarget)
// get segment which exist on leader view, but not on current target and next target
for _, segment := range view.GrowingSegments {
@ -227,10 +227,11 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64,
// GetSealedSegmentDiff get historical segment diff between target and dist
func (c *SegmentChecker) getSealedSegmentDiff(
ctx context.Context,
collectionID int64,
replicaID int64,
) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) {
replica := c.meta.Get(replicaID)
replica := c.meta.Get(ctx, replicaID)
if replica == nil {
log.Info("replica does not exist, skip it")
return
@ -278,9 +279,9 @@ func (c *SegmentChecker) getSealedSegmentDiff(
return !existInDist
}
nextTargetExist := c.targetMgr.IsNextTargetExist(collectionID)
nextTargetMap := c.targetMgr.GetSealedSegmentsByCollection(collectionID, meta.NextTarget)
currentTargetMap := c.targetMgr.GetSealedSegmentsByCollection(collectionID, meta.CurrentTarget)
nextTargetExist := c.targetMgr.IsNextTargetExist(ctx, collectionID)
nextTargetMap := c.targetMgr.GetSealedSegmentsByCollection(ctx, collectionID, meta.NextTarget)
currentTargetMap := c.targetMgr.GetSealedSegmentsByCollection(ctx, collectionID, meta.CurrentTarget)
// Segment which exist on next target, but not on dist
for _, segment := range nextTargetMap {
@ -325,9 +326,9 @@ func (c *SegmentChecker) getSealedSegmentDiff(
return
}
func (c *SegmentChecker) findRepeatedSealedSegments(replicaID int64) []*meta.Segment {
func (c *SegmentChecker) findRepeatedSealedSegments(ctx context.Context, replicaID int64) []*meta.Segment {
segments := make([]*meta.Segment, 0)
replica := c.meta.Get(replicaID)
replica := c.meta.Get(ctx, replicaID)
if replica == nil {
log.Info("replica does not exist, skip it")
return segments
@ -336,7 +337,7 @@ func (c *SegmentChecker) findRepeatedSealedSegments(replicaID int64) []*meta.Seg
versions := make(map[int64]*meta.Segment)
for _, s := range dist {
// l0 segment should be release with channel together
segment := c.targetMgr.GetSealedSegment(s.GetCollectionID(), s.GetID(), meta.CurrentTargetFirst)
segment := c.targetMgr.GetSealedSegment(ctx, s.GetCollectionID(), s.GetID(), meta.CurrentTargetFirst)
existInTarget := segment != nil
isL0Segment := existInTarget && segment.GetLevel() == datapb.SegmentLevel_L0
if isL0Segment {
@ -378,7 +379,7 @@ func (c *SegmentChecker) filterExistedOnLeader(replica *meta.Replica, segments [
return filtered
}
func (c *SegmentChecker) filterSegmentInUse(replica *meta.Replica, segments []*meta.Segment) []*meta.Segment {
func (c *SegmentChecker) filterSegmentInUse(ctx context.Context, replica *meta.Replica, segments []*meta.Segment) []*meta.Segment {
filtered := make([]*meta.Segment, 0, len(segments))
for _, s := range segments {
leaderID, ok := c.dist.ChannelDistManager.GetShardLeader(replica, s.GetInsertChannel())
@ -387,8 +388,8 @@ func (c *SegmentChecker) filterSegmentInUse(replica *meta.Replica, segments []*m
}
view := c.dist.LeaderViewManager.GetLeaderShardView(leaderID, s.GetInsertChannel())
currentTargetVersion := c.targetMgr.GetCollectionTargetVersion(s.CollectionID, meta.CurrentTarget)
partition := c.meta.CollectionManager.GetPartition(s.PartitionID)
currentTargetVersion := c.targetMgr.GetCollectionTargetVersion(ctx, s.CollectionID, meta.CurrentTarget)
partition := c.meta.CollectionManager.GetPartition(ctx, s.PartitionID)
// if delegator has valid target version, and before it update to latest readable version, skip release it's sealed segment
// Notice: if syncTargetVersion stuck, segment on delegator won't be released
@ -435,7 +436,7 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments []
SegmentInfo: s,
}
})
shardPlans := c.getBalancerFunc().AssignSegment(replica.GetCollectionID(), segmentInfos, rwNodes, false)
shardPlans := c.getBalancerFunc().AssignSegment(ctx, replica.GetCollectionID(), segmentInfos, rwNodes, false)
for i := range shardPlans {
shardPlans[i].Replica = replica
}
@ -474,7 +475,7 @@ func (c *SegmentChecker) createSegmentReduceTasks(ctx context.Context, segments
}
func (c *SegmentChecker) getTraceCtx(ctx context.Context, collectionID int64) context.Context {
coll := c.meta.GetCollection(collectionID)
coll := c.meta.GetCollection(ctx, collectionID)
if coll == nil || coll.LoadSpan == nil {
return ctx
}

View File

@ -88,7 +88,7 @@ func (suite *SegmentCheckerTestSuite) TearDownTest() {
func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance {
balancer := balance.NewMockBalancer(suite.T())
balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64, _ bool) []balance.SegmentAssignPlan {
balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, _ bool) []balance.SegmentAssignPlan {
plans := make([]balance.SegmentAssignPlan, 0, len(segments))
for i, s := range segments {
plan := balance.SegmentAssignPlan{
@ -105,11 +105,12 @@ func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance {
}
func (suite *SegmentCheckerTestSuite) TestLoadSegments() {
ctx := context.Background()
checker := suite.checker
// set meta
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -120,8 +121,8 @@ func (suite *SegmentCheckerTestSuite) TestLoadSegments() {
Address: "localhost",
Hostname: "localhost",
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(2)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 2)
// set target
segments := []*datapb.SegmentInfo{
@ -141,7 +142,7 @@ func (suite *SegmentCheckerTestSuite) TestLoadSegments() {
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
// set dist
checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
@ -170,11 +171,12 @@ func (suite *SegmentCheckerTestSuite) TestLoadSegments() {
}
func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() {
ctx := context.Background()
checker := suite.checker
// set meta
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -187,8 +189,8 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() {
Hostname: "localhost",
Version: common.Version,
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(2)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 2)
// set target
segments := []*datapb.SegmentInfo{
@ -209,7 +211,7 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() {
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
// set dist
checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
@ -227,7 +229,7 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() {
suite.EqualValues(2, action.Node())
suite.Equal(tasks[0].Priority(), task.TaskPriorityNormal)
checker.targetMgr.UpdateCollectionCurrentTarget(int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1))
// test load l0 segments in current target
tasks = checker.Check(context.TODO())
suite.Len(tasks, 1)
@ -241,7 +243,7 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() {
suite.Equal(tasks[0].Priority(), task.TaskPriorityNormal)
// seg l0 segment exist on a non delegator node
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 1, 1, "test-insert-channel"))
// test load l0 segments to delegator
tasks = checker.Check(context.TODO())
@ -257,11 +259,12 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() {
}
func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() {
ctx := context.Background()
checker := suite.checker
// set meta
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -272,8 +275,8 @@ func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() {
Address: "localhost",
Hostname: "localhost",
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(2)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 2)
// set target
segments := []*datapb.SegmentInfo{
@ -294,8 +297,8 @@ func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() {
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1))
// set dist
checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
@ -315,9 +318,9 @@ func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() {
suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, nil, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
tasks = checker.Check(context.TODO())
suite.Len(tasks, 1)
@ -332,11 +335,12 @@ func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() {
}
func (suite *SegmentCheckerTestSuite) TestSkipLoadSegments() {
ctx := context.Background()
checker := suite.checker
// set meta
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -347,8 +351,8 @@ func (suite *SegmentCheckerTestSuite) TestSkipLoadSegments() {
Address: "localhost",
Hostname: "localhost",
}))
checker.meta.ResourceManager.HandleNodeUp(1)
checker.meta.ResourceManager.HandleNodeUp(2)
checker.meta.ResourceManager.HandleNodeUp(ctx, 1)
checker.meta.ResourceManager.HandleNodeUp(ctx, 2)
// set target
segments := []*datapb.SegmentInfo{
@ -368,7 +372,7 @@ func (suite *SegmentCheckerTestSuite) TestSkipLoadSegments() {
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
// when channel not subscribed, segment_checker won't generate load segment task
tasks := checker.Check(context.TODO())
@ -376,11 +380,12 @@ func (suite *SegmentCheckerTestSuite) TestSkipLoadSegments() {
}
func (suite *SegmentCheckerTestSuite) TestReleaseSegments() {
ctx := context.Background()
checker := suite.checker
// set meta
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
// set target
channels := []*datapb.VchannelInfo{
@ -391,7 +396,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseSegments() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, nil, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
// set dist
checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
@ -410,11 +415,12 @@ func (suite *SegmentCheckerTestSuite) TestReleaseSegments() {
}
func (suite *SegmentCheckerTestSuite) TestReleaseRepeatedSegments() {
ctx := context.Background()
checker := suite.checker
// set meta
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
// set target
segments := []*datapb.SegmentInfo{
@ -432,7 +438,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseRepeatedSegments() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
// set dist
checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
@ -458,11 +464,12 @@ func (suite *SegmentCheckerTestSuite) TestReleaseRepeatedSegments() {
}
func (suite *SegmentCheckerTestSuite) TestReleaseDirtySegments() {
ctx := context.Background()
checker := suite.checker
// set meta
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -490,7 +497,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseDirtySegments() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
// set dist
checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel"))
@ -510,15 +517,16 @@ func (suite *SegmentCheckerTestSuite) TestReleaseDirtySegments() {
}
func (suite *SegmentCheckerTestSuite) TestSkipReleaseSealedSegments() {
ctx := context.Background()
checker := suite.checker
collectionID := int64(1)
partitionID := int64(1)
// set meta
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(collectionID, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, collectionID, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(collectionID, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, collectionID, []int64{1, 2}))
// set target
channels := []*datapb.VchannelInfo{
@ -531,10 +539,10 @@ func (suite *SegmentCheckerTestSuite) TestSkipReleaseSealedSegments() {
segments := []*datapb.SegmentInfo{}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(collectionID)
checker.targetMgr.UpdateCollectionCurrentTarget(collectionID)
checker.targetMgr.UpdateCollectionNextTarget(collectionID)
readableVersion := checker.targetMgr.GetCollectionTargetVersion(collectionID, meta.CurrentTarget)
checker.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
checker.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID)
checker.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
readableVersion := checker.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.CurrentTarget)
// test less target version exist on leader,meet segment doesn't exit in target, segment should be released
nodeID := int64(2)
@ -579,12 +587,13 @@ func (suite *SegmentCheckerTestSuite) TestSkipReleaseSealedSegments() {
}
func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() {
ctx := context.Background()
checker := suite.checker
// segment3 is compacted from segment2, and node2 has growing segments 2 and 3. checker should generate
// 2 tasks to reduce segment 2 and 3.
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentInfo{
{
@ -602,9 +611,9 @@ func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
growingSegments := make(map[int64]*meta.Segment)
growingSegments[2] = utils.CreateTestSegment(1, 1, 2, 2, 0, "test-insert-channel")
@ -618,7 +627,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() {
dmChannel.UnflushedSegmentIds = []int64{2, 3}
checker.dist.ChannelDistManager.Update(2, dmChannel)
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 2}, growingSegments)
view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(int64(1), meta.CurrentTarget)
view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(ctx, int64(1), meta.CurrentTarget)
checker.dist.LeaderViewManager.Update(2, view)
checker.dist.SegmentDistManager.Update(2, utils.CreateTestSegment(1, 1, 3, 2, 2, "test-insert-channel"))
@ -647,11 +656,12 @@ func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() {
}
func (suite *SegmentCheckerTestSuite) TestReleaseCompactedGrowingSegments() {
ctx := context.Background()
checker := suite.checker
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentInfo{
{
@ -670,9 +680,9 @@ func (suite *SegmentCheckerTestSuite) TestReleaseCompactedGrowingSegments() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
growingSegments := make(map[int64]*meta.Segment)
// segment start pos after chekcpoint
@ -683,7 +693,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseCompactedGrowingSegments() {
dmChannel.UnflushedSegmentIds = []int64{2, 3}
checker.dist.ChannelDistManager.Update(2, dmChannel)
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 2}, growingSegments)
view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(int64(1), meta.CurrentTarget)
view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(ctx, int64(1), meta.CurrentTarget)
checker.dist.LeaderViewManager.Update(2, view)
checker.dist.SegmentDistManager.Update(2, utils.CreateTestSegment(1, 1, 3, 2, 2, "test-insert-channel"))
@ -703,10 +713,11 @@ func (suite *SegmentCheckerTestSuite) TestReleaseCompactedGrowingSegments() {
}
func (suite *SegmentCheckerTestSuite) TestSkipReleaseGrowingSegments() {
ctx := context.Background()
checker := suite.checker
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1))
checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentInfo{}
channels := []*datapb.VchannelInfo{
@ -718,9 +729,9 @@ func (suite *SegmentCheckerTestSuite) TestSkipReleaseGrowingSegments() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1))
checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1))
growingSegments := make(map[int64]*meta.Segment)
growingSegments[2] = utils.CreateTestSegment(1, 1, 2, 2, 0, "test-insert-channel")
@ -730,13 +741,13 @@ func (suite *SegmentCheckerTestSuite) TestSkipReleaseGrowingSegments() {
dmChannel.UnflushedSegmentIds = []int64{2, 3}
checker.dist.ChannelDistManager.Update(2, dmChannel)
view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, growingSegments)
view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(int64(1), meta.CurrentTarget) - 1
view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(ctx, int64(1), meta.CurrentTarget) - 1
checker.dist.LeaderViewManager.Update(2, view)
tasks := checker.Check(context.TODO())
suite.Len(tasks, 0)
view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(int64(1), meta.CurrentTarget)
view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(ctx, int64(1), meta.CurrentTarget)
checker.dist.LeaderViewManager.Update(2, view)
tasks = checker.Check(context.TODO())
suite.Len(tasks, 1)

View File

@ -79,7 +79,7 @@ func (dc *ControllerImpl) SyncAll(ctx context.Context) {
if err != nil {
log.Warn("SyncAll come across err when getting data distribution", zap.Error(err))
} else {
handler.handleDistResp(resp, true)
handler.handleDistResp(ctx, resp, true)
}
}(h)
}

View File

@ -103,11 +103,11 @@ func (dh *distHandler) pullDist(ctx context.Context, failures *int, dispatchTask
log.RatedWarn(30.0, "failed to get data distribution", fields...)
} else {
*failures = 0
dh.handleDistResp(resp, dispatchTask)
dh.handleDistResp(ctx, resp, dispatchTask)
}
}
func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse, dispatchTask bool) {
func (dh *distHandler) handleDistResp(ctx context.Context, resp *querypb.GetDataDistributionResponse, dispatchTask bool) {
node := dh.nodeManager.Get(resp.GetNodeID())
if node == nil {
return
@ -130,9 +130,9 @@ func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse,
session.WithChannelCnt(len(resp.GetChannels())),
session.WithMemCapacity(resp.GetMemCapacityInMB()),
)
dh.updateSegmentsDistribution(resp)
dh.updateChannelsDistribution(resp)
dh.updateLeaderView(resp)
dh.updateSegmentsDistribution(ctx, resp)
dh.updateChannelsDistribution(ctx, resp)
dh.updateLeaderView(ctx, resp)
}
if dispatchTask {
@ -140,10 +140,10 @@ func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse,
}
}
func (dh *distHandler) updateSegmentsDistribution(resp *querypb.GetDataDistributionResponse) {
func (dh *distHandler) updateSegmentsDistribution(ctx context.Context, resp *querypb.GetDataDistributionResponse) {
updates := make([]*meta.Segment, 0, len(resp.GetSegments()))
for _, s := range resp.GetSegments() {
segmentInfo := dh.target.GetSealedSegment(s.GetCollection(), s.GetID(), meta.CurrentTargetFirst)
segmentInfo := dh.target.GetSealedSegment(ctx, s.GetCollection(), s.GetID(), meta.CurrentTargetFirst)
if segmentInfo == nil {
segmentInfo = &datapb.SegmentInfo{
ID: s.GetID(),
@ -166,10 +166,10 @@ func (dh *distHandler) updateSegmentsDistribution(resp *querypb.GetDataDistribut
dh.dist.SegmentDistManager.Update(resp.GetNodeID(), updates...)
}
func (dh *distHandler) updateChannelsDistribution(resp *querypb.GetDataDistributionResponse) {
func (dh *distHandler) updateChannelsDistribution(ctx context.Context, resp *querypb.GetDataDistributionResponse) {
updates := make([]*meta.DmChannel, 0, len(resp.GetChannels()))
for _, ch := range resp.GetChannels() {
channelInfo := dh.target.GetDmChannel(ch.GetCollection(), ch.GetChannel(), meta.CurrentTarget)
channelInfo := dh.target.GetDmChannel(ctx, ch.GetCollection(), ch.GetChannel(), meta.CurrentTarget)
var channel *meta.DmChannel
if channelInfo == nil {
channel = &meta.DmChannel{
@ -193,7 +193,7 @@ func (dh *distHandler) updateChannelsDistribution(resp *querypb.GetDataDistribut
dh.dist.ChannelDistManager.Update(resp.GetNodeID(), updates...)
}
func (dh *distHandler) updateLeaderView(resp *querypb.GetDataDistributionResponse) {
func (dh *distHandler) updateLeaderView(ctx context.Context, resp *querypb.GetDataDistributionResponse) {
updates := make([]*meta.LeaderView, 0, len(resp.GetLeaderViews()))
channels := lo.SliceToMap(resp.GetChannels(), func(channel *querypb.ChannelVersionInfo) (string, *querypb.ChannelVersionInfo) {
@ -248,7 +248,7 @@ func (dh *distHandler) updateLeaderView(resp *querypb.GetDataDistributionRespons
// if target version hasn't been synced, delegator will get empty readable segment list
// so shard leader should be unserviceable until target version is synced
currentTargetVersion := dh.target.GetCollectionTargetVersion(lview.GetCollection(), meta.CurrentTarget)
currentTargetVersion := dh.target.GetCollectionTargetVersion(ctx, lview.GetCollection(), meta.CurrentTarget)
if lview.TargetVersion <= 0 {
err := merr.WrapErrServiceInternal(fmt.Sprintf("target version mismatch, collection: %d, channel: %s, current target version: %v, leader version: %v",
lview.GetCollection(), lview.GetChannel(), currentTargetVersion, lview.TargetVersion))

View File

@ -66,9 +66,9 @@ func (suite *DistHandlerSuite) SetupSuite() {
suite.executedFlagChan = make(chan struct{}, 1)
suite.scheduler.EXPECT().GetExecutedFlag(mock.Anything).Return(suite.executedFlagChan).Maybe()
suite.target.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.target.EXPECT().GetDmChannel(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.target.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe()
suite.target.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.target.EXPECT().GetDmChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.target.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe()
}
func (suite *DistHandlerSuite) TestBasic() {
@ -77,7 +77,7 @@ func (suite *DistHandlerSuite) TestBasic() {
suite.dispatchMockCall = nil
}
suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{})
suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{})
suite.dispatchMockCall = suite.scheduler.EXPECT().Dispatch(mock.Anything).Maybe()
suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
@ -126,7 +126,7 @@ func (suite *DistHandlerSuite) TestGetDistributionFailed() {
suite.dispatchMockCall.Unset()
suite.dispatchMockCall = nil
}
suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}).Maybe()
suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}).Maybe()
suite.dispatchMockCall = suite.scheduler.EXPECT().Dispatch(mock.Anything).Maybe()
suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
@ -148,7 +148,7 @@ func (suite *DistHandlerSuite) TestForcePullDist() {
suite.dispatchMockCall = nil
}
suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}).Maybe()
suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}).Maybe()
suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,

View File

@ -49,7 +49,7 @@ import (
// may come from different replica group. We only need these shards to form a replica that serves query
// requests.
func (s *Server) checkAnyReplicaAvailable(collectionID int64) bool {
for _, replica := range s.meta.ReplicaManager.GetByCollection(collectionID) {
for _, replica := range s.meta.ReplicaManager.GetByCollection(s.ctx, collectionID) {
isAvailable := true
for _, node := range replica.GetRONodes() {
if s.nodeMgr.Get(node) == nil {
@ -64,9 +64,9 @@ func (s *Server) checkAnyReplicaAvailable(collectionID int64) bool {
return false
}
func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentInfo {
func (s *Server) getCollectionSegmentInfo(ctx context.Context, collection int64) []*querypb.SegmentInfo {
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(collection))
currentTargetSegmentsMap := s.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget)
currentTargetSegmentsMap := s.targetMgr.GetSealedSegmentsByCollection(ctx, collection, meta.CurrentTarget)
infos := make(map[int64]*querypb.SegmentInfo)
for _, segment := range segments {
if _, existCurrentTarget := currentTargetSegmentsMap[segment.GetID()]; !existCurrentTarget {
@ -104,7 +104,7 @@ func (s *Server) balanceSegments(ctx context.Context,
copyMode bool,
) error {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), zap.Int64("srcNode", srcNode))
plans := s.getBalancerFunc().AssignSegment(collectionID, segments, dstNodes, true)
plans := s.getBalancerFunc().AssignSegment(ctx, collectionID, segments, dstNodes, true)
for i := range plans {
plans[i].From = srcNode
plans[i].Replica = replica
@ -183,7 +183,7 @@ func (s *Server) balanceChannels(ctx context.Context,
) error {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID))
plans := s.getBalancerFunc().AssignChannel(channels, dstNodes, true)
plans := s.getBalancerFunc().AssignChannel(ctx, channels, dstNodes, true)
for i := range plans {
plans[i].From = srcNode
plans[i].Replica = replica
@ -458,16 +458,16 @@ func (s *Server) tryGetNodesMetrics(ctx context.Context, req *milvuspb.GetMetric
return ret
}
func (s *Server) fillReplicaInfo(replica *meta.Replica, withShardNodes bool) *milvuspb.ReplicaInfo {
func (s *Server) fillReplicaInfo(ctx context.Context, replica *meta.Replica, withShardNodes bool) *milvuspb.ReplicaInfo {
info := &milvuspb.ReplicaInfo{
ReplicaID: replica.GetID(),
CollectionID: replica.GetCollectionID(),
NodeIds: replica.GetNodes(),
ResourceGroupName: replica.GetResourceGroup(),
NumOutboundNode: s.meta.GetOutgoingNodeNumByReplica(replica),
NumOutboundNode: s.meta.GetOutgoingNodeNumByReplica(ctx, replica),
}
channels := s.targetMgr.GetDmChannelsByCollection(replica.GetCollectionID(), meta.CurrentTarget)
channels := s.targetMgr.GetDmChannelsByCollection(ctx, replica.GetCollectionID(), meta.CurrentTarget)
if len(channels) == 0 {
log.Warn("failed to get channels, collection may be not loaded or in recovering", zap.Int64("collectionID", replica.GetCollectionID()))
return info

View File

@ -98,14 +98,14 @@ func (job *LoadCollectionJob) PreExecute() error {
req.ResourceGroups = []string{meta.DefaultResourceGroupName}
}
collection := job.meta.GetCollection(req.GetCollectionID())
collection := job.meta.GetCollection(job.ctx, req.GetCollectionID())
if collection == nil {
return nil
}
if collection.GetReplicaNumber() != req.GetReplicaNumber() {
msg := fmt.Sprintf("collection with different replica number %d existed, release this collection first before changing its replica number",
job.meta.GetReplicaNumber(req.GetCollectionID()),
job.meta.GetReplicaNumber(job.ctx, req.GetCollectionID()),
)
log.Warn(msg)
return merr.WrapErrParameterInvalid(collection.GetReplicaNumber(), req.GetReplicaNumber(), "can't change the replica number for loaded collection")
@ -125,7 +125,7 @@ func (job *LoadCollectionJob) PreExecute() error {
)
return merr.WrapErrParameterInvalid(collection.GetLoadFields(), req.GetLoadFields(), "can't change the load field list for loaded collection")
}
collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(collection.GetCollectionID()).Collect()
collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(job.ctx, collection.GetCollectionID()).Collect()
left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups())
if len(left) > 0 || len(right) > 0 {
msg := fmt.Sprintf("collection with different resource groups %v existed, release this collection first before changing its resource groups",
@ -149,7 +149,7 @@ func (job *LoadCollectionJob) Execute() error {
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()),
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID()),
func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
@ -163,10 +163,10 @@ func (job *LoadCollectionJob) Execute() error {
job.undo.LackPartitions = lackPartitionIDs
log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs))
colExisted := job.meta.CollectionManager.Exist(req.GetCollectionID())
colExisted := job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID())
if !colExisted {
// Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to clear stale replicas"
log.Warn(msg, zap.Error(err))
@ -175,7 +175,7 @@ func (job *LoadCollectionJob) Execute() error {
}
// 2. create replica if not exist
replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID())
replicas := job.meta.ReplicaManager.GetByCollection(job.ctx, req.GetCollectionID())
if len(replicas) == 0 {
collectionInfo, err := job.broker.DescribeCollection(job.ctx, req.GetCollectionID())
if err != nil {
@ -184,7 +184,7 @@ func (job *LoadCollectionJob) Execute() error {
// API of LoadCollection is wired, we should use map[resourceGroupNames]replicaNumber as input, to keep consistency with `TransferReplica` API.
// Then we can implement dynamic replica changed in different resource group independently.
_, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames())
_, err = utils.SpawnReplicasWithRG(job.ctx, job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames())
if err != nil {
msg := "failed to spawn replica for collection"
log.Warn(msg, zap.Error(err))
@ -227,7 +227,7 @@ func (job *LoadCollectionJob) Execute() error {
LoadSpan: sp,
}
job.undo.IsNewCollection = true
err = job.meta.CollectionManager.PutCollection(collection, partitions...)
err = job.meta.CollectionManager.PutCollection(job.ctx, collection, partitions...)
if err != nil {
msg := "failed to store collection and partitions"
log.Warn(msg, zap.Error(err))
@ -312,7 +312,7 @@ func (job *LoadPartitionJob) PreExecute() error {
req.ResourceGroups = []string{meta.DefaultResourceGroupName}
}
collection := job.meta.GetCollection(req.GetCollectionID())
collection := job.meta.GetCollection(job.ctx, req.GetCollectionID())
if collection == nil {
return nil
}
@ -337,7 +337,7 @@ func (job *LoadPartitionJob) PreExecute() error {
)
return merr.WrapErrParameterInvalid(collection.GetLoadFields(), req.GetLoadFields(), "can't change the load field list for loaded collection")
}
collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(collection.GetCollectionID()).Collect()
collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(job.ctx, collection.GetCollectionID()).Collect()
left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups())
if len(left) > 0 || len(right) > 0 {
msg := fmt.Sprintf("collection with different resource groups %v existed, release this collection first before changing its resource groups",
@ -358,7 +358,7 @@ func (job *LoadPartitionJob) Execute() error {
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// 1. Fetch target partitions
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()),
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID()),
func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
@ -373,9 +373,9 @@ func (job *LoadPartitionJob) Execute() error {
log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs))
var err error
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) {
// Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to clear stale replicas"
log.Warn(msg, zap.Error(err))
@ -384,13 +384,13 @@ func (job *LoadPartitionJob) Execute() error {
}
// 2. create replica if not exist
replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID())
replicas := job.meta.ReplicaManager.GetByCollection(context.TODO(), req.GetCollectionID())
if len(replicas) == 0 {
collectionInfo, err := job.broker.DescribeCollection(job.ctx, req.GetCollectionID())
if err != nil {
return err
}
_, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames())
_, err = utils.SpawnReplicasWithRG(job.ctx, job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames())
if err != nil {
msg := "failed to spawn replica for collection"
log.Warn(msg, zap.Error(err))
@ -419,7 +419,7 @@ func (job *LoadPartitionJob) Execute() error {
}
})
ctx, sp := otel.Tracer(typeutil.QueryCoordRole).Start(job.ctx, "LoadPartition", trace.WithNewRoot())
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) {
job.undo.IsNewCollection = true
collection := &meta.Collection{
@ -434,14 +434,14 @@ func (job *LoadPartitionJob) Execute() error {
CreatedAt: time.Now(),
LoadSpan: sp,
}
err = job.meta.CollectionManager.PutCollection(collection, partitions...)
err = job.meta.CollectionManager.PutCollection(job.ctx, collection, partitions...)
if err != nil {
msg := "failed to store collection and partitions"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
} else { // collection exists, put partitions only
err = job.meta.CollectionManager.PutPartition(partitions...)
err = job.meta.CollectionManager.PutPartition(job.ctx, partitions...)
if err != nil {
msg := "failed to store partitions"
log.Warn(msg, zap.Error(err))

View File

@ -77,25 +77,25 @@ func (job *ReleaseCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID())
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID())
toRelease := lo.Map(loadedPartitions, func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
releasePartitions(job.ctx, job.meta, job.cluster, req.GetCollectionID(), toRelease...)
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
err := job.meta.CollectionManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to remove collection"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to remove replicas"
log.Warn(msg, zap.Error(err))
@ -166,12 +166,12 @@ func (job *ReleasePartitionJob) Execute() error {
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID())
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID())
toRelease := lo.FilterMap(loadedPartitions, func(partition *meta.Partition, _ int) (int64, bool) {
return partition.GetPartitionID(), lo.Contains(req.GetPartitionIDs(), partition.GetPartitionID())
})
@ -185,13 +185,13 @@ func (job *ReleasePartitionJob) Execute() error {
// If all partitions are released, clear all
if len(toRelease) == len(loadedPartitions) {
log.Info("release partitions covers all partitions, will remove the whole collection")
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
err := job.meta.CollectionManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to remove replicas", zap.Error(err))
}
@ -207,7 +207,7 @@ func (job *ReleasePartitionJob) Execute() error {
waitCollectionReleased(job.dist, job.checkerController, req.GetCollectionID())
} else {
err := job.meta.CollectionManager.RemovePartition(req.GetCollectionID(), toRelease...)
err := job.meta.CollectionManager.RemovePartition(job.ctx, req.GetCollectionID(), toRelease...)
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))

View File

@ -65,13 +65,13 @@ func (job *SyncNewCreatedPartitionJob) Execute() error {
)
// check if collection not load or loadType is loadPartition
collection := job.meta.GetCollection(job.req.GetCollectionID())
collection := job.meta.GetCollection(job.ctx, job.req.GetCollectionID())
if collection == nil || collection.GetLoadType() == querypb.LoadType_LoadPartition {
return nil
}
// check if partition already existed
if partition := job.meta.GetPartition(job.req.GetPartitionID()); partition != nil {
if partition := job.meta.GetPartition(job.ctx, job.req.GetPartitionID()); partition != nil {
return nil
}
@ -89,7 +89,7 @@ func (job *SyncNewCreatedPartitionJob) Execute() error {
LoadPercentage: 100,
CreatedAt: time.Now(),
}
err = job.meta.CollectionManager.PutPartition(partition)
err = job.meta.CollectionManager.PutPartition(job.ctx, partition)
if err != nil {
msg := "failed to store partitions"
log.Warn(msg, zap.Error(err))

View File

@ -77,6 +77,8 @@ type JobSuite struct {
// Test objects
scheduler *Scheduler
ctx context.Context
}
func (suite *JobSuite) SetupSuite() {
@ -160,6 +162,7 @@ func (suite *JobSuite) SetupTest() {
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.ctx = context.Background()
suite.store = querycoord.NewCatalog(suite.kv)
suite.dist = meta.NewDistributionManager()
@ -195,9 +198,9 @@ func (suite *JobSuite) SetupTest() {
Hostname: "localhost",
}))
suite.meta.HandleNodeUp(1000)
suite.meta.HandleNodeUp(2000)
suite.meta.HandleNodeUp(3000)
suite.meta.HandleNodeUp(suite.ctx, 1000)
suite.meta.HandleNodeUp(suite.ctx, 2000)
suite.meta.HandleNodeUp(suite.ctx, 3000)
suite.checkerController = &checkers.CheckerController{}
suite.collectionObserver = observers.NewCollectionObserver(
@ -253,8 +256,8 @@ func (suite *JobSuite) TestLoadCollection() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
@ -346,9 +349,9 @@ func (suite *JobSuite) TestLoadCollection() {
},
}
suite.meta.ResourceManager.AddResourceGroup("rg1", cfg)
suite.meta.ResourceManager.AddResourceGroup("rg2", cfg)
suite.meta.ResourceManager.AddResourceGroup("rg3", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg1", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg3", cfg)
// Load with 3 replica on 1 rg
req := &querypb.LoadCollectionRequest{
@ -455,8 +458,8 @@ func (suite *JobSuite) TestLoadCollectionWithLoadFields() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
})
@ -580,8 +583,8 @@ func (suite *JobSuite) TestLoadPartition() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
@ -704,9 +707,9 @@ func (suite *JobSuite) TestLoadPartition() {
NodeNum: 1,
},
}
suite.meta.ResourceManager.AddResourceGroup("rg1", cfg)
suite.meta.ResourceManager.AddResourceGroup("rg2", cfg)
suite.meta.ResourceManager.AddResourceGroup("rg3", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg1", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg3", cfg)
// test load 3 replica in 1 rg, should pass rg check
req := &querypb.LoadPartitionsRequest{
@ -786,8 +789,8 @@ func (suite *JobSuite) TestLoadPartitionWithLoadFields() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
})
@ -941,7 +944,7 @@ func (suite *JobSuite) TestDynamicLoad() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(collection, p0, p1, p2)
// loaded: p0, p1, p2
@ -961,13 +964,13 @@ func (suite *JobSuite) TestDynamicLoad() {
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(collection, p0, p1)
job = newLoadPartJob(p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(collection, p2)
// loaded: p0, p1
@ -978,13 +981,13 @@ func (suite *JobSuite) TestDynamicLoad() {
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(collection, p0, p1)
job = newLoadPartJob(p1, p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(collection, p2)
// loaded: p0, p1
@ -995,13 +998,13 @@ func (suite *JobSuite) TestDynamicLoad() {
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(collection, p0, p1)
colJob := newLoadColJob()
suite.scheduler.Add(colJob)
err = colJob.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(collection, p2)
}
@ -1166,8 +1169,8 @@ func (suite *JobSuite) TestReleasePartition() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.True(suite.meta.Exist(collection))
partitions := suite.meta.GetPartitionsByCollection(collection)
suite.True(suite.meta.Exist(ctx, collection))
partitions := suite.meta.GetPartitionsByCollection(ctx, collection)
suite.Len(partitions, 1)
suite.Equal(suite.partitions[collection][0], partitions[0].GetPartitionID())
suite.assertPartitionReleased(collection, suite.partitions[collection][1:]...)
@ -1247,7 +1250,7 @@ func (suite *JobSuite) TestDynamicRelease() {
err = job.Wait()
suite.NoError(err)
suite.assertPartitionReleased(col0, p0, p1, p2)
suite.False(suite.meta.Exist(col0))
suite.False(suite.meta.Exist(ctx, col0))
// loaded: p0, p1, p2
// action: release col
@ -1275,14 +1278,15 @@ func (suite *JobSuite) TestDynamicRelease() {
}
func (suite *JobSuite) TestLoadCollectionStoreFailed() {
ctx := context.Background()
// Store collection failed
store := mocks.NewQueryCoordCatalog(suite.T())
suite.meta = meta.NewMeta(RandomIncrementIDAllocator(), store, suite.nodeMgr)
store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil)
suite.meta.HandleNodeUp(1000)
suite.meta.HandleNodeUp(2000)
suite.meta.HandleNodeUp(3000)
suite.meta.HandleNodeUp(ctx, 1000)
suite.meta.HandleNodeUp(ctx, 2000)
suite.meta.HandleNodeUp(ctx, 3000)
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
@ -1290,9 +1294,9 @@ func (suite *JobSuite) TestLoadCollectionStoreFailed() {
}
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
err := errors.New("failed to store collection")
store.EXPECT().SaveReplica(mock.Anything).Return(nil)
store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err)
store.EXPECT().ReleaseReplicas(collection).Return(nil)
store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err)
store.EXPECT().ReleaseReplicas(mock.Anything, collection).Return(nil)
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
@ -1316,14 +1320,15 @@ func (suite *JobSuite) TestLoadCollectionStoreFailed() {
}
func (suite *JobSuite) TestLoadPartitionStoreFailed() {
ctx := context.Background()
// Store partition failed
store := mocks.NewQueryCoordCatalog(suite.T())
suite.meta = meta.NewMeta(RandomIncrementIDAllocator(), store, suite.nodeMgr)
store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil)
suite.meta.HandleNodeUp(1000)
suite.meta.HandleNodeUp(2000)
suite.meta.HandleNodeUp(3000)
store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.meta.HandleNodeUp(ctx, 1000)
suite.meta.HandleNodeUp(ctx, 2000)
suite.meta.HandleNodeUp(ctx, 3000)
err := errors.New("failed to store collection")
for _, collection := range suite.collections {
@ -1331,9 +1336,9 @@ func (suite *JobSuite) TestLoadPartitionStoreFailed() {
continue
}
store.EXPECT().SaveReplica(mock.Anything).Return(nil)
store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err)
store.EXPECT().ReleaseReplicas(collection).Return(nil)
store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err)
store.EXPECT().ReleaseReplicas(mock.Anything, collection).Return(nil)
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
@ -1548,6 +1553,7 @@ func (suite *JobSuite) TestCallReleasePartitionFailed() {
func (suite *JobSuite) TestSyncNewCreatedPartition() {
newPartition := int64(999)
ctx := context.Background()
// test sync new created partition
suite.loadAll()
@ -1565,7 +1571,7 @@ func (suite *JobSuite) TestSyncNewCreatedPartition() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
partition := suite.meta.CollectionManager.GetPartition(newPartition)
partition := suite.meta.CollectionManager.GetPartition(ctx, newPartition)
suite.NotNil(partition)
suite.Equal(querypb.LoadStatus_Loaded, partition.GetStatus())
@ -1624,11 +1630,11 @@ func (suite *JobSuite) loadAll() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetCollection(collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.True(suite.meta.Exist(ctx, collection))
suite.NotNil(suite.meta.GetCollection(ctx, collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
} else {
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
@ -1649,11 +1655,11 @@ func (suite *JobSuite) loadAll() {
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetCollection(collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.True(suite.meta.Exist(ctx, collection))
suite.NotNil(suite.meta.GetCollection(ctx, collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
}
}
}
@ -1684,54 +1690,58 @@ func (suite *JobSuite) releaseAll() {
}
func (suite *JobSuite) assertCollectionLoaded(collection int64) {
suite.True(suite.meta.Exist(collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
ctx := context.Background()
suite.True(suite.meta.Exist(ctx, collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(ctx, collection)))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget))
}
for _, segments := range suite.segments[collection] {
for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget))
suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
}
}
}
func (suite *JobSuite) assertPartitionLoaded(collection int64, partitionIDs ...int64) {
suite.True(suite.meta.Exist(collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
ctx := context.Background()
suite.True(suite.meta.Exist(ctx, collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(ctx, collection)))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget))
}
for partitionID, segments := range suite.segments[collection] {
if !lo.Contains(partitionIDs, partitionID) {
continue
}
suite.NotNil(suite.meta.GetPartition(partitionID))
suite.NotNil(suite.meta.GetPartition(ctx, partitionID))
for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget))
suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
}
}
}
func (suite *JobSuite) assertCollectionReleased(collection int64) {
suite.False(suite.meta.Exist(collection))
suite.Equal(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
ctx := context.Background()
suite.False(suite.meta.Exist(ctx, collection))
suite.Equal(0, len(suite.meta.ReplicaManager.GetByCollection(ctx, collection)))
for _, channel := range suite.channels[collection] {
suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
suite.Nil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget))
}
for _, partitions := range suite.segments[collection] {
for _, segment := range partitions {
suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget))
suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
}
}
}
func (suite *JobSuite) assertPartitionReleased(collection int64, partitionIDs ...int64) {
ctx := context.Background()
for _, partition := range partitionIDs {
suite.Nil(suite.meta.GetPartition(partition))
suite.Nil(suite.meta.GetPartition(ctx, partition))
segments := suite.segments[collection][partition]
for _, segment := range segments {
suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget))
suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
}
}
}

View File

@ -62,7 +62,7 @@ func NewUpdateLoadConfigJob(ctx context.Context,
}
func (job *UpdateLoadConfigJob) Execute() error {
if !job.meta.CollectionManager.Exist(job.collectionID) {
if !job.meta.CollectionManager.Exist(job.ctx, job.collectionID) {
msg := "modify replica for unloaded collection is not supported"
err := merr.WrapErrCollectionNotLoaded(msg)
log.Warn(msg, zap.Error(err))
@ -83,7 +83,7 @@ func (job *UpdateLoadConfigJob) Execute() error {
var err error
// 2. reassign
toSpawn, toTransfer, toRelease, err := utils.ReassignReplicaToRG(job.meta, job.collectionID, job.newReplicaNumber, job.newResourceGroups)
toSpawn, toTransfer, toRelease, err := utils.ReassignReplicaToRG(job.ctx, job.meta, job.collectionID, job.newReplicaNumber, job.newResourceGroups)
if err != nil {
log.Warn("failed to reassign replica", zap.Error(err))
return err
@ -98,8 +98,8 @@ func (job *UpdateLoadConfigJob) Execute() error {
zap.Any("toRelease", toRelease))
// 3. try to spawn new replica
channels := job.targetMgr.GetDmChannelsByCollection(job.collectionID, meta.CurrentTargetFirst)
newReplicas, spawnErr := job.meta.ReplicaManager.Spawn(job.collectionID, toSpawn, lo.Keys(channels))
channels := job.targetMgr.GetDmChannelsByCollection(job.ctx, job.collectionID, meta.CurrentTargetFirst)
newReplicas, spawnErr := job.meta.ReplicaManager.Spawn(job.ctx, job.collectionID, toSpawn, lo.Keys(channels))
if spawnErr != nil {
log.Warn("failed to spawn replica", zap.Error(spawnErr))
err := spawnErr
@ -109,7 +109,7 @@ func (job *UpdateLoadConfigJob) Execute() error {
if err != nil {
// roll back replica from meta
replicaIDs := lo.Map(newReplicas, func(r *meta.Replica, _ int) int64 { return r.GetID() })
err := job.meta.ReplicaManager.RemoveReplicas(job.collectionID, replicaIDs...)
err := job.meta.ReplicaManager.RemoveReplicas(job.ctx, job.collectionID, replicaIDs...)
if err != nil {
log.Warn("failed to remove replicas", zap.Int64s("replicaIDs", replicaIDs), zap.Error(err))
}
@ -125,7 +125,7 @@ func (job *UpdateLoadConfigJob) Execute() error {
replicaOldRG[replica.GetID()] = replica.GetResourceGroup()
}
if transferErr := job.meta.ReplicaManager.MoveReplica(rg, replicas); transferErr != nil {
if transferErr := job.meta.ReplicaManager.MoveReplica(job.ctx, rg, replicas); transferErr != nil {
log.Warn("failed to transfer replica for collection", zap.Int64("collectionID", collectionID), zap.Error(transferErr))
err = transferErr
return err
@ -138,7 +138,7 @@ func (job *UpdateLoadConfigJob) Execute() error {
for _, replica := range replicas {
oldRG := replicaOldRG[replica.GetID()]
if replica.GetResourceGroup() != oldRG {
if err := job.meta.ReplicaManager.TransferReplica(replica.GetID(), replica.GetResourceGroup(), oldRG, 1); err != nil {
if err := job.meta.ReplicaManager.TransferReplica(job.ctx, replica.GetID(), replica.GetResourceGroup(), oldRG, 1); err != nil {
log.Warn("failed to roll back replicas", zap.Int64("replica", replica.GetID()), zap.Error(err))
}
}
@ -148,17 +148,17 @@ func (job *UpdateLoadConfigJob) Execute() error {
}()
// 5. remove replica from meta
err = job.meta.ReplicaManager.RemoveReplicas(job.collectionID, toRelease...)
err = job.meta.ReplicaManager.RemoveReplicas(job.ctx, job.collectionID, toRelease...)
if err != nil {
log.Warn("failed to remove replicas", zap.Int64s("replicaIDs", toRelease), zap.Error(err))
return err
}
// 6. recover node distribution among replicas
utils.RecoverReplicaOfCollection(job.meta, job.collectionID)
utils.RecoverReplicaOfCollection(job.ctx, job.meta, job.collectionID)
// 7. update replica number in meta
err = job.meta.UpdateReplicaNumber(job.collectionID, job.newReplicaNumber)
err = job.meta.UpdateReplicaNumber(job.ctx, job.collectionID, job.newReplicaNumber)
if err != nil {
msg := "failed to update replica number"
log.Warn(msg, zap.Error(err))

View File

@ -68,9 +68,9 @@ func (u *UndoList) RollBack() {
var err error
if u.IsNewCollection || u.IsReplicaCreated {
err = u.meta.CollectionManager.RemoveCollection(u.CollectionID)
err = u.meta.CollectionManager.RemoveCollection(u.ctx, u.CollectionID)
} else {
err = u.meta.CollectionManager.RemovePartition(u.CollectionID, u.LackPartitions...)
err = u.meta.CollectionManager.RemovePartition(u.ctx, u.CollectionID, u.LackPartitions...)
}
if err != nil {
log.Warn("failed to rollback collection from meta", zap.Error(err))

View File

@ -90,7 +90,7 @@ func loadPartitions(ctx context.Context,
return err
}
replicas := meta.ReplicaManager.GetByCollection(collection)
replicas := meta.ReplicaManager.GetByCollection(ctx, collection)
loadReq := &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
@ -124,7 +124,7 @@ func releasePartitions(ctx context.Context,
partitions ...int64,
) {
log := log.Ctx(ctx).With(zap.Int64("collection", collection), zap.Int64s("partitions", partitions))
replicas := meta.ReplicaManager.GetByCollection(collection)
replicas := meta.ReplicaManager.GetByCollection(ctx, collection)
releaseReq := &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,

View File

@ -122,17 +122,17 @@ func NewCollectionManager(catalog metastore.QueryCoordCatalog) *CollectionManage
// Recover recovers collections from kv store,
// panics if failed
func (m *CollectionManager) Recover(broker Broker) error {
collections, err := m.catalog.GetCollections()
func (m *CollectionManager) Recover(ctx context.Context, broker Broker) error {
collections, err := m.catalog.GetCollections(ctx)
if err != nil {
return err
}
partitions, err := m.catalog.GetPartitions()
partitions, err := m.catalog.GetPartitions(ctx)
if err != nil {
return err
}
ctx := log.WithTraceID(context.Background(), strconv.FormatInt(time.Now().UnixNano(), 10))
ctx = log.WithTraceID(ctx, strconv.FormatInt(time.Now().UnixNano(), 10))
ctxLog := log.Ctx(ctx)
ctxLog.Info("recover collections and partitions from kv store")
@ -141,13 +141,13 @@ func (m *CollectionManager) Recover(broker Broker) error {
ctxLog.Info("skip recovery and release collection due to invalid replica number",
zap.Int64("collectionID", collection.GetCollectionID()),
zap.Int32("replicaNumber", collection.GetReplicaNumber()))
m.catalog.ReleaseCollection(collection.GetCollectionID())
m.catalog.ReleaseCollection(ctx, collection.GetCollectionID())
continue
}
if collection.GetStatus() != querypb.LoadStatus_Loaded {
if collection.RecoverTimes >= paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32() {
m.catalog.ReleaseCollection(collection.CollectionID)
m.catalog.ReleaseCollection(ctx, collection.CollectionID)
ctxLog.Info("recover loading collection times reach limit, release collection",
zap.Int64("collectionID", collection.CollectionID),
zap.Int32("recoverTimes", collection.RecoverTimes))
@ -155,11 +155,11 @@ func (m *CollectionManager) Recover(broker Broker) error {
}
// update recoverTimes meta in etcd
collection.RecoverTimes += 1
m.putCollection(true, &Collection{CollectionLoadInfo: collection})
m.putCollection(ctx, true, &Collection{CollectionLoadInfo: collection})
continue
}
err := m.upgradeLoadFields(collection, broker)
err := m.upgradeLoadFields(ctx, collection, broker)
if err != nil {
if errors.Is(err, merr.ErrCollectionNotFound) {
log.Warn("collection not found, skip upgrade logic and wait for release")
@ -170,7 +170,7 @@ func (m *CollectionManager) Recover(broker Broker) error {
}
// update collection's CreateAt and UpdateAt to now after qc restart
m.putCollection(false, &Collection{
m.putCollection(ctx, false, &Collection{
CollectionLoadInfo: collection,
CreatedAt: time.Now(),
})
@ -181,7 +181,7 @@ func (m *CollectionManager) Recover(broker Broker) error {
// Partitions not loaded done should be deprecated
if partition.GetStatus() != querypb.LoadStatus_Loaded {
if partition.RecoverTimes >= paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32() {
m.catalog.ReleaseCollection(collection)
m.catalog.ReleaseCollection(ctx, collection)
ctxLog.Info("recover loading partition times reach limit, release collection",
zap.Int64("collectionID", collection),
zap.Int32("recoverTimes", partition.RecoverTimes))
@ -189,7 +189,7 @@ func (m *CollectionManager) Recover(broker Broker) error {
}
partition.RecoverTimes += 1
m.putPartition([]*Partition{
m.putPartition(ctx, []*Partition{
{
PartitionLoadInfo: partition,
CreatedAt: time.Now(),
@ -198,7 +198,7 @@ func (m *CollectionManager) Recover(broker Broker) error {
continue
}
m.putPartition([]*Partition{
m.putPartition(ctx, []*Partition{
{
PartitionLoadInfo: partition,
CreatedAt: time.Now(),
@ -207,7 +207,7 @@ func (m *CollectionManager) Recover(broker Broker) error {
}
}
err = m.upgradeRecover(broker)
err = m.upgradeRecover(ctx, broker)
if err != nil {
log.Warn("upgrade recover failed", zap.Error(err))
return err
@ -215,7 +215,7 @@ func (m *CollectionManager) Recover(broker Broker) error {
return nil
}
func (m *CollectionManager) upgradeLoadFields(collection *querypb.CollectionLoadInfo, broker Broker) error {
func (m *CollectionManager) upgradeLoadFields(ctx context.Context, collection *querypb.CollectionLoadInfo, broker Broker) error {
// only fill load fields when value is nil
if collection.LoadFields != nil {
return nil
@ -234,7 +234,7 @@ func (m *CollectionManager) upgradeLoadFields(collection *querypb.CollectionLoad
})
// put updated meta back to store
err = m.putCollection(true, &Collection{
err = m.putCollection(ctx, true, &Collection{
CollectionLoadInfo: collection,
LoadPercentage: 100,
})
@ -246,10 +246,10 @@ func (m *CollectionManager) upgradeLoadFields(collection *querypb.CollectionLoad
}
// upgradeRecover recovers from old version <= 2.2.x for compatibility.
func (m *CollectionManager) upgradeRecover(broker Broker) error {
func (m *CollectionManager) upgradeRecover(ctx context.Context, broker Broker) error {
// for loaded collection from 2.2, it only save a old version CollectionLoadInfo without LoadType.
// we should update the CollectionLoadInfo and save all PartitionLoadInfo to meta store
for _, collection := range m.GetAllCollections() {
for _, collection := range m.GetAllCollections(ctx) {
if collection.GetLoadType() == querypb.LoadType_UnKnownType {
partitionIDs, err := broker.GetPartitions(context.Background(), collection.GetCollectionID())
if err != nil {
@ -267,14 +267,14 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error {
LoadPercentage: 100,
}
})
err = m.putPartition(partitions, true)
err = m.putPartition(ctx, partitions, true)
if err != nil {
return err
}
newInfo := collection.Clone()
newInfo.LoadType = querypb.LoadType_LoadCollection
err = m.putCollection(true, newInfo)
err = m.putCollection(ctx, true, newInfo)
if err != nil {
return err
}
@ -283,7 +283,7 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error {
// for loaded partition from 2.2, it only save load PartitionLoadInfo.
// we should save it's CollectionLoadInfo to meta store
for _, partition := range m.GetAllPartitions() {
for _, partition := range m.GetAllPartitions(ctx) {
// In old version, collection would NOT be stored if the partition existed.
if _, ok := m.collections[partition.GetCollectionID()]; !ok {
col := &Collection{
@ -296,7 +296,7 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error {
},
LoadPercentage: 100,
}
err := m.PutCollection(col)
err := m.PutCollection(ctx, col)
if err != nil {
return err
}
@ -305,21 +305,21 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error {
return nil
}
func (m *CollectionManager) GetCollection(collectionID typeutil.UniqueID) *Collection {
func (m *CollectionManager) GetCollection(ctx context.Context, collectionID typeutil.UniqueID) *Collection {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return m.collections[collectionID]
}
func (m *CollectionManager) GetPartition(partitionID typeutil.UniqueID) *Partition {
func (m *CollectionManager) GetPartition(ctx context.Context, partitionID typeutil.UniqueID) *Partition {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return m.partitions[partitionID]
}
func (m *CollectionManager) GetLoadType(collectionID typeutil.UniqueID) querypb.LoadType {
func (m *CollectionManager) GetLoadType(ctx context.Context, collectionID typeutil.UniqueID) querypb.LoadType {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -330,7 +330,7 @@ func (m *CollectionManager) GetLoadType(collectionID typeutil.UniqueID) querypb.
return querypb.LoadType_UnKnownType
}
func (m *CollectionManager) GetReplicaNumber(collectionID typeutil.UniqueID) int32 {
func (m *CollectionManager) GetReplicaNumber(ctx context.Context, collectionID typeutil.UniqueID) int32 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -342,7 +342,7 @@ func (m *CollectionManager) GetReplicaNumber(collectionID typeutil.UniqueID) int
}
// CalculateLoadPercentage checks if collection is currently fully loaded.
func (m *CollectionManager) CalculateLoadPercentage(collectionID typeutil.UniqueID) int32 {
func (m *CollectionManager) CalculateLoadPercentage(ctx context.Context, collectionID typeutil.UniqueID) int32 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -362,7 +362,7 @@ func (m *CollectionManager) calculateLoadPercentage(collectionID typeutil.Unique
return -1
}
func (m *CollectionManager) GetPartitionLoadPercentage(partitionID typeutil.UniqueID) int32 {
func (m *CollectionManager) GetPartitionLoadPercentage(ctx context.Context, partitionID typeutil.UniqueID) int32 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -373,7 +373,7 @@ func (m *CollectionManager) GetPartitionLoadPercentage(partitionID typeutil.Uniq
return -1
}
func (m *CollectionManager) CalculateLoadStatus(collectionID typeutil.UniqueID) querypb.LoadStatus {
func (m *CollectionManager) CalculateLoadStatus(ctx context.Context, collectionID typeutil.UniqueID) querypb.LoadStatus {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -396,7 +396,7 @@ func (m *CollectionManager) CalculateLoadStatus(collectionID typeutil.UniqueID)
return querypb.LoadStatus_Invalid
}
func (m *CollectionManager) GetFieldIndex(collectionID typeutil.UniqueID) map[int64]int64 {
func (m *CollectionManager) GetFieldIndex(ctx context.Context, collectionID typeutil.UniqueID) map[int64]int64 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -407,7 +407,7 @@ func (m *CollectionManager) GetFieldIndex(collectionID typeutil.UniqueID) map[in
return nil
}
func (m *CollectionManager) GetLoadFields(collectionID typeutil.UniqueID) []int64 {
func (m *CollectionManager) GetLoadFields(ctx context.Context, collectionID typeutil.UniqueID) []int64 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -418,7 +418,7 @@ func (m *CollectionManager) GetLoadFields(collectionID typeutil.UniqueID) []int6
return nil
}
func (m *CollectionManager) Exist(collectionID typeutil.UniqueID) bool {
func (m *CollectionManager) Exist(ctx context.Context, collectionID typeutil.UniqueID) bool {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -427,7 +427,7 @@ func (m *CollectionManager) Exist(collectionID typeutil.UniqueID) bool {
}
// GetAll returns the collection ID of all loaded collections
func (m *CollectionManager) GetAll() []int64 {
func (m *CollectionManager) GetAll(ctx context.Context) []int64 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -438,21 +438,21 @@ func (m *CollectionManager) GetAll() []int64 {
return ids.Collect()
}
func (m *CollectionManager) GetAllCollections() []*Collection {
func (m *CollectionManager) GetAllCollections(ctx context.Context) []*Collection {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return lo.Values(m.collections)
}
func (m *CollectionManager) GetAllPartitions() []*Partition {
func (m *CollectionManager) GetAllPartitions(ctx context.Context) []*Partition {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return lo.Values(m.partitions)
}
func (m *CollectionManager) GetPartitionsByCollection(collectionID typeutil.UniqueID) []*Partition {
func (m *CollectionManager) GetPartitionsByCollection(ctx context.Context, collectionID typeutil.UniqueID) []*Partition {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -463,26 +463,26 @@ func (m *CollectionManager) getPartitionsByCollection(collectionID typeutil.Uniq
return lo.Map(m.collectionPartitions[collectionID].Collect(), func(partitionID int64, _ int) *Partition { return m.partitions[partitionID] })
}
func (m *CollectionManager) PutCollection(collection *Collection, partitions ...*Partition) error {
func (m *CollectionManager) PutCollection(ctx context.Context, collection *Collection, partitions ...*Partition) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
return m.putCollection(true, collection, partitions...)
return m.putCollection(ctx, true, collection, partitions...)
}
func (m *CollectionManager) PutCollectionWithoutSave(collection *Collection) error {
func (m *CollectionManager) PutCollectionWithoutSave(ctx context.Context, collection *Collection) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
return m.putCollection(false, collection)
return m.putCollection(ctx, false, collection)
}
func (m *CollectionManager) putCollection(withSave bool, collection *Collection, partitions ...*Partition) error {
func (m *CollectionManager) putCollection(ctx context.Context, withSave bool, collection *Collection, partitions ...*Partition) error {
if withSave {
partitionInfos := lo.Map(partitions, func(partition *Partition, _ int) *querypb.PartitionLoadInfo {
return partition.PartitionLoadInfo
})
err := m.catalog.SaveCollection(collection.CollectionLoadInfo, partitionInfos...)
err := m.catalog.SaveCollection(ctx, collection.CollectionLoadInfo, partitionInfos...)
if err != nil {
return err
}
@ -504,26 +504,26 @@ func (m *CollectionManager) putCollection(withSave bool, collection *Collection,
return nil
}
func (m *CollectionManager) PutPartition(partitions ...*Partition) error {
func (m *CollectionManager) PutPartition(ctx context.Context, partitions ...*Partition) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
return m.putPartition(partitions, true)
return m.putPartition(ctx, partitions, true)
}
func (m *CollectionManager) PutPartitionWithoutSave(partitions ...*Partition) error {
func (m *CollectionManager) PutPartitionWithoutSave(ctx context.Context, partitions ...*Partition) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
return m.putPartition(partitions, false)
return m.putPartition(ctx, partitions, false)
}
func (m *CollectionManager) putPartition(partitions []*Partition, withSave bool) error {
func (m *CollectionManager) putPartition(ctx context.Context, partitions []*Partition, withSave bool) error {
if withSave {
loadInfos := lo.Map(partitions, func(partition *Partition, _ int) *querypb.PartitionLoadInfo {
return partition.PartitionLoadInfo
})
err := m.catalog.SavePartition(loadInfos...)
err := m.catalog.SavePartition(ctx, loadInfos...)
if err != nil {
return err
}
@ -543,7 +543,7 @@ func (m *CollectionManager) putPartition(partitions []*Partition, withSave bool)
return nil
}
func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int32) (int32, error) {
func (m *CollectionManager) UpdateLoadPercent(ctx context.Context, partitionID int64, loadPercent int32) (int32, error) {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
@ -565,7 +565,7 @@ func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int
metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(elapsed.Milliseconds()))
eventlog.Record(eventlog.NewRawEvt(eventlog.Level_Info, fmt.Sprintf("Partition %d loaded", partitionID)))
}
err := m.putPartition([]*Partition{newPartition}, savePartition)
err := m.putPartition(ctx, []*Partition{newPartition}, savePartition)
if err != nil {
return 0, err
}
@ -595,17 +595,17 @@ func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int
metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(elapsed.Milliseconds()))
eventlog.Record(eventlog.NewRawEvt(eventlog.Level_Info, fmt.Sprintf("Collection %d loaded", newCollection.CollectionID)))
}
return collectionPercent, m.putCollection(saveCollection, newCollection)
return collectionPercent, m.putCollection(ctx, saveCollection, newCollection)
}
// RemoveCollection removes collection and its partitions.
func (m *CollectionManager) RemoveCollection(collectionID typeutil.UniqueID) error {
func (m *CollectionManager) RemoveCollection(ctx context.Context, collectionID typeutil.UniqueID) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
_, ok := m.collections[collectionID]
if ok {
err := m.catalog.ReleaseCollection(collectionID)
err := m.catalog.ReleaseCollection(ctx, collectionID)
if err != nil {
return err
}
@ -619,7 +619,7 @@ func (m *CollectionManager) RemoveCollection(collectionID typeutil.UniqueID) err
return nil
}
func (m *CollectionManager) RemovePartition(collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error {
func (m *CollectionManager) RemovePartition(ctx context.Context, collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error {
if len(partitionIDs) == 0 {
return nil
}
@ -627,11 +627,11 @@ func (m *CollectionManager) RemovePartition(collectionID typeutil.UniqueID, part
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
return m.removePartition(collectionID, partitionIDs...)
return m.removePartition(ctx, collectionID, partitionIDs...)
}
func (m *CollectionManager) removePartition(collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error {
err := m.catalog.ReleasePartition(collectionID, partitionIDs...)
func (m *CollectionManager) removePartition(ctx context.Context, collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error {
err := m.catalog.ReleasePartition(ctx, collectionID, partitionIDs...)
if err != nil {
return err
}
@ -644,7 +644,7 @@ func (m *CollectionManager) removePartition(collectionID typeutil.UniqueID, part
return nil
}
func (m *CollectionManager) UpdateReplicaNumber(collectionID typeutil.UniqueID, replicaNumber int32) error {
func (m *CollectionManager) UpdateReplicaNumber(ctx context.Context, collectionID typeutil.UniqueID, replicaNumber int32) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
@ -663,5 +663,5 @@ func (m *CollectionManager) UpdateReplicaNumber(collectionID typeutil.UniqueID,
newPartitions = append(newPartitions, newPartition)
}
return m.putCollection(true, newCollection, newPartitions...)
return m.putCollection(ctx, true, newCollection, newPartitions...)
}

View File

@ -17,6 +17,7 @@
package meta
import (
"context"
"sort"
"testing"
"time"
@ -59,6 +60,8 @@ type CollectionManagerSuite struct {
// Test object
mgr *CollectionManager
ctx context.Context
}
func (suite *CollectionManagerSuite) SetupSuite() {
@ -85,6 +88,7 @@ func (suite *CollectionManagerSuite) SetupSuite() {
102: {100, 100, 100},
103: {},
}
suite.ctx = context.Background()
}
func (suite *CollectionManagerSuite) SetupTest() {
@ -113,12 +117,13 @@ func (suite *CollectionManagerSuite) TearDownTest() {
func (suite *CollectionManagerSuite) TestGetProperty() {
mgr := suite.mgr
ctx := suite.ctx
for i, collection := range suite.collections {
loadType := mgr.GetLoadType(collection)
replicaNumber := mgr.GetReplicaNumber(collection)
percentage := mgr.CalculateLoadPercentage(collection)
exist := mgr.Exist(collection)
loadType := mgr.GetLoadType(ctx, collection)
replicaNumber := mgr.GetReplicaNumber(ctx, collection)
percentage := mgr.CalculateLoadPercentage(ctx, collection)
exist := mgr.Exist(ctx, collection)
suite.Equal(suite.loadTypes[i], loadType)
suite.Equal(suite.replicaNumber[i], replicaNumber)
suite.Equal(suite.colLoadPercent[i], percentage)
@ -126,10 +131,10 @@ func (suite *CollectionManagerSuite) TestGetProperty() {
}
invalidCollection := -1
loadType := mgr.GetLoadType(int64(invalidCollection))
replicaNumber := mgr.GetReplicaNumber(int64(invalidCollection))
percentage := mgr.CalculateLoadPercentage(int64(invalidCollection))
exist := mgr.Exist(int64(invalidCollection))
loadType := mgr.GetLoadType(ctx, int64(invalidCollection))
replicaNumber := mgr.GetReplicaNumber(ctx, int64(invalidCollection))
percentage := mgr.CalculateLoadPercentage(ctx, int64(invalidCollection))
exist := mgr.Exist(ctx, int64(invalidCollection))
suite.Equal(querypb.LoadType_UnKnownType, loadType)
suite.EqualValues(-1, replicaNumber)
suite.EqualValues(-1, percentage)
@ -138,6 +143,7 @@ func (suite *CollectionManagerSuite) TestGetProperty() {
func (suite *CollectionManagerSuite) TestPut() {
suite.releaseAll()
ctx := suite.ctx
// test put collection with partitions
for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded
@ -167,7 +173,7 @@ func (suite *CollectionManagerSuite) TestPut() {
CreatedAt: time.Now(),
}
})
err := suite.mgr.PutCollection(col, partitions...)
err := suite.mgr.PutCollection(ctx, col, partitions...)
suite.NoError(err)
}
suite.checkLoadResult()
@ -179,43 +185,44 @@ func (suite *CollectionManagerSuite) TestGet() {
func (suite *CollectionManagerSuite) TestUpdate() {
mgr := suite.mgr
ctx := suite.ctx
collections := mgr.GetAllCollections()
partitions := mgr.GetAllPartitions()
collections := mgr.GetAllCollections(ctx)
partitions := mgr.GetAllPartitions(ctx)
for _, collection := range collections {
collection := collection.Clone()
collection.LoadPercentage = 100
err := mgr.PutCollectionWithoutSave(collection)
err := mgr.PutCollectionWithoutSave(ctx, collection)
suite.NoError(err)
modified := mgr.GetCollection(collection.GetCollectionID())
modified := mgr.GetCollection(ctx, collection.GetCollectionID())
suite.Equal(collection, modified)
suite.EqualValues(100, modified.LoadPercentage)
collection.Status = querypb.LoadStatus_Loaded
err = mgr.PutCollection(collection)
err = mgr.PutCollection(ctx, collection)
suite.NoError(err)
}
for _, partition := range partitions {
partition := partition.Clone()
partition.LoadPercentage = 100
err := mgr.PutPartitionWithoutSave(partition)
err := mgr.PutPartitionWithoutSave(ctx, partition)
suite.NoError(err)
modified := mgr.GetPartition(partition.GetPartitionID())
modified := mgr.GetPartition(ctx, partition.GetPartitionID())
suite.Equal(partition, modified)
suite.EqualValues(100, modified.LoadPercentage)
partition.Status = querypb.LoadStatus_Loaded
err = mgr.PutPartition(partition)
err = mgr.PutPartition(ctx, partition)
suite.NoError(err)
}
suite.clearMemory()
err := mgr.Recover(suite.broker)
err := mgr.Recover(ctx, suite.broker)
suite.NoError(err)
collections = mgr.GetAllCollections()
partitions = mgr.GetAllPartitions()
collections = mgr.GetAllCollections(ctx)
partitions = mgr.GetAllPartitions(ctx)
for _, collection := range collections {
suite.Equal(querypb.LoadStatus_Loaded, collection.GetStatus())
}
@ -226,7 +233,8 @@ func (suite *CollectionManagerSuite) TestUpdate() {
func (suite *CollectionManagerSuite) TestGetFieldIndex() {
mgr := suite.mgr
mgr.PutCollection(&Collection{
ctx := suite.ctx
mgr.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: 1,
ReplicaNumber: 1,
@ -237,7 +245,7 @@ func (suite *CollectionManagerSuite) TestGetFieldIndex() {
LoadPercentage: 0,
CreatedAt: time.Now(),
})
indexID := mgr.GetFieldIndex(1)
indexID := mgr.GetFieldIndex(ctx, 1)
suite.Len(indexID, 2)
suite.Contains(indexID, int64(1))
suite.Contains(indexID, int64(2))
@ -245,14 +253,15 @@ func (suite *CollectionManagerSuite) TestGetFieldIndex() {
func (suite *CollectionManagerSuite) TestRemove() {
mgr := suite.mgr
ctx := suite.ctx
// Remove collections/partitions
for i, collectionID := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
err := mgr.RemoveCollection(collectionID)
err := mgr.RemoveCollection(ctx, collectionID)
suite.NoError(err)
} else {
err := mgr.RemovePartition(collectionID, suite.partitions[collectionID]...)
err := mgr.RemovePartition(ctx, collectionID, suite.partitions[collectionID]...)
suite.NoError(err)
}
}
@ -260,23 +269,23 @@ func (suite *CollectionManagerSuite) TestRemove() {
// Try to get the removed items
for i, collectionID := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
collection := mgr.GetCollection(collectionID)
collection := mgr.GetCollection(ctx, collectionID)
suite.Nil(collection)
} else {
partitions := mgr.GetPartitionsByCollection(collectionID)
partitions := mgr.GetPartitionsByCollection(ctx, collectionID)
suite.Empty(partitions)
}
}
// Make sure the removes applied to meta store
err := mgr.Recover(suite.broker)
err := mgr.Recover(ctx, suite.broker)
suite.NoError(err)
for i, collectionID := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
collection := mgr.GetCollection(collectionID)
collection := mgr.GetCollection(ctx, collectionID)
suite.Nil(collection)
} else {
partitions := mgr.GetPartitionsByCollection(collectionID)
partitions := mgr.GetPartitionsByCollection(ctx, collectionID)
suite.Empty(partitions)
}
}
@ -285,9 +294,9 @@ func (suite *CollectionManagerSuite) TestRemove() {
suite.loadAll()
for i, collectionID := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadPartition {
err := mgr.RemoveCollection(collectionID)
err := mgr.RemoveCollection(ctx, collectionID)
suite.NoError(err)
partitions := mgr.GetPartitionsByCollection(collectionID)
partitions := mgr.GetPartitionsByCollection(ctx, collectionID)
suite.Empty(partitions)
}
}
@ -296,27 +305,28 @@ func (suite *CollectionManagerSuite) TestRemove() {
suite.releaseAll()
suite.loadAll()
for _, collectionID := range suite.collections {
err := mgr.RemoveCollection(collectionID)
err := mgr.RemoveCollection(ctx, collectionID)
suite.NoError(err)
err = mgr.Recover(suite.broker)
err = mgr.Recover(ctx, suite.broker)
suite.NoError(err)
collection := mgr.GetCollection(collectionID)
collection := mgr.GetCollection(ctx, collectionID)
suite.Nil(collection)
partitions := mgr.GetPartitionsByCollection(collectionID)
partitions := mgr.GetPartitionsByCollection(ctx, collectionID)
suite.Empty(partitions)
}
}
func (suite *CollectionManagerSuite) TestRecover_normal() {
mgr := suite.mgr
ctx := suite.ctx
suite.clearMemory()
err := mgr.Recover(suite.broker)
err := mgr.Recover(ctx, suite.broker)
suite.NoError(err)
for _, collection := range suite.collections {
suite.True(mgr.Exist(collection))
suite.True(mgr.Exist(ctx, collection))
for _, partitionID := range suite.partitions[collection] {
partition := mgr.GetPartition(partitionID)
partition := mgr.GetPartition(ctx, partitionID)
suite.NotNil(partition)
}
}
@ -325,6 +335,7 @@ func (suite *CollectionManagerSuite) TestRecover_normal() {
func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() {
mgr := suite.mgr
suite.releaseAll()
ctx := suite.ctx
// test put collection with partitions
for i, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
@ -350,20 +361,20 @@ func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() {
CreatedAt: time.Now(),
}
})
err := suite.mgr.PutCollection(col, partitions...)
err := suite.mgr.PutCollection(ctx, col, partitions...)
suite.NoError(err)
}
// recover for first time, expected recover success
suite.clearMemory()
err := mgr.Recover(suite.broker)
err := mgr.Recover(ctx, suite.broker)
suite.NoError(err)
for _, collectionID := range suite.collections {
collection := mgr.GetCollection(collectionID)
collection := mgr.GetCollection(ctx, collectionID)
suite.NotNil(collection)
suite.Equal(int32(1), collection.GetRecoverTimes())
for _, partitionID := range suite.partitions[collectionID] {
partition := mgr.GetPartition(partitionID)
partition := mgr.GetPartition(ctx, partitionID)
suite.NotNil(partition)
suite.Equal(int32(1), partition.GetRecoverTimes())
}
@ -372,18 +383,18 @@ func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() {
// update load percent, then recover for second time
for _, collectionID := range suite.collections {
for _, partitionID := range suite.partitions[collectionID] {
mgr.UpdateLoadPercent(partitionID, 10)
mgr.UpdateLoadPercent(ctx, partitionID, 10)
}
}
suite.clearMemory()
err = mgr.Recover(suite.broker)
err = mgr.Recover(ctx, suite.broker)
suite.NoError(err)
for _, collectionID := range suite.collections {
collection := mgr.GetCollection(collectionID)
collection := mgr.GetCollection(ctx, collectionID)
suite.NotNil(collection)
suite.Equal(int32(2), collection.GetRecoverTimes())
for _, partitionID := range suite.partitions[collectionID] {
partition := mgr.GetPartition(partitionID)
partition := mgr.GetPartition(ctx, partitionID)
suite.NotNil(partition)
suite.Equal(int32(2), partition.GetRecoverTimes())
}
@ -393,14 +404,14 @@ func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() {
for i := 0; i < int(paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32()); i++ {
log.Info("stupid", zap.Int("count", i))
suite.clearMemory()
err = mgr.Recover(suite.broker)
err = mgr.Recover(ctx, suite.broker)
suite.NoError(err)
}
for _, collectionID := range suite.collections {
collection := mgr.GetCollection(collectionID)
collection := mgr.GetCollection(ctx, collectionID)
suite.Nil(collection)
for _, partitionID := range suite.partitions[collectionID] {
partition := mgr.GetPartition(partitionID)
partition := mgr.GetPartition(ctx, partitionID)
suite.Nil(partition)
}
}
@ -408,7 +419,8 @@ func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() {
func (suite *CollectionManagerSuite) TestUpdateLoadPercentage() {
mgr := suite.mgr
mgr.PutCollection(&Collection{
ctx := suite.ctx
mgr.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: 1,
ReplicaNumber: 1,
@ -421,7 +433,7 @@ func (suite *CollectionManagerSuite) TestUpdateLoadPercentage() {
partitions := []int64{1, 2}
for _, partition := range partitions {
mgr.PutPartition(&Partition{
mgr.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: 1,
PartitionID: partition,
@ -432,42 +444,43 @@ func (suite *CollectionManagerSuite) TestUpdateLoadPercentage() {
})
}
// test update partition load percentage
mgr.UpdateLoadPercent(1, 30)
partition := mgr.GetPartition(1)
mgr.UpdateLoadPercent(ctx, 1, 30)
partition := mgr.GetPartition(ctx, 1)
suite.Equal(int32(30), partition.LoadPercentage)
suite.Equal(int32(30), mgr.GetPartitionLoadPercentage(partition.PartitionID))
suite.Equal(int32(30), mgr.GetPartitionLoadPercentage(ctx, partition.PartitionID))
suite.Equal(querypb.LoadStatus_Loading, partition.Status)
collection := mgr.GetCollection(1)
collection := mgr.GetCollection(ctx, 1)
suite.Equal(int32(15), collection.LoadPercentage)
suite.Equal(querypb.LoadStatus_Loading, collection.Status)
// test update partition load percentage to 100
mgr.UpdateLoadPercent(1, 100)
partition = mgr.GetPartition(1)
mgr.UpdateLoadPercent(ctx, 1, 100)
partition = mgr.GetPartition(ctx, 1)
suite.Equal(int32(100), partition.LoadPercentage)
suite.Equal(querypb.LoadStatus_Loaded, partition.Status)
collection = mgr.GetCollection(1)
collection = mgr.GetCollection(ctx, 1)
suite.Equal(int32(50), collection.LoadPercentage)
suite.Equal(querypb.LoadStatus_Loading, collection.Status)
// test update collection load percentage
mgr.UpdateLoadPercent(2, 100)
partition = mgr.GetPartition(1)
mgr.UpdateLoadPercent(ctx, 2, 100)
partition = mgr.GetPartition(ctx, 1)
suite.Equal(int32(100), partition.LoadPercentage)
suite.Equal(querypb.LoadStatus_Loaded, partition.Status)
collection = mgr.GetCollection(1)
collection = mgr.GetCollection(ctx, 1)
suite.Equal(int32(100), collection.LoadPercentage)
suite.Equal(querypb.LoadStatus_Loaded, collection.Status)
suite.Equal(querypb.LoadStatus_Loaded, mgr.CalculateLoadStatus(collection.CollectionID))
suite.Equal(querypb.LoadStatus_Loaded, mgr.CalculateLoadStatus(ctx, collection.CollectionID))
}
func (suite *CollectionManagerSuite) TestUpgradeRecover() {
suite.releaseAll()
mgr := suite.mgr
ctx := suite.ctx
// put old version of collections and partitions
for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded
if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
mgr.PutCollection(&Collection{
mgr.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i],
@ -479,7 +492,7 @@ func (suite *CollectionManagerSuite) TestUpgradeRecover() {
})
} else {
for _, partition := range suite.partitions[collection] {
mgr.PutPartition(&Partition{
mgr.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
@ -513,12 +526,12 @@ func (suite *CollectionManagerSuite) TestUpgradeRecover() {
// do recovery
suite.clearMemory()
err := mgr.Recover(suite.broker)
err := mgr.Recover(ctx, suite.broker)
suite.NoError(err)
suite.checkLoadResult()
for i, collection := range suite.collections {
newColl := mgr.GetCollection(collection)
newColl := mgr.GetCollection(ctx, collection)
suite.Equal(suite.loadTypes[i], newColl.GetLoadType())
}
}
@ -526,10 +539,11 @@ func (suite *CollectionManagerSuite) TestUpgradeRecover() {
func (suite *CollectionManagerSuite) TestUpgradeLoadFields() {
suite.releaseAll()
mgr := suite.mgr
ctx := suite.ctx
// put old version of collections and partitions
for i, collection := range suite.collections {
mgr.PutCollection(&Collection{
mgr.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i],
@ -541,7 +555,7 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFields() {
CreatedAt: time.Now(),
})
for j, partition := range suite.partitions[collection] {
mgr.PutPartition(&Partition{
mgr.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
@ -570,12 +584,12 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFields() {
// do recovery
suite.clearMemory()
err := mgr.Recover(suite.broker)
err := mgr.Recover(ctx, suite.broker)
suite.NoError(err)
suite.checkLoadResult()
for _, collection := range suite.collections {
newColl := mgr.GetCollection(collection)
newColl := mgr.GetCollection(ctx, collection)
suite.ElementsMatch([]int64{100, 101}, newColl.GetLoadFields())
}
}
@ -584,8 +598,9 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() {
suite.Run("normal_error", func() {
suite.releaseAll()
mgr := suite.mgr
ctx := suite.ctx
mgr.PutCollection(&Collection{
mgr.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: 100,
ReplicaNumber: 1,
@ -596,7 +611,7 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() {
LoadPercentage: 100,
CreatedAt: time.Now(),
})
mgr.PutPartition(&Partition{
mgr.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: 100,
PartitionID: 1000,
@ -609,15 +624,16 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() {
suite.broker.EXPECT().DescribeCollection(mock.Anything, int64(100)).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
// do recovery
suite.clearMemory()
err := mgr.Recover(suite.broker)
err := mgr.Recover(ctx, suite.broker)
suite.Error(err)
})
suite.Run("normal_error", func() {
suite.releaseAll()
mgr := suite.mgr
ctx := suite.ctx
mgr.PutCollection(&Collection{
mgr.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: 100,
ReplicaNumber: 1,
@ -628,7 +644,7 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() {
LoadPercentage: 100,
CreatedAt: time.Now(),
})
mgr.PutPartition(&Partition{
mgr.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: 100,
PartitionID: 1000,
@ -643,13 +659,14 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() {
}, nil).Once()
// do recovery
suite.clearMemory()
err := mgr.Recover(suite.broker)
err := mgr.Recover(ctx, suite.broker)
suite.NoError(err)
})
}
func (suite *CollectionManagerSuite) loadAll() {
mgr := suite.mgr
ctx := suite.ctx
for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded
@ -657,7 +674,7 @@ func (suite *CollectionManagerSuite) loadAll() {
status = querypb.LoadStatus_Loading
}
mgr.PutCollection(&Collection{
mgr.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i],
@ -670,7 +687,7 @@ func (suite *CollectionManagerSuite) loadAll() {
})
for j, partition := range suite.partitions[collection] {
mgr.PutPartition(&Partition{
mgr.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
@ -685,18 +702,19 @@ func (suite *CollectionManagerSuite) loadAll() {
func (suite *CollectionManagerSuite) checkLoadResult() {
mgr := suite.mgr
ctx := suite.ctx
allCollections := mgr.GetAllCollections()
allPartitions := mgr.GetAllPartitions()
allCollections := mgr.GetAllCollections(ctx)
allPartitions := mgr.GetAllPartitions(ctx)
for _, collectionID := range suite.collections {
collection := mgr.GetCollection(collectionID)
collection := mgr.GetCollection(ctx, collectionID)
suite.Equal(collectionID, collection.GetCollectionID())
suite.Contains(allCollections, collection)
partitions := mgr.GetPartitionsByCollection(collectionID)
partitions := mgr.GetPartitionsByCollection(ctx, collectionID)
suite.Len(partitions, len(suite.partitions[collectionID]))
for _, partitionID := range suite.partitions[collectionID] {
partition := mgr.GetPartition(partitionID)
partition := mgr.GetPartition(ctx, partitionID)
suite.Equal(collectionID, partition.GetCollectionID())
suite.Equal(partitionID, partition.GetPartitionID())
suite.Contains(partitions, partition)
@ -704,14 +722,14 @@ func (suite *CollectionManagerSuite) checkLoadResult() {
}
}
all := mgr.GetAll()
all := mgr.GetAll(ctx)
sort.Slice(all, func(i, j int) bool { return all[i] < all[j] })
suite.Equal(suite.collections, all)
}
func (suite *CollectionManagerSuite) releaseAll() {
for _, collection := range suite.collections {
err := suite.mgr.RemoveCollection(collection)
err := suite.mgr.RemoveCollection(context.TODO(), collection)
suite.NoError(err)
}
}

View File

@ -3,6 +3,8 @@
package meta
import (
context "context"
metastore "github.com/milvus-io/milvus/internal/metastore"
datapb "github.com/milvus-io/milvus/internal/proto/datapb"
@ -24,17 +26,17 @@ func (_m *MockTargetManager) EXPECT() *MockTargetManager_Expecter {
return &MockTargetManager_Expecter{mock: &_m.Mock}
}
// CanSegmentBeMoved provides a mock function with given fields: collectionID, segmentID
func (_m *MockTargetManager) CanSegmentBeMoved(collectionID int64, segmentID int64) bool {
ret := _m.Called(collectionID, segmentID)
// CanSegmentBeMoved provides a mock function with given fields: ctx, collectionID, segmentID
func (_m *MockTargetManager) CanSegmentBeMoved(ctx context.Context, collectionID int64, segmentID int64) bool {
ret := _m.Called(ctx, collectionID, segmentID)
if len(ret) == 0 {
panic("no return value specified for CanSegmentBeMoved")
}
var r0 bool
if rf, ok := ret.Get(0).(func(int64, int64) bool); ok {
r0 = rf(collectionID, segmentID)
if rf, ok := ret.Get(0).(func(context.Context, int64, int64) bool); ok {
r0 = rf(ctx, collectionID, segmentID)
} else {
r0 = ret.Get(0).(bool)
}
@ -48,15 +50,16 @@ type MockTargetManager_CanSegmentBeMoved_Call struct {
}
// CanSegmentBeMoved is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - segmentID int64
func (_e *MockTargetManager_Expecter) CanSegmentBeMoved(collectionID interface{}, segmentID interface{}) *MockTargetManager_CanSegmentBeMoved_Call {
return &MockTargetManager_CanSegmentBeMoved_Call{Call: _e.mock.On("CanSegmentBeMoved", collectionID, segmentID)}
func (_e *MockTargetManager_Expecter) CanSegmentBeMoved(ctx interface{}, collectionID interface{}, segmentID interface{}) *MockTargetManager_CanSegmentBeMoved_Call {
return &MockTargetManager_CanSegmentBeMoved_Call{Call: _e.mock.On("CanSegmentBeMoved", ctx, collectionID, segmentID)}
}
func (_c *MockTargetManager_CanSegmentBeMoved_Call) Run(run func(collectionID int64, segmentID int64)) *MockTargetManager_CanSegmentBeMoved_Call {
func (_c *MockTargetManager_CanSegmentBeMoved_Call) Run(run func(ctx context.Context, collectionID int64, segmentID int64)) *MockTargetManager_CanSegmentBeMoved_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int64))
run(args[0].(context.Context), args[1].(int64), args[2].(int64))
})
return _c
}
@ -66,22 +69,22 @@ func (_c *MockTargetManager_CanSegmentBeMoved_Call) Return(_a0 bool) *MockTarget
return _c
}
func (_c *MockTargetManager_CanSegmentBeMoved_Call) RunAndReturn(run func(int64, int64) bool) *MockTargetManager_CanSegmentBeMoved_Call {
func (_c *MockTargetManager_CanSegmentBeMoved_Call) RunAndReturn(run func(context.Context, int64, int64) bool) *MockTargetManager_CanSegmentBeMoved_Call {
_c.Call.Return(run)
return _c
}
// GetCollectionTargetVersion provides a mock function with given fields: collectionID, scope
func (_m *MockTargetManager) GetCollectionTargetVersion(collectionID int64, scope int32) int64 {
ret := _m.Called(collectionID, scope)
// GetCollectionTargetVersion provides a mock function with given fields: ctx, collectionID, scope
func (_m *MockTargetManager) GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope int32) int64 {
ret := _m.Called(ctx, collectionID, scope)
if len(ret) == 0 {
panic("no return value specified for GetCollectionTargetVersion")
}
var r0 int64
if rf, ok := ret.Get(0).(func(int64, int32) int64); ok {
r0 = rf(collectionID, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, int32) int64); ok {
r0 = rf(ctx, collectionID, scope)
} else {
r0 = ret.Get(0).(int64)
}
@ -95,15 +98,16 @@ type MockTargetManager_GetCollectionTargetVersion_Call struct {
}
// GetCollectionTargetVersion is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - scope int32
func (_e *MockTargetManager_Expecter) GetCollectionTargetVersion(collectionID interface{}, scope interface{}) *MockTargetManager_GetCollectionTargetVersion_Call {
return &MockTargetManager_GetCollectionTargetVersion_Call{Call: _e.mock.On("GetCollectionTargetVersion", collectionID, scope)}
func (_e *MockTargetManager_Expecter) GetCollectionTargetVersion(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetCollectionTargetVersion_Call {
return &MockTargetManager_GetCollectionTargetVersion_Call{Call: _e.mock.On("GetCollectionTargetVersion", ctx, collectionID, scope)}
}
func (_c *MockTargetManager_GetCollectionTargetVersion_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetCollectionTargetVersion_Call {
func (_c *MockTargetManager_GetCollectionTargetVersion_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetCollectionTargetVersion_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(int32))
})
return _c
}
@ -113,22 +117,22 @@ func (_c *MockTargetManager_GetCollectionTargetVersion_Call) Return(_a0 int64) *
return _c
}
func (_c *MockTargetManager_GetCollectionTargetVersion_Call) RunAndReturn(run func(int64, int32) int64) *MockTargetManager_GetCollectionTargetVersion_Call {
func (_c *MockTargetManager_GetCollectionTargetVersion_Call) RunAndReturn(run func(context.Context, int64, int32) int64) *MockTargetManager_GetCollectionTargetVersion_Call {
_c.Call.Return(run)
return _c
}
// GetDmChannel provides a mock function with given fields: collectionID, channel, scope
func (_m *MockTargetManager) GetDmChannel(collectionID int64, channel string, scope int32) *DmChannel {
ret := _m.Called(collectionID, channel, scope)
// GetDmChannel provides a mock function with given fields: ctx, collectionID, channel, scope
func (_m *MockTargetManager) GetDmChannel(ctx context.Context, collectionID int64, channel string, scope int32) *DmChannel {
ret := _m.Called(ctx, collectionID, channel, scope)
if len(ret) == 0 {
panic("no return value specified for GetDmChannel")
}
var r0 *DmChannel
if rf, ok := ret.Get(0).(func(int64, string, int32) *DmChannel); ok {
r0 = rf(collectionID, channel, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, string, int32) *DmChannel); ok {
r0 = rf(ctx, collectionID, channel, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*DmChannel)
@ -144,16 +148,17 @@ type MockTargetManager_GetDmChannel_Call struct {
}
// GetDmChannel is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - channel string
// - scope int32
func (_e *MockTargetManager_Expecter) GetDmChannel(collectionID interface{}, channel interface{}, scope interface{}) *MockTargetManager_GetDmChannel_Call {
return &MockTargetManager_GetDmChannel_Call{Call: _e.mock.On("GetDmChannel", collectionID, channel, scope)}
func (_e *MockTargetManager_Expecter) GetDmChannel(ctx interface{}, collectionID interface{}, channel interface{}, scope interface{}) *MockTargetManager_GetDmChannel_Call {
return &MockTargetManager_GetDmChannel_Call{Call: _e.mock.On("GetDmChannel", ctx, collectionID, channel, scope)}
}
func (_c *MockTargetManager_GetDmChannel_Call) Run(run func(collectionID int64, channel string, scope int32)) *MockTargetManager_GetDmChannel_Call {
func (_c *MockTargetManager_GetDmChannel_Call) Run(run func(ctx context.Context, collectionID int64, channel string, scope int32)) *MockTargetManager_GetDmChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(string), args[2].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(string), args[3].(int32))
})
return _c
}
@ -163,22 +168,22 @@ func (_c *MockTargetManager_GetDmChannel_Call) Return(_a0 *DmChannel) *MockTarge
return _c
}
func (_c *MockTargetManager_GetDmChannel_Call) RunAndReturn(run func(int64, string, int32) *DmChannel) *MockTargetManager_GetDmChannel_Call {
func (_c *MockTargetManager_GetDmChannel_Call) RunAndReturn(run func(context.Context, int64, string, int32) *DmChannel) *MockTargetManager_GetDmChannel_Call {
_c.Call.Return(run)
return _c
}
// GetDmChannelsByCollection provides a mock function with given fields: collectionID, scope
func (_m *MockTargetManager) GetDmChannelsByCollection(collectionID int64, scope int32) map[string]*DmChannel {
ret := _m.Called(collectionID, scope)
// GetDmChannelsByCollection provides a mock function with given fields: ctx, collectionID, scope
func (_m *MockTargetManager) GetDmChannelsByCollection(ctx context.Context, collectionID int64, scope int32) map[string]*DmChannel {
ret := _m.Called(ctx, collectionID, scope)
if len(ret) == 0 {
panic("no return value specified for GetDmChannelsByCollection")
}
var r0 map[string]*DmChannel
if rf, ok := ret.Get(0).(func(int64, int32) map[string]*DmChannel); ok {
r0 = rf(collectionID, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, int32) map[string]*DmChannel); ok {
r0 = rf(ctx, collectionID, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]*DmChannel)
@ -194,15 +199,16 @@ type MockTargetManager_GetDmChannelsByCollection_Call struct {
}
// GetDmChannelsByCollection is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - scope int32
func (_e *MockTargetManager_Expecter) GetDmChannelsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetDmChannelsByCollection_Call {
return &MockTargetManager_GetDmChannelsByCollection_Call{Call: _e.mock.On("GetDmChannelsByCollection", collectionID, scope)}
func (_e *MockTargetManager_Expecter) GetDmChannelsByCollection(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetDmChannelsByCollection_Call {
return &MockTargetManager_GetDmChannelsByCollection_Call{Call: _e.mock.On("GetDmChannelsByCollection", ctx, collectionID, scope)}
}
func (_c *MockTargetManager_GetDmChannelsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetDmChannelsByCollection_Call {
func (_c *MockTargetManager_GetDmChannelsByCollection_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetDmChannelsByCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(int32))
})
return _c
}
@ -212,22 +218,22 @@ func (_c *MockTargetManager_GetDmChannelsByCollection_Call) Return(_a0 map[strin
return _c
}
func (_c *MockTargetManager_GetDmChannelsByCollection_Call) RunAndReturn(run func(int64, int32) map[string]*DmChannel) *MockTargetManager_GetDmChannelsByCollection_Call {
func (_c *MockTargetManager_GetDmChannelsByCollection_Call) RunAndReturn(run func(context.Context, int64, int32) map[string]*DmChannel) *MockTargetManager_GetDmChannelsByCollection_Call {
_c.Call.Return(run)
return _c
}
// GetDroppedSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope
func (_m *MockTargetManager) GetDroppedSegmentsByChannel(collectionID int64, channelName string, scope int32) []int64 {
ret := _m.Called(collectionID, channelName, scope)
// GetDroppedSegmentsByChannel provides a mock function with given fields: ctx, collectionID, channelName, scope
func (_m *MockTargetManager) GetDroppedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope int32) []int64 {
ret := _m.Called(ctx, collectionID, channelName, scope)
if len(ret) == 0 {
panic("no return value specified for GetDroppedSegmentsByChannel")
}
var r0 []int64
if rf, ok := ret.Get(0).(func(int64, string, int32) []int64); ok {
r0 = rf(collectionID, channelName, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, string, int32) []int64); ok {
r0 = rf(ctx, collectionID, channelName, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int64)
@ -243,16 +249,17 @@ type MockTargetManager_GetDroppedSegmentsByChannel_Call struct {
}
// GetDroppedSegmentsByChannel is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - channelName string
// - scope int32
func (_e *MockTargetManager_Expecter) GetDroppedSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetDroppedSegmentsByChannel_Call {
return &MockTargetManager_GetDroppedSegmentsByChannel_Call{Call: _e.mock.On("GetDroppedSegmentsByChannel", collectionID, channelName, scope)}
func (_e *MockTargetManager_Expecter) GetDroppedSegmentsByChannel(ctx interface{}, collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetDroppedSegmentsByChannel_Call {
return &MockTargetManager_GetDroppedSegmentsByChannel_Call{Call: _e.mock.On("GetDroppedSegmentsByChannel", ctx, collectionID, channelName, scope)}
}
func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetDroppedSegmentsByChannel_Call {
func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) Run(run func(ctx context.Context, collectionID int64, channelName string, scope int32)) *MockTargetManager_GetDroppedSegmentsByChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(string), args[2].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(string), args[3].(int32))
})
return _c
}
@ -262,22 +269,22 @@ func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) Return(_a0 []int64
return _c
}
func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) []int64) *MockTargetManager_GetDroppedSegmentsByChannel_Call {
func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) RunAndReturn(run func(context.Context, int64, string, int32) []int64) *MockTargetManager_GetDroppedSegmentsByChannel_Call {
_c.Call.Return(run)
return _c
}
// GetGrowingSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope
func (_m *MockTargetManager) GetGrowingSegmentsByChannel(collectionID int64, channelName string, scope int32) typeutil.Set[int64] {
ret := _m.Called(collectionID, channelName, scope)
// GetGrowingSegmentsByChannel provides a mock function with given fields: ctx, collectionID, channelName, scope
func (_m *MockTargetManager) GetGrowingSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope int32) typeutil.Set[int64] {
ret := _m.Called(ctx, collectionID, channelName, scope)
if len(ret) == 0 {
panic("no return value specified for GetGrowingSegmentsByChannel")
}
var r0 typeutil.Set[int64]
if rf, ok := ret.Get(0).(func(int64, string, int32) typeutil.Set[int64]); ok {
r0 = rf(collectionID, channelName, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, string, int32) typeutil.Set[int64]); ok {
r0 = rf(ctx, collectionID, channelName, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(typeutil.Set[int64])
@ -293,16 +300,17 @@ type MockTargetManager_GetGrowingSegmentsByChannel_Call struct {
}
// GetGrowingSegmentsByChannel is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - channelName string
// - scope int32
func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByChannel_Call {
return &MockTargetManager_GetGrowingSegmentsByChannel_Call{Call: _e.mock.On("GetGrowingSegmentsByChannel", collectionID, channelName, scope)}
func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByChannel(ctx interface{}, collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByChannel_Call {
return &MockTargetManager_GetGrowingSegmentsByChannel_Call{Call: _e.mock.On("GetGrowingSegmentsByChannel", ctx, collectionID, channelName, scope)}
}
func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetGrowingSegmentsByChannel_Call {
func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) Run(run func(ctx context.Context, collectionID int64, channelName string, scope int32)) *MockTargetManager_GetGrowingSegmentsByChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(string), args[2].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(string), args[3].(int32))
})
return _c
}
@ -312,22 +320,22 @@ func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) Return(_a0 typeuti
return _c
}
func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByChannel_Call {
func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) RunAndReturn(run func(context.Context, int64, string, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByChannel_Call {
_c.Call.Return(run)
return _c
}
// GetGrowingSegmentsByCollection provides a mock function with given fields: collectionID, scope
func (_m *MockTargetManager) GetGrowingSegmentsByCollection(collectionID int64, scope int32) typeutil.Set[int64] {
ret := _m.Called(collectionID, scope)
// GetGrowingSegmentsByCollection provides a mock function with given fields: ctx, collectionID, scope
func (_m *MockTargetManager) GetGrowingSegmentsByCollection(ctx context.Context, collectionID int64, scope int32) typeutil.Set[int64] {
ret := _m.Called(ctx, collectionID, scope)
if len(ret) == 0 {
panic("no return value specified for GetGrowingSegmentsByCollection")
}
var r0 typeutil.Set[int64]
if rf, ok := ret.Get(0).(func(int64, int32) typeutil.Set[int64]); ok {
r0 = rf(collectionID, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, int32) typeutil.Set[int64]); ok {
r0 = rf(ctx, collectionID, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(typeutil.Set[int64])
@ -343,15 +351,16 @@ type MockTargetManager_GetGrowingSegmentsByCollection_Call struct {
}
// GetGrowingSegmentsByCollection is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - scope int32
func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByCollection_Call {
return &MockTargetManager_GetGrowingSegmentsByCollection_Call{Call: _e.mock.On("GetGrowingSegmentsByCollection", collectionID, scope)}
func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByCollection(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByCollection_Call {
return &MockTargetManager_GetGrowingSegmentsByCollection_Call{Call: _e.mock.On("GetGrowingSegmentsByCollection", ctx, collectionID, scope)}
}
func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetGrowingSegmentsByCollection_Call {
func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetGrowingSegmentsByCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(int32))
})
return _c
}
@ -361,22 +370,22 @@ func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) Return(_a0 type
return _c
}
func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) RunAndReturn(run func(int64, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByCollection_Call {
func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) RunAndReturn(run func(context.Context, int64, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByCollection_Call {
_c.Call.Return(run)
return _c
}
// GetSealedSegment provides a mock function with given fields: collectionID, id, scope
func (_m *MockTargetManager) GetSealedSegment(collectionID int64, id int64, scope int32) *datapb.SegmentInfo {
ret := _m.Called(collectionID, id, scope)
// GetSealedSegment provides a mock function with given fields: ctx, collectionID, id, scope
func (_m *MockTargetManager) GetSealedSegment(ctx context.Context, collectionID int64, id int64, scope int32) *datapb.SegmentInfo {
ret := _m.Called(ctx, collectionID, id, scope)
if len(ret) == 0 {
panic("no return value specified for GetSealedSegment")
}
var r0 *datapb.SegmentInfo
if rf, ok := ret.Get(0).(func(int64, int64, int32) *datapb.SegmentInfo); ok {
r0 = rf(collectionID, id, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, int64, int32) *datapb.SegmentInfo); ok {
r0 = rf(ctx, collectionID, id, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*datapb.SegmentInfo)
@ -392,16 +401,17 @@ type MockTargetManager_GetSealedSegment_Call struct {
}
// GetSealedSegment is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - id int64
// - scope int32
func (_e *MockTargetManager_Expecter) GetSealedSegment(collectionID interface{}, id interface{}, scope interface{}) *MockTargetManager_GetSealedSegment_Call {
return &MockTargetManager_GetSealedSegment_Call{Call: _e.mock.On("GetSealedSegment", collectionID, id, scope)}
func (_e *MockTargetManager_Expecter) GetSealedSegment(ctx interface{}, collectionID interface{}, id interface{}, scope interface{}) *MockTargetManager_GetSealedSegment_Call {
return &MockTargetManager_GetSealedSegment_Call{Call: _e.mock.On("GetSealedSegment", ctx, collectionID, id, scope)}
}
func (_c *MockTargetManager_GetSealedSegment_Call) Run(run func(collectionID int64, id int64, scope int32)) *MockTargetManager_GetSealedSegment_Call {
func (_c *MockTargetManager_GetSealedSegment_Call) Run(run func(ctx context.Context, collectionID int64, id int64, scope int32)) *MockTargetManager_GetSealedSegment_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int64), args[2].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(int64), args[3].(int32))
})
return _c
}
@ -411,22 +421,22 @@ func (_c *MockTargetManager_GetSealedSegment_Call) Return(_a0 *datapb.SegmentInf
return _c
}
func (_c *MockTargetManager_GetSealedSegment_Call) RunAndReturn(run func(int64, int64, int32) *datapb.SegmentInfo) *MockTargetManager_GetSealedSegment_Call {
func (_c *MockTargetManager_GetSealedSegment_Call) RunAndReturn(run func(context.Context, int64, int64, int32) *datapb.SegmentInfo) *MockTargetManager_GetSealedSegment_Call {
_c.Call.Return(run)
return _c
}
// GetSealedSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope
func (_m *MockTargetManager) GetSealedSegmentsByChannel(collectionID int64, channelName string, scope int32) map[int64]*datapb.SegmentInfo {
ret := _m.Called(collectionID, channelName, scope)
// GetSealedSegmentsByChannel provides a mock function with given fields: ctx, collectionID, channelName, scope
func (_m *MockTargetManager) GetSealedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope int32) map[int64]*datapb.SegmentInfo {
ret := _m.Called(ctx, collectionID, channelName, scope)
if len(ret) == 0 {
panic("no return value specified for GetSealedSegmentsByChannel")
}
var r0 map[int64]*datapb.SegmentInfo
if rf, ok := ret.Get(0).(func(int64, string, int32) map[int64]*datapb.SegmentInfo); ok {
r0 = rf(collectionID, channelName, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, string, int32) map[int64]*datapb.SegmentInfo); ok {
r0 = rf(ctx, collectionID, channelName, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo)
@ -442,16 +452,17 @@ type MockTargetManager_GetSealedSegmentsByChannel_Call struct {
}
// GetSealedSegmentsByChannel is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - channelName string
// - scope int32
func (_e *MockTargetManager_Expecter) GetSealedSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByChannel_Call {
return &MockTargetManager_GetSealedSegmentsByChannel_Call{Call: _e.mock.On("GetSealedSegmentsByChannel", collectionID, channelName, scope)}
func (_e *MockTargetManager_Expecter) GetSealedSegmentsByChannel(ctx interface{}, collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByChannel_Call {
return &MockTargetManager_GetSealedSegmentsByChannel_Call{Call: _e.mock.On("GetSealedSegmentsByChannel", ctx, collectionID, channelName, scope)}
}
func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetSealedSegmentsByChannel_Call {
func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) Run(run func(ctx context.Context, collectionID int64, channelName string, scope int32)) *MockTargetManager_GetSealedSegmentsByChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(string), args[2].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(string), args[3].(int32))
})
return _c
}
@ -461,22 +472,22 @@ func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) Return(_a0 map[int6
return _c
}
func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByChannel_Call {
func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) RunAndReturn(run func(context.Context, int64, string, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByChannel_Call {
_c.Call.Return(run)
return _c
}
// GetSealedSegmentsByCollection provides a mock function with given fields: collectionID, scope
func (_m *MockTargetManager) GetSealedSegmentsByCollection(collectionID int64, scope int32) map[int64]*datapb.SegmentInfo {
ret := _m.Called(collectionID, scope)
// GetSealedSegmentsByCollection provides a mock function with given fields: ctx, collectionID, scope
func (_m *MockTargetManager) GetSealedSegmentsByCollection(ctx context.Context, collectionID int64, scope int32) map[int64]*datapb.SegmentInfo {
ret := _m.Called(ctx, collectionID, scope)
if len(ret) == 0 {
panic("no return value specified for GetSealedSegmentsByCollection")
}
var r0 map[int64]*datapb.SegmentInfo
if rf, ok := ret.Get(0).(func(int64, int32) map[int64]*datapb.SegmentInfo); ok {
r0 = rf(collectionID, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, int32) map[int64]*datapb.SegmentInfo); ok {
r0 = rf(ctx, collectionID, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo)
@ -492,15 +503,16 @@ type MockTargetManager_GetSealedSegmentsByCollection_Call struct {
}
// GetSealedSegmentsByCollection is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - scope int32
func (_e *MockTargetManager_Expecter) GetSealedSegmentsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByCollection_Call {
return &MockTargetManager_GetSealedSegmentsByCollection_Call{Call: _e.mock.On("GetSealedSegmentsByCollection", collectionID, scope)}
func (_e *MockTargetManager_Expecter) GetSealedSegmentsByCollection(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByCollection_Call {
return &MockTargetManager_GetSealedSegmentsByCollection_Call{Call: _e.mock.On("GetSealedSegmentsByCollection", ctx, collectionID, scope)}
}
func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByCollection_Call {
func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(int32))
})
return _c
}
@ -510,22 +522,22 @@ func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) Return(_a0 map[i
return _c
}
func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) RunAndReturn(run func(int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByCollection_Call {
func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) RunAndReturn(run func(context.Context, int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByCollection_Call {
_c.Call.Return(run)
return _c
}
// GetSealedSegmentsByPartition provides a mock function with given fields: collectionID, partitionID, scope
func (_m *MockTargetManager) GetSealedSegmentsByPartition(collectionID int64, partitionID int64, scope int32) map[int64]*datapb.SegmentInfo {
ret := _m.Called(collectionID, partitionID, scope)
// GetSealedSegmentsByPartition provides a mock function with given fields: ctx, collectionID, partitionID, scope
func (_m *MockTargetManager) GetSealedSegmentsByPartition(ctx context.Context, collectionID int64, partitionID int64, scope int32) map[int64]*datapb.SegmentInfo {
ret := _m.Called(ctx, collectionID, partitionID, scope)
if len(ret) == 0 {
panic("no return value specified for GetSealedSegmentsByPartition")
}
var r0 map[int64]*datapb.SegmentInfo
if rf, ok := ret.Get(0).(func(int64, int64, int32) map[int64]*datapb.SegmentInfo); ok {
r0 = rf(collectionID, partitionID, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, int64, int32) map[int64]*datapb.SegmentInfo); ok {
r0 = rf(ctx, collectionID, partitionID, scope)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo)
@ -541,16 +553,17 @@ type MockTargetManager_GetSealedSegmentsByPartition_Call struct {
}
// GetSealedSegmentsByPartition is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - partitionID int64
// - scope int32
func (_e *MockTargetManager_Expecter) GetSealedSegmentsByPartition(collectionID interface{}, partitionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByPartition_Call {
return &MockTargetManager_GetSealedSegmentsByPartition_Call{Call: _e.mock.On("GetSealedSegmentsByPartition", collectionID, partitionID, scope)}
func (_e *MockTargetManager_Expecter) GetSealedSegmentsByPartition(ctx interface{}, collectionID interface{}, partitionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByPartition_Call {
return &MockTargetManager_GetSealedSegmentsByPartition_Call{Call: _e.mock.On("GetSealedSegmentsByPartition", ctx, collectionID, partitionID, scope)}
}
func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) Run(run func(collectionID int64, partitionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByPartition_Call {
func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) Run(run func(ctx context.Context, collectionID int64, partitionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByPartition_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int64), args[2].(int32))
run(args[0].(context.Context), args[1].(int64), args[2].(int64), args[3].(int32))
})
return _c
}
@ -560,22 +573,22 @@ func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) Return(_a0 map[in
return _c
}
func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) RunAndReturn(run func(int64, int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByPartition_Call {
func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) RunAndReturn(run func(context.Context, int64, int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByPartition_Call {
_c.Call.Return(run)
return _c
}
// GetTargetJSON provides a mock function with given fields: scope
func (_m *MockTargetManager) GetTargetJSON(scope int32) string {
ret := _m.Called(scope)
// GetTargetJSON provides a mock function with given fields: ctx, scope
func (_m *MockTargetManager) GetTargetJSON(ctx context.Context, scope int32) string {
ret := _m.Called(ctx, scope)
if len(ret) == 0 {
panic("no return value specified for GetTargetJSON")
}
var r0 string
if rf, ok := ret.Get(0).(func(int32) string); ok {
r0 = rf(scope)
if rf, ok := ret.Get(0).(func(context.Context, int32) string); ok {
r0 = rf(ctx, scope)
} else {
r0 = ret.Get(0).(string)
}
@ -589,14 +602,15 @@ type MockTargetManager_GetTargetJSON_Call struct {
}
// GetTargetJSON is a helper method to define mock.On call
// - ctx context.Context
// - scope int32
func (_e *MockTargetManager_Expecter) GetTargetJSON(scope interface{}) *MockTargetManager_GetTargetJSON_Call {
return &MockTargetManager_GetTargetJSON_Call{Call: _e.mock.On("GetTargetJSON", scope)}
func (_e *MockTargetManager_Expecter) GetTargetJSON(ctx interface{}, scope interface{}) *MockTargetManager_GetTargetJSON_Call {
return &MockTargetManager_GetTargetJSON_Call{Call: _e.mock.On("GetTargetJSON", ctx, scope)}
}
func (_c *MockTargetManager_GetTargetJSON_Call) Run(run func(scope int32)) *MockTargetManager_GetTargetJSON_Call {
func (_c *MockTargetManager_GetTargetJSON_Call) Run(run func(ctx context.Context, scope int32)) *MockTargetManager_GetTargetJSON_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int32))
run(args[0].(context.Context), args[1].(int32))
})
return _c
}
@ -606,22 +620,22 @@ func (_c *MockTargetManager_GetTargetJSON_Call) Return(_a0 string) *MockTargetMa
return _c
}
func (_c *MockTargetManager_GetTargetJSON_Call) RunAndReturn(run func(int32) string) *MockTargetManager_GetTargetJSON_Call {
func (_c *MockTargetManager_GetTargetJSON_Call) RunAndReturn(run func(context.Context, int32) string) *MockTargetManager_GetTargetJSON_Call {
_c.Call.Return(run)
return _c
}
// IsCurrentTargetExist provides a mock function with given fields: collectionID, partitionID
func (_m *MockTargetManager) IsCurrentTargetExist(collectionID int64, partitionID int64) bool {
ret := _m.Called(collectionID, partitionID)
// IsCurrentTargetExist provides a mock function with given fields: ctx, collectionID, partitionID
func (_m *MockTargetManager) IsCurrentTargetExist(ctx context.Context, collectionID int64, partitionID int64) bool {
ret := _m.Called(ctx, collectionID, partitionID)
if len(ret) == 0 {
panic("no return value specified for IsCurrentTargetExist")
}
var r0 bool
if rf, ok := ret.Get(0).(func(int64, int64) bool); ok {
r0 = rf(collectionID, partitionID)
if rf, ok := ret.Get(0).(func(context.Context, int64, int64) bool); ok {
r0 = rf(ctx, collectionID, partitionID)
} else {
r0 = ret.Get(0).(bool)
}
@ -635,15 +649,16 @@ type MockTargetManager_IsCurrentTargetExist_Call struct {
}
// IsCurrentTargetExist is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - partitionID int64
func (_e *MockTargetManager_Expecter) IsCurrentTargetExist(collectionID interface{}, partitionID interface{}) *MockTargetManager_IsCurrentTargetExist_Call {
return &MockTargetManager_IsCurrentTargetExist_Call{Call: _e.mock.On("IsCurrentTargetExist", collectionID, partitionID)}
func (_e *MockTargetManager_Expecter) IsCurrentTargetExist(ctx interface{}, collectionID interface{}, partitionID interface{}) *MockTargetManager_IsCurrentTargetExist_Call {
return &MockTargetManager_IsCurrentTargetExist_Call{Call: _e.mock.On("IsCurrentTargetExist", ctx, collectionID, partitionID)}
}
func (_c *MockTargetManager_IsCurrentTargetExist_Call) Run(run func(collectionID int64, partitionID int64)) *MockTargetManager_IsCurrentTargetExist_Call {
func (_c *MockTargetManager_IsCurrentTargetExist_Call) Run(run func(ctx context.Context, collectionID int64, partitionID int64)) *MockTargetManager_IsCurrentTargetExist_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int64))
run(args[0].(context.Context), args[1].(int64), args[2].(int64))
})
return _c
}
@ -653,22 +668,22 @@ func (_c *MockTargetManager_IsCurrentTargetExist_Call) Return(_a0 bool) *MockTar
return _c
}
func (_c *MockTargetManager_IsCurrentTargetExist_Call) RunAndReturn(run func(int64, int64) bool) *MockTargetManager_IsCurrentTargetExist_Call {
func (_c *MockTargetManager_IsCurrentTargetExist_Call) RunAndReturn(run func(context.Context, int64, int64) bool) *MockTargetManager_IsCurrentTargetExist_Call {
_c.Call.Return(run)
return _c
}
// IsNextTargetExist provides a mock function with given fields: collectionID
func (_m *MockTargetManager) IsNextTargetExist(collectionID int64) bool {
ret := _m.Called(collectionID)
// IsNextTargetExist provides a mock function with given fields: ctx, collectionID
func (_m *MockTargetManager) IsNextTargetExist(ctx context.Context, collectionID int64) bool {
ret := _m.Called(ctx, collectionID)
if len(ret) == 0 {
panic("no return value specified for IsNextTargetExist")
}
var r0 bool
if rf, ok := ret.Get(0).(func(int64) bool); ok {
r0 = rf(collectionID)
if rf, ok := ret.Get(0).(func(context.Context, int64) bool); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Get(0).(bool)
}
@ -682,14 +697,15 @@ type MockTargetManager_IsNextTargetExist_Call struct {
}
// IsNextTargetExist is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockTargetManager_Expecter) IsNextTargetExist(collectionID interface{}) *MockTargetManager_IsNextTargetExist_Call {
return &MockTargetManager_IsNextTargetExist_Call{Call: _e.mock.On("IsNextTargetExist", collectionID)}
func (_e *MockTargetManager_Expecter) IsNextTargetExist(ctx interface{}, collectionID interface{}) *MockTargetManager_IsNextTargetExist_Call {
return &MockTargetManager_IsNextTargetExist_Call{Call: _e.mock.On("IsNextTargetExist", ctx, collectionID)}
}
func (_c *MockTargetManager_IsNextTargetExist_Call) Run(run func(collectionID int64)) *MockTargetManager_IsNextTargetExist_Call {
func (_c *MockTargetManager_IsNextTargetExist_Call) Run(run func(ctx context.Context, collectionID int64)) *MockTargetManager_IsNextTargetExist_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -699,22 +715,22 @@ func (_c *MockTargetManager_IsNextTargetExist_Call) Return(_a0 bool) *MockTarget
return _c
}
func (_c *MockTargetManager_IsNextTargetExist_Call) RunAndReturn(run func(int64) bool) *MockTargetManager_IsNextTargetExist_Call {
func (_c *MockTargetManager_IsNextTargetExist_Call) RunAndReturn(run func(context.Context, int64) bool) *MockTargetManager_IsNextTargetExist_Call {
_c.Call.Return(run)
return _c
}
// Recover provides a mock function with given fields: catalog
func (_m *MockTargetManager) Recover(catalog metastore.QueryCoordCatalog) error {
ret := _m.Called(catalog)
// Recover provides a mock function with given fields: ctx, catalog
func (_m *MockTargetManager) Recover(ctx context.Context, catalog metastore.QueryCoordCatalog) error {
ret := _m.Called(ctx, catalog)
if len(ret) == 0 {
panic("no return value specified for Recover")
}
var r0 error
if rf, ok := ret.Get(0).(func(metastore.QueryCoordCatalog) error); ok {
r0 = rf(catalog)
if rf, ok := ret.Get(0).(func(context.Context, metastore.QueryCoordCatalog) error); ok {
r0 = rf(ctx, catalog)
} else {
r0 = ret.Error(0)
}
@ -728,14 +744,15 @@ type MockTargetManager_Recover_Call struct {
}
// Recover is a helper method to define mock.On call
// - ctx context.Context
// - catalog metastore.QueryCoordCatalog
func (_e *MockTargetManager_Expecter) Recover(catalog interface{}) *MockTargetManager_Recover_Call {
return &MockTargetManager_Recover_Call{Call: _e.mock.On("Recover", catalog)}
func (_e *MockTargetManager_Expecter) Recover(ctx interface{}, catalog interface{}) *MockTargetManager_Recover_Call {
return &MockTargetManager_Recover_Call{Call: _e.mock.On("Recover", ctx, catalog)}
}
func (_c *MockTargetManager_Recover_Call) Run(run func(catalog metastore.QueryCoordCatalog)) *MockTargetManager_Recover_Call {
func (_c *MockTargetManager_Recover_Call) Run(run func(ctx context.Context, catalog metastore.QueryCoordCatalog)) *MockTargetManager_Recover_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(metastore.QueryCoordCatalog))
run(args[0].(context.Context), args[1].(metastore.QueryCoordCatalog))
})
return _c
}
@ -745,14 +762,14 @@ func (_c *MockTargetManager_Recover_Call) Return(_a0 error) *MockTargetManager_R
return _c
}
func (_c *MockTargetManager_Recover_Call) RunAndReturn(run func(metastore.QueryCoordCatalog) error) *MockTargetManager_Recover_Call {
func (_c *MockTargetManager_Recover_Call) RunAndReturn(run func(context.Context, metastore.QueryCoordCatalog) error) *MockTargetManager_Recover_Call {
_c.Call.Return(run)
return _c
}
// RemoveCollection provides a mock function with given fields: collectionID
func (_m *MockTargetManager) RemoveCollection(collectionID int64) {
_m.Called(collectionID)
// RemoveCollection provides a mock function with given fields: ctx, collectionID
func (_m *MockTargetManager) RemoveCollection(ctx context.Context, collectionID int64) {
_m.Called(ctx, collectionID)
}
// MockTargetManager_RemoveCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCollection'
@ -761,14 +778,15 @@ type MockTargetManager_RemoveCollection_Call struct {
}
// RemoveCollection is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockTargetManager_Expecter) RemoveCollection(collectionID interface{}) *MockTargetManager_RemoveCollection_Call {
return &MockTargetManager_RemoveCollection_Call{Call: _e.mock.On("RemoveCollection", collectionID)}
func (_e *MockTargetManager_Expecter) RemoveCollection(ctx interface{}, collectionID interface{}) *MockTargetManager_RemoveCollection_Call {
return &MockTargetManager_RemoveCollection_Call{Call: _e.mock.On("RemoveCollection", ctx, collectionID)}
}
func (_c *MockTargetManager_RemoveCollection_Call) Run(run func(collectionID int64)) *MockTargetManager_RemoveCollection_Call {
func (_c *MockTargetManager_RemoveCollection_Call) Run(run func(ctx context.Context, collectionID int64)) *MockTargetManager_RemoveCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -778,19 +796,19 @@ func (_c *MockTargetManager_RemoveCollection_Call) Return() *MockTargetManager_R
return _c
}
func (_c *MockTargetManager_RemoveCollection_Call) RunAndReturn(run func(int64)) *MockTargetManager_RemoveCollection_Call {
func (_c *MockTargetManager_RemoveCollection_Call) RunAndReturn(run func(context.Context, int64)) *MockTargetManager_RemoveCollection_Call {
_c.Call.Return(run)
return _c
}
// RemovePartition provides a mock function with given fields: collectionID, partitionIDs
func (_m *MockTargetManager) RemovePartition(collectionID int64, partitionIDs ...int64) {
// RemovePartition provides a mock function with given fields: ctx, collectionID, partitionIDs
func (_m *MockTargetManager) RemovePartition(ctx context.Context, collectionID int64, partitionIDs ...int64) {
_va := make([]interface{}, len(partitionIDs))
for _i := range partitionIDs {
_va[_i] = partitionIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, collectionID)
_ca = append(_ca, ctx, collectionID)
_ca = append(_ca, _va...)
_m.Called(_ca...)
}
@ -801,22 +819,23 @@ type MockTargetManager_RemovePartition_Call struct {
}
// RemovePartition is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - partitionIDs ...int64
func (_e *MockTargetManager_Expecter) RemovePartition(collectionID interface{}, partitionIDs ...interface{}) *MockTargetManager_RemovePartition_Call {
func (_e *MockTargetManager_Expecter) RemovePartition(ctx interface{}, collectionID interface{}, partitionIDs ...interface{}) *MockTargetManager_RemovePartition_Call {
return &MockTargetManager_RemovePartition_Call{Call: _e.mock.On("RemovePartition",
append([]interface{}{collectionID}, partitionIDs...)...)}
append([]interface{}{ctx, collectionID}, partitionIDs...)...)}
}
func (_c *MockTargetManager_RemovePartition_Call) Run(run func(collectionID int64, partitionIDs ...int64)) *MockTargetManager_RemovePartition_Call {
func (_c *MockTargetManager_RemovePartition_Call) Run(run func(ctx context.Context, collectionID int64, partitionIDs ...int64)) *MockTargetManager_RemovePartition_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]int64, len(args)-1)
for i, a := range args[1:] {
variadicArgs := make([]int64, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(int64)
}
}
run(args[0].(int64), variadicArgs...)
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
})
return _c
}
@ -826,14 +845,14 @@ func (_c *MockTargetManager_RemovePartition_Call) Return() *MockTargetManager_Re
return _c
}
func (_c *MockTargetManager_RemovePartition_Call) RunAndReturn(run func(int64, ...int64)) *MockTargetManager_RemovePartition_Call {
func (_c *MockTargetManager_RemovePartition_Call) RunAndReturn(run func(context.Context, int64, ...int64)) *MockTargetManager_RemovePartition_Call {
_c.Call.Return(run)
return _c
}
// SaveCurrentTarget provides a mock function with given fields: catalog
func (_m *MockTargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog) {
_m.Called(catalog)
// SaveCurrentTarget provides a mock function with given fields: ctx, catalog
func (_m *MockTargetManager) SaveCurrentTarget(ctx context.Context, catalog metastore.QueryCoordCatalog) {
_m.Called(ctx, catalog)
}
// MockTargetManager_SaveCurrentTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCurrentTarget'
@ -842,14 +861,15 @@ type MockTargetManager_SaveCurrentTarget_Call struct {
}
// SaveCurrentTarget is a helper method to define mock.On call
// - ctx context.Context
// - catalog metastore.QueryCoordCatalog
func (_e *MockTargetManager_Expecter) SaveCurrentTarget(catalog interface{}) *MockTargetManager_SaveCurrentTarget_Call {
return &MockTargetManager_SaveCurrentTarget_Call{Call: _e.mock.On("SaveCurrentTarget", catalog)}
func (_e *MockTargetManager_Expecter) SaveCurrentTarget(ctx interface{}, catalog interface{}) *MockTargetManager_SaveCurrentTarget_Call {
return &MockTargetManager_SaveCurrentTarget_Call{Call: _e.mock.On("SaveCurrentTarget", ctx, catalog)}
}
func (_c *MockTargetManager_SaveCurrentTarget_Call) Run(run func(catalog metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call {
func (_c *MockTargetManager_SaveCurrentTarget_Call) Run(run func(ctx context.Context, catalog metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(metastore.QueryCoordCatalog))
run(args[0].(context.Context), args[1].(metastore.QueryCoordCatalog))
})
return _c
}
@ -859,22 +879,22 @@ func (_c *MockTargetManager_SaveCurrentTarget_Call) Return() *MockTargetManager_
return _c
}
func (_c *MockTargetManager_SaveCurrentTarget_Call) RunAndReturn(run func(metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call {
func (_c *MockTargetManager_SaveCurrentTarget_Call) RunAndReturn(run func(context.Context, metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call {
_c.Call.Return(run)
return _c
}
// UpdateCollectionCurrentTarget provides a mock function with given fields: collectionID
func (_m *MockTargetManager) UpdateCollectionCurrentTarget(collectionID int64) bool {
ret := _m.Called(collectionID)
// UpdateCollectionCurrentTarget provides a mock function with given fields: ctx, collectionID
func (_m *MockTargetManager) UpdateCollectionCurrentTarget(ctx context.Context, collectionID int64) bool {
ret := _m.Called(ctx, collectionID)
if len(ret) == 0 {
panic("no return value specified for UpdateCollectionCurrentTarget")
}
var r0 bool
if rf, ok := ret.Get(0).(func(int64) bool); ok {
r0 = rf(collectionID)
if rf, ok := ret.Get(0).(func(context.Context, int64) bool); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Get(0).(bool)
}
@ -888,14 +908,15 @@ type MockTargetManager_UpdateCollectionCurrentTarget_Call struct {
}
// UpdateCollectionCurrentTarget is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockTargetManager_Expecter) UpdateCollectionCurrentTarget(collectionID interface{}) *MockTargetManager_UpdateCollectionCurrentTarget_Call {
return &MockTargetManager_UpdateCollectionCurrentTarget_Call{Call: _e.mock.On("UpdateCollectionCurrentTarget", collectionID)}
func (_e *MockTargetManager_Expecter) UpdateCollectionCurrentTarget(ctx interface{}, collectionID interface{}) *MockTargetManager_UpdateCollectionCurrentTarget_Call {
return &MockTargetManager_UpdateCollectionCurrentTarget_Call{Call: _e.mock.On("UpdateCollectionCurrentTarget", ctx, collectionID)}
}
func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) Run(run func(collectionID int64)) *MockTargetManager_UpdateCollectionCurrentTarget_Call {
func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) Run(run func(ctx context.Context, collectionID int64)) *MockTargetManager_UpdateCollectionCurrentTarget_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -905,22 +926,22 @@ func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) Return(_a0 bool)
return _c
}
func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) RunAndReturn(run func(int64) bool) *MockTargetManager_UpdateCollectionCurrentTarget_Call {
func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) RunAndReturn(run func(context.Context, int64) bool) *MockTargetManager_UpdateCollectionCurrentTarget_Call {
_c.Call.Return(run)
return _c
}
// UpdateCollectionNextTarget provides a mock function with given fields: collectionID
func (_m *MockTargetManager) UpdateCollectionNextTarget(collectionID int64) error {
ret := _m.Called(collectionID)
// UpdateCollectionNextTarget provides a mock function with given fields: ctx, collectionID
func (_m *MockTargetManager) UpdateCollectionNextTarget(ctx context.Context, collectionID int64) error {
ret := _m.Called(ctx, collectionID)
if len(ret) == 0 {
panic("no return value specified for UpdateCollectionNextTarget")
}
var r0 error
if rf, ok := ret.Get(0).(func(int64) error); ok {
r0 = rf(collectionID)
if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Error(0)
}
@ -934,14 +955,15 @@ type MockTargetManager_UpdateCollectionNextTarget_Call struct {
}
// UpdateCollectionNextTarget is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockTargetManager_Expecter) UpdateCollectionNextTarget(collectionID interface{}) *MockTargetManager_UpdateCollectionNextTarget_Call {
return &MockTargetManager_UpdateCollectionNextTarget_Call{Call: _e.mock.On("UpdateCollectionNextTarget", collectionID)}
func (_e *MockTargetManager_Expecter) UpdateCollectionNextTarget(ctx interface{}, collectionID interface{}) *MockTargetManager_UpdateCollectionNextTarget_Call {
return &MockTargetManager_UpdateCollectionNextTarget_Call{Call: _e.mock.On("UpdateCollectionNextTarget", ctx, collectionID)}
}
func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) Run(run func(collectionID int64)) *MockTargetManager_UpdateCollectionNextTarget_Call {
func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) Run(run func(ctx context.Context, collectionID int64)) *MockTargetManager_UpdateCollectionNextTarget_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -951,7 +973,7 @@ func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) Return(_a0 error) *
return _c
}
func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) RunAndReturn(run func(int64) error) *MockTargetManager_UpdateCollectionNextTarget_Call {
func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) RunAndReturn(run func(context.Context, int64) error) *MockTargetManager_UpdateCollectionNextTarget_Call {
_c.Call.Return(run)
return _c
}

View File

@ -17,6 +17,7 @@
package meta
import (
"context"
"fmt"
"sync"
@ -78,8 +79,8 @@ func NewReplicaManager(idAllocator func() (int64, error), catalog metastore.Quer
}
// Recover recovers the replicas for given collections from meta store
func (m *ReplicaManager) Recover(collections []int64) error {
replicas, err := m.catalog.GetReplicas()
func (m *ReplicaManager) Recover(ctx context.Context, collections []int64) error {
replicas, err := m.catalog.GetReplicas(ctx)
if err != nil {
return fmt.Errorf("failed to recover replicas, err=%w", err)
}
@ -98,7 +99,7 @@ func (m *ReplicaManager) Recover(collections []int64) error {
zap.Int64s("nodes", replica.GetNodes()),
)
} else {
err := m.catalog.ReleaseReplica(replica.GetCollectionID(), replica.GetID())
err := m.catalog.ReleaseReplica(ctx, replica.GetCollectionID(), replica.GetID())
if err != nil {
return err
}
@ -114,7 +115,7 @@ func (m *ReplicaManager) Recover(collections []int64) error {
// Get returns the replica by id.
// Replica should be read-only, do not modify it.
func (m *ReplicaManager) Get(id typeutil.UniqueID) *Replica {
func (m *ReplicaManager) Get(ctx context.Context, id typeutil.UniqueID) *Replica {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -122,7 +123,7 @@ func (m *ReplicaManager) Get(id typeutil.UniqueID) *Replica {
}
// Spawn spawns N replicas at resource group for given collection in ReplicaManager.
func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int, channels []string) ([]*Replica, error) {
func (m *ReplicaManager) Spawn(ctx context.Context, collection int64, replicaNumInRG map[string]int, channels []string) ([]*Replica, error) {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
@ -151,7 +152,7 @@ func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int,
}))
}
}
if err := m.put(replicas...); err != nil {
if err := m.put(ctx, replicas...); err != nil {
return nil, err
}
return replicas, nil
@ -159,14 +160,14 @@ func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int,
// Deprecated: Warning, break the consistency of ReplicaManager,
// never use it in non-test code, use Spawn instead.
func (m *ReplicaManager) Put(replicas ...*Replica) error {
func (m *ReplicaManager) Put(ctx context.Context, replicas ...*Replica) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
return m.put(replicas...)
return m.put(ctx, replicas...)
}
func (m *ReplicaManager) put(replicas ...*Replica) error {
func (m *ReplicaManager) put(ctx context.Context, replicas ...*Replica) error {
if len(replicas) == 0 {
return nil
}
@ -175,7 +176,7 @@ func (m *ReplicaManager) put(replicas ...*Replica) error {
for _, replica := range replicas {
replicaPBs = append(replicaPBs, replica.replicaPB)
}
if err := m.catalog.SaveReplica(replicaPBs...); err != nil {
if err := m.catalog.SaveReplica(ctx, replicaPBs...); err != nil {
return err
}
@ -198,7 +199,7 @@ func (m *ReplicaManager) putReplicaInMemory(replicas ...*Replica) {
}
// TransferReplica transfers N replicas from srcRGName to dstRGName.
func (m *ReplicaManager) TransferReplica(collectionID typeutil.UniqueID, srcRGName string, dstRGName string, replicaNum int) error {
func (m *ReplicaManager) TransferReplica(ctx context.Context, collectionID typeutil.UniqueID, srcRGName string, dstRGName string, replicaNum int) error {
if srcRGName == dstRGName {
return merr.WrapErrParameterInvalidMsg("source resource group and target resource group should not be the same, resource group: %s", srcRGName)
}
@ -223,10 +224,10 @@ func (m *ReplicaManager) TransferReplica(collectionID typeutil.UniqueID, srcRGNa
mutableReplica.SetResourceGroup(dstRGName)
replicas = append(replicas, mutableReplica.IntoReplica())
}
return m.put(replicas...)
return m.put(ctx, replicas...)
}
func (m *ReplicaManager) MoveReplica(dstRGName string, toMove []*Replica) error {
func (m *ReplicaManager) MoveReplica(ctx context.Context, dstRGName string, toMove []*Replica) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
replicas := make([]*Replica, 0, len(toMove))
@ -238,7 +239,7 @@ func (m *ReplicaManager) MoveReplica(dstRGName string, toMove []*Replica) error
replicaIDs = append(replicaIDs, replica.GetID())
}
log.Info("move replicas to resource group", zap.String("dstRGName", dstRGName), zap.Int64s("replicas", replicaIDs))
return m.put(replicas...)
return m.put(ctx, replicas...)
}
// getSrcReplicasAndCheckIfTransferable checks if the collection can be transfer from srcRGName to dstRGName.
@ -267,11 +268,11 @@ func (m *ReplicaManager) getSrcReplicasAndCheckIfTransferable(collectionID typeu
// RemoveCollection removes replicas of given collection,
// returns error if failed to remove replica from KV
func (m *ReplicaManager) RemoveCollection(collectionID typeutil.UniqueID) error {
func (m *ReplicaManager) RemoveCollection(ctx context.Context, collectionID typeutil.UniqueID) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
err := m.catalog.ReleaseReplicas(collectionID)
err := m.catalog.ReleaseReplicas(ctx, collectionID)
if err != nil {
return err
}
@ -286,17 +287,17 @@ func (m *ReplicaManager) RemoveCollection(collectionID typeutil.UniqueID) error
return nil
}
func (m *ReplicaManager) RemoveReplicas(collectionID typeutil.UniqueID, replicas ...typeutil.UniqueID) error {
func (m *ReplicaManager) RemoveReplicas(ctx context.Context, collectionID typeutil.UniqueID, replicas ...typeutil.UniqueID) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
log.Info("release replicas", zap.Int64("collectionID", collectionID), zap.Int64s("replicas", replicas))
return m.removeReplicas(collectionID, replicas...)
return m.removeReplicas(ctx, collectionID, replicas...)
}
func (m *ReplicaManager) removeReplicas(collectionID typeutil.UniqueID, replicas ...typeutil.UniqueID) error {
err := m.catalog.ReleaseReplica(collectionID, replicas...)
func (m *ReplicaManager) removeReplicas(ctx context.Context, collectionID typeutil.UniqueID, replicas ...typeutil.UniqueID) error {
err := m.catalog.ReleaseReplica(ctx, collectionID, replicas...)
if err != nil {
return err
}
@ -312,7 +313,7 @@ func (m *ReplicaManager) removeReplicas(collectionID typeutil.UniqueID, replicas
return nil
}
func (m *ReplicaManager) GetByCollection(collectionID typeutil.UniqueID) []*Replica {
func (m *ReplicaManager) GetByCollection(ctx context.Context, collectionID typeutil.UniqueID) []*Replica {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return m.getByCollection(collectionID)
@ -327,7 +328,7 @@ func (m *ReplicaManager) getByCollection(collectionID typeutil.UniqueID) []*Repl
return collReplicas.replicas
}
func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID typeutil.UniqueID) *Replica {
func (m *ReplicaManager) GetByCollectionAndNode(ctx context.Context, collectionID, nodeID typeutil.UniqueID) *Replica {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -342,7 +343,7 @@ func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID typeutil.Un
return nil
}
func (m *ReplicaManager) GetByNode(nodeID typeutil.UniqueID) []*Replica {
func (m *ReplicaManager) GetByNode(ctx context.Context, nodeID typeutil.UniqueID) []*Replica {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -367,7 +368,7 @@ func (m *ReplicaManager) getByCollectionAndRG(collectionID int64, rgName string)
})
}
func (m *ReplicaManager) GetByResourceGroup(rgName string) []*Replica {
func (m *ReplicaManager) GetByResourceGroup(ctx context.Context, rgName string) []*Replica {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -386,7 +387,7 @@ func (m *ReplicaManager) GetByResourceGroup(rgName string) []*Replica {
// 1. Move the rw nodes to ro nodes if they are not in related resource group.
// 2. Add new incoming nodes into the replica if they are not in-used by other replicas of same collection.
// 3. replicas in same resource group will shared the nodes in resource group fairly.
func (m *ReplicaManager) RecoverNodesInCollection(collectionID typeutil.UniqueID, rgs map[string]typeutil.UniqueSet) error {
func (m *ReplicaManager) RecoverNodesInCollection(ctx context.Context, collectionID typeutil.UniqueID, rgs map[string]typeutil.UniqueSet) error {
if err := m.validateResourceGroups(rgs); err != nil {
return err
}
@ -427,7 +428,7 @@ func (m *ReplicaManager) RecoverNodesInCollection(collectionID typeutil.UniqueID
modifiedReplicas = append(modifiedReplicas, mutableReplica.IntoReplica())
})
})
return m.put(modifiedReplicas...)
return m.put(ctx, modifiedReplicas...)
}
// validateResourceGroups checks if the resource groups are valid.
@ -468,7 +469,7 @@ func (m *ReplicaManager) getCollectionAssignmentHelper(collectionID typeutil.Uni
}
// RemoveNode removes the node from all replicas of given collection.
func (m *ReplicaManager) RemoveNode(replicaID typeutil.UniqueID, nodes ...typeutil.UniqueID) error {
func (m *ReplicaManager) RemoveNode(ctx context.Context, replicaID typeutil.UniqueID, nodes ...typeutil.UniqueID) error {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
@ -479,11 +480,11 @@ func (m *ReplicaManager) RemoveNode(replicaID typeutil.UniqueID, nodes ...typeut
mutableReplica := replica.CopyForWrite()
mutableReplica.RemoveNode(nodes...) // ro -> unused
return m.put(mutableReplica.IntoReplica())
return m.put(ctx, mutableReplica.IntoReplica())
}
func (m *ReplicaManager) GetResourceGroupByCollection(collection typeutil.UniqueID) typeutil.Set[string] {
replicas := m.GetByCollection(collection)
func (m *ReplicaManager) GetResourceGroupByCollection(ctx context.Context, collection typeutil.UniqueID) typeutil.Set[string] {
replicas := m.GetByCollection(ctx, collection)
ret := typeutil.NewSet(lo.Map(replicas, func(r *Replica, _ int) string { return r.GetResourceGroup() })...)
return ret
}
@ -492,7 +493,7 @@ func (m *ReplicaManager) GetResourceGroupByCollection(collection typeutil.Unique
// It locks the ReplicaManager for reading, converts the replicas to their protobuf representation,
// marshals them into a JSON string, and returns the result.
// If an error occurs during marshaling, it logs a warning and returns an empty string.
func (m *ReplicaManager) GetReplicasJSON() string {
func (m *ReplicaManager) GetReplicasJSON(ctx context.Context) string {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

View File

@ -17,6 +17,7 @@
package meta
import (
"context"
"testing"
"github.com/samber/lo"
@ -62,6 +63,7 @@ type ReplicaManagerSuite struct {
kv kv.MetaKv
catalog metastore.QueryCoordCatalog
mgr *ReplicaManager
ctx context.Context
}
func (suite *ReplicaManagerSuite) SetupSuite() {
@ -86,6 +88,7 @@ func (suite *ReplicaManagerSuite) SetupSuite() {
spawnConfig: map[string]int{"RG1": 1, "RG2": 1, "RG3": 1},
},
}
suite.ctx = context.Background()
}
func (suite *ReplicaManagerSuite) SetupTest() {
@ -114,16 +117,17 @@ func (suite *ReplicaManagerSuite) TearDownTest() {
func (suite *ReplicaManagerSuite) TestSpawn() {
mgr := suite.mgr
ctx := suite.ctx
mgr.idAllocator = ErrorIDAllocator()
_, err := mgr.Spawn(1, map[string]int{DefaultResourceGroupName: 1}, nil)
_, err := mgr.Spawn(ctx, 1, map[string]int{DefaultResourceGroupName: 1}, nil)
suite.Error(err)
replicas := mgr.GetByCollection(1)
replicas := mgr.GetByCollection(ctx, 1)
suite.Len(replicas, 0)
mgr.idAllocator = suite.idAllocator
replicas, err = mgr.Spawn(1, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
replicas, err = mgr.Spawn(ctx, 1, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
suite.NoError(err)
for _, replica := range replicas {
suite.Len(replica.replicaPB.GetChannelNodeInfos(), 0)
@ -131,7 +135,7 @@ func (suite *ReplicaManagerSuite) TestSpawn() {
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.Balancer.Key, ChannelLevelScoreBalancerName)
defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.Balancer.Key)
replicas, err = mgr.Spawn(2, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
replicas, err = mgr.Spawn(ctx, 2, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"})
suite.NoError(err)
for _, replica := range replicas {
suite.Len(replica.replicaPB.GetChannelNodeInfos(), 2)
@ -140,14 +144,15 @@ func (suite *ReplicaManagerSuite) TestSpawn() {
func (suite *ReplicaManagerSuite) TestGet() {
mgr := suite.mgr
ctx := suite.ctx
for collectionID, collectionCfg := range suite.collections {
replicas := mgr.GetByCollection(collectionID)
replicas := mgr.GetByCollection(ctx, collectionID)
replicaNodes := make(map[int64][]int64)
nodes := make([]int64, 0)
for _, replica := range replicas {
suite.Equal(collectionID, replica.GetCollectionID())
suite.Equal(replica, mgr.Get(replica.GetID()))
suite.Equal(replica, mgr.Get(ctx, replica.GetID()))
suite.Equal(len(replica.replicaPB.GetNodes()), replica.RWNodesCount())
suite.Equal(replica.replicaPB.GetNodes(), replica.GetNodes())
replicaNodes[replica.GetID()] = replica.GetNodes()
@ -162,7 +167,7 @@ func (suite *ReplicaManagerSuite) TestGet() {
for replicaID, nodes := range replicaNodes {
for _, node := range nodes {
replica := mgr.GetByCollectionAndNode(collectionID, node)
replica := mgr.GetByCollectionAndNode(ctx, collectionID, node)
suite.Equal(replicaID, replica.GetID())
}
}
@ -171,6 +176,7 @@ func (suite *ReplicaManagerSuite) TestGet() {
func (suite *ReplicaManagerSuite) TestGetByNode() {
mgr := suite.mgr
ctx := suite.ctx
randomNodeID := int64(11111)
testReplica1 := newReplica(&querypb.Replica{
@ -185,18 +191,19 @@ func (suite *ReplicaManagerSuite) TestGetByNode() {
Nodes: []int64{randomNodeID},
ResourceGroup: DefaultResourceGroupName,
})
mgr.Put(testReplica1, testReplica2)
mgr.Put(ctx, testReplica1, testReplica2)
replicas := mgr.GetByNode(randomNodeID)
replicas := mgr.GetByNode(ctx, randomNodeID)
suite.Len(replicas, 2)
}
func (suite *ReplicaManagerSuite) TestRecover() {
mgr := suite.mgr
ctx := suite.ctx
// Clear data in memory, and then recover from meta store
suite.clearMemory()
mgr.Recover(lo.Keys(suite.collections))
mgr.Recover(ctx, lo.Keys(suite.collections))
suite.TestGet()
// Test recover from 2.1 meta store
@ -210,8 +217,8 @@ func (suite *ReplicaManagerSuite) TestRecover() {
suite.kv.Save(querycoord.ReplicaMetaPrefixV1+"/2100", string(value))
suite.clearMemory()
mgr.Recover(append(lo.Keys(suite.collections), 1000))
replica := mgr.Get(2100)
mgr.Recover(ctx, append(lo.Keys(suite.collections), 1000))
replica := mgr.Get(ctx, 2100)
suite.NotNil(replica)
suite.EqualValues(1000, replica.GetCollectionID())
suite.EqualValues([]int64{1, 2, 3}, replica.GetNodes())
@ -223,25 +230,27 @@ func (suite *ReplicaManagerSuite) TestRecover() {
func (suite *ReplicaManagerSuite) TestRemove() {
mgr := suite.mgr
ctx := suite.ctx
for collection := range suite.collections {
err := mgr.RemoveCollection(collection)
err := mgr.RemoveCollection(ctx, collection)
suite.NoError(err)
replicas := mgr.GetByCollection(collection)
replicas := mgr.GetByCollection(ctx, collection)
suite.Empty(replicas)
}
// Check whether the replicas are also removed from meta store
mgr.Recover(lo.Keys(suite.collections))
mgr.Recover(ctx, lo.Keys(suite.collections))
for collection := range suite.collections {
replicas := mgr.GetByCollection(collection)
replicas := mgr.GetByCollection(ctx, collection)
suite.Empty(replicas)
}
}
func (suite *ReplicaManagerSuite) TestNodeManipulate() {
mgr := suite.mgr
ctx := suite.ctx
// add node into rg.
rgs := map[string]typeutil.UniqueSet{
@ -256,10 +265,10 @@ func (suite *ReplicaManagerSuite) TestNodeManipulate() {
for rg := range cfg.spawnConfig {
rgsOfCollection[rg] = rgs[rg]
}
mgr.RecoverNodesInCollection(collectionID, rgsOfCollection)
mgr.RecoverNodesInCollection(ctx, collectionID, rgsOfCollection)
for rg := range cfg.spawnConfig {
for _, node := range rgs[rg].Collect() {
replica := mgr.GetByCollectionAndNode(collectionID, node)
replica := mgr.GetByCollectionAndNode(ctx, collectionID, node)
suite.Contains(replica.GetNodes(), node)
}
}
@ -267,11 +276,11 @@ func (suite *ReplicaManagerSuite) TestNodeManipulate() {
// Check these modifications are applied to meta store
suite.clearMemory()
mgr.Recover(lo.Keys(suite.collections))
mgr.Recover(ctx, lo.Keys(suite.collections))
for collectionID, cfg := range suite.collections {
for rg := range cfg.spawnConfig {
for _, node := range rgs[rg].Collect() {
replica := mgr.GetByCollectionAndNode(collectionID, node)
replica := mgr.GetByCollectionAndNode(ctx, collectionID, node)
suite.Contains(replica.GetNodes(), node)
}
}
@ -280,9 +289,10 @@ func (suite *ReplicaManagerSuite) TestNodeManipulate() {
func (suite *ReplicaManagerSuite) spawnAll() {
mgr := suite.mgr
ctx := suite.ctx
for id, cfg := range suite.collections {
replicas, err := mgr.Spawn(id, cfg.spawnConfig, nil)
replicas, err := mgr.Spawn(ctx, id, cfg.spawnConfig, nil)
suite.NoError(err)
totalSpawn := 0
rgsOfCollection := make(map[string]typeutil.UniqueSet)
@ -290,26 +300,27 @@ func (suite *ReplicaManagerSuite) spawnAll() {
totalSpawn += spawnNum
rgsOfCollection[rg] = suite.rgs[rg]
}
mgr.RecoverNodesInCollection(id, rgsOfCollection)
mgr.RecoverNodesInCollection(ctx, id, rgsOfCollection)
suite.Len(replicas, totalSpawn)
}
}
func (suite *ReplicaManagerSuite) TestResourceGroup() {
mgr := NewReplicaManager(suite.idAllocator, suite.catalog)
replicas1, err := mgr.Spawn(int64(1000), map[string]int{DefaultResourceGroupName: 1}, nil)
ctx := suite.ctx
replicas1, err := mgr.Spawn(ctx, int64(1000), map[string]int{DefaultResourceGroupName: 1}, nil)
suite.NoError(err)
suite.NotNil(replicas1)
suite.Len(replicas1, 1)
replica2, err := mgr.Spawn(int64(2000), map[string]int{DefaultResourceGroupName: 1}, nil)
replica2, err := mgr.Spawn(ctx, int64(2000), map[string]int{DefaultResourceGroupName: 1}, nil)
suite.NoError(err)
suite.NotNil(replica2)
suite.Len(replica2, 1)
replicas := mgr.GetByResourceGroup(DefaultResourceGroupName)
replicas := mgr.GetByResourceGroup(ctx, DefaultResourceGroupName)
suite.Len(replicas, 2)
rgNames := mgr.GetResourceGroupByCollection(int64(1000))
rgNames := mgr.GetResourceGroupByCollection(ctx, int64(1000))
suite.Len(rgNames, 1)
suite.True(rgNames.Contain(DefaultResourceGroupName))
}
@ -326,6 +337,7 @@ type ReplicaManagerV2Suite struct {
kv kv.MetaKv
catalog metastore.QueryCoordCatalog
mgr *ReplicaManager
ctx context.Context
}
func (suite *ReplicaManagerV2Suite) SetupSuite() {
@ -375,6 +387,7 @@ func (suite *ReplicaManagerV2Suite) SetupSuite() {
idAllocator := RandomIncrementIDAllocator()
suite.mgr = NewReplicaManager(idAllocator, suite.catalog)
suite.ctx = context.Background()
}
func (suite *ReplicaManagerV2Suite) TearDownSuite() {
@ -383,32 +396,34 @@ func (suite *ReplicaManagerV2Suite) TearDownSuite() {
func (suite *ReplicaManagerV2Suite) TestSpawn() {
mgr := suite.mgr
ctx := suite.ctx
for id, cfg := range suite.collections {
replicas, err := mgr.Spawn(id, cfg.spawnConfig, nil)
replicas, err := mgr.Spawn(ctx, id, cfg.spawnConfig, nil)
suite.NoError(err)
rgsOfCollection := make(map[string]typeutil.UniqueSet)
for rg := range cfg.spawnConfig {
rgsOfCollection[rg] = suite.rgs[rg]
}
mgr.RecoverNodesInCollection(id, rgsOfCollection)
mgr.RecoverNodesInCollection(ctx, id, rgsOfCollection)
for rg := range cfg.spawnConfig {
for _, node := range suite.rgs[rg].Collect() {
replica := mgr.GetByCollectionAndNode(id, node)
replica := mgr.GetByCollectionAndNode(ctx, id, node)
suite.Contains(replica.GetNodes(), node)
}
}
suite.Len(replicas, cfg.getTotalSpawn())
replicas = mgr.GetByCollection(id)
replicas = mgr.GetByCollection(ctx, id)
suite.Len(replicas, cfg.getTotalSpawn())
}
suite.testIfBalanced()
}
func (suite *ReplicaManagerV2Suite) testIfBalanced() {
ctx := suite.ctx
// If balanced
for id := range suite.collections {
replicas := suite.mgr.GetByCollection(id)
replicas := suite.mgr.GetByCollection(ctx, id)
rgToReplica := make(map[string][]*Replica, 0)
for _, r := range replicas {
rgToReplica[r.GetResourceGroup()] = append(rgToReplica[r.GetResourceGroup()], r)
@ -440,22 +455,24 @@ func (suite *ReplicaManagerV2Suite) testIfBalanced() {
}
func (suite *ReplicaManagerV2Suite) TestTransferReplica() {
ctx := suite.ctx
// param error
err := suite.mgr.TransferReplica(10086, "RG4", "RG5", 1)
err := suite.mgr.TransferReplica(ctx, 10086, "RG4", "RG5", 1)
suite.Error(err)
err = suite.mgr.TransferReplica(1005, "RG4", "RG5", 0)
err = suite.mgr.TransferReplica(ctx, 1005, "RG4", "RG5", 0)
suite.Error(err)
err = suite.mgr.TransferReplica(1005, "RG4", "RG4", 1)
err = suite.mgr.TransferReplica(ctx, 1005, "RG4", "RG4", 1)
suite.Error(err)
err = suite.mgr.TransferReplica(1005, "RG4", "RG5", 1)
err = suite.mgr.TransferReplica(ctx, 1005, "RG4", "RG5", 1)
suite.NoError(err)
suite.recoverReplica(2, true)
suite.testIfBalanced()
}
func (suite *ReplicaManagerV2Suite) TestTransferReplicaAndAddNode() {
suite.mgr.TransferReplica(1005, "RG4", "RG5", 1)
ctx := suite.ctx
suite.mgr.TransferReplica(ctx, 1005, "RG4", "RG5", 1)
suite.recoverReplica(1, false)
suite.rgs["RG5"].Insert(16, 17, 18)
suite.recoverReplica(2, true)
@ -470,6 +487,7 @@ func (suite *ReplicaManagerV2Suite) TestTransferNode() {
}
func (suite *ReplicaManagerV2Suite) recoverReplica(k int, clearOutbound bool) {
ctx := suite.ctx
// need at least two times to recover the replicas.
// transfer node between replicas need set to outbound and then set to incoming.
for i := 0; i < k; i++ {
@ -479,16 +497,16 @@ func (suite *ReplicaManagerV2Suite) recoverReplica(k int, clearOutbound bool) {
for rg := range cfg.spawnConfig {
rgsOfCollection[rg] = suite.rgs[rg]
}
suite.mgr.RecoverNodesInCollection(id, rgsOfCollection)
suite.mgr.RecoverNodesInCollection(ctx, id, rgsOfCollection)
}
// clear all outbound nodes
if clearOutbound {
for id := range suite.collections {
replicas := suite.mgr.GetByCollection(id)
replicas := suite.mgr.GetByCollection(ctx, id)
for _, r := range replicas {
outboundNodes := r.GetRONodes()
suite.mgr.RemoveNode(r.GetID(), outboundNodes...)
suite.mgr.RemoveNode(ctx, r.GetID(), outboundNodes...)
}
}
}
@ -502,9 +520,10 @@ func TestReplicaManager(t *testing.T) {
func TestGetReplicasJSON(t *testing.T) {
catalog := mocks.NewQueryCoordCatalog(t)
catalog.EXPECT().SaveReplica(mock.Anything).Return(nil)
catalog.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil)
idAllocator := RandomIncrementIDAllocator()
replicaManager := NewReplicaManager(idAllocator, catalog)
ctx := context.Background()
// Add some replicas to the ReplicaManager
replica1 := newReplica(&querypb.Replica{
@ -520,13 +539,13 @@ func TestGetReplicasJSON(t *testing.T) {
Nodes: []int64{4, 5, 6},
})
err := replicaManager.put(replica1)
err := replicaManager.put(ctx, replica1)
assert.NoError(t, err)
err = replicaManager.put(replica2)
err = replicaManager.put(ctx, replica2)
assert.NoError(t, err)
jsonOutput := replicaManager.GetReplicasJSON()
jsonOutput := replicaManager.GetReplicasJSON(ctx)
var replicas []*metricsinfo.Replica
err = json.Unmarshal([]byte(jsonOutput), &replicas)
assert.NoError(t, err)

View File

@ -17,6 +17,7 @@
package meta
import (
"context"
"fmt"
"sync"
@ -77,11 +78,11 @@ func NewResourceManager(catalog metastore.QueryCoordCatalog, nodeMgr *session.No
}
// Recover recover resource group from meta, other interface of ResourceManager can be only called after recover is done.
func (rm *ResourceManager) Recover() error {
func (rm *ResourceManager) Recover(ctx context.Context) error {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
rgs, err := rm.catalog.GetResourceGroups()
rgs, err := rm.catalog.GetResourceGroups(ctx)
if err != nil {
return errors.Wrap(err, "failed to recover resource group from store")
}
@ -111,14 +112,14 @@ func (rm *ResourceManager) Recover() error {
}
if len(upgrades) > 0 {
log.Info("upgrade resource group meta into latest", zap.Int("num", len(upgrades)))
return rm.catalog.SaveResourceGroup(upgrades...)
return rm.catalog.SaveResourceGroup(ctx, upgrades...)
}
return nil
}
// AddResourceGroup create a new ResourceGroup.
// Do no changed with node, all node will be reassign to new resource group by auto recover.
func (rm *ResourceManager) AddResourceGroup(rgName string, cfg *rgpb.ResourceGroupConfig) error {
func (rm *ResourceManager) AddResourceGroup(ctx context.Context, rgName string, cfg *rgpb.ResourceGroupConfig) error {
if len(rgName) == 0 {
return merr.WrapErrParameterMissing("resource group name couldn't be empty")
}
@ -148,7 +149,7 @@ func (rm *ResourceManager) AddResourceGroup(rgName string, cfg *rgpb.ResourceGro
}
rg := NewResourceGroup(rgName, cfg, rm.nodeMgr)
if err := rm.catalog.SaveResourceGroup(rg.GetMeta()); err != nil {
if err := rm.catalog.SaveResourceGroup(ctx, rg.GetMeta()); err != nil {
log.Warn("failed to add resource group",
zap.String("rgName", rgName),
zap.Any("config", cfg),
@ -170,18 +171,18 @@ func (rm *ResourceManager) AddResourceGroup(rgName string, cfg *rgpb.ResourceGro
// UpdateResourceGroups update resource group configuration.
// Only change the configuration, no change with node. all node will be reassign by auto recover.
func (rm *ResourceManager) UpdateResourceGroups(rgs map[string]*rgpb.ResourceGroupConfig) error {
func (rm *ResourceManager) UpdateResourceGroups(ctx context.Context, rgs map[string]*rgpb.ResourceGroupConfig) error {
if len(rgs) == 0 {
return nil
}
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
return rm.updateResourceGroups(rgs)
return rm.updateResourceGroups(ctx, rgs)
}
// updateResourceGroups update resource group configuration.
func (rm *ResourceManager) updateResourceGroups(rgs map[string]*rgpb.ResourceGroupConfig) error {
func (rm *ResourceManager) updateResourceGroups(ctx context.Context, rgs map[string]*rgpb.ResourceGroupConfig) error {
modifiedRG := make([]*ResourceGroup, 0, len(rgs))
updates := make([]*querypb.ResourceGroup, 0, len(rgs))
for rgName, cfg := range rgs {
@ -200,7 +201,7 @@ func (rm *ResourceManager) updateResourceGroups(rgs map[string]*rgpb.ResourceGro
modifiedRG = append(modifiedRG, rg)
}
if err := rm.catalog.SaveResourceGroup(updates...); err != nil {
if err := rm.catalog.SaveResourceGroup(ctx, updates...); err != nil {
for rgName, cfg := range rgs {
log.Warn("failed to update resource group",
zap.String("rgName", rgName),
@ -227,7 +228,7 @@ func (rm *ResourceManager) updateResourceGroups(rgs map[string]*rgpb.ResourceGro
// go:deprecated TransferNode transfer node from source resource group to target resource group.
// Deprecated, use Declarative API `UpdateResourceGroups` instead.
func (rm *ResourceManager) TransferNode(sourceRGName string, targetRGName string, nodeNum int) error {
func (rm *ResourceManager) TransferNode(ctx context.Context, sourceRGName string, targetRGName string, nodeNum int) error {
if sourceRGName == targetRGName {
return merr.WrapErrParameterInvalidMsg("source resource group and target resource group should not be the same, resource group: %s", sourceRGName)
}
@ -272,14 +273,14 @@ func (rm *ResourceManager) TransferNode(sourceRGName string, targetRGName string
if targetCfg.Requests.NodeNum > targetCfg.Limits.NodeNum {
targetCfg.Limits.NodeNum = targetCfg.Requests.NodeNum
}
return rm.updateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
return rm.updateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
sourceRGName: sourceCfg,
targetRGName: targetCfg,
})
}
// RemoveResourceGroup remove resource group.
func (rm *ResourceManager) RemoveResourceGroup(rgName string) error {
func (rm *ResourceManager) RemoveResourceGroup(ctx context.Context, rgName string) error {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
@ -296,7 +297,7 @@ func (rm *ResourceManager) RemoveResourceGroup(rgName string) error {
// Nodes may be still assign to these group,
// recover the resource group from redundant status before remove it.
if rm.groups[rgName].NodeNum() > 0 {
if err := rm.recoverRedundantNodeRG(rgName); err != nil {
if err := rm.recoverRedundantNodeRG(ctx, rgName); err != nil {
log.Info("failed to recover redundant node resource group before remove it",
zap.String("rgName", rgName),
zap.Error(err),
@ -306,7 +307,7 @@ func (rm *ResourceManager) RemoveResourceGroup(rgName string) error {
}
// Remove it from meta storage.
if err := rm.catalog.RemoveResourceGroup(rgName); err != nil {
if err := rm.catalog.RemoveResourceGroup(ctx, rgName); err != nil {
log.Info("failed to remove resource group",
zap.String("rgName", rgName),
zap.Error(err),
@ -327,7 +328,7 @@ func (rm *ResourceManager) RemoveResourceGroup(rgName string) error {
}
// GetNodesOfMultiRG return nodes of multi rg, it can be used to get a consistent view of nodes of multi rg.
func (rm *ResourceManager) GetNodesOfMultiRG(rgName []string) (map[string]typeutil.UniqueSet, error) {
func (rm *ResourceManager) GetNodesOfMultiRG(ctx context.Context, rgName []string) (map[string]typeutil.UniqueSet, error) {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
@ -342,7 +343,7 @@ func (rm *ResourceManager) GetNodesOfMultiRG(rgName []string) (map[string]typeut
}
// GetNodes return nodes of given resource group.
func (rm *ResourceManager) GetNodes(rgName string) ([]int64, error) {
func (rm *ResourceManager) GetNodes(ctx context.Context, rgName string) ([]int64, error) {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
if rm.groups[rgName] == nil {
@ -352,7 +353,7 @@ func (rm *ResourceManager) GetNodes(rgName string) ([]int64, error) {
}
// GetResourceGroupByNodeID return whether resource group's node match required node count
func (rm *ResourceManager) VerifyNodeCount(requiredNodeCount map[string]int) error {
func (rm *ResourceManager) VerifyNodeCount(ctx context.Context, requiredNodeCount map[string]int) error {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
for rgName, nodeCount := range requiredNodeCount {
@ -368,7 +369,7 @@ func (rm *ResourceManager) VerifyNodeCount(requiredNodeCount map[string]int) err
}
// GetOutgoingNodeNumByReplica return outgoing node num on each rg from this replica.
func (rm *ResourceManager) GetOutgoingNodeNumByReplica(replica *Replica) map[string]int32 {
func (rm *ResourceManager) GetOutgoingNodeNumByReplica(ctx context.Context, replica *Replica) map[string]int32 {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
@ -397,7 +398,7 @@ func (rm *ResourceManager) getResourceGroupByNodeID(nodeID int64) *ResourceGroup
}
// ContainsNode return whether given node is in given resource group.
func (rm *ResourceManager) ContainsNode(rgName string, node int64) bool {
func (rm *ResourceManager) ContainsNode(ctx context.Context, rgName string, node int64) bool {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
if rm.groups[rgName] == nil {
@ -407,14 +408,14 @@ func (rm *ResourceManager) ContainsNode(rgName string, node int64) bool {
}
// ContainResourceGroup return whether given resource group is exist.
func (rm *ResourceManager) ContainResourceGroup(rgName string) bool {
func (rm *ResourceManager) ContainResourceGroup(ctx context.Context, rgName string) bool {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
return rm.groups[rgName] != nil
}
// GetResourceGroup return resource group snapshot by name.
func (rm *ResourceManager) GetResourceGroup(rgName string) *ResourceGroup {
func (rm *ResourceManager) GetResourceGroup(ctx context.Context, rgName string) *ResourceGroup {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
@ -425,7 +426,7 @@ func (rm *ResourceManager) GetResourceGroup(rgName string) *ResourceGroup {
}
// ListResourceGroups return all resource groups names.
func (rm *ResourceManager) ListResourceGroups() []string {
func (rm *ResourceManager) ListResourceGroups(ctx context.Context) []string {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
@ -434,7 +435,7 @@ func (rm *ResourceManager) ListResourceGroups() []string {
// MeetRequirement return whether resource group meet requirement.
// Return error with reason if not meet requirement.
func (rm *ResourceManager) MeetRequirement(rgName string) error {
func (rm *ResourceManager) MeetRequirement(ctx context.Context, rgName string) error {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
if rm.groups[rgName] == nil {
@ -444,21 +445,21 @@ func (rm *ResourceManager) MeetRequirement(rgName string) error {
}
// CheckIncomingNodeNum return incoming node num.
func (rm *ResourceManager) CheckIncomingNodeNum() int {
func (rm *ResourceManager) CheckIncomingNodeNum(ctx context.Context) int {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
return rm.incomingNode.Len()
}
// HandleNodeUp handle node when new node is incoming.
func (rm *ResourceManager) HandleNodeUp(node int64) {
func (rm *ResourceManager) HandleNodeUp(ctx context.Context, node int64) {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
rm.incomingNode.Insert(node)
// Trigger assign incoming node right away.
// error can be ignored here, because `AssignPendingIncomingNode`` will retry assign node.
rgName, err := rm.assignIncomingNodeWithNodeCheck(node)
rgName, err := rm.assignIncomingNodeWithNodeCheck(ctx, node)
log.Info("HandleNodeUp: add node to resource group",
zap.String("rgName", rgName),
zap.Int64("node", node),
@ -467,7 +468,7 @@ func (rm *ResourceManager) HandleNodeUp(node int64) {
}
// HandleNodeDown handle the node when node is leave.
func (rm *ResourceManager) HandleNodeDown(node int64) {
func (rm *ResourceManager) HandleNodeDown(ctx context.Context, node int64) {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
@ -476,7 +477,7 @@ func (rm *ResourceManager) HandleNodeDown(node int64) {
// for stopping query node becomes offline, node change won't be triggered,
// cause when it becomes stopping, it already remove from resource manager
// then `unassignNode` will do nothing
rgName, err := rm.unassignNode(node)
rgName, err := rm.unassignNode(ctx, node)
// trigger node changes, expected to remove ro node from replica immediately
rm.nodeChangedNotifier.NotifyAll()
@ -487,12 +488,12 @@ func (rm *ResourceManager) HandleNodeDown(node int64) {
)
}
func (rm *ResourceManager) HandleNodeStopping(node int64) {
func (rm *ResourceManager) HandleNodeStopping(ctx context.Context, node int64) {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
rm.incomingNode.Remove(node)
rgName, err := rm.unassignNode(node)
rgName, err := rm.unassignNode(ctx, node)
log.Info("HandleNodeStopping: remove node from resource group",
zap.String("rgName", rgName),
zap.Int64("node", node),
@ -501,22 +502,22 @@ func (rm *ResourceManager) HandleNodeStopping(node int64) {
}
// ListenResourceGroupChanged return a listener for resource group changed.
func (rm *ResourceManager) ListenResourceGroupChanged() *syncutil.VersionedListener {
func (rm *ResourceManager) ListenResourceGroupChanged(ctx context.Context) *syncutil.VersionedListener {
return rm.rgChangedNotifier.Listen(syncutil.VersionedListenAtEarliest)
}
// ListenNodeChanged return a listener for node changed.
func (rm *ResourceManager) ListenNodeChanged() *syncutil.VersionedListener {
func (rm *ResourceManager) ListenNodeChanged(ctx context.Context) *syncutil.VersionedListener {
return rm.nodeChangedNotifier.Listen(syncutil.VersionedListenAtEarliest)
}
// AssignPendingIncomingNode assign incoming node to resource group.
func (rm *ResourceManager) AssignPendingIncomingNode() {
func (rm *ResourceManager) AssignPendingIncomingNode(ctx context.Context) {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
for node := range rm.incomingNode {
rgName, err := rm.assignIncomingNodeWithNodeCheck(node)
rgName, err := rm.assignIncomingNodeWithNodeCheck(ctx, node)
log.Info("Pending HandleNodeUp: add node to resource group",
zap.String("rgName", rgName),
zap.Int64("node", node),
@ -526,7 +527,7 @@ func (rm *ResourceManager) AssignPendingIncomingNode() {
}
// AutoRecoverResourceGroup auto recover rg, return recover used node num
func (rm *ResourceManager) AutoRecoverResourceGroup(rgName string) error {
func (rm *ResourceManager) AutoRecoverResourceGroup(ctx context.Context, rgName string) error {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
@ -536,19 +537,19 @@ func (rm *ResourceManager) AutoRecoverResourceGroup(rgName string) error {
}
if rg.MissingNumOfNodes() > 0 {
return rm.recoverMissingNodeRG(rgName)
return rm.recoverMissingNodeRG(ctx, rgName)
}
// DefaultResourceGroup is the backup resource group of redundant recovery,
// So after all other resource group is reach the `limits`, rest redundant node will be transfer to DefaultResourceGroup.
if rg.RedundantNumOfNodes() > 0 {
return rm.recoverRedundantNodeRG(rgName)
return rm.recoverRedundantNodeRG(ctx, rgName)
}
return nil
}
// recoverMissingNodeRG recover resource group by transfer node from other resource group.
func (rm *ResourceManager) recoverMissingNodeRG(rgName string) error {
func (rm *ResourceManager) recoverMissingNodeRG(ctx context.Context, rgName string) error {
for rm.groups[rgName].MissingNumOfNodes() > 0 {
targetRG := rm.groups[rgName]
node, sourceRG := rm.selectNodeForMissingRecover(targetRG)
@ -557,7 +558,7 @@ func (rm *ResourceManager) recoverMissingNodeRG(rgName string) error {
return ErrNodeNotEnough
}
err := rm.transferNode(targetRG.GetName(), node)
err := rm.transferNode(ctx, targetRG.GetName(), node)
if err != nil {
log.Warn("failed to recover missing node by transfer node from other resource group",
zap.String("sourceRG", sourceRG.GetName()),
@ -622,7 +623,7 @@ func (rm *ResourceManager) selectNodeForMissingRecover(targetRG *ResourceGroup)
}
// recoverRedundantNodeRG recover resource group by transfer node to other resource group.
func (rm *ResourceManager) recoverRedundantNodeRG(rgName string) error {
func (rm *ResourceManager) recoverRedundantNodeRG(ctx context.Context, rgName string) error {
for rm.groups[rgName].RedundantNumOfNodes() > 0 {
sourceRG := rm.groups[rgName]
node, targetRG := rm.selectNodeForRedundantRecover(sourceRG)
@ -632,7 +633,7 @@ func (rm *ResourceManager) recoverRedundantNodeRG(rgName string) error {
return errors.New("all resource group reach limits")
}
if err := rm.transferNode(targetRG.GetName(), node); err != nil {
if err := rm.transferNode(ctx, targetRG.GetName(), node); err != nil {
log.Warn("failed to recover redundant node by transfer node to other resource group",
zap.String("sourceRG", sourceRG.GetName()),
zap.String("targetRG", targetRG.GetName()),
@ -704,7 +705,7 @@ func (rm *ResourceManager) selectNodeForRedundantRecover(sourceRG *ResourceGroup
}
// assignIncomingNodeWithNodeCheck assign node to resource group with node status check.
func (rm *ResourceManager) assignIncomingNodeWithNodeCheck(node int64) (string, error) {
func (rm *ResourceManager) assignIncomingNodeWithNodeCheck(ctx context.Context, node int64) (string, error) {
// node is on stopping or stopped, remove it from incoming node set.
if rm.nodeMgr.Get(node) == nil {
rm.incomingNode.Remove(node)
@ -715,7 +716,7 @@ func (rm *ResourceManager) assignIncomingNodeWithNodeCheck(node int64) (string,
return "", errors.New("node has been stopped")
}
rgName, err := rm.assignIncomingNode(node)
rgName, err := rm.assignIncomingNode(ctx, node)
if err != nil {
return "", err
}
@ -725,7 +726,7 @@ func (rm *ResourceManager) assignIncomingNodeWithNodeCheck(node int64) (string,
}
// assignIncomingNode assign node to resource group.
func (rm *ResourceManager) assignIncomingNode(node int64) (string, error) {
func (rm *ResourceManager) assignIncomingNode(ctx context.Context, node int64) (string, error) {
// If node already assign to rg.
rg := rm.getResourceGroupByNodeID(node)
if rg != nil {
@ -738,7 +739,7 @@ func (rm *ResourceManager) assignIncomingNode(node int64) (string, error) {
// select a resource group to assign incoming node.
rg = rm.mustSelectAssignIncomingNodeTargetRG(node)
if err := rm.transferNode(rg.GetName(), node); err != nil {
if err := rm.transferNode(ctx, rg.GetName(), node); err != nil {
return "", errors.Wrap(err, "at finally assign to default resource group")
}
return rg.GetName(), nil
@ -791,7 +792,7 @@ func (rm *ResourceManager) findMaxRGWithGivenFilter(filter func(rg *ResourceGrou
// transferNode transfer given node to given resource group.
// if given node is assigned in given resource group, do nothing.
// if given node is assigned to other resource group, it will be unassigned first.
func (rm *ResourceManager) transferNode(rgName string, node int64) error {
func (rm *ResourceManager) transferNode(ctx context.Context, rgName string, node int64) error {
if rm.groups[rgName] == nil {
return merr.WrapErrResourceGroupNotFound(rgName)
}
@ -827,7 +828,7 @@ func (rm *ResourceManager) transferNode(rgName string, node int64) error {
modifiedRG = append(modifiedRG, rg)
// Commit updates to meta storage.
if err := rm.catalog.SaveResourceGroup(updates...); err != nil {
if err := rm.catalog.SaveResourceGroup(ctx, updates...); err != nil {
log.Warn("failed to transfer node to resource group",
zap.String("rgName", rgName),
zap.String("originalRG", originalRG),
@ -854,12 +855,12 @@ func (rm *ResourceManager) transferNode(rgName string, node int64) error {
}
// unassignNode remove a node from resource group where it belongs to.
func (rm *ResourceManager) unassignNode(node int64) (string, error) {
func (rm *ResourceManager) unassignNode(ctx context.Context, node int64) (string, error) {
if rg := rm.getResourceGroupByNodeID(node); rg != nil {
mrg := rg.CopyForWrite()
mrg.UnassignNode(node)
rg := mrg.ToResourceGroup()
if err := rm.catalog.SaveResourceGroup(rg.GetMeta()); err != nil {
if err := rm.catalog.SaveResourceGroup(ctx, rg.GetMeta()); err != nil {
log.Fatal("unassign node from resource group",
zap.String("rgName", rg.GetName()),
zap.Int64("node", node),
@ -943,7 +944,7 @@ func (rm *ResourceManager) validateResourceGroupIsDeletable(rgName string) error
return nil
}
func (rm *ResourceManager) GetResourceGroupsJSON() string {
func (rm *ResourceManager) GetResourceGroupsJSON(ctx context.Context) string {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()

View File

@ -16,6 +16,7 @@
package meta
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -44,6 +45,7 @@ type ResourceManagerSuite struct {
kv kv.MetaKv
manager *ResourceManager
ctx context.Context
}
func (suite *ResourceManagerSuite) SetupSuite() {
@ -65,6 +67,7 @@ func (suite *ResourceManagerSuite) SetupTest() {
store := querycoord.NewCatalog(suite.kv)
suite.manager = NewResourceManager(store, session.NewNodeManager())
suite.ctx = context.Background()
}
func (suite *ResourceManagerSuite) TearDownSuite() {
@ -76,6 +79,7 @@ func TestResourceManager(t *testing.T) {
}
func (suite *ResourceManagerSuite) TestValidateConfiguration() {
ctx := suite.ctx
err := suite.manager.validateResourceGroupConfig("rg1", newResourceGroupConfig(0, 0))
suite.NoError(err)
@ -111,16 +115,17 @@ func (suite *ResourceManagerSuite) TestValidateConfiguration() {
err = suite.manager.validateResourceGroupConfig("rg1", cfg)
suite.ErrorIs(err, merr.ErrResourceGroupIllegalConfig)
err = suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(0, 0))
err = suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(0, 0))
suite.NoError(err)
err = suite.manager.RemoveResourceGroup("rg2")
err = suite.manager.RemoveResourceGroup(ctx, "rg2")
suite.NoError(err)
}
func (suite *ResourceManagerSuite) TestValidateDelete() {
ctx := suite.ctx
// Non empty resource group can not be removed.
err := suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(1, 1))
err := suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(1, 1))
suite.NoError(err)
err = suite.manager.validateResourceGroupIsDeletable(DefaultResourceGroupName)
@ -131,8 +136,8 @@ func (suite *ResourceManagerSuite) TestValidateDelete() {
cfg := newResourceGroupConfig(0, 0)
cfg.TransferFrom = []*rgpb.ResourceGroupTransfer{{ResourceGroup: "rg1"}}
suite.manager.AddResourceGroup("rg2", cfg)
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.AddResourceGroup(ctx, "rg2", cfg)
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(0, 0),
})
err = suite.manager.validateResourceGroupIsDeletable("rg1")
@ -140,64 +145,65 @@ func (suite *ResourceManagerSuite) TestValidateDelete() {
cfg = newResourceGroupConfig(0, 0)
cfg.TransferTo = []*rgpb.ResourceGroupTransfer{{ResourceGroup: "rg1"}}
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg2": cfg,
})
err = suite.manager.validateResourceGroupIsDeletable("rg1")
suite.ErrorIs(err, merr.ErrParameterInvalid)
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg2": newResourceGroupConfig(0, 0),
})
err = suite.manager.validateResourceGroupIsDeletable("rg1")
suite.NoError(err)
err = suite.manager.RemoveResourceGroup("rg1")
err = suite.manager.RemoveResourceGroup(ctx, "rg1")
suite.NoError(err)
err = suite.manager.RemoveResourceGroup("rg2")
err = suite.manager.RemoveResourceGroup(ctx, "rg2")
suite.NoError(err)
}
func (suite *ResourceManagerSuite) TestManipulateResourceGroup() {
ctx := suite.ctx
// test add rg
err := suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(0, 0))
err := suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(0, 0))
suite.NoError(err)
suite.True(suite.manager.ContainResourceGroup("rg1"))
suite.Len(suite.manager.ListResourceGroups(), 2)
suite.True(suite.manager.ContainResourceGroup(ctx, "rg1"))
suite.Len(suite.manager.ListResourceGroups(ctx), 2)
// test add duplicate rg but same configuration is ok
err = suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(0, 0))
err = suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(0, 0))
suite.NoError(err)
err = suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(1, 1))
err = suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(1, 1))
suite.Error(err)
// test delete rg
err = suite.manager.RemoveResourceGroup("rg1")
err = suite.manager.RemoveResourceGroup(ctx, "rg1")
suite.NoError(err)
// test delete rg which doesn't exist
err = suite.manager.RemoveResourceGroup("rg1")
err = suite.manager.RemoveResourceGroup(ctx, "rg1")
suite.NoError(err)
// test delete default rg
err = suite.manager.RemoveResourceGroup(DefaultResourceGroupName)
err = suite.manager.RemoveResourceGroup(ctx, DefaultResourceGroupName)
suite.ErrorIs(err, merr.ErrParameterInvalid)
// test delete a rg not empty.
err = suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(1, 1))
err = suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(1, 1))
suite.NoError(err)
err = suite.manager.RemoveResourceGroup("rg2")
err = suite.manager.RemoveResourceGroup(ctx, "rg2")
suite.ErrorIs(err, merr.ErrParameterInvalid)
// test delete a rg after update
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg2": newResourceGroupConfig(0, 0),
})
err = suite.manager.RemoveResourceGroup("rg2")
err = suite.manager.RemoveResourceGroup(ctx, "rg2")
suite.NoError(err)
// assign a node to rg.
err = suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(1, 1))
err = suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(1, 1))
suite.NoError(err)
suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
@ -205,68 +211,69 @@ func (suite *ResourceManagerSuite) TestManipulateResourceGroup() {
Hostname: "localhost",
}))
defer suite.manager.nodeMgr.Remove(1)
suite.manager.HandleNodeUp(1)
err = suite.manager.RemoveResourceGroup("rg2")
suite.manager.HandleNodeUp(ctx, 1)
err = suite.manager.RemoveResourceGroup(ctx, "rg2")
suite.ErrorIs(err, merr.ErrParameterInvalid)
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg2": newResourceGroupConfig(0, 0),
})
log.Info("xxxxx")
// RemoveResourceGroup will remove all nodes from the resource group.
err = suite.manager.RemoveResourceGroup("rg2")
err = suite.manager.RemoveResourceGroup(ctx, "rg2")
suite.NoError(err)
}
func (suite *ResourceManagerSuite) TestNodeUpAndDown() {
ctx := suite.ctx
suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
err := suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(1, 1))
err := suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(1, 1))
suite.NoError(err)
// test add node to rg
suite.manager.HandleNodeUp(1)
suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.manager.HandleNodeUp(ctx, 1)
suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
// test add non-exist node to rg
err = suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
err = suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(2, 3),
})
suite.NoError(err)
suite.manager.HandleNodeUp(2)
suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeUp(ctx, 2)
suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// teardown a non-exist node from rg.
suite.manager.HandleNodeDown(2)
suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeDown(ctx, 2)
suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// test add exist node to rg
suite.manager.HandleNodeUp(1)
suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeUp(ctx, 1)
suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// teardown a exist node from rg.
suite.manager.HandleNodeDown(1)
suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeDown(ctx, 1)
suite.Zero(suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// teardown a exist node from rg.
suite.manager.HandleNodeDown(1)
suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeDown(ctx, 1)
suite.Zero(suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeUp(1)
suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeUp(ctx, 1)
suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
err = suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
err = suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(4, 4),
})
suite.NoError(err)
suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(1, 1))
suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(1, 1))
suite.NoError(err)
suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
@ -289,29 +296,29 @@ func (suite *ResourceManagerSuite) TestNodeUpAndDown() {
Address: "localhost",
Hostname: "localhost",
}))
suite.manager.HandleNodeUp(11)
suite.manager.HandleNodeUp(12)
suite.manager.HandleNodeUp(13)
suite.manager.HandleNodeUp(14)
suite.manager.HandleNodeUp(ctx, 11)
suite.manager.HandleNodeUp(ctx, 12)
suite.manager.HandleNodeUp(ctx, 13)
suite.manager.HandleNodeUp(ctx, 14)
suite.Equal(4, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(1, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(4, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeDown(11)
suite.manager.HandleNodeDown(12)
suite.manager.HandleNodeDown(13)
suite.manager.HandleNodeDown(14)
suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeDown(ctx, 11)
suite.manager.HandleNodeDown(ctx, 12)
suite.manager.HandleNodeDown(ctx, 13)
suite.manager.HandleNodeDown(ctx, 14)
suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeDown(1)
suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeDown(ctx, 1)
suite.Zero(suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(20, 30),
"rg2": newResourceGroupConfig(30, 40),
})
@ -321,106 +328,107 @@ func (suite *ResourceManagerSuite) TestNodeUpAndDown() {
Address: "localhost",
Hostname: "localhost",
}))
suite.manager.HandleNodeUp(int64(i))
suite.manager.HandleNodeUp(ctx, int64(i))
}
suite.Equal(20, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(30, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(50, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(30, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(50, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// down all nodes
for i := 1; i <= 100; i++ {
suite.manager.HandleNodeDown(int64(i))
suite.Equal(100-i, suite.manager.GetResourceGroup("rg1").NodeNum()+
suite.manager.GetResourceGroup("rg2").NodeNum()+
suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeDown(ctx, int64(i))
suite.Equal(100-i, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()+
suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()+
suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
}
// if there are all rgs reach limit, should be fall back to default rg.
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(0, 0),
"rg2": newResourceGroupConfig(0, 0),
DefaultResourceGroupName: newResourceGroupConfig(0, 0),
})
for i := 1; i <= 100; i++ {
suite.manager.HandleNodeUp(int64(i))
suite.Equal(i, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.manager.HandleNodeUp(ctx, int64(i))
suite.Equal(i, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
}
}
func (suite *ResourceManagerSuite) TestAutoRecover() {
ctx := suite.ctx
for i := 1; i <= 100; i++ {
suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: int64(i),
Address: "localhost",
Hostname: "localhost",
}))
suite.manager.HandleNodeUp(int64(i))
suite.manager.HandleNodeUp(ctx, int64(i))
}
suite.Equal(100, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(100, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// Recover 10 nodes from default resource group
suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(10, 30))
suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg1").MissingNumOfNodes())
suite.Equal(100, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup("rg1")
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg1").MissingNumOfNodes())
suite.Equal(90, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(10, 30))
suite.Zero(suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").MissingNumOfNodes())
suite.Equal(100, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg1").MissingNumOfNodes())
suite.Equal(90, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// Recover 20 nodes from default resource group
suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(20, 30))
suite.Zero(suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup("rg2").MissingNumOfNodes())
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(90, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup("rg2")
suite.Equal(20, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(70, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(20, 30))
suite.Zero(suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg2").MissingNumOfNodes())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(90, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(70, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// Recover 5 redundant nodes from resource group
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(5, 5),
})
suite.manager.AutoRecoverResourceGroup("rg1")
suite.Equal(20, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(75, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(75, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// Recover 10 redundant nodes from resource group 2 to resource group 1 and default resource group.
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(10, 20),
"rg2": newResourceGroupConfig(5, 10),
})
suite.manager.AutoRecoverResourceGroup("rg2")
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(80, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(80, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// recover redundant nodes from default resource group
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(10, 20),
"rg2": newResourceGroupConfig(20, 30),
DefaultResourceGroupName: newResourceGroupConfig(10, 20),
})
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
// Even though the default resource group has 20 nodes limits,
// all redundant nodes will be assign to default resource group.
suite.Equal(20, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(30, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(50, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(30, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(50, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// Test recover missing from high priority resource group by set `from`.
suite.manager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{
suite.manager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 15,
},
@ -431,23 +439,23 @@ func (suite *ResourceManagerSuite) TestAutoRecover() {
ResourceGroup: "rg1",
}},
})
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
DefaultResourceGroupName: newResourceGroupConfig(30, 40),
})
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup("rg3")
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg3")
// Get 10 from default group for redundant nodes, get 5 from rg1 for rg3 at high priority.
suite.Equal(15, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(30, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(15, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(15, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(30, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(15, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// Test recover redundant to high priority resource group by set `to`.
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg3": {
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 0,
@ -463,21 +471,21 @@ func (suite *ResourceManagerSuite) TestAutoRecover() {
"rg2": newResourceGroupConfig(15, 40),
})
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup("rg3")
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg3")
// Recover rg3 by transfer 10 nodes to rg2 with high priority, 5 to rg1.
suite.Equal(20, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.testTransferNode()
// Test redundant nodes recover to default resource group.
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
DefaultResourceGroupName: newResourceGroupConfig(1, 1),
"rg3": newResourceGroupConfig(0, 0),
"rg2": newResourceGroupConfig(0, 0),
@ -485,107 +493,109 @@ func (suite *ResourceManagerSuite) TestAutoRecover() {
})
// Even default resource group has 1 node limit,
// all redundant nodes will be assign to default resource group if there's no resource group can hold.
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup("rg3")
suite.Equal(0, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(100, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, "rg3")
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(100, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// Test redundant recover to missing nodes and missing nodes from redundant nodes.
// Initialize
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
DefaultResourceGroupName: newResourceGroupConfig(0, 0),
"rg3": newResourceGroupConfig(10, 10),
"rg2": newResourceGroupConfig(80, 80),
"rg1": newResourceGroupConfig(10, 10),
})
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup("rg3")
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(80, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, "rg3")
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(80, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
DefaultResourceGroupName: newResourceGroupConfig(0, 5),
"rg3": newResourceGroupConfig(5, 5),
"rg2": newResourceGroupConfig(80, 80),
"rg1": newResourceGroupConfig(20, 30),
})
suite.manager.AutoRecoverResourceGroup("rg3") // recover redundant to missing rg.
suite.Equal(15, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(80, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.updateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.AutoRecoverResourceGroup(ctx, "rg3") // recover redundant to missing rg.
suite.Equal(15, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(80, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
suite.manager.updateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
DefaultResourceGroupName: newResourceGroupConfig(5, 5),
"rg3": newResourceGroupConfig(5, 10),
"rg2": newResourceGroupConfig(80, 80),
"rg1": newResourceGroupConfig(10, 10),
})
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) // recover missing from redundant rg.
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(80, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) // recover missing from redundant rg.
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(80, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
}
func (suite *ResourceManagerSuite) testTransferNode() {
ctx := suite.ctx
// Test redundant nodes recover to default resource group.
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
DefaultResourceGroupName: newResourceGroupConfig(40, 40),
"rg3": newResourceGroupConfig(0, 0),
"rg2": newResourceGroupConfig(40, 40),
"rg1": newResourceGroupConfig(20, 20),
})
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup("rg3")
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg3")
suite.Equal(20, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// Test TransferNode.
// param error.
err := suite.manager.TransferNode("rg1", "rg1", 1)
err := suite.manager.TransferNode(ctx, "rg1", "rg1", 1)
suite.Error(err)
err = suite.manager.TransferNode("rg1", "rg2", 0)
err = suite.manager.TransferNode(ctx, "rg1", "rg2", 0)
suite.Error(err)
err = suite.manager.TransferNode("rg3", "rg2", 1)
err = suite.manager.TransferNode(ctx, "rg3", "rg2", 1)
suite.Error(err)
err = suite.manager.TransferNode("rg1", "rg10086", 1)
err = suite.manager.TransferNode(ctx, "rg1", "rg10086", 1)
suite.Error(err)
err = suite.manager.TransferNode("rg10086", "rg2", 1)
err = suite.manager.TransferNode(ctx, "rg10086", "rg2", 1)
suite.Error(err)
// success
err = suite.manager.TransferNode("rg1", "rg3", 5)
err = suite.manager.TransferNode(ctx, "rg1", "rg3", 5)
suite.NoError(err)
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup("rg3")
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg3")
suite.Equal(15, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(15, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(5, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(40, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
}
func (suite *ResourceManagerSuite) TestIncomingNode() {
ctx := suite.ctx
suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
@ -593,15 +603,16 @@ func (suite *ResourceManagerSuite) TestIncomingNode() {
}))
suite.manager.incomingNode.Insert(1)
suite.Equal(1, suite.manager.CheckIncomingNodeNum())
suite.manager.AssignPendingIncomingNode()
suite.Equal(0, suite.manager.CheckIncomingNodeNum())
nodes, err := suite.manager.GetNodes(DefaultResourceGroupName)
suite.Equal(1, suite.manager.CheckIncomingNodeNum(ctx))
suite.manager.AssignPendingIncomingNode(ctx)
suite.Equal(0, suite.manager.CheckIncomingNodeNum(ctx))
nodes, err := suite.manager.GetNodes(ctx, DefaultResourceGroupName)
suite.NoError(err)
suite.Len(nodes, 1)
}
func (suite *ResourceManagerSuite) TestUnassignFail() {
ctx := suite.ctx
// suite.man
mockKV := mocks.NewMetaKv(suite.T())
mockKV.EXPECT().MultiSave(mock.Anything).Return(nil).Once()
@ -609,7 +620,7 @@ func (suite *ResourceManagerSuite) TestUnassignFail() {
store := querycoord.NewCatalog(mockKV)
suite.manager = NewResourceManager(store, session.NewNodeManager())
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": newResourceGroupConfig(20, 30),
})
@ -618,16 +629,17 @@ func (suite *ResourceManagerSuite) TestUnassignFail() {
Address: "localhost",
Hostname: "localhost",
}))
suite.manager.HandleNodeUp(1)
suite.manager.HandleNodeUp(ctx, 1)
mockKV.EXPECT().MultiSave(mock.Anything).Return(merr.WrapErrServiceInternal("mocked")).Once()
suite.Panics(func() {
suite.manager.HandleNodeDown(1)
suite.manager.HandleNodeDown(ctx, 1)
})
}
func TestGetResourceGroupsJSON(t *testing.T) {
ctx := context.Background()
nodeManager := session.NewNodeManager()
manager := &ResourceManager{groups: make(map[string]*ResourceGroup)}
rg1 := NewResourceGroup("rg1", newResourceGroupConfig(0, 10), nodeManager)
@ -637,7 +649,7 @@ func TestGetResourceGroupsJSON(t *testing.T) {
manager.groups["rg1"] = rg1
manager.groups["rg2"] = rg2
jsonOutput := manager.GetResourceGroupsJSON()
jsonOutput := manager.GetResourceGroupsJSON(ctx)
var resourceGroups []*metricsinfo.ResourceGroup
err := json.Unmarshal([]byte(jsonOutput), &resourceGroups)
assert.NoError(t, err)
@ -659,7 +671,8 @@ func TestGetResourceGroupsJSON(t *testing.T) {
}
func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() {
suite.manager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{
ctx := suite.ctx
suite.manager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 10,
},
@ -676,7 +689,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() {
},
})
suite.manager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{
suite.manager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 10,
},
@ -693,7 +706,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() {
},
})
suite.manager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{
suite.manager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 10,
},
@ -720,12 +733,12 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() {
"dc_name": "label1",
},
}))
suite.manager.HandleNodeUp(int64(i))
suite.manager.HandleNodeUp(ctx, int64(i))
}
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// test new querynode with label2
for i := 31; i <= 40; i++ {
@ -737,13 +750,13 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() {
"dc_name": "label2",
},
}))
suite.manager.HandleNodeUp(int64(i))
suite.manager.HandleNodeUp(ctx, int64(i))
}
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
nodesInRG, _ := suite.manager.GetNodes("rg2")
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
nodesInRG, _ := suite.manager.GetNodes(ctx, "rg2")
for _, node := range nodesInRG {
suite.Equal("label2", suite.manager.nodeMgr.Get(node).Labels()["dc_name"])
}
@ -758,19 +771,19 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() {
"dc_name": "label3",
},
}))
suite.manager.HandleNodeUp(int64(i))
suite.manager.HandleNodeUp(ctx, int64(i))
}
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
nodesInRG, _ = suite.manager.GetNodes("rg3")
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
nodesInRG, _ = suite.manager.GetNodes(ctx, "rg3")
for _, node := range nodesInRG {
suite.Equal("label3", suite.manager.nodeMgr.Get(node).Labels()["dc_name"])
}
// test swap rg's label
suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{
suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{
"rg1": {
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 10,
@ -823,33 +836,34 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() {
log.Info("test swap rg's label")
for i := 0; i < 4; i++ {
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup("rg3")
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, "rg3")
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
}
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
nodesInRG, _ = suite.manager.GetNodes("rg1")
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(20, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
nodesInRG, _ = suite.manager.GetNodes(ctx, "rg1")
for _, node := range nodesInRG {
suite.Equal("label2", suite.manager.nodeMgr.Get(node).Labels()["dc_name"])
}
nodesInRG, _ = suite.manager.GetNodes("rg2")
nodesInRG, _ = suite.manager.GetNodes(ctx, "rg2")
for _, node := range nodesInRG {
suite.Equal("label3", suite.manager.nodeMgr.Get(node).Labels()["dc_name"])
}
nodesInRG, _ = suite.manager.GetNodes("rg3")
nodesInRG, _ = suite.manager.GetNodes(ctx, "rg3")
for _, node := range nodesInRG {
suite.Equal("label1", suite.manager.nodeMgr.Get(node).Labels()["dc_name"])
}
}
func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
suite.manager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{
ctx := suite.ctx
suite.manager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 10,
},
@ -866,7 +880,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
},
})
suite.manager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{
suite.manager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 10,
},
@ -883,7 +897,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
},
})
suite.manager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{
suite.manager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 10,
},
@ -910,7 +924,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
"dc_name": "label1",
},
}))
suite.manager.HandleNodeUp(int64(i))
suite.manager.HandleNodeUp(ctx, int64(i))
}
// test new querynode with label2
@ -923,7 +937,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
"dc_name": "label2",
},
}))
suite.manager.HandleNodeUp(int64(i))
suite.manager.HandleNodeUp(ctx, int64(i))
}
// test new querynode with label3
for i := 41; i <= 50; i++ {
@ -935,18 +949,18 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
"dc_name": "label3",
},
}))
suite.manager.HandleNodeUp(int64(i))
suite.manager.HandleNodeUp(ctx, int64(i))
}
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
// test node down with label1
suite.manager.HandleNodeDown(int64(1))
suite.manager.HandleNodeDown(ctx, int64(1))
suite.manager.nodeMgr.Remove(int64(1))
suite.Equal(9, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(9, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
// test node up with label2
suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
@ -957,11 +971,11 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
"dc_name": "label2",
},
}))
suite.manager.HandleNodeUp(int64(101))
suite.Equal(9, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(1, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
suite.manager.HandleNodeUp(ctx, int64(101))
suite.Equal(9, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(1, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
// test node up with label1
suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
@ -972,21 +986,21 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
"dc_name": "label1",
},
}))
suite.manager.HandleNodeUp(int64(102))
suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum())
suite.Equal(1, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum())
nodesInRG, _ := suite.manager.GetNodes("rg1")
suite.manager.HandleNodeUp(ctx, int64(102))
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum())
suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum())
suite.Equal(1, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum())
nodesInRG, _ := suite.manager.GetNodes(ctx, "rg1")
for _, node := range nodesInRG {
suite.Equal("label1", suite.manager.nodeMgr.Get(node).Labels()["dc_name"])
}
suite.manager.AutoRecoverResourceGroup("rg1")
suite.manager.AutoRecoverResourceGroup("rg2")
suite.manager.AutoRecoverResourceGroup("rg3")
suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName)
nodesInRG, _ = suite.manager.GetNodes(DefaultResourceGroupName)
suite.manager.AutoRecoverResourceGroup(ctx, "rg1")
suite.manager.AutoRecoverResourceGroup(ctx, "rg2")
suite.manager.AutoRecoverResourceGroup(ctx, "rg3")
suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName)
nodesInRG, _ = suite.manager.GetNodes(ctx, DefaultResourceGroupName)
for _, node := range nodesInRG {
suite.Equal("label2", suite.manager.nodeMgr.Get(node).Labels()["dc_name"])
}

View File

@ -50,26 +50,26 @@ const (
)
type TargetManagerInterface interface {
UpdateCollectionCurrentTarget(collectionID int64) bool
UpdateCollectionNextTarget(collectionID int64) error
RemoveCollection(collectionID int64)
RemovePartition(collectionID int64, partitionIDs ...int64)
GetGrowingSegmentsByCollection(collectionID int64, scope TargetScope) typeutil.UniqueSet
GetGrowingSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) typeutil.UniqueSet
GetSealedSegmentsByCollection(collectionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo
GetSealedSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) map[int64]*datapb.SegmentInfo
GetDroppedSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) []int64
GetSealedSegmentsByPartition(collectionID int64, partitionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo
GetDmChannelsByCollection(collectionID int64, scope TargetScope) map[string]*DmChannel
GetDmChannel(collectionID int64, channel string, scope TargetScope) *DmChannel
GetSealedSegment(collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo
GetCollectionTargetVersion(collectionID int64, scope TargetScope) int64
IsCurrentTargetExist(collectionID int64, partitionID int64) bool
IsNextTargetExist(collectionID int64) bool
SaveCurrentTarget(catalog metastore.QueryCoordCatalog)
Recover(catalog metastore.QueryCoordCatalog) error
CanSegmentBeMoved(collectionID, segmentID int64) bool
GetTargetJSON(scope TargetScope) string
UpdateCollectionCurrentTarget(ctx context.Context, collectionID int64) bool
UpdateCollectionNextTarget(ctx context.Context, collectionID int64) error
RemoveCollection(ctx context.Context, collectionID int64)
RemovePartition(ctx context.Context, collectionID int64, partitionIDs ...int64)
GetGrowingSegmentsByCollection(ctx context.Context, collectionID int64, scope TargetScope) typeutil.UniqueSet
GetGrowingSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope) typeutil.UniqueSet
GetSealedSegmentsByCollection(ctx context.Context, collectionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo
GetSealedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope) map[int64]*datapb.SegmentInfo
GetDroppedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope) []int64
GetSealedSegmentsByPartition(ctx context.Context, collectionID int64, partitionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo
GetDmChannelsByCollection(ctx context.Context, collectionID int64, scope TargetScope) map[string]*DmChannel
GetDmChannel(ctx context.Context, collectionID int64, channel string, scope TargetScope) *DmChannel
GetSealedSegment(ctx context.Context, collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo
GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope TargetScope) int64
IsCurrentTargetExist(ctx context.Context, collectionID int64, partitionID int64) bool
IsNextTargetExist(ctx context.Context, collectionID int64) bool
SaveCurrentTarget(ctx context.Context, catalog metastore.QueryCoordCatalog)
Recover(ctx context.Context, catalog metastore.QueryCoordCatalog) error
CanSegmentBeMoved(ctx context.Context, collectionID, segmentID int64) bool
GetTargetJSON(ctx context.Context, scope TargetScope) string
}
type TargetManager struct {
@ -96,7 +96,7 @@ func NewTargetManager(broker Broker, meta *Meta) *TargetManager {
// UpdateCollectionCurrentTarget updates the current target to next target,
// WARN: DO NOT call this method for an existing collection as target observer running, or it will lead to a double-update,
// which may make the current target not available
func (mgr *TargetManager) UpdateCollectionCurrentTarget(collectionID int64) bool {
func (mgr *TargetManager) UpdateCollectionCurrentTarget(ctx context.Context, collectionID int64) bool {
mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock()
log := log.With(zap.Int64("collectionID", collectionID))
@ -137,7 +137,7 @@ func (mgr *TargetManager) UpdateCollectionCurrentTarget(collectionID int64) bool
// UpdateCollectionNextTarget updates the next target with new target pulled from DataCoord,
// WARN: DO NOT call this method for an existing collection as target observer running, or it will lead to a double-update,
// which may make the current target not available
func (mgr *TargetManager) UpdateCollectionNextTarget(collectionID int64) error {
func (mgr *TargetManager) UpdateCollectionNextTarget(ctx context.Context, collectionID int64) error {
var vChannelInfos []*datapb.VchannelInfo
var segmentInfos []*datapb.SegmentInfo
err := retry.Handle(context.TODO(), func() (bool, error) {
@ -155,7 +155,7 @@ func (mgr *TargetManager) UpdateCollectionNextTarget(collectionID int64) error {
mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock()
partitions := mgr.meta.GetPartitionsByCollection(collectionID)
partitions := mgr.meta.GetPartitionsByCollection(ctx, collectionID)
partitionIDs := lo.Map(partitions, func(partition *Partition, i int) int64 {
return partition.PartitionID
})
@ -223,7 +223,7 @@ func (mgr *TargetManager) mergeDmChannelInfo(infos []*datapb.VchannelInfo) *DmCh
}
// RemoveCollection removes all channels and segments in the given collection
func (mgr *TargetManager) RemoveCollection(collectionID int64) {
func (mgr *TargetManager) RemoveCollection(ctx context.Context, collectionID int64) {
mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock()
log.Info("remove collection from targets",
@ -245,7 +245,7 @@ func (mgr *TargetManager) RemoveCollection(collectionID int64) {
// RemovePartition removes all segment in the given partition,
// NOTE: this doesn't remove any channel even the given one is the only partition
func (mgr *TargetManager) RemovePartition(collectionID int64, partitionIDs ...int64) {
func (mgr *TargetManager) RemovePartition(ctx context.Context, collectionID int64, partitionIDs ...int64) {
mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock()
@ -352,7 +352,7 @@ func (mgr *TargetManager) getCollectionTarget(scope TargetScope, collectionID in
return nil
}
func (mgr *TargetManager) GetGrowingSegmentsByCollection(collectionID int64,
func (mgr *TargetManager) GetGrowingSegmentsByCollection(ctx context.Context, collectionID int64,
scope TargetScope,
) typeutil.UniqueSet {
mgr.rwMutex.RLock()
@ -374,7 +374,7 @@ func (mgr *TargetManager) GetGrowingSegmentsByCollection(collectionID int64,
return nil
}
func (mgr *TargetManager) GetGrowingSegmentsByChannel(collectionID int64,
func (mgr *TargetManager) GetGrowingSegmentsByChannel(ctx context.Context, collectionID int64,
channelName string,
scope TargetScope,
) typeutil.UniqueSet {
@ -398,7 +398,7 @@ func (mgr *TargetManager) GetGrowingSegmentsByChannel(collectionID int64,
return nil
}
func (mgr *TargetManager) GetSealedSegmentsByCollection(collectionID int64,
func (mgr *TargetManager) GetSealedSegmentsByCollection(ctx context.Context, collectionID int64,
scope TargetScope,
) map[int64]*datapb.SegmentInfo {
mgr.rwMutex.RLock()
@ -413,7 +413,7 @@ func (mgr *TargetManager) GetSealedSegmentsByCollection(collectionID int64,
return nil
}
func (mgr *TargetManager) GetSealedSegmentsByChannel(collectionID int64,
func (mgr *TargetManager) GetSealedSegmentsByChannel(ctx context.Context, collectionID int64,
channelName string,
scope TargetScope,
) map[int64]*datapb.SegmentInfo {
@ -437,7 +437,7 @@ func (mgr *TargetManager) GetSealedSegmentsByChannel(collectionID int64,
return nil
}
func (mgr *TargetManager) GetDroppedSegmentsByChannel(collectionID int64,
func (mgr *TargetManager) GetDroppedSegmentsByChannel(ctx context.Context, collectionID int64,
channelName string,
scope TargetScope,
) []int64 {
@ -454,7 +454,7 @@ func (mgr *TargetManager) GetDroppedSegmentsByChannel(collectionID int64,
return nil
}
func (mgr *TargetManager) GetSealedSegmentsByPartition(collectionID int64,
func (mgr *TargetManager) GetSealedSegmentsByPartition(ctx context.Context, collectionID int64,
partitionID int64,
scope TargetScope,
) map[int64]*datapb.SegmentInfo {
@ -478,7 +478,7 @@ func (mgr *TargetManager) GetSealedSegmentsByPartition(collectionID int64,
return nil
}
func (mgr *TargetManager) GetDmChannelsByCollection(collectionID int64, scope TargetScope) map[string]*DmChannel {
func (mgr *TargetManager) GetDmChannelsByCollection(ctx context.Context, collectionID int64, scope TargetScope) map[string]*DmChannel {
mgr.rwMutex.RLock()
defer mgr.rwMutex.RUnlock()
@ -491,7 +491,7 @@ func (mgr *TargetManager) GetDmChannelsByCollection(collectionID int64, scope Ta
return nil
}
func (mgr *TargetManager) GetDmChannel(collectionID int64, channel string, scope TargetScope) *DmChannel {
func (mgr *TargetManager) GetDmChannel(ctx context.Context, collectionID int64, channel string, scope TargetScope) *DmChannel {
mgr.rwMutex.RLock()
defer mgr.rwMutex.RUnlock()
@ -504,7 +504,7 @@ func (mgr *TargetManager) GetDmChannel(collectionID int64, channel string, scope
return nil
}
func (mgr *TargetManager) GetSealedSegment(collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo {
func (mgr *TargetManager) GetSealedSegment(ctx context.Context, collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo {
mgr.rwMutex.RLock()
defer mgr.rwMutex.RUnlock()
@ -518,7 +518,7 @@ func (mgr *TargetManager) GetSealedSegment(collectionID int64, id int64, scope T
return nil
}
func (mgr *TargetManager) GetCollectionTargetVersion(collectionID int64, scope TargetScope) int64 {
func (mgr *TargetManager) GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope TargetScope) int64 {
mgr.rwMutex.RLock()
defer mgr.rwMutex.RUnlock()
@ -532,7 +532,7 @@ func (mgr *TargetManager) GetCollectionTargetVersion(collectionID int64, scope T
return 0
}
func (mgr *TargetManager) IsCurrentTargetExist(collectionID int64, partitionID int64) bool {
func (mgr *TargetManager) IsCurrentTargetExist(ctx context.Context, collectionID int64, partitionID int64) bool {
mgr.rwMutex.RLock()
defer mgr.rwMutex.RUnlock()
@ -541,13 +541,13 @@ func (mgr *TargetManager) IsCurrentTargetExist(collectionID int64, partitionID i
return len(targets) > 0 && (targets[0].partitions.Contain(partitionID) || partitionID == common.AllPartitionsID) && len(targets[0].dmChannels) > 0
}
func (mgr *TargetManager) IsNextTargetExist(collectionID int64) bool {
newChannels := mgr.GetDmChannelsByCollection(collectionID, NextTarget)
func (mgr *TargetManager) IsNextTargetExist(ctx context.Context, collectionID int64) bool {
newChannels := mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)
return len(newChannels) > 0
}
func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog) {
func (mgr *TargetManager) SaveCurrentTarget(ctx context.Context, catalog metastore.QueryCoordCatalog) {
mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock()
if mgr.current != nil {
@ -562,7 +562,7 @@ func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog)
pool.Submit(func() (any, error) {
defer wg.Done()
ids := lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) int64 { return p.A })
if err := catalog.SaveCollectionTargets(lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) *querypb.CollectionTarget {
if err := catalog.SaveCollectionTargets(ctx, lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) *querypb.CollectionTarget {
return p.B
})...); err != nil {
log.Warn("failed to save current target for collection", zap.Int64s("collectionIDs", ids), zap.Error(err))
@ -587,11 +587,11 @@ func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog)
}
}
func (mgr *TargetManager) Recover(catalog metastore.QueryCoordCatalog) error {
func (mgr *TargetManager) Recover(ctx context.Context, catalog metastore.QueryCoordCatalog) error {
mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock()
targets, err := catalog.GetCollectionTargets()
targets, err := catalog.GetCollectionTargets(ctx)
if err != nil {
log.Warn("failed to recover collection target from etcd", zap.Error(err))
return err
@ -608,7 +608,7 @@ func (mgr *TargetManager) Recover(catalog metastore.QueryCoordCatalog) error {
)
// clear target info in meta store
err := catalog.RemoveCollectionTarget(t.GetCollectionID())
err := catalog.RemoveCollectionTarget(ctx, t.GetCollectionID())
if err != nil {
log.Warn("failed to clear collection target from etcd", zap.Error(err))
}
@ -618,7 +618,7 @@ func (mgr *TargetManager) Recover(catalog metastore.QueryCoordCatalog) error {
}
// if segment isn't l0 segment, and exist in current/next target, then it can be moved
func (mgr *TargetManager) CanSegmentBeMoved(collectionID, segmentID int64) bool {
func (mgr *TargetManager) CanSegmentBeMoved(ctx context.Context, collectionID, segmentID int64) bool {
mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock()
current := mgr.current.getCollectionTarget(collectionID)
@ -634,7 +634,7 @@ func (mgr *TargetManager) CanSegmentBeMoved(collectionID, segmentID int64) bool
return false
}
func (mgr *TargetManager) GetTargetJSON(scope TargetScope) string {
func (mgr *TargetManager) GetTargetJSON(ctx context.Context, scope TargetScope) string {
mgr.rwMutex.RLock()
defer mgr.rwMutex.RUnlock()

View File

@ -17,6 +17,7 @@
package meta
import (
"context"
"testing"
"time"
@ -60,6 +61,8 @@ type TargetManagerSuite struct {
broker *MockBroker
// Test object
mgr *TargetManager
ctx context.Context
}
func (suite *TargetManagerSuite) SetupSuite() {
@ -110,6 +113,7 @@ func (suite *TargetManagerSuite) SetupTest() {
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.ctx = context.Background()
// meta
suite.catalog = querycoord.NewCatalog(suite.kv)
@ -141,14 +145,14 @@ func (suite *TargetManagerSuite) SetupTest() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(dmChannels, allSegments, nil)
suite.meta.PutCollection(&Collection{
suite.meta.PutCollection(suite.ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection,
ReplicaNumber: 1,
},
})
for _, partition := range suite.partitions[collection] {
suite.meta.PutPartition(&Partition{
suite.meta.PutPartition(suite.ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
@ -156,7 +160,7 @@ func (suite *TargetManagerSuite) SetupTest() {
})
}
suite.mgr.UpdateCollectionNextTarget(collection)
suite.mgr.UpdateCollectionNextTarget(suite.ctx, collection)
}
}
@ -165,35 +169,37 @@ func (suite *TargetManagerSuite) TearDownSuite() {
}
func (suite *TargetManagerSuite) TestUpdateCurrentTarget() {
ctx := suite.ctx
collectionID := int64(1000)
suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]),
suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
suite.mgr.UpdateCollectionCurrentTarget(collectionID)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.mgr.UpdateCollectionCurrentTarget(ctx, collectionID)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]),
suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
}
func (suite *TargetManagerSuite) TestUpdateNextTarget() {
ctx := suite.ctx
collectionID := int64(1003)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
suite.meta.PutCollection(&Collection{
suite.meta.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collectionID,
ReplicaNumber: 1,
},
})
suite.meta.PutPartition(&Partition{
suite.meta.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collectionID,
PartitionID: 1,
@ -225,62 +231,64 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil)
suite.mgr.UpdateCollectionNextTarget(collectionID)
suite.assertSegments([]int64{11, 12}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{"channel-1", "channel-2"}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.mgr.UpdateCollectionNextTarget(ctx, collectionID)
suite.assertSegments([]int64{11, 12}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels([]string{"channel-1", "channel-2"}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
suite.broker.ExpectedCalls = nil
// test getRecoveryInfoV2 failed , then retry getRecoveryInfoV2 succeed
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nil, nil, errors.New("fake error")).Times(1)
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil)
err := suite.mgr.UpdateCollectionNextTarget(collectionID)
err := suite.mgr.UpdateCollectionNextTarget(ctx, collectionID)
suite.NoError(err)
err = suite.mgr.UpdateCollectionNextTarget(collectionID)
err = suite.mgr.UpdateCollectionNextTarget(ctx, collectionID)
suite.NoError(err)
}
func (suite *TargetManagerSuite) TestRemovePartition() {
ctx := suite.ctx
collectionID := int64(1000)
suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
suite.mgr.RemovePartition(collectionID, 100)
suite.assertSegments(append([]int64{3, 4}, suite.level0Segments...), suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.mgr.RemovePartition(ctx, collectionID, 100)
suite.assertSegments(append([]int64{3, 4}, suite.level0Segments...), suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
}
func (suite *TargetManagerSuite) TestRemoveCollection() {
ctx := suite.ctx
collectionID := int64(1000)
suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
suite.mgr.RemoveCollection(collectionID)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.mgr.RemoveCollection(ctx, collectionID)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
collectionID = int64(1001)
suite.mgr.UpdateCollectionCurrentTarget(collectionID)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.mgr.UpdateCollectionCurrentTarget(ctx, collectionID)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
suite.mgr.RemoveCollection(collectionID)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.mgr.RemoveCollection(ctx, collectionID)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
}
func (suite *TargetManagerSuite) getAllSegment(collectionID int64, partitionIDs []int64) []int64 {
@ -325,6 +333,7 @@ func (suite *TargetManagerSuite) assertSegments(expected []int64, actual map[int
}
func (suite *TargetManagerSuite) TestGetCollectionTargetVersion() {
ctx := suite.ctx
t1 := time.Now().UnixNano()
target := NewCollectionTarget(nil, nil, nil)
t2 := time.Now().UnixNano()
@ -335,28 +344,29 @@ func (suite *TargetManagerSuite) TestGetCollectionTargetVersion() {
collectionID := suite.collections[0]
t3 := time.Now().UnixNano()
suite.mgr.UpdateCollectionNextTarget(collectionID)
suite.mgr.UpdateCollectionNextTarget(ctx, collectionID)
t4 := time.Now().UnixNano()
collectionVersion := suite.mgr.GetCollectionTargetVersion(collectionID, NextTarget)
collectionVersion := suite.mgr.GetCollectionTargetVersion(ctx, collectionID, NextTarget)
suite.True(t3 <= collectionVersion)
suite.True(t4 >= collectionVersion)
}
func (suite *TargetManagerSuite) TestGetSegmentByChannel() {
ctx := suite.ctx
collectionID := int64(1003)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
suite.meta.PutCollection(&Collection{
suite.meta.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collectionID,
ReplicaNumber: 1,
},
})
suite.meta.PutPartition(&Partition{
suite.meta.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collectionID,
PartitionID: 1,
@ -391,17 +401,17 @@ func (suite *TargetManagerSuite) TestGetSegmentByChannel() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil)
suite.mgr.UpdateCollectionNextTarget(collectionID)
suite.Len(suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget), 2)
suite.Len(suite.mgr.GetSealedSegmentsByChannel(collectionID, "channel-1", NextTarget), 1)
suite.Len(suite.mgr.GetSealedSegmentsByChannel(collectionID, "channel-2", NextTarget), 1)
suite.Len(suite.mgr.GetGrowingSegmentsByChannel(collectionID, "channel-1", NextTarget), 4)
suite.Len(suite.mgr.GetGrowingSegmentsByChannel(collectionID, "channel-2", NextTarget), 1)
suite.Len(suite.mgr.GetDroppedSegmentsByChannel(collectionID, "channel-1", NextTarget), 3)
suite.Len(suite.mgr.GetGrowingSegmentsByCollection(collectionID, NextTarget), 5)
suite.Len(suite.mgr.GetSealedSegmentsByPartition(collectionID, 1, NextTarget), 2)
suite.NotNil(suite.mgr.GetSealedSegment(collectionID, 11, NextTarget))
suite.NotNil(suite.mgr.GetDmChannel(collectionID, "channel-1", NextTarget))
suite.mgr.UpdateCollectionNextTarget(ctx, collectionID)
suite.Len(suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget), 2)
suite.Len(suite.mgr.GetSealedSegmentsByChannel(ctx, collectionID, "channel-1", NextTarget), 1)
suite.Len(suite.mgr.GetSealedSegmentsByChannel(ctx, collectionID, "channel-2", NextTarget), 1)
suite.Len(suite.mgr.GetGrowingSegmentsByChannel(ctx, collectionID, "channel-1", NextTarget), 4)
suite.Len(suite.mgr.GetGrowingSegmentsByChannel(ctx, collectionID, "channel-2", NextTarget), 1)
suite.Len(suite.mgr.GetDroppedSegmentsByChannel(ctx, collectionID, "channel-1", NextTarget), 3)
suite.Len(suite.mgr.GetGrowingSegmentsByCollection(ctx, collectionID, NextTarget), 5)
suite.Len(suite.mgr.GetSealedSegmentsByPartition(ctx, collectionID, 1, NextTarget), 2)
suite.NotNil(suite.mgr.GetSealedSegment(ctx, collectionID, 11, NextTarget))
suite.NotNil(suite.mgr.GetDmChannel(ctx, collectionID, "channel-1", NextTarget))
}
func (suite *TargetManagerSuite) TestGetTarget() {
@ -535,19 +545,20 @@ func (suite *TargetManagerSuite) TestGetTarget() {
}
func (suite *TargetManagerSuite) TestRecover() {
ctx := suite.ctx
collectionID := int64(1003)
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget))
suite.meta.PutCollection(&Collection{
suite.meta.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collectionID,
ReplicaNumber: 1,
},
})
suite.meta.PutPartition(&Partition{
suite.meta.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collectionID,
PartitionID: 1,
@ -582,16 +593,16 @@ func (suite *TargetManagerSuite) TestRecover() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil)
suite.mgr.UpdateCollectionNextTarget(collectionID)
suite.mgr.UpdateCollectionCurrentTarget(collectionID)
suite.mgr.UpdateCollectionNextTarget(ctx, collectionID)
suite.mgr.UpdateCollectionCurrentTarget(ctx, collectionID)
suite.mgr.SaveCurrentTarget(suite.catalog)
suite.mgr.SaveCurrentTarget(ctx, suite.catalog)
// clear target in memory
version := suite.mgr.current.getCollectionTarget(collectionID).GetTargetVersion()
suite.mgr.current.removeCollectionTarget(collectionID)
// try to recover
suite.mgr.Recover(suite.catalog)
suite.mgr.Recover(ctx, suite.catalog)
target := suite.mgr.current.getCollectionTarget(collectionID)
suite.NotNil(target)
@ -600,20 +611,21 @@ func (suite *TargetManagerSuite) TestRecover() {
suite.Equal(target.GetTargetVersion(), version)
// after recover, target info should be cleaned up
targets, err := suite.catalog.GetCollectionTargets()
targets, err := suite.catalog.GetCollectionTargets(ctx)
suite.NoError(err)
suite.Len(targets, 0)
}
func (suite *TargetManagerSuite) TestGetTargetJSON() {
ctx := suite.ctx
collectionID := int64(1003)
suite.meta.PutCollection(&Collection{
suite.meta.PutCollection(ctx, &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collectionID,
ReplicaNumber: 1,
},
})
suite.meta.PutPartition(&Partition{
suite.meta.PutPartition(ctx, &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collectionID,
PartitionID: 1,
@ -648,10 +660,10 @@ func (suite *TargetManagerSuite) TestGetTargetJSON() {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil)
suite.NoError(suite.mgr.UpdateCollectionNextTarget(collectionID))
suite.True(suite.mgr.UpdateCollectionCurrentTarget(collectionID))
suite.NoError(suite.mgr.UpdateCollectionNextTarget(ctx, collectionID))
suite.True(suite.mgr.UpdateCollectionCurrentTarget(ctx, collectionID))
jsonStr := suite.mgr.GetTargetJSON(CurrentTarget)
jsonStr := suite.mgr.GetTargetJSON(ctx, CurrentTarget)
assert.NotEmpty(suite.T(), jsonStr)
var currentTarget []*metricsinfo.QueryCoordTarget

View File

@ -86,7 +86,7 @@ func NewCollectionObserver(
}
// Add load task for collection recovery
collections := meta.GetAllCollections()
collections := meta.GetAllCollections(context.TODO())
for _, collection := range collections {
ob.LoadCollection(context.Background(), collection.GetCollectionID())
}
@ -157,13 +157,13 @@ func (ob *CollectionObserver) LoadPartitions(ctx context.Context, collectionID i
}
func (ob *CollectionObserver) Observe(ctx context.Context) {
ob.observeTimeout()
ob.observeTimeout(ctx)
ob.observeLoadStatus(ctx)
}
func (ob *CollectionObserver) observeTimeout() {
func (ob *CollectionObserver) observeTimeout(ctx context.Context) {
ob.loadTasks.Range(func(traceID string, task LoadTask) bool {
collection := ob.meta.CollectionManager.GetCollection(task.CollectionID)
collection := ob.meta.CollectionManager.GetCollection(ctx, task.CollectionID)
// collection released
if collection == nil {
log.Info("Load Collection Task canceled, collection removed from meta", zap.Int64("collectionID", task.CollectionID), zap.String("traceID", traceID))
@ -178,14 +178,14 @@ func (ob *CollectionObserver) observeTimeout() {
log.Info("load collection timeout, cancel it",
zap.Int64("collectionID", collection.GetCollectionID()),
zap.Duration("loadTime", time.Since(collection.CreatedAt)))
ob.meta.CollectionManager.RemoveCollection(collection.GetCollectionID())
ob.meta.ReplicaManager.RemoveCollection(collection.GetCollectionID())
ob.meta.CollectionManager.RemoveCollection(ctx, collection.GetCollectionID())
ob.meta.ReplicaManager.RemoveCollection(ctx, collection.GetCollectionID())
ob.targetObserver.ReleaseCollection(collection.GetCollectionID())
ob.loadTasks.Remove(traceID)
}
case querypb.LoadType_LoadPartition:
partitionIDs := typeutil.NewSet(task.PartitionIDs...)
partitions := ob.meta.GetPartitionsByCollection(task.CollectionID)
partitions := ob.meta.GetPartitionsByCollection(ctx, task.CollectionID)
partitions = lo.Filter(partitions, func(partition *meta.Partition, _ int) bool {
return partitionIDs.Contain(partition.GetPartitionID())
})
@ -213,16 +213,16 @@ func (ob *CollectionObserver) observeTimeout() {
zap.Int64("collectionID", task.CollectionID),
zap.Int64s("partitionIDs", task.PartitionIDs))
for _, partition := range partitions {
ob.meta.CollectionManager.RemovePartition(partition.CollectionID, partition.GetPartitionID())
ob.meta.CollectionManager.RemovePartition(ctx, partition.CollectionID, partition.GetPartitionID())
ob.targetObserver.ReleasePartition(partition.GetCollectionID(), partition.GetPartitionID())
}
// all partition timeout, remove collection
if len(ob.meta.CollectionManager.GetPartitionsByCollection(task.CollectionID)) == 0 {
if len(ob.meta.CollectionManager.GetPartitionsByCollection(ctx, task.CollectionID)) == 0 {
log.Info("collection timeout due to all partition removed", zap.Int64("collection", task.CollectionID))
ob.meta.CollectionManager.RemoveCollection(task.CollectionID)
ob.meta.ReplicaManager.RemoveCollection(task.CollectionID)
ob.meta.CollectionManager.RemoveCollection(ctx, task.CollectionID)
ob.meta.ReplicaManager.RemoveCollection(ctx, task.CollectionID)
ob.targetObserver.ReleaseCollection(task.CollectionID)
}
}
@ -231,9 +231,9 @@ func (ob *CollectionObserver) observeTimeout() {
})
}
func (ob *CollectionObserver) readyToObserve(collectionID int64) bool {
metaExist := (ob.meta.GetCollection(collectionID) != nil)
targetExist := ob.targetMgr.IsNextTargetExist(collectionID) || ob.targetMgr.IsCurrentTargetExist(collectionID, common.AllPartitionsID)
func (ob *CollectionObserver) readyToObserve(ctx context.Context, collectionID int64) bool {
metaExist := (ob.meta.GetCollection(ctx, collectionID) != nil)
targetExist := ob.targetMgr.IsNextTargetExist(ctx, collectionID) || ob.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID)
return metaExist && targetExist
}
@ -243,7 +243,7 @@ func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) {
ob.loadTasks.Range(func(traceID string, task LoadTask) bool {
loading = true
collection := ob.meta.CollectionManager.GetCollection(task.CollectionID)
collection := ob.meta.CollectionManager.GetCollection(ctx, task.CollectionID)
if collection == nil {
return true
}
@ -251,10 +251,10 @@ func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) {
var partitions []*meta.Partition
switch task.LoadType {
case querypb.LoadType_LoadCollection:
partitions = ob.meta.GetPartitionsByCollection(task.CollectionID)
partitions = ob.meta.GetPartitionsByCollection(ctx, task.CollectionID)
case querypb.LoadType_LoadPartition:
partitionIDs := typeutil.NewSet[int64](task.PartitionIDs...)
partitions = ob.meta.GetPartitionsByCollection(task.CollectionID)
partitions = ob.meta.GetPartitionsByCollection(ctx, task.CollectionID)
partitions = lo.Filter(partitions, func(partition *meta.Partition, _ int) bool {
return partitionIDs.Contain(partition.GetPartitionID())
})
@ -265,11 +265,11 @@ func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) {
if partition.LoadPercentage == 100 {
continue
}
if ob.readyToObserve(partition.CollectionID) {
replicaNum := ob.meta.GetReplicaNumber(partition.GetCollectionID())
if ob.readyToObserve(ctx, partition.CollectionID) {
replicaNum := ob.meta.GetReplicaNumber(ctx, partition.GetCollectionID())
ob.observePartitionLoadStatus(ctx, partition, replicaNum)
}
partition = ob.meta.GetPartition(partition.PartitionID)
partition = ob.meta.GetPartition(ctx, partition.PartitionID)
if partition != nil && partition.LoadPercentage != 100 {
loaded = false
}
@ -299,8 +299,8 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa
zap.Int64("partitionID", partition.GetPartitionID()),
)
segmentTargets := ob.targetMgr.GetSealedSegmentsByPartition(partition.GetCollectionID(), partition.GetPartitionID(), meta.NextTarget)
channelTargets := ob.targetMgr.GetDmChannelsByCollection(partition.GetCollectionID(), meta.NextTarget)
segmentTargets := ob.targetMgr.GetSealedSegmentsByPartition(ctx, partition.GetCollectionID(), partition.GetPartitionID(), meta.NextTarget)
channelTargets := ob.targetMgr.GetDmChannelsByCollection(ctx, partition.GetCollectionID(), meta.NextTarget)
targetNum := len(segmentTargets) + len(channelTargets)
if targetNum == 0 {
@ -320,7 +320,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa
for _, channel := range channelTargets {
views := ob.dist.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(channel.GetChannelName()))
nodes := lo.Map(views, func(v *meta.LeaderView, _ int) int64 { return v.ID })
group := utils.GroupNodesByReplica(ob.meta.ReplicaManager, partition.GetCollectionID(), nodes)
group := utils.GroupNodesByReplica(ctx, ob.meta.ReplicaManager, partition.GetCollectionID(), nodes)
loadedCount += len(group)
}
subChannelCount := loadedCount
@ -329,7 +329,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa
meta.WithChannelName2LeaderView(segment.GetInsertChannel()),
meta.WithSegment2LeaderView(segment.GetID(), false))
nodes := lo.Map(views, func(view *meta.LeaderView, _ int) int64 { return view.ID })
group := utils.GroupNodesByReplica(ob.meta.ReplicaManager, partition.GetCollectionID(), nodes)
group := utils.GroupNodesByReplica(ctx, ob.meta.ReplicaManager, partition.GetCollectionID(), nodes)
loadedCount += len(group)
}
if loadedCount > 0 {
@ -352,7 +352,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa
}
delete(ob.partitionLoadedCount, partition.GetPartitionID())
}
collectionPercentage, err := ob.meta.CollectionManager.UpdateLoadPercent(partition.PartitionID, loadPercentage)
collectionPercentage, err := ob.meta.CollectionManager.UpdateLoadPercent(ctx, partition.PartitionID, loadPercentage)
if err != nil {
log.Warn("failed to update load percentage")
}

View File

@ -75,6 +75,8 @@ type CollectionObserverSuite struct {
// Test object
ob *CollectionObserver
ctx context.Context
}
func (suite *CollectionObserverSuite) SetupSuite() {
@ -236,6 +238,7 @@ func (suite *CollectionObserverSuite) SetupTest() {
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 3,
}))
suite.ctx = context.Background()
}
func (suite *CollectionObserverSuite) TearDownTest() {
@ -249,7 +252,7 @@ func (suite *CollectionObserverSuite) TestObserve() {
timeout = 3 * time.Second
)
// time before load
time := suite.meta.GetCollection(suite.collections[2]).UpdatedAt
time := suite.meta.GetCollection(suite.ctx, suite.collections[2]).UpdatedAt
// Not timeout
paramtable.Get().Save(Params.QueryCoordCfg.LoadTimeoutSeconds.Key, "3")
@ -357,12 +360,13 @@ func (suite *CollectionObserverSuite) TestObservePartition() {
}
func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool {
exist := suite.meta.Exist(collection)
percentage := suite.meta.CalculateLoadPercentage(collection)
status := suite.meta.CalculateLoadStatus(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget)
segments := suite.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget)
ctx := suite.ctx
exist := suite.meta.Exist(ctx, collection)
percentage := suite.meta.CalculateLoadPercentage(ctx, collection)
status := suite.meta.CalculateLoadStatus(ctx, collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
channels := suite.targetMgr.GetDmChannelsByCollection(ctx, collection, meta.CurrentTarget)
segments := suite.targetMgr.GetSealedSegmentsByCollection(ctx, collection, meta.CurrentTarget)
return exist &&
percentage == 100 &&
@ -373,15 +377,16 @@ func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool
}
func (suite *CollectionObserverSuite) isPartitionLoaded(partitionID int64) bool {
partition := suite.meta.GetPartition(partitionID)
ctx := suite.ctx
partition := suite.meta.GetPartition(ctx, partitionID)
if partition == nil {
return false
}
collection := partition.GetCollectionID()
percentage := suite.meta.GetPartitionLoadPercentage(partitionID)
percentage := suite.meta.GetPartitionLoadPercentage(ctx, partitionID)
status := partition.GetStatus()
channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget)
segments := suite.targetMgr.GetSealedSegmentsByPartition(collection, partitionID, meta.CurrentTarget)
channels := suite.targetMgr.GetDmChannelsByCollection(ctx, collection, meta.CurrentTarget)
segments := suite.targetMgr.GetSealedSegmentsByPartition(ctx, collection, partitionID, meta.CurrentTarget)
expectedSegments := lo.Filter(suite.segments[collection], func(seg *datapb.SegmentInfo, _ int) bool {
return seg.PartitionID == partitionID
})
@ -392,10 +397,11 @@ func (suite *CollectionObserverSuite) isPartitionLoaded(partitionID int64) bool
}
func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool {
exist := suite.meta.Exist(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget)
segments := suite.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget)
ctx := suite.ctx
exist := suite.meta.Exist(ctx, collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
channels := suite.targetMgr.GetDmChannelsByCollection(ctx, collection, meta.CurrentTarget)
segments := suite.targetMgr.GetSealedSegmentsByCollection(ctx, collection, meta.CurrentTarget)
return !(exist ||
len(replicas) > 0 ||
len(channels) > 0 ||
@ -403,36 +409,39 @@ func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool
}
func (suite *CollectionObserverSuite) isPartitionTimeout(collection int64, partitionID int64) bool {
partition := suite.meta.GetPartition(partitionID)
segments := suite.targetMgr.GetSealedSegmentsByPartition(collection, partitionID, meta.CurrentTarget)
ctx := suite.ctx
partition := suite.meta.GetPartition(ctx, partitionID)
segments := suite.targetMgr.GetSealedSegmentsByPartition(ctx, collection, partitionID, meta.CurrentTarget)
return partition == nil && len(segments) == 0
}
func (suite *CollectionObserverSuite) isCollectionLoadedContinue(collection int64, beforeTime time.Time) bool {
return suite.meta.GetCollection(collection).UpdatedAt.After(beforeTime)
return suite.meta.GetCollection(suite.ctx, collection).UpdatedAt.After(beforeTime)
}
func (suite *CollectionObserverSuite) loadAll() {
ctx := suite.ctx
for _, collection := range suite.collections {
suite.load(collection)
}
suite.targetMgr.UpdateCollectionCurrentTarget(suite.collections[0])
suite.targetMgr.UpdateCollectionNextTarget(suite.collections[0])
suite.targetMgr.UpdateCollectionCurrentTarget(suite.collections[2])
suite.targetMgr.UpdateCollectionNextTarget(suite.collections[2])
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, suite.collections[0])
suite.targetMgr.UpdateCollectionNextTarget(ctx, suite.collections[0])
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, suite.collections[2])
suite.targetMgr.UpdateCollectionNextTarget(ctx, suite.collections[2])
}
func (suite *CollectionObserverSuite) load(collection int64) {
ctx := suite.ctx
// Mock meta data
replicas, err := suite.meta.ReplicaManager.Spawn(collection, map[string]int{meta.DefaultResourceGroupName: int(suite.replicaNumber[collection])}, nil)
replicas, err := suite.meta.ReplicaManager.Spawn(ctx, collection, map[string]int{meta.DefaultResourceGroupName: int(suite.replicaNumber[collection])}, nil)
suite.NoError(err)
for _, replica := range replicas {
replica.AddRWNode(suite.nodes...)
}
err = suite.meta.ReplicaManager.Put(replicas...)
err = suite.meta.ReplicaManager.Put(ctx, replicas...)
suite.NoError(err)
suite.meta.PutCollection(&meta.Collection{
suite.meta.PutCollection(ctx, &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[collection],
@ -444,7 +453,7 @@ func (suite *CollectionObserverSuite) load(collection int64) {
})
for _, partition := range suite.partitions[collection] {
suite.meta.PutPartition(&meta.Partition{
suite.meta.PutPartition(ctx, &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
@ -474,7 +483,7 @@ func (suite *CollectionObserverSuite) load(collection int64) {
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(dmChannels, allSegments, nil)
suite.targetMgr.UpdateCollectionNextTarget(collection)
suite.targetMgr.UpdateCollectionNextTarget(ctx, collection)
suite.ob.LoadCollection(context.Background(), collection)
}

View File

@ -71,7 +71,7 @@ func (ob *ReplicaObserver) schedule(ctx context.Context) {
defer ob.wg.Done()
log.Info("Start check replica loop")
listener := ob.meta.ResourceManager.ListenNodeChanged()
listener := ob.meta.ResourceManager.ListenNodeChanged(ctx)
for {
ob.waitNodeChangedOrTimeout(ctx, listener)
// stop if the context is canceled.
@ -92,15 +92,16 @@ func (ob *ReplicaObserver) waitNodeChangedOrTimeout(ctx context.Context, listene
}
func (ob *ReplicaObserver) checkNodesInReplica() {
log := log.Ctx(context.Background()).WithRateGroup("qcv2.replicaObserver", 1, 60)
collections := ob.meta.GetAll()
ctx := context.Background()
log := log.Ctx(ctx).WithRateGroup("qcv2.replicaObserver", 1, 60)
collections := ob.meta.GetAll(ctx)
for _, collectionID := range collections {
utils.RecoverReplicaOfCollection(ob.meta, collectionID)
utils.RecoverReplicaOfCollection(ctx, ob.meta, collectionID)
}
// check all ro nodes, remove it from replica if all segment/channel has been moved
for _, collectionID := range collections {
replicas := ob.meta.ReplicaManager.GetByCollection(collectionID)
replicas := ob.meta.ReplicaManager.GetByCollection(ctx, collectionID)
for _, replica := range replicas {
roNodes := replica.GetRONodes()
rwNodes := replica.GetRWNodes()
@ -130,7 +131,7 @@ func (ob *ReplicaObserver) checkNodesInReplica() {
zap.Int64s("roNodes", roNodes),
zap.Int64s("rwNodes", rwNodes),
)
if err := ob.meta.ReplicaManager.RemoveNode(replica.GetID(), removeNodes...); err != nil {
if err := ob.meta.ReplicaManager.RemoveNode(ctx, replica.GetID(), removeNodes...); err != nil {
logger.Warn("fail to remove node from replica", zap.Error(err))
continue
}

View File

@ -16,6 +16,7 @@
package observers
import (
"context"
"testing"
"time"
@ -47,6 +48,7 @@ type ReplicaObserverSuite struct {
collectionID int64
partitionID int64
ctx context.Context
}
func (suite *ReplicaObserverSuite) SetupSuite() {
@ -67,6 +69,7 @@ func (suite *ReplicaObserverSuite) SetupTest() {
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.ctx = context.Background()
// meta
store := querycoord.NewCatalog(suite.kv)
@ -82,11 +85,12 @@ func (suite *ReplicaObserverSuite) SetupTest() {
}
func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() {
suite.meta.ResourceManager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{
ctx := suite.ctx
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 2},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 2},
})
suite.meta.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 2},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 2},
})
@ -110,14 +114,14 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() {
Address: "localhost:8080",
Hostname: "localhost",
}))
suite.meta.ResourceManager.HandleNodeUp(1)
suite.meta.ResourceManager.HandleNodeUp(2)
suite.meta.ResourceManager.HandleNodeUp(3)
suite.meta.ResourceManager.HandleNodeUp(4)
suite.meta.ResourceManager.HandleNodeUp(ctx, 1)
suite.meta.ResourceManager.HandleNodeUp(ctx, 2)
suite.meta.ResourceManager.HandleNodeUp(ctx, 3)
suite.meta.ResourceManager.HandleNodeUp(ctx, 4)
err := suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 2))
err := suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(suite.collectionID, 2))
suite.NoError(err)
replicas, err := suite.meta.Spawn(suite.collectionID, map[string]int{
replicas, err := suite.meta.Spawn(ctx, suite.collectionID, map[string]int{
"rg1": 1,
"rg2": 1,
}, nil)
@ -127,7 +131,7 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() {
suite.Eventually(func() bool {
availableNodes := typeutil.NewUniqueSet()
for _, r := range replicas {
replica := suite.meta.ReplicaManager.Get(r.GetID())
replica := suite.meta.ReplicaManager.Get(ctx, r.GetID())
suite.NotNil(replica)
if replica.RWNodesCount() != 2 {
return false
@ -151,13 +155,13 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() {
}
// Do a replica transfer.
suite.meta.ReplicaManager.TransferReplica(suite.collectionID, "rg1", "rg2", 1)
suite.meta.ReplicaManager.TransferReplica(ctx, suite.collectionID, "rg1", "rg2", 1)
// All replica should in the rg2 but not rg1
// And some nodes will become ro nodes before all segment and channel on it is cleaned.
suite.Eventually(func() bool {
for _, r := range replicas {
replica := suite.meta.ReplicaManager.Get(r.GetID())
replica := suite.meta.ReplicaManager.Get(ctx, r.GetID())
suite.NotNil(replica)
suite.Equal("rg2", replica.GetResourceGroup())
// all replica should have ro nodes.
@ -178,7 +182,7 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() {
suite.Eventually(func() bool {
for _, r := range replicas {
replica := suite.meta.ReplicaManager.Get(r.GetID())
replica := suite.meta.ReplicaManager.Get(ctx, r.GetID())
suite.NotNil(replica)
suite.Equal("rg2", replica.GetResourceGroup())
if replica.RONodesCount() > 0 {

View File

@ -69,7 +69,7 @@ func (ob *ResourceObserver) schedule(ctx context.Context) {
defer ob.wg.Done()
log.Info("Start check resource group loop")
listener := ob.meta.ResourceManager.ListenResourceGroupChanged()
listener := ob.meta.ResourceManager.ListenResourceGroupChanged(ctx)
for {
ob.waitRGChangedOrTimeout(ctx, listener)
// stop if the context is canceled.
@ -79,7 +79,7 @@ func (ob *ResourceObserver) schedule(ctx context.Context) {
}
// do check once.
ob.checkAndRecoverResourceGroup()
ob.checkAndRecoverResourceGroup(ctx)
}
}
@ -89,29 +89,29 @@ func (ob *ResourceObserver) waitRGChangedOrTimeout(ctx context.Context, listener
listener.Wait(ctxWithTimeout)
}
func (ob *ResourceObserver) checkAndRecoverResourceGroup() {
func (ob *ResourceObserver) checkAndRecoverResourceGroup(ctx context.Context) {
manager := ob.meta.ResourceManager
rgNames := manager.ListResourceGroups()
rgNames := manager.ListResourceGroups(ctx)
enableRGAutoRecover := params.Params.QueryCoordCfg.EnableRGAutoRecover.GetAsBool()
log.Debug("start to check resource group", zap.Bool("enableRGAutoRecover", enableRGAutoRecover), zap.Int("resourceGroupNum", len(rgNames)))
// Check if there is any incoming node.
if manager.CheckIncomingNodeNum() > 0 {
log.Info("new incoming node is ready to be assigned...", zap.Int("incomingNodeNum", manager.CheckIncomingNodeNum()))
manager.AssignPendingIncomingNode()
if manager.CheckIncomingNodeNum(ctx) > 0 {
log.Info("new incoming node is ready to be assigned...", zap.Int("incomingNodeNum", manager.CheckIncomingNodeNum(ctx)))
manager.AssignPendingIncomingNode(ctx)
}
log.Debug("recover resource groups...")
// Recover all resource group into expected configuration.
for _, rgName := range rgNames {
if err := manager.MeetRequirement(rgName); err != nil {
if err := manager.MeetRequirement(ctx, rgName); err != nil {
log.Info("found resource group need to be recovered",
zap.String("rgName", rgName),
zap.String("reason", err.Error()),
)
if enableRGAutoRecover {
err := manager.AutoRecoverResourceGroup(rgName)
err := manager.AutoRecoverResourceGroup(ctx, rgName)
if err != nil {
log.Warn("failed to recover resource group",
zap.String("rgName", rgName),

View File

@ -16,6 +16,7 @@
package observers
import (
"context"
"fmt"
"testing"
"time"
@ -47,6 +48,8 @@ type ResourceObserverSuite struct {
collectionID int64
partitionID int64
ctx context.Context
}
func (suite *ResourceObserverSuite) SetupSuite() {
@ -67,6 +70,7 @@ func (suite *ResourceObserverSuite) SetupTest() {
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdKV.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.ctx = context.Background()
// meta
suite.store = mocks.NewQueryCoordCatalog(suite.T())
@ -76,15 +80,15 @@ func (suite *ResourceObserverSuite) SetupTest() {
suite.observer = NewResourceObserver(suite.meta)
suite.store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil)
suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil)
suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(nil)
for i := 0; i < 10; i++ {
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: int64(i),
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.HandleNodeUp(int64(i))
suite.meta.ResourceManager.HandleNodeUp(suite.ctx, int64(i))
}
}
@ -93,80 +97,82 @@ func (suite *ResourceObserverSuite) TearDownTest() {
}
func (suite *ResourceObserverSuite) TestObserverRecoverOperation() {
suite.meta.ResourceManager.AddResourceGroup("rg", &rgpb.ResourceGroupConfig{
ctx := suite.ctx
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 4},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 6},
})
suite.Error(suite.meta.ResourceManager.MeetRequirement("rg"))
suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg"))
// There's 10 exists node in cluster, new incoming resource group should get 4 nodes after recover.
suite.observer.checkAndRecoverResourceGroup()
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg"))
suite.observer.checkAndRecoverResourceGroup(ctx)
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg"))
suite.meta.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 6},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 10},
})
suite.Error(suite.meta.ResourceManager.MeetRequirement("rg2"))
suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2"))
// There's 10 exists node in cluster, new incoming resource group should get 6 nodes after recover.
suite.observer.checkAndRecoverResourceGroup()
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2"))
suite.observer.checkAndRecoverResourceGroup(ctx)
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2"))
suite.meta.ResourceManager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 1},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 1},
})
suite.Error(suite.meta.ResourceManager.MeetRequirement("rg3"))
suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3"))
// There's 10 exists node in cluster, but has been occupied by rg1 and rg2, new incoming resource group cannot get any node.
suite.observer.checkAndRecoverResourceGroup()
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2"))
suite.Error(suite.meta.ResourceManager.MeetRequirement("rg3"))
suite.observer.checkAndRecoverResourceGroup(ctx)
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2"))
suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3"))
// New node up, rg3 should get the node.
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 10,
}))
suite.meta.ResourceManager.HandleNodeUp(10)
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg3"))
suite.meta.ResourceManager.HandleNodeUp(ctx, 10)
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3"))
// new node is down, rg3 cannot use that node anymore.
suite.nodeMgr.Remove(10)
suite.meta.ResourceManager.HandleNodeDown(10)
suite.observer.checkAndRecoverResourceGroup()
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2"))
suite.Error(suite.meta.ResourceManager.MeetRequirement("rg3"))
suite.meta.ResourceManager.HandleNodeDown(ctx, 10)
suite.observer.checkAndRecoverResourceGroup(ctx)
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2"))
suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3"))
// create a new incoming node failure.
suite.store.EXPECT().SaveResourceGroup(mock.Anything).Unset()
suite.store.EXPECT().SaveResourceGroup(mock.Anything).Return(errors.New("failure"))
suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Unset()
suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(errors.New("failure"))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 11,
}))
// should be failure, so new node cannot be used by rg3.
suite.meta.ResourceManager.HandleNodeUp(11)
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2"))
suite.Error(suite.meta.ResourceManager.MeetRequirement("rg3"))
suite.store.EXPECT().SaveResourceGroup(mock.Anything).Unset()
suite.store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil)
suite.meta.ResourceManager.HandleNodeUp(ctx, 11)
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2"))
suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3"))
suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Unset()
suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil)
// storage recovered, so next recover will be success.
suite.observer.checkAndRecoverResourceGroup()
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg3"))
suite.observer.checkAndRecoverResourceGroup(ctx)
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2"))
suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3"))
}
func (suite *ResourceObserverSuite) TestSchedule() {
suite.observer.Start()
defer suite.observer.Stop()
ctx := suite.ctx
check := func() {
suite.Eventually(func() bool {
rgs := suite.meta.ResourceManager.ListResourceGroups()
rgs := suite.meta.ResourceManager.ListResourceGroups(ctx)
for _, rg := range rgs {
if err := suite.meta.ResourceManager.GetResourceGroup(rg).MeetRequirement(); err != nil {
if err := suite.meta.ResourceManager.GetResourceGroup(ctx, rg).MeetRequirement(); err != nil {
return false
}
}
@ -175,7 +181,7 @@ func (suite *ResourceObserverSuite) TestSchedule() {
}
for i := 1; i <= 4; i++ {
suite.meta.ResourceManager.AddResourceGroup(fmt.Sprintf("rg%d", i), &rgpb.ResourceGroupConfig{
suite.meta.ResourceManager.AddResourceGroup(ctx, fmt.Sprintf("rg%d", i), &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: int32(i)},
Limits: &rgpb.ResourceGroupLimit{NodeNum: int32(i)},
})
@ -183,7 +189,7 @@ func (suite *ResourceObserverSuite) TestSchedule() {
check()
for i := 1; i <= 4; i++ {
suite.meta.ResourceManager.AddResourceGroup(fmt.Sprintf("rg%d", i), &rgpb.ResourceGroupConfig{
suite.meta.ResourceManager.AddResourceGroup(ctx, fmt.Sprintf("rg%d", i), &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 0},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 0},
})

View File

@ -169,14 +169,14 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
return
case <-ob.initChan:
for _, collectionID := range ob.meta.GetAll() {
for _, collectionID := range ob.meta.GetAll(ctx) {
ob.init(ctx, collectionID)
}
log.Info("target observer init done")
case <-ticker.C:
ob.clean()
loaded := lo.FilterMap(ob.meta.GetAllCollections(), func(collection *meta.Collection, _ int) (int64, bool) {
loaded := lo.FilterMap(ob.meta.GetAllCollections(ctx), func(collection *meta.Collection, _ int) (int64, bool) {
if collection.GetStatus() == querypb.LoadStatus_Loaded {
return collection.GetCollectionID(), true
}
@ -192,7 +192,7 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
switch req.opType {
case UpdateCollection:
ob.keylocks.Lock(req.CollectionID)
err := ob.updateNextTarget(req.CollectionID)
err := ob.updateNextTarget(ctx, req.CollectionID)
ob.keylocks.Unlock(req.CollectionID)
if err != nil {
log.Warn("failed to manually update next target",
@ -214,10 +214,10 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
delete(ob.readyNotifiers, req.CollectionID)
ob.mut.Unlock()
ob.targetMgr.RemoveCollection(req.CollectionID)
ob.targetMgr.RemoveCollection(ctx, req.CollectionID)
req.Notifier <- nil
case ReleasePartition:
ob.targetMgr.RemovePartition(req.CollectionID, req.PartitionIDs...)
ob.targetMgr.RemovePartition(ctx, req.CollectionID, req.PartitionIDs...)
req.Notifier <- nil
}
log.Info("manually trigger update target done",
@ -230,7 +230,7 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
// Check whether provided collection is has current target.
// If not, submit an async task into dispatcher.
func (ob *TargetObserver) Check(ctx context.Context, collectionID int64, partitionID int64) bool {
result := ob.targetMgr.IsCurrentTargetExist(collectionID, partitionID)
result := ob.targetMgr.IsCurrentTargetExist(ctx, collectionID, partitionID)
if !result {
ob.loadingDispatcher.AddTask(collectionID)
}
@ -246,24 +246,24 @@ func (ob *TargetObserver) check(ctx context.Context, collectionID int64) {
defer ob.keylocks.Unlock(collectionID)
if ob.shouldUpdateCurrentTarget(ctx, collectionID) {
ob.updateCurrentTarget(collectionID)
ob.updateCurrentTarget(ctx, collectionID)
}
if ob.shouldUpdateNextTarget(collectionID) {
if ob.shouldUpdateNextTarget(ctx, collectionID) {
// update next target in collection level
ob.updateNextTarget(collectionID)
ob.updateNextTarget(ctx, collectionID)
}
}
func (ob *TargetObserver) init(ctx context.Context, collectionID int64) {
// pull next target first if not exist
if !ob.targetMgr.IsNextTargetExist(collectionID) {
ob.updateNextTarget(collectionID)
if !ob.targetMgr.IsNextTargetExist(ctx, collectionID) {
ob.updateNextTarget(ctx, collectionID)
}
// try to update current target if all segment/channel are ready
if ob.shouldUpdateCurrentTarget(ctx, collectionID) {
ob.updateCurrentTarget(collectionID)
ob.updateCurrentTarget(ctx, collectionID)
}
// refresh collection loading status upon restart
ob.check(ctx, collectionID)
@ -310,7 +310,7 @@ func (ob *TargetObserver) ReleasePartition(collectionID int64, partitionID ...in
}
func (ob *TargetObserver) clean() {
collectionSet := typeutil.NewUniqueSet(ob.meta.GetAll()...)
collectionSet := typeutil.NewUniqueSet(ob.meta.GetAll(context.TODO())...)
// for collection which has been removed from target, try to clear nextTargetLastUpdate
ob.nextTargetLastUpdate.Range(func(collectionID int64, _ time.Time) bool {
if !collectionSet.Contain(collectionID) {
@ -331,8 +331,8 @@ func (ob *TargetObserver) clean() {
}
}
func (ob *TargetObserver) shouldUpdateNextTarget(collectionID int64) bool {
return !ob.targetMgr.IsNextTargetExist(collectionID) || ob.isNextTargetExpired(collectionID)
func (ob *TargetObserver) shouldUpdateNextTarget(ctx context.Context, collectionID int64) bool {
return !ob.targetMgr.IsNextTargetExist(ctx, collectionID) || ob.isNextTargetExpired(collectionID)
}
func (ob *TargetObserver) isNextTargetExpired(collectionID int64) bool {
@ -343,12 +343,12 @@ func (ob *TargetObserver) isNextTargetExpired(collectionID int64) bool {
return time.Since(lastUpdated) > params.Params.QueryCoordCfg.NextTargetSurviveTime.GetAsDuration(time.Second)
}
func (ob *TargetObserver) updateNextTarget(collectionID int64) error {
func (ob *TargetObserver) updateNextTarget(ctx context.Context, collectionID int64) error {
log := log.Ctx(context.TODO()).WithRateGroup("qcv2.TargetObserver", 1, 60).
With(zap.Int64("collectionID", collectionID))
log.RatedInfo(10, "observer trigger update next target")
err := ob.targetMgr.UpdateCollectionNextTarget(collectionID)
err := ob.targetMgr.UpdateCollectionNextTarget(ctx, collectionID)
if err != nil {
log.Warn("failed to update next target for collection",
zap.Error(err))
@ -363,7 +363,7 @@ func (ob *TargetObserver) updateNextTargetTimestamp(collectionID int64) {
}
func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collectionID int64) bool {
replicaNum := ob.meta.CollectionManager.GetReplicaNumber(collectionID)
replicaNum := ob.meta.CollectionManager.GetReplicaNumber(ctx, collectionID)
log := log.Ctx(ctx).WithRateGroup(
fmt.Sprintf("qcv2.TargetObserver-%d", collectionID),
10,
@ -374,7 +374,7 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
)
// check channel first
channelNames := ob.targetMgr.GetDmChannelsByCollection(collectionID, meta.NextTarget)
channelNames := ob.targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.NextTarget)
if len(channelNames) == 0 {
// next target is empty, no need to update
log.RatedInfo(10, "next target is empty, no need to update")
@ -402,13 +402,13 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
var partitions []int64
var indexInfo []*indexpb.IndexInfo
var err error
newVersion := ob.targetMgr.GetCollectionTargetVersion(collectionID, meta.NextTarget)
newVersion := ob.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.NextTarget)
for _, leader := range collectionReadyLeaders {
updateVersionAction := ob.checkNeedUpdateTargetVersion(ctx, leader, newVersion)
if updateVersionAction == nil {
continue
}
replica := ob.meta.ReplicaManager.GetByCollectionAndNode(collectionID, leader.ID)
replica := ob.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, leader.ID)
if replica == nil {
log.Warn("replica not found", zap.Int64("nodeID", leader.ID), zap.Int64("collectionID", collectionID))
continue
@ -422,7 +422,7 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
return false
}
partitions, err = utils.GetPartitions(ob.meta.CollectionManager, collectionID)
partitions, err = utils.GetPartitions(ctx, ob.meta.CollectionManager, collectionID)
if err != nil {
log.Warn("failed to get partitions", zap.Error(err))
return false
@ -467,7 +467,7 @@ func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, leade
Actions: diffs,
Schema: collectionInfo.GetSchema(),
LoadMeta: &querypb.LoadMetaInfo{
LoadType: ob.meta.GetLoadType(leaderView.CollectionID),
LoadType: ob.meta.GetLoadType(ctx, leaderView.CollectionID),
CollectionID: leaderView.CollectionID,
PartitionIDs: partitions,
DbName: collectionInfo.GetDbName(),
@ -506,10 +506,10 @@ func (ob *TargetObserver) checkNeedUpdateTargetVersion(ctx context.Context, lead
zap.Int64("newVersion", targetVersion),
)
sealedSegments := ob.targetMgr.GetSealedSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
growingSegments := ob.targetMgr.GetGrowingSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
droppedSegments := ob.targetMgr.GetDroppedSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
channel := ob.targetMgr.GetDmChannel(leaderView.CollectionID, leaderView.Channel, meta.NextTargetFirst)
sealedSegments := ob.targetMgr.GetSealedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
growingSegments := ob.targetMgr.GetGrowingSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
droppedSegments := ob.targetMgr.GetDroppedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
channel := ob.targetMgr.GetDmChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTargetFirst)
action := &querypb.SyncAction{
Type: querypb.SyncType_UpdateVersion,
@ -526,10 +526,10 @@ func (ob *TargetObserver) checkNeedUpdateTargetVersion(ctx context.Context, lead
return action
}
func (ob *TargetObserver) updateCurrentTarget(collectionID int64) {
log := log.Ctx(context.TODO()).WithRateGroup("qcv2.TargetObserver", 1, 60)
func (ob *TargetObserver) updateCurrentTarget(ctx context.Context, collectionID int64) {
log := log.Ctx(ctx).WithRateGroup("qcv2.TargetObserver", 1, 60)
log.RatedInfo(10, "observer trigger update current target", zap.Int64("collectionID", collectionID))
if ob.targetMgr.UpdateCollectionCurrentTarget(collectionID) {
if ob.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) {
ob.mut.Lock()
defer ob.mut.Unlock()
notifiers := ob.readyNotifiers[collectionID]

View File

@ -57,6 +57,7 @@ type TargetObserverSuite struct {
partitionID int64
nextTargetSegments []*datapb.SegmentInfo
nextTargetChannels []*datapb.VchannelInfo
ctx context.Context
}
func (suite *TargetObserverSuite) SetupSuite() {
@ -77,6 +78,7 @@ func (suite *TargetObserverSuite) SetupTest() {
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.ctx = context.Background()
// meta
nodeMgr := session.NewNodeManager()
@ -100,14 +102,14 @@ func (suite *TargetObserverSuite) SetupTest() {
testCollection := utils.CreateTestCollection(suite.collectionID, 1)
testCollection.Status = querypb.LoadStatus_Loaded
err = suite.meta.CollectionManager.PutCollection(testCollection)
err = suite.meta.CollectionManager.PutCollection(suite.ctx, testCollection)
suite.NoError(err)
err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID))
err = suite.meta.CollectionManager.PutPartition(suite.ctx, utils.CreateTestPartition(suite.collectionID, suite.partitionID))
suite.NoError(err)
replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil)
replicas, err := suite.meta.ReplicaManager.Spawn(suite.ctx, suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil)
suite.NoError(err)
replicas[0].AddRWNode(2)
err = suite.meta.ReplicaManager.Put(replicas...)
err = suite.meta.ReplicaManager.Put(suite.ctx, replicas...)
suite.NoError(err)
suite.nextTargetChannels = []*datapb.VchannelInfo{
@ -140,9 +142,11 @@ func (suite *TargetObserverSuite) SetupTest() {
}
func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
ctx := suite.ctx
suite.Eventually(func() bool {
return len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 2 &&
len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2
return len(suite.targetMgr.GetSealedSegmentsByCollection(ctx, suite.collectionID, meta.NextTarget)) == 2 &&
len(suite.targetMgr.GetDmChannelsByCollection(ctx, suite.collectionID, meta.NextTarget)) == 2
}, 5*time.Second, 1*time.Second)
suite.distMgr.LeaderViewManager.Update(2,
@ -166,7 +170,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
// Never update current target if it's empty, even the next target is ready
suite.Eventually(func() bool {
return len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.CurrentTarget)) == 0
return len(suite.targetMgr.GetDmChannelsByCollection(ctx, suite.collectionID, meta.CurrentTarget)) == 0
}, 3*time.Second, 1*time.Second)
suite.broker.AssertExpectations(suite.T())
@ -176,7 +180,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
PartitionID: suite.partitionID,
InsertChannel: "channel-1",
})
suite.targetMgr.UpdateCollectionCurrentTarget(suite.collectionID)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, suite.collectionID)
// Pull next again
suite.broker.EXPECT().
@ -184,8 +188,8 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
suite.Eventually(func() bool {
return len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 &&
len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2
return len(suite.targetMgr.GetSealedSegmentsByCollection(ctx, suite.collectionID, meta.NextTarget)) == 3 &&
len(suite.targetMgr.GetDmChannelsByCollection(ctx, suite.collectionID, meta.NextTarget)) == 2
}, 7*time.Second, 1*time.Second)
suite.broker.AssertExpectations(suite.T())
@ -226,18 +230,19 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
default:
}
return isReady &&
len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.CurrentTarget)) == 3 &&
len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.CurrentTarget)) == 2
len(suite.targetMgr.GetSealedSegmentsByCollection(ctx, suite.collectionID, meta.CurrentTarget)) == 3 &&
len(suite.targetMgr.GetDmChannelsByCollection(ctx, suite.collectionID, meta.CurrentTarget)) == 2
}, 7*time.Second, 1*time.Second)
}
func (suite *TargetObserverSuite) TestTriggerRelease() {
ctx := suite.ctx
// Manually update next target
_, err := suite.observer.UpdateNextTarget(suite.collectionID)
suite.NoError(err)
// manually release partition
partitions := suite.meta.CollectionManager.GetPartitionsByCollection(suite.collectionID)
partitions := suite.meta.CollectionManager.GetPartitionsByCollection(ctx, suite.collectionID)
partitionIDs := lo.Map(partitions, func(partition *meta.Partition, _ int) int64 { return partition.PartitionID })
suite.observer.ReleasePartition(suite.collectionID, partitionIDs[0])
@ -265,6 +270,7 @@ type TargetObserverCheckSuite struct {
collectionID int64
partitionID int64
ctx context.Context
}
func (suite *TargetObserverCheckSuite) SetupSuite() {
@ -284,6 +290,7 @@ func (suite *TargetObserverCheckSuite) SetupTest() {
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.ctx = context.Background()
// meta
store := querycoord.NewCatalog(suite.kv)
@ -306,14 +313,14 @@ func (suite *TargetObserverCheckSuite) SetupTest() {
suite.collectionID = int64(1000)
suite.partitionID = int64(100)
err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1))
err = suite.meta.CollectionManager.PutCollection(suite.ctx, utils.CreateTestCollection(suite.collectionID, 1))
suite.NoError(err)
err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID))
err = suite.meta.CollectionManager.PutPartition(suite.ctx, utils.CreateTestPartition(suite.collectionID, suite.partitionID))
suite.NoError(err)
replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil)
replicas, err := suite.meta.ReplicaManager.Spawn(suite.ctx, suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil)
suite.NoError(err)
replicas[0].AddRWNode(2)
err = suite.meta.ReplicaManager.Put(replicas...)
err = suite.meta.ReplicaManager.Put(suite.ctx, replicas...)
suite.NoError(err)
}

View File

@ -440,8 +440,8 @@ func (suite *OpsServiceSuite) TestSuspendAndResumeNode() {
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.HandleNodeUp(1)
nodes, err := suite.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName)
suite.meta.ResourceManager.HandleNodeUp(ctx, 1)
nodes, err := suite.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName)
suite.NoError(err)
suite.Contains(nodes, int64(1))
// test success
@ -451,7 +451,7 @@ func (suite *OpsServiceSuite) TestSuspendAndResumeNode() {
})
suite.NoError(err)
suite.True(merr.Ok(resp))
nodes, err = suite.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName)
nodes, err = suite.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName)
suite.NoError(err)
suite.NotContains(nodes, int64(1))
@ -460,7 +460,7 @@ func (suite *OpsServiceSuite) TestSuspendAndResumeNode() {
})
suite.NoError(err)
suite.True(merr.Ok(resp))
nodes, err = suite.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName)
nodes, err = suite.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName)
suite.NoError(err)
suite.Contains(nodes, int64(1))
}
@ -492,10 +492,10 @@ func (suite *OpsServiceSuite) TestTransferSegment() {
replicaID := int64(1)
nodes := []int64{1, 2, 3, 4}
replica := utils.CreateTestReplica(replicaID, collectionID, nodes)
suite.meta.ReplicaManager.Put(replica)
suite.meta.ReplicaManager.Put(ctx, replica)
collection := utils.CreateTestCollection(collectionID, 1)
partition := utils.CreateTestPartition(partitionID, collectionID)
suite.meta.PutCollection(collection, partition)
suite.meta.PutCollection(ctx, collection, partition)
segmentIDs := []int64{1, 2, 3, 4}
channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"}
@ -594,8 +594,8 @@ func (suite *OpsServiceSuite) TestTransferSegment() {
suite.True(merr.Ok(resp))
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil)
suite.targetMgr.UpdateCollectionNextTarget(1)
suite.targetMgr.UpdateCollectionCurrentTarget(1)
suite.targetMgr.UpdateCollectionNextTarget(ctx, 1)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, 1)
suite.dist.SegmentDistManager.Update(1, segmentInfos...)
suite.dist.ChannelDistManager.Update(1, chanenlInfos...)
@ -605,7 +605,7 @@ func (suite *OpsServiceSuite) TestTransferSegment() {
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.HandleNodeUp(node)
suite.meta.ResourceManager.HandleNodeUp(ctx, node)
}
// test transfer segment success, expect generate 1 balance segment task
@ -741,10 +741,10 @@ func (suite *OpsServiceSuite) TestTransferChannel() {
replicaID := int64(1)
nodes := []int64{1, 2, 3, 4}
replica := utils.CreateTestReplica(replicaID, collectionID, nodes)
suite.meta.ReplicaManager.Put(replica)
suite.meta.ReplicaManager.Put(ctx, replica)
collection := utils.CreateTestCollection(collectionID, 1)
partition := utils.CreateTestPartition(partitionID, collectionID)
suite.meta.PutCollection(collection, partition)
suite.meta.PutCollection(ctx, collection, partition)
segmentIDs := []int64{1, 2, 3, 4}
channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"}
@ -845,8 +845,8 @@ func (suite *OpsServiceSuite) TestTransferChannel() {
suite.True(merr.Ok(resp))
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil)
suite.targetMgr.UpdateCollectionNextTarget(1)
suite.targetMgr.UpdateCollectionCurrentTarget(1)
suite.targetMgr.UpdateCollectionNextTarget(ctx, 1)
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, 1)
suite.dist.SegmentDistManager.Update(1, segmentInfos...)
suite.dist.ChannelDistManager.Update(1, chanenlInfos...)
@ -856,7 +856,7 @@ func (suite *OpsServiceSuite) TestTransferChannel() {
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.HandleNodeUp(node)
suite.meta.ResourceManager.HandleNodeUp(ctx, node)
}
// test transfer channel success, expect generate 1 balance channel task

View File

@ -212,7 +212,7 @@ func (s *Server) SuspendNode(ctx context.Context, req *querypb.SuspendNodeReques
return merr.Status(err), nil
}
s.meta.ResourceManager.HandleNodeDown(req.GetNodeID())
s.meta.ResourceManager.HandleNodeDown(ctx, req.GetNodeID())
return merr.Success(), nil
}
@ -233,7 +233,7 @@ func (s *Server) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest)
return merr.Status(err), nil
}
s.meta.ResourceManager.HandleNodeUp(req.GetNodeID())
s.meta.ResourceManager.HandleNodeUp(ctx, req.GetNodeID())
return merr.Success(), nil
}
@ -262,7 +262,7 @@ func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegme
return merr.Status(err), nil
}
replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID())
replicas := s.meta.ReplicaManager.GetByNode(ctx, req.GetSourceNodeID())
for _, replica := range replicas {
// when no dst node specified, default to use all other nodes in same
dstNodeSet := typeutil.NewUniqueSet()
@ -292,7 +292,7 @@ func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegme
return merr.Status(err), nil
}
existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
existInTarget := s.targetMgr.GetSealedSegment(ctx, segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
if !existInTarget {
log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", req.GetSegmentID()))
} else {
@ -334,7 +334,7 @@ func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChann
return merr.Status(err), nil
}
replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID())
replicas := s.meta.ReplicaManager.GetByNode(ctx, req.GetSourceNodeID())
for _, replica := range replicas {
// when no dst node specified, default to use all other nodes in same
dstNodeSet := typeutil.NewUniqueSet()
@ -362,7 +362,7 @@ func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChann
err := merr.WrapErrChannelNotFound(req.GetChannelName(), "channel not found in source node")
return merr.Status(err), nil
}
existInTarget := s.targetMgr.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) != nil
existInTarget := s.targetMgr.GetDmChannel(ctx, channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) != nil
if !existInTarget {
log.Info("channel doesn't exist in current target, skip it", zap.String("channelName", channel.GetChannelName()))
} else {
@ -414,7 +414,7 @@ func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.Ch
return ch.GetChannelName(), ch
})
for _, ch := range channelOnSrc {
if s.targetMgr.GetDmChannel(ch.GetCollectionID(), ch.GetChannelName(), meta.CurrentTargetFirst) == nil {
if s.targetMgr.GetDmChannel(ctx, ch.GetCollectionID(), ch.GetChannelName(), meta.CurrentTargetFirst) == nil {
continue
}
@ -430,7 +430,7 @@ func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.Ch
return s.GetID(), s
})
for _, segment := range segmentOnSrc {
if s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTargetFirst) == nil {
if s.targetMgr.GetSealedSegment(ctx, segment.GetCollectionID(), segment.GetID(), meta.CurrentTargetFirst) == nil {
continue
}

View File

@ -209,15 +209,15 @@ func (s *Server) registerMetricsRequest() {
if v.Exists() {
scope = meta.TargetScope(v.Int())
}
return s.targetMgr.GetTargetJSON(scope), nil
return s.targetMgr.GetTargetJSON(ctx, scope), nil
}
QueryReplicasAction := func(ctx context.Context, req *milvuspb.GetMetricsRequest, jsonReq gjson.Result) (string, error) {
return s.meta.GetReplicasJSON(), nil
return s.meta.GetReplicasJSON(ctx), nil
}
QueryResourceGroupsAction := func(ctx context.Context, req *milvuspb.GetMetricsRequest, jsonReq gjson.Result) (string, error) {
return s.meta.GetResourceGroupsJSON(), nil
return s.meta.GetResourceGroupsJSON(ctx), nil
}
QuerySegmentsAction := func(ctx context.Context, req *milvuspb.GetMetricsRequest, jsonReq gjson.Result) (string, error) {
@ -421,26 +421,26 @@ func (s *Server) initMeta() error {
)
log.Info("recover meta...")
err := s.meta.CollectionManager.Recover(s.broker)
err := s.meta.CollectionManager.Recover(s.ctx, s.broker)
if err != nil {
log.Warn("failed to recover collections", zap.Error(err))
return err
}
collections := s.meta.GetAll()
collections := s.meta.GetAll(s.ctx)
log.Info("recovering collections...", zap.Int64s("collections", collections))
// We really update the metric after observers think the collection loaded.
metrics.QueryCoordNumCollections.WithLabelValues().Set(0)
metrics.QueryCoordNumPartitions.WithLabelValues().Set(float64(len(s.meta.GetAllPartitions())))
metrics.QueryCoordNumPartitions.WithLabelValues().Set(float64(len(s.meta.GetAllPartitions(s.ctx))))
err = s.meta.ReplicaManager.Recover(collections)
err = s.meta.ReplicaManager.Recover(s.ctx, collections)
if err != nil {
log.Warn("failed to recover replicas", zap.Error(err))
return err
}
err = s.meta.ResourceManager.Recover()
err = s.meta.ResourceManager.Recover(s.ctx)
if err != nil {
log.Warn("failed to recover resource groups", zap.Error(err))
return err
@ -452,7 +452,7 @@ func (s *Server) initMeta() error {
LeaderViewManager: meta.NewLeaderViewManager(),
}
s.targetMgr = meta.NewTargetManager(s.broker, s.meta)
err = s.targetMgr.Recover(s.store)
err = s.targetMgr.Recover(s.ctx, s.store)
if err != nil {
log.Warn("failed to recover collection targets", zap.Error(err))
}
@ -609,7 +609,7 @@ func (s *Server) Stop() error {
// save target to meta store, after querycoord restart, make it fast to recover current target
// should save target after target observer stop, incase of target changed
if s.targetMgr != nil {
s.targetMgr.SaveCurrentTarget(s.store)
s.targetMgr.SaveCurrentTarget(s.ctx, s.store)
}
if s.replicaObserver != nil {
@ -773,7 +773,7 @@ func (s *Server) watchNodes(revision int64) {
)
s.nodeMgr.Stopping(nodeID)
s.checkerController.Check()
s.meta.ResourceManager.HandleNodeStopping(nodeID)
s.meta.ResourceManager.HandleNodeStopping(s.ctx, nodeID)
case sessionutil.SessionDelEvent:
nodeID := event.Session.ServerID
@ -833,7 +833,7 @@ func (s *Server) handleNodeUp(node int64) {
s.taskScheduler.AddExecutor(node)
s.distController.StartDistInstance(s.ctx, node)
// need assign to new rg and replica
s.meta.ResourceManager.HandleNodeUp(node)
s.meta.ResourceManager.HandleNodeUp(s.ctx, node)
}
func (s *Server) handleNodeDown(node int64) {
@ -848,18 +848,18 @@ func (s *Server) handleNodeDown(node int64) {
// Clear tasks
s.taskScheduler.RemoveByNode(node)
s.meta.ResourceManager.HandleNodeDown(node)
s.meta.ResourceManager.HandleNodeDown(s.ctx, node)
}
func (s *Server) checkNodeStateInRG() {
for _, rgName := range s.meta.ListResourceGroups() {
rg := s.meta.ResourceManager.GetResourceGroup(rgName)
for _, rgName := range s.meta.ListResourceGroups(s.ctx) {
rg := s.meta.ResourceManager.GetResourceGroup(s.ctx, rgName)
for _, node := range rg.GetNodes() {
info := s.nodeMgr.Get(node)
if info == nil {
s.meta.ResourceManager.HandleNodeDown(node)
s.meta.ResourceManager.HandleNodeDown(s.ctx, node)
} else if info.IsStoppingState() {
s.meta.ResourceManager.HandleNodeStopping(node)
s.meta.ResourceManager.HandleNodeStopping(s.ctx, node)
}
}
}
@ -917,7 +917,7 @@ func (s *Server) watchLoadConfigChanges() {
replicaNumHandler := config.NewHandler("watchReplicaNumberChanges", func(e *config.Event) {
log.Info("watch load config changes", zap.String("key", e.Key), zap.String("value", e.Value), zap.String("type", e.EventType))
collectionIDs := s.meta.GetAll()
collectionIDs := s.meta.GetAll(s.ctx)
if len(collectionIDs) == 0 {
log.Warn("no collection loaded, skip to trigger update load config")
return
@ -944,7 +944,7 @@ func (s *Server) watchLoadConfigChanges() {
rgHandler := config.NewHandler("watchResourceGroupChanges", func(e *config.Event) {
log.Info("watch load config changes", zap.String("key", e.Key), zap.String("value", e.Value), zap.String("type", e.EventType))
collectionIDs := s.meta.GetAll()
collectionIDs := s.meta.GetAll(s.ctx)
if len(collectionIDs) == 0 {
log.Warn("no collection loaded, skip to trigger update load config")
return

View File

@ -89,6 +89,7 @@ type ServerSuite struct {
tikvCli *txnkv.Client
server *Server
nodes []*mocks.MockQueryNode
ctx context.Context
}
var testMeta string
@ -125,6 +126,7 @@ func (suite *ServerSuite) SetupSuite() {
1001: 3,
}
suite.nodes = make([]*mocks.MockQueryNode, 3)
suite.ctx = context.Background()
}
func (suite *ServerSuite) SetupTest() {
@ -144,13 +146,13 @@ func (suite *ServerSuite) SetupTest() {
suite.Require().NoError(err)
ok := suite.waitNodeUp(suite.nodes[i], 5*time.Second)
suite.Require().True(ok)
suite.server.meta.ResourceManager.HandleNodeUp(suite.nodes[i].ID)
suite.server.meta.ResourceManager.HandleNodeUp(suite.ctx, suite.nodes[i].ID)
suite.expectLoadAndReleasePartitions(suite.nodes[i])
}
suite.loadAll()
for _, collection := range suite.collections {
suite.True(suite.server.meta.Exist(collection))
suite.True(suite.server.meta.Exist(suite.ctx, collection))
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
}
}
@ -181,7 +183,7 @@ func (suite *ServerSuite) TestRecover() {
suite.NoError(err)
for _, collection := range suite.collections {
suite.True(suite.server.meta.Exist(collection))
suite.True(suite.server.meta.Exist(suite.ctx, collection))
}
suite.True(suite.server.nodeMgr.IsStoppingNode(suite.nodes[0].ID))
@ -201,7 +203,7 @@ func (suite *ServerSuite) TestNodeUp() {
return false
}
for _, collection := range suite.collections {
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(collection, node1.ID)
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node1.ID)
if replica == nil {
return false
}
@ -230,7 +232,7 @@ func (suite *ServerSuite) TestNodeUp() {
return false
}
for _, collection := range suite.collections {
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(collection, node2.ID)
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node2.ID)
if replica == nil {
return true
}
@ -249,7 +251,7 @@ func (suite *ServerSuite) TestNodeUp() {
return false
}
for _, collection := range suite.collections {
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(collection, node2.ID)
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node2.ID)
if replica == nil {
return false
}
@ -279,7 +281,7 @@ func (suite *ServerSuite) TestNodeDown() {
return false
}
for _, collection := range suite.collections {
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(collection, downNode.ID)
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, downNode.ID)
if replica != nil {
return false
}
@ -525,7 +527,7 @@ func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64,
}
func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status querypb.LoadStatus) {
collection := suite.server.meta.GetCollection(collectionID)
collection := suite.server.meta.GetCollection(suite.ctx, collectionID)
if collection != nil {
collection := collection.Clone()
collection.LoadPercentage = 0
@ -533,9 +535,9 @@ func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status quer
collection.LoadPercentage = 100
}
collection.CollectionLoadInfo.Status = status
suite.server.meta.PutCollection(collection)
suite.server.meta.PutCollection(suite.ctx, collection)
partitions := suite.server.meta.GetPartitionsByCollection(collectionID)
partitions := suite.server.meta.GetPartitionsByCollection(suite.ctx, collectionID)
for _, partition := range partitions {
partition := partition.Clone()
partition.LoadPercentage = 0
@ -543,7 +545,7 @@ func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status quer
partition.LoadPercentage = 100
}
partition.PartitionLoadInfo.Status = status
suite.server.meta.PutPartition(partition)
suite.server.meta.PutPartition(suite.ctx, partition)
}
}
}

View File

@ -70,7 +70,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
isGetAll := false
collectionSet := typeutil.NewUniqueSet(req.GetCollectionIDs()...)
if len(req.GetCollectionIDs()) == 0 {
for _, collection := range s.meta.GetAllCollections() {
for _, collection := range s.meta.GetAllCollections(ctx) {
collectionSet.Insert(collection.GetCollectionID())
}
isGetAll = true
@ -86,9 +86,9 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
for _, collectionID := range collections {
log := log.With(zap.Int64("collectionID", collectionID))
collection := s.meta.CollectionManager.GetCollection(collectionID)
percentage := s.meta.CollectionManager.CalculateLoadPercentage(collectionID)
loadFields := s.meta.CollectionManager.GetLoadFields(collectionID)
collection := s.meta.CollectionManager.GetCollection(ctx, collectionID)
percentage := s.meta.CollectionManager.CalculateLoadPercentage(ctx, collectionID)
loadFields := s.meta.CollectionManager.GetLoadFields(ctx, collectionID)
refreshProgress := int64(0)
if percentage < 0 {
if isGetAll {
@ -150,13 +150,13 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions
refreshProgress := int64(0)
if len(partitions) == 0 {
partitions = lo.Map(s.meta.GetPartitionsByCollection(req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 {
partitions = lo.Map(s.meta.GetPartitionsByCollection(ctx, req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
}
for _, partitionID := range partitions {
percentage := s.meta.GetPartitionLoadPercentage(partitionID)
percentage := s.meta.GetPartitionLoadPercentage(ctx, partitionID)
if percentage < 0 {
err := meta.GlobalFailedLoadCache.Get(req.GetCollectionID())
if err != nil {
@ -177,7 +177,7 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions
percentages = append(percentages, int64(percentage))
}
collection := s.meta.GetCollection(req.GetCollectionID())
collection := s.meta.GetCollection(ctx, req.GetCollectionID())
if collection != nil && collection.IsRefreshed() {
refreshProgress = 100
}
@ -217,7 +217,7 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection
// If refresh mode is ON.
if req.GetRefresh() {
err := s.refreshCollection(req.GetCollectionID())
err := s.refreshCollection(ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to refresh collection", zap.Error(err))
}
@ -253,11 +253,11 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection
}
var loadJob job.Job
collection := s.meta.GetCollection(req.GetCollectionID())
collection := s.meta.GetCollection(ctx, req.GetCollectionID())
if collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded {
// if collection is loaded, check if collection is loaded with the same replica number and resource groups
// if replica number or resource group changes switch to update load config
collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(collection.GetCollectionID()).Collect()
collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(ctx, collection.GetCollectionID()).Collect()
left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups())
rgChanged := len(left) > 0 || len(right) > 0
replicaChanged := collection.GetReplicaNumber() != req.GetReplicaNumber()
@ -372,7 +372,7 @@ func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions
// If refresh mode is ON.
if req.GetRefresh() {
err := s.refreshCollection(req.GetCollectionID())
err := s.refreshCollection(ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to refresh partitions", zap.Error(err))
}
@ -494,9 +494,9 @@ func (s *Server) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti
}
states := make([]*querypb.PartitionStates, 0, len(req.GetPartitionIDs()))
switch s.meta.GetLoadType(req.GetCollectionID()) {
switch s.meta.GetLoadType(ctx, req.GetCollectionID()) {
case querypb.LoadType_LoadCollection:
collection := s.meta.GetCollection(req.GetCollectionID())
collection := s.meta.GetCollection(ctx, req.GetCollectionID())
state := querypb.PartitionState_PartialInMemory
if collection.LoadPercentage >= 100 {
state = querypb.PartitionState_InMemory
@ -515,7 +515,7 @@ func (s *Server) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti
case querypb.LoadType_LoadPartition:
for _, partitionID := range req.GetPartitionIDs() {
partition := s.meta.GetPartition(partitionID)
partition := s.meta.GetPartition(ctx, partitionID)
if partition == nil {
log.Warn(msg, zap.Int64("partition", partitionID))
return notLoadResp, nil
@ -558,7 +558,7 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
infos := make([]*querypb.SegmentInfo, 0, len(req.GetSegmentIDs()))
if len(req.GetSegmentIDs()) == 0 {
infos = s.getCollectionSegmentInfo(req.GetCollectionID())
infos = s.getCollectionSegmentInfo(ctx, req.GetCollectionID())
} else {
for _, segmentID := range req.GetSegmentIDs() {
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithSegmentID(segmentID))
@ -611,8 +611,8 @@ func (s *Server) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncN
// tries to load them up. It returns when all segments of the given collection are loaded, or when error happens.
// Note that a collection's loading progress always stays at 100% after a successful load and will not get updated
// during refreshCollection.
func (s *Server) refreshCollection(collectionID int64) error {
collection := s.meta.CollectionManager.GetCollection(collectionID)
func (s *Server) refreshCollection(ctx context.Context, collectionID int64) error {
collection := s.meta.CollectionManager.GetCollection(ctx, collectionID)
if collection == nil {
return merr.WrapErrCollectionNotLoaded(collectionID)
}
@ -724,14 +724,14 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques
log.Warn(msg, zap.Int("source-nodes-num", len(req.GetSourceNodeIDs())))
return merr.Status(err), nil
}
if s.meta.CollectionManager.CalculateLoadPercentage(req.GetCollectionID()) < 100 {
if s.meta.CollectionManager.CalculateLoadPercentage(ctx, req.GetCollectionID()) < 100 {
err := merr.WrapErrCollectionNotFullyLoaded(req.GetCollectionID())
msg := "can't balance segments of not fully loaded collection"
log.Warn(msg)
return merr.Status(err), nil
}
srcNode := req.GetSourceNodeIDs()[0]
replica := s.meta.ReplicaManager.GetByCollectionAndNode(req.GetCollectionID(), srcNode)
replica := s.meta.ReplicaManager.GetByCollectionAndNode(ctx, req.GetCollectionID(), srcNode)
if replica == nil {
err := merr.WrapErrNodeNotFound(srcNode, fmt.Sprintf("source node not found in any replica of collection %d", req.GetCollectionID()))
msg := "source node not found in any replica"
@ -785,7 +785,7 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques
}
// Only balance segments in targets
existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
existInTarget := s.targetMgr.GetSealedSegment(ctx, segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil
if !existInTarget {
log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", segmentID))
continue
@ -881,13 +881,13 @@ func (s *Server) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReque
Replicas: make([]*milvuspb.ReplicaInfo, 0),
}
replicas := s.meta.ReplicaManager.GetByCollection(req.GetCollectionID())
replicas := s.meta.ReplicaManager.GetByCollection(ctx, req.GetCollectionID())
if len(replicas) == 0 {
return resp, nil
}
for _, replica := range replicas {
resp.Replicas = append(resp.Replicas, s.fillReplicaInfo(replica, req.GetWithShardNodes()))
resp.Replicas = append(resp.Replicas, s.fillReplicaInfo(ctx, replica, req.GetWithShardNodes()))
}
return resp, nil
}
@ -969,7 +969,7 @@ func (s *Server) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateRe
return merr.Status(err), nil
}
err := s.meta.ResourceManager.AddResourceGroup(req.GetResourceGroup(), req.GetConfig())
err := s.meta.ResourceManager.AddResourceGroup(ctx, req.GetResourceGroup(), req.GetConfig())
if err != nil {
log.Warn("failed to create resource group", zap.Error(err))
return merr.Status(err), nil
@ -988,7 +988,7 @@ func (s *Server) UpdateResourceGroups(ctx context.Context, req *querypb.UpdateRe
return merr.Status(err), nil
}
err := s.meta.ResourceManager.UpdateResourceGroups(req.GetResourceGroups())
err := s.meta.ResourceManager.UpdateResourceGroups(ctx, req.GetResourceGroups())
if err != nil {
log.Warn("failed to update resource group", zap.Error(err))
return merr.Status(err), nil
@ -1007,14 +1007,14 @@ func (s *Server) DropResourceGroup(ctx context.Context, req *milvuspb.DropResour
return merr.Status(err), nil
}
replicas := s.meta.ReplicaManager.GetByResourceGroup(req.GetResourceGroup())
replicas := s.meta.ReplicaManager.GetByResourceGroup(ctx, req.GetResourceGroup())
if len(replicas) > 0 {
err := merr.WrapErrParameterInvalid("empty resource group", fmt.Sprintf("resource group %s has collection %d loaded", req.GetResourceGroup(), replicas[0].GetCollectionID()))
return merr.Status(errors.Wrap(err,
fmt.Sprintf("some replicas still loaded in resource group[%s], release it first", req.GetResourceGroup()))), nil
}
err := s.meta.ResourceManager.RemoveResourceGroup(req.GetResourceGroup())
err := s.meta.ResourceManager.RemoveResourceGroup(ctx, req.GetResourceGroup())
if err != nil {
log.Warn("failed to drop resource group", zap.Error(err))
return merr.Status(err), nil
@ -1037,7 +1037,7 @@ func (s *Server) TransferNode(ctx context.Context, req *milvuspb.TransferNodeReq
}
// Move node from source resource group to target resource group.
if err := s.meta.ResourceManager.TransferNode(req.GetSourceResourceGroup(), req.GetTargetResourceGroup(), int(req.GetNumNode())); err != nil {
if err := s.meta.ResourceManager.TransferNode(ctx, req.GetSourceResourceGroup(), req.GetTargetResourceGroup(), int(req.GetNumNode())); err != nil {
log.Warn("failed to transfer node", zap.Error(err))
return merr.Status(err), nil
}
@ -1059,20 +1059,20 @@ func (s *Server) TransferReplica(ctx context.Context, req *querypb.TransferRepli
}
// TODO: !!!WARNING, replica manager and resource manager doesn't protected with each other by lock.
if ok := s.meta.ResourceManager.ContainResourceGroup(req.GetSourceResourceGroup()); !ok {
if ok := s.meta.ResourceManager.ContainResourceGroup(ctx, req.GetSourceResourceGroup()); !ok {
err := merr.WrapErrResourceGroupNotFound(req.GetSourceResourceGroup())
return merr.Status(errors.Wrap(err,
fmt.Sprintf("the source resource group[%s] doesn't exist", req.GetSourceResourceGroup()))), nil
}
if ok := s.meta.ResourceManager.ContainResourceGroup(req.GetTargetResourceGroup()); !ok {
if ok := s.meta.ResourceManager.ContainResourceGroup(ctx, req.GetTargetResourceGroup()); !ok {
err := merr.WrapErrResourceGroupNotFound(req.GetTargetResourceGroup())
return merr.Status(errors.Wrap(err,
fmt.Sprintf("the target resource group[%s] doesn't exist", req.GetTargetResourceGroup()))), nil
}
// Apply change into replica manager.
err := s.meta.TransferReplica(req.GetCollectionID(), req.GetSourceResourceGroup(), req.GetTargetResourceGroup(), int(req.GetNumReplica()))
err := s.meta.TransferReplica(ctx, req.GetCollectionID(), req.GetSourceResourceGroup(), req.GetTargetResourceGroup(), int(req.GetNumReplica()))
return merr.Status(err), nil
}
@ -1089,7 +1089,7 @@ func (s *Server) ListResourceGroups(ctx context.Context, req *milvuspb.ListResou
return resp, nil
}
resp.ResourceGroups = s.meta.ResourceManager.ListResourceGroups()
resp.ResourceGroups = s.meta.ResourceManager.ListResourceGroups(ctx)
return resp, nil
}
@ -1108,7 +1108,7 @@ func (s *Server) DescribeResourceGroup(ctx context.Context, req *querypb.Describ
return resp, nil
}
rg := s.meta.ResourceManager.GetResourceGroup(req.GetResourceGroup())
rg := s.meta.ResourceManager.GetResourceGroup(ctx, req.GetResourceGroup())
if rg == nil {
err := merr.WrapErrResourceGroupNotFound(req.GetResourceGroup())
resp.Status = merr.Status(err)
@ -1117,26 +1117,26 @@ func (s *Server) DescribeResourceGroup(ctx context.Context, req *querypb.Describ
loadedReplicas := make(map[int64]int32)
outgoingNodes := make(map[int64]int32)
replicasInRG := s.meta.GetByResourceGroup(req.GetResourceGroup())
replicasInRG := s.meta.GetByResourceGroup(ctx, req.GetResourceGroup())
for _, replica := range replicasInRG {
loadedReplicas[replica.GetCollectionID()]++
for _, node := range replica.GetRONodes() {
if !s.meta.ContainsNode(replica.GetResourceGroup(), node) {
if !s.meta.ContainsNode(ctx, replica.GetResourceGroup(), node) {
outgoingNodes[replica.GetCollectionID()]++
}
}
}
incomingNodes := make(map[int64]int32)
collections := s.meta.GetAll()
collections := s.meta.GetAll(ctx)
for _, collection := range collections {
replicas := s.meta.GetByCollection(collection)
replicas := s.meta.GetByCollection(ctx, collection)
for _, replica := range replicas {
if replica.GetResourceGroup() == req.GetResourceGroup() {
continue
}
for _, node := range replica.GetRONodes() {
if s.meta.ContainsNode(req.GetResourceGroup(), node) {
if s.meta.ContainsNode(ctx, req.GetResourceGroup(), node) {
incomingNodes[collection]++
}
}
@ -1184,14 +1184,14 @@ func (s *Server) UpdateLoadConfig(ctx context.Context, req *querypb.UpdateLoadCo
jobs := make([]job.Job, 0, len(req.GetCollectionIDs()))
for _, collectionID := range req.GetCollectionIDs() {
collection := s.meta.GetCollection(collectionID)
collection := s.meta.GetCollection(ctx, collectionID)
if collection == nil || collection.GetStatus() != querypb.LoadStatus_Loaded {
err := merr.WrapErrCollectionNotLoaded(collectionID)
log.Warn("failed to update load config", zap.Error(err))
continue
}
collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(collection.GetCollectionID()).Collect()
collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(ctx, collection.GetCollectionID()).Collect()
left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups())
rgChanged := len(left) > 0 || len(right) > 0
replicaChanged := collection.GetReplicaNumber() != req.GetReplicaNumber()

View File

@ -166,7 +166,7 @@ func (suite *ServiceSuite) SetupTest() {
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.HandleNodeUp(node)
suite.meta.ResourceManager.HandleNodeUp(context.TODO(), node)
}
suite.cluster = session.NewMockCluster(suite.T())
suite.cluster.EXPECT().SyncDistribution(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil).Maybe()
@ -250,15 +250,15 @@ func (suite *ServiceSuite) TestShowCollections() {
suite.Equal(collection, resp.CollectionIDs[0])
// Test insufficient memory
colBak := suite.meta.CollectionManager.GetCollection(collection)
err = suite.meta.CollectionManager.RemoveCollection(collection)
colBak := suite.meta.CollectionManager.GetCollection(ctx, collection)
err = suite.meta.CollectionManager.RemoveCollection(ctx, collection)
suite.NoError(err)
meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10))
resp, err = server.ShowCollections(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode())
meta.GlobalFailedLoadCache.Remove(collection)
err = suite.meta.CollectionManager.PutCollection(colBak)
err = suite.meta.CollectionManager.PutCollection(ctx, colBak)
suite.NoError(err)
// Test when server is not healthy
@ -304,27 +304,27 @@ func (suite *ServiceSuite) TestShowPartitions() {
// Test insufficient memory
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
colBak := suite.meta.CollectionManager.GetCollection(collection)
err = suite.meta.CollectionManager.RemoveCollection(collection)
colBak := suite.meta.CollectionManager.GetCollection(ctx, collection)
err = suite.meta.CollectionManager.RemoveCollection(ctx, collection)
suite.NoError(err)
meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10))
resp, err = server.ShowPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode())
meta.GlobalFailedLoadCache.Remove(collection)
err = suite.meta.CollectionManager.PutCollection(colBak)
err = suite.meta.CollectionManager.PutCollection(ctx, colBak)
suite.NoError(err)
} else {
partitionID := partitions[0]
parBak := suite.meta.CollectionManager.GetPartition(partitionID)
err = suite.meta.CollectionManager.RemovePartition(collection, partitionID)
parBak := suite.meta.CollectionManager.GetPartition(ctx, partitionID)
err = suite.meta.CollectionManager.RemovePartition(ctx, collection, partitionID)
suite.NoError(err)
meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10))
resp, err = server.ShowPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode())
meta.GlobalFailedLoadCache.Remove(collection)
err = suite.meta.CollectionManager.PutPartition(parBak)
err = suite.meta.CollectionManager.PutPartition(ctx, parBak)
suite.NoError(err)
}
}
@ -354,7 +354,7 @@ func (suite *ServiceSuite) TestLoadCollection() {
resp, err := server.LoadCollection(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.assertLoaded(collection)
suite.assertLoaded(ctx, collection)
}
// Test load again
@ -430,21 +430,21 @@ func (suite *ServiceSuite) TestResourceGroup() {
Address: "localhost",
Hostname: "localhost",
}))
server.meta.ResourceManager.AddResourceGroup("rg11", &rgpb.ResourceGroupConfig{
server.meta.ResourceManager.AddResourceGroup(ctx, "rg11", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 2},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 2},
})
server.meta.ResourceManager.HandleNodeUp(1011)
server.meta.ResourceManager.HandleNodeUp(1012)
server.meta.ResourceManager.AddResourceGroup("rg12", &rgpb.ResourceGroupConfig{
server.meta.ResourceManager.HandleNodeUp(ctx, 1011)
server.meta.ResourceManager.HandleNodeUp(ctx, 1012)
server.meta.ResourceManager.AddResourceGroup(ctx, "rg12", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 2},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 2},
})
server.meta.ResourceManager.HandleNodeUp(1013)
server.meta.ResourceManager.HandleNodeUp(1014)
server.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
server.meta.CollectionManager.PutCollection(utils.CreateTestCollection(2, 1))
server.meta.ReplicaManager.Put(meta.NewReplica(&querypb.Replica{
server.meta.ResourceManager.HandleNodeUp(ctx, 1013)
server.meta.ResourceManager.HandleNodeUp(ctx, 1014)
server.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
server.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(2, 1))
server.meta.ReplicaManager.Put(ctx, meta.NewReplica(&querypb.Replica{
ID: 1,
CollectionID: 1,
Nodes: []int64{1011},
@ -453,7 +453,7 @@ func (suite *ServiceSuite) TestResourceGroup() {
},
typeutil.NewUniqueSet(1011, 1013)),
)
server.meta.ReplicaManager.Put(meta.NewReplica(&querypb.Replica{
server.meta.ReplicaManager.Put(ctx, meta.NewReplica(&querypb.Replica{
ID: 2,
CollectionID: 2,
Nodes: []int64{1014},
@ -548,18 +548,18 @@ func (suite *ServiceSuite) TestTransferNode() {
defer server.resourceObserver.Stop()
defer server.replicaObserver.Stop()
err := server.meta.ResourceManager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{
err := server.meta.ResourceManager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 0},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 0},
})
suite.NoError(err)
err = server.meta.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{
err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 0},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 0},
})
suite.NoError(err)
suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 2))
suite.meta.ReplicaManager.Put(meta.NewReplica(
suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 2))
suite.meta.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 1,
CollectionID: 1,
@ -578,15 +578,15 @@ func (suite *ServiceSuite) TestTransferNode() {
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.Eventually(func() bool {
nodes, err := server.meta.ResourceManager.GetNodes("rg1")
nodes, err := server.meta.ResourceManager.GetNodes(ctx, "rg1")
if err != nil || len(nodes) != 1 {
return false
}
nodesInReplica := server.meta.ReplicaManager.Get(1).GetNodes()
nodesInReplica := server.meta.ReplicaManager.Get(ctx, 1).GetNodes()
return len(nodesInReplica) == 1
}, 5*time.Second, 100*time.Millisecond)
suite.meta.ReplicaManager.Put(meta.NewReplica(
suite.meta.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 2,
CollectionID: 1,
@ -612,12 +612,12 @@ func (suite *ServiceSuite) TestTransferNode() {
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
err = server.meta.ResourceManager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{
err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 4},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 4},
})
suite.NoError(err)
err = server.meta.ResourceManager.AddResourceGroup("rg4", &rgpb.ResourceGroupConfig{
err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg4", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 0},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 0},
})
@ -642,10 +642,10 @@ func (suite *ServiceSuite) TestTransferNode() {
Address: "localhost",
Hostname: "localhost",
}))
suite.meta.ResourceManager.HandleNodeUp(11)
suite.meta.ResourceManager.HandleNodeUp(12)
suite.meta.ResourceManager.HandleNodeUp(13)
suite.meta.ResourceManager.HandleNodeUp(14)
suite.meta.ResourceManager.HandleNodeUp(ctx, 11)
suite.meta.ResourceManager.HandleNodeUp(ctx, 12)
suite.meta.ResourceManager.HandleNodeUp(ctx, 13)
suite.meta.ResourceManager.HandleNodeUp(ctx, 14)
resp, err = server.TransferNode(ctx, &milvuspb.TransferNodeRequest{
SourceResourceGroup: "rg3",
@ -656,11 +656,11 @@ func (suite *ServiceSuite) TestTransferNode() {
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.Eventually(func() bool {
nodes, err := server.meta.ResourceManager.GetNodes("rg3")
nodes, err := server.meta.ResourceManager.GetNodes(ctx, "rg3")
if err != nil || len(nodes) != 1 {
return false
}
nodes, err = server.meta.ResourceManager.GetNodes("rg4")
nodes, err = server.meta.ResourceManager.GetNodes(ctx, "rg4")
return err == nil && len(nodes) == 3
}, 5*time.Second, 100*time.Millisecond)
@ -695,17 +695,17 @@ func (suite *ServiceSuite) TestTransferReplica() {
ctx := context.Background()
server := suite.server
err := server.meta.ResourceManager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{
err := server.meta.ResourceManager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 1},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 1},
})
suite.NoError(err)
err = server.meta.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{
err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 1},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 1},
})
suite.NoError(err)
err = server.meta.ResourceManager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{
err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 3},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 3},
})
@ -747,17 +747,17 @@ func (suite *ServiceSuite) TestTransferReplica() {
suite.NoError(err)
suite.ErrorIs(merr.Error(resp), merr.ErrParameterInvalid)
suite.server.meta.Put(meta.NewReplica(&querypb.Replica{
suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{
CollectionID: 1,
ID: 111,
ResourceGroup: meta.DefaultResourceGroupName,
}, typeutil.NewUniqueSet(1)))
suite.server.meta.Put(meta.NewReplica(&querypb.Replica{
suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{
CollectionID: 1,
ID: 222,
ResourceGroup: meta.DefaultResourceGroupName,
}, typeutil.NewUniqueSet(2)))
suite.server.meta.Put(meta.NewReplica(&querypb.Replica{
suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{
CollectionID: 1,
ID: 333,
ResourceGroup: meta.DefaultResourceGroupName,
@ -788,18 +788,18 @@ func (suite *ServiceSuite) TestTransferReplica() {
Address: "localhost",
Hostname: "localhost",
}))
suite.server.meta.HandleNodeUp(1001)
suite.server.meta.HandleNodeUp(1002)
suite.server.meta.HandleNodeUp(1003)
suite.server.meta.HandleNodeUp(1004)
suite.server.meta.HandleNodeUp(1005)
suite.server.meta.HandleNodeUp(ctx, 1001)
suite.server.meta.HandleNodeUp(ctx, 1002)
suite.server.meta.HandleNodeUp(ctx, 1003)
suite.server.meta.HandleNodeUp(ctx, 1004)
suite.server.meta.HandleNodeUp(ctx, 1005)
suite.server.meta.Put(meta.NewReplica(&querypb.Replica{
suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{
CollectionID: 2,
ID: 444,
ResourceGroup: meta.DefaultResourceGroupName,
}, typeutil.NewUniqueSet(3)))
suite.server.meta.Put(meta.NewReplica(&querypb.Replica{
suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{
CollectionID: 2,
ID: 555,
ResourceGroup: "rg2",
@ -824,7 +824,7 @@ func (suite *ServiceSuite) TestTransferReplica() {
// we support transfer replica to resource group load same collection.
suite.Equal(resp.ErrorCode, commonpb.ErrorCode_Success)
replicaNum := len(suite.server.meta.ReplicaManager.GetByCollection(1))
replicaNum := len(suite.server.meta.ReplicaManager.GetByCollection(ctx, 1))
suite.Equal(3, replicaNum)
resp, err = suite.server.TransferReplica(ctx, &querypb.TransferReplicaRequest{
SourceResourceGroup: meta.DefaultResourceGroupName,
@ -842,7 +842,7 @@ func (suite *ServiceSuite) TestTransferReplica() {
})
suite.NoError(err)
suite.Equal(resp.ErrorCode, commonpb.ErrorCode_Success)
suite.Len(suite.server.meta.GetByResourceGroup("rg3"), 3)
suite.Len(suite.server.meta.GetByResourceGroup(ctx, "rg3"), 3)
// server unhealthy
server.UpdateStateCode(commonpb.StateCode_Abnormal)
@ -924,7 +924,7 @@ func (suite *ServiceSuite) TestLoadPartition() {
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.assertLoaded(collection)
suite.assertLoaded(ctx, collection)
}
// Test load again
@ -1020,7 +1020,7 @@ func (suite *ServiceSuite) TestReleaseCollection() {
resp, err := server.ReleaseCollection(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.assertReleased(collection)
suite.assertReleased(ctx, collection)
}
// Test release again
@ -1059,7 +1059,7 @@ func (suite *ServiceSuite) TestReleasePartition() {
resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...)
suite.assertPartitionLoaded(ctx, collection, suite.partitions[collection][1:]...)
}
// Test release again
@ -1071,7 +1071,7 @@ func (suite *ServiceSuite) TestReleasePartition() {
resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...)
suite.assertPartitionLoaded(ctx, collection, suite.partitions[collection][1:]...)
}
// Test when server is not healthy
@ -1086,11 +1086,12 @@ func (suite *ServiceSuite) TestReleasePartition() {
}
func (suite *ServiceSuite) TestRefreshCollection() {
ctx := context.Background()
server := suite.server
// Test refresh all collections.
for _, collection := range suite.collections {
err := server.refreshCollection(collection)
err := server.refreshCollection(ctx, collection)
// Collection not loaded error.
suite.ErrorIs(err, merr.ErrCollectionNotLoaded)
}
@ -1100,19 +1101,19 @@ func (suite *ServiceSuite) TestRefreshCollection() {
// Test refresh all collections again when collections are loaded. This time should fail with collection not 100% loaded.
for _, collection := range suite.collections {
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loading)
err := server.refreshCollection(collection)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loading)
err := server.refreshCollection(ctx, collection)
suite.ErrorIs(err, merr.ErrCollectionNotLoaded)
}
// Test refresh all collections
for _, id := range suite.collections {
// Load and explicitly mark load percentage to 100%.
suite.updateChannelDist(id)
suite.updateChannelDist(ctx, id)
suite.updateSegmentDist(id, suite.nodes[0])
suite.updateCollectionStatus(id, querypb.LoadStatus_Loaded)
suite.updateCollectionStatus(ctx, id, querypb.LoadStatus_Loaded)
err := server.refreshCollection(id)
err := server.refreshCollection(ctx, id)
suite.NoError(err)
readyCh, err := server.targetObserver.UpdateNextTarget(id)
@ -1120,18 +1121,18 @@ func (suite *ServiceSuite) TestRefreshCollection() {
<-readyCh
// Now the refresh must be done
collection := server.meta.CollectionManager.GetCollection(id)
collection := server.meta.CollectionManager.GetCollection(ctx, id)
suite.True(collection.IsRefreshed())
}
// Test refresh not ready
for _, id := range suite.collections {
suite.updateChannelDistWithoutSegment(id)
err := server.refreshCollection(id)
suite.updateChannelDistWithoutSegment(ctx, id)
err := server.refreshCollection(ctx, id)
suite.NoError(err)
// Now the refresh must be not done
collection := server.meta.CollectionManager.GetCollection(id)
collection := server.meta.CollectionManager.GetCollection(ctx, id)
suite.False(collection.IsRefreshed())
}
}
@ -1209,11 +1210,11 @@ func (suite *ServiceSuite) TestLoadBalance() {
// Test get balance first segment
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
nodes := replicas[0].GetNodes()
srcNode := nodes[0]
dstNode := nodes[1]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded)
suite.updateSegmentDist(collection, srcNode)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
@ -1258,10 +1259,10 @@ func (suite *ServiceSuite) TestLoadBalanceWithNoDstNode() {
// Test get balance first segment
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
nodes := replicas[0].GetNodes()
srcNode := nodes[0]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded)
suite.updateSegmentDist(collection, srcNode)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
@ -1310,10 +1311,10 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() {
// update two collection's dist
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
replicas[0].AddRWNode(srcNode)
replicas[0].AddRWNode(dstNode)
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded)
for partition, segments := range suite.segments[collection] {
for _, segment := range segments {
@ -1336,9 +1337,9 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() {
}))
defer func() {
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
suite.meta.ReplicaManager.RemoveNode(replicas[0].GetID(), srcNode)
suite.meta.ReplicaManager.RemoveNode(replicas[0].GetID(), dstNode)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
suite.meta.ReplicaManager.RemoveNode(ctx, replicas[0].GetID(), srcNode)
suite.meta.ReplicaManager.RemoveNode(ctx, replicas[0].GetID(), dstNode)
}
suite.nodeMgr.Remove(1001)
suite.nodeMgr.Remove(1002)
@ -1380,7 +1381,7 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() {
// Test load balance without source node
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
dstNode := replicas[0].GetNodes()[1]
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
@ -1395,11 +1396,11 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() {
// Test load balance with not fully loaded
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
nodes := replicas[0].GetNodes()
srcNode := nodes[0]
dstNode := nodes[1]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loading)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loading)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
CollectionID: collection,
@ -1418,10 +1419,10 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() {
continue
}
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
srcNode := replicas[0].GetNodes()[0]
dstNode := replicas[1].GetNodes()[0]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded)
suite.updateSegmentDist(collection, srcNode)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
@ -1437,11 +1438,11 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() {
// Test balance task failed
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
nodes := replicas[0].GetNodes()
srcNode := nodes[0]
dstNode := nodes[1]
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded)
suite.updateSegmentDist(collection, srcNode)
segments := suite.getAllSegments(collection)
req := &querypb.LoadBalanceRequest{
@ -1458,7 +1459,7 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() {
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
suite.Contains(resp.Reason, "mock error")
suite.meta.ReplicaManager.RecoverNodesInCollection(collection, map[string]typeutil.UniqueSet{meta.DefaultResourceGroupName: typeutil.NewUniqueSet(10)})
suite.meta.ReplicaManager.RecoverNodesInCollection(ctx, collection, map[string]typeutil.UniqueSet{meta.DefaultResourceGroupName: typeutil.NewUniqueSet(10)})
req.SourceNodeIDs = []int64{10}
resp, err = server.LoadBalance(ctx, req)
suite.NoError(err)
@ -1480,7 +1481,7 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() {
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
suite.nodeMgr.Remove(10)
suite.meta.ReplicaManager.RemoveNode(replicas[0].GetID(), 10)
suite.meta.ReplicaManager.RemoveNode(ctx, replicas[0].GetID(), 10)
}
}
@ -1545,7 +1546,7 @@ func (suite *ServiceSuite) TestGetReplicas() {
server := suite.server
for _, collection := range suite.collections {
suite.updateChannelDist(collection)
suite.updateChannelDist(ctx, collection)
req := &milvuspb.GetReplicasRequest{
CollectionID: collection,
}
@ -1557,11 +1558,11 @@ func (suite *ServiceSuite) TestGetReplicas() {
// Test get with shard nodes
for _, collection := range suite.collections {
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
for _, replica := range replicas {
suite.updateSegmentDist(collection, replica.GetNodes()[0])
}
suite.updateChannelDist(collection)
suite.updateChannelDist(ctx, collection)
req := &milvuspb.GetReplicasRequest{
CollectionID: collection,
WithShardNodes: true,
@ -1582,7 +1583,7 @@ func (suite *ServiceSuite) TestGetReplicas() {
}
}
suite.Equal(len(replica.GetNodeIds()), len(suite.meta.ReplicaManager.Get(replica.ReplicaID).GetNodes()))
suite.Equal(len(replica.GetNodeIds()), len(suite.meta.ReplicaManager.Get(ctx, replica.ReplicaID).GetNodes()))
}
}
@ -1601,13 +1602,13 @@ func (suite *ServiceSuite) TestGetReplicasWhenNoAvailableNodes() {
ctx := context.Background()
server := suite.server
replicas := suite.meta.ReplicaManager.GetByCollection(suite.collections[0])
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, suite.collections[0])
for _, replica := range replicas {
suite.updateSegmentDist(suite.collections[0], replica.GetNodes()[0])
}
suite.updateChannelDist(suite.collections[0])
suite.updateChannelDist(ctx, suite.collections[0])
suite.meta.ReplicaManager.Put(utils.CreateTestReplica(100001, suite.collections[0], []int64{}))
suite.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(100001, suite.collections[0], []int64{}))
req := &milvuspb.GetReplicasRequest{
CollectionID: suite.collections[0],
@ -1660,14 +1661,14 @@ func (suite *ServiceSuite) TestCheckHealth() {
// Test for check channel ok
for _, collection := range suite.collections {
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateChannelDist(collection)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded)
suite.updateChannelDist(ctx, collection)
}
assertCheckHealthResult(true)
// Test for check channel fail
tm := meta.NewMockTargetManager(suite.T())
tm.EXPECT().GetDmChannelsByCollection(mock.Anything, mock.Anything).Return(nil).Maybe()
tm.EXPECT().GetDmChannelsByCollection(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
otm := server.targetMgr
server.targetMgr = tm
assertCheckHealthResult(true)
@ -1686,8 +1687,8 @@ func (suite *ServiceSuite) TestGetShardLeaders() {
server := suite.server
for _, collection := range suite.collections {
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateChannelDist(collection)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded)
suite.updateChannelDist(ctx, collection)
req := &querypb.GetShardLeadersRequest{
CollectionID: collection,
}
@ -1718,8 +1719,8 @@ func (suite *ServiceSuite) TestGetShardLeadersFailed() {
server := suite.server
for _, collection := range suite.collections {
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded)
suite.updateChannelDist(collection)
suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded)
suite.updateChannelDist(ctx, collection)
req := &querypb.GetShardLeadersRequest{
CollectionID: collection,
}
@ -1746,7 +1747,7 @@ func (suite *ServiceSuite) TestGetShardLeadersFailed() {
suite.dist.ChannelDistManager.Update(node)
suite.dist.LeaderViewManager.Update(node)
}
suite.updateChannelDistWithoutSegment(collection)
suite.updateChannelDistWithoutSegment(ctx, collection)
suite.fetchHeartbeats(time.Now())
resp, err = server.GetShardLeaders(ctx, req)
suite.NoError(err)
@ -1789,9 +1790,10 @@ func (suite *ServiceSuite) TestHandleNodeUp() {
suite.server.resourceObserver.Start()
defer suite.server.resourceObserver.Stop()
ctx := context.Background()
server := suite.server
suite.server.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
suite.server.meta.ReplicaManager.Put(meta.NewReplica(
suite.server.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1))
suite.server.meta.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 1,
CollectionID: 1,
@ -1812,12 +1814,12 @@ func (suite *ServiceSuite) TestHandleNodeUp() {
server.handleNodeUp(111)
// wait for async update by observer
suite.Eventually(func() bool {
nodes := suite.server.meta.ReplicaManager.Get(1).GetNodes()
nodesInRG, _ := suite.server.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName)
nodes := suite.server.meta.ReplicaManager.Get(ctx, 1).GetNodes()
nodesInRG, _ := suite.server.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName)
return len(nodes) == len(nodesInRG)
}, 5*time.Second, 100*time.Millisecond)
nodes := suite.server.meta.ReplicaManager.Get(1).GetNodes()
nodesInRG, _ := suite.server.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName)
nodes := suite.server.meta.ReplicaManager.Get(ctx, 1).GetNodes()
nodesInRG, _ := suite.server.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName)
suite.ElementsMatch(nodes, nodesInRG)
}
@ -1846,10 +1848,10 @@ func (suite *ServiceSuite) loadAll() {
suite.jobScheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(ctx, collection))
suite.True(suite.meta.Exist(ctx, collection))
suite.NotNil(suite.meta.GetCollection(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
} else {
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
@ -1871,30 +1873,30 @@ func (suite *ServiceSuite) loadAll() {
suite.jobScheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(ctx, collection))
suite.True(suite.meta.Exist(ctx, collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
}
}
}
func (suite *ServiceSuite) assertLoaded(collection int64) {
suite.True(suite.meta.Exist(collection))
func (suite *ServiceSuite) assertLoaded(ctx context.Context, collection int64) {
suite.True(suite.meta.Exist(ctx, collection))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.NextTarget))
suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.NextTarget))
}
for _, partitions := range suite.segments[collection] {
for _, segment := range partitions {
suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.NextTarget))
suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.NextTarget))
}
}
}
func (suite *ServiceSuite) assertPartitionLoaded(collection int64, partitions ...int64) {
suite.True(suite.meta.Exist(collection))
func (suite *ServiceSuite) assertPartitionLoaded(ctx context.Context, collection int64, partitions ...int64) {
suite.True(suite.meta.Exist(ctx, collection))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget))
}
partitionSet := typeutil.NewUniqueSet(partitions...)
for partition, segments := range suite.segments[collection] {
@ -1902,20 +1904,20 @@ func (suite *ServiceSuite) assertPartitionLoaded(collection int64, partitions ..
continue
}
for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget))
suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
}
}
}
func (suite *ServiceSuite) assertReleased(collection int64) {
suite.False(suite.meta.Exist(collection))
func (suite *ServiceSuite) assertReleased(ctx context.Context, collection int64) {
suite.False(suite.meta.Exist(ctx, collection))
for _, channel := range suite.channels[collection] {
suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
suite.Nil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget))
}
for _, partitions := range suite.segments[collection] {
for _, segment := range partitions {
suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget))
suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.NextTarget))
suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.NextTarget))
}
}
}
@ -1989,11 +1991,11 @@ func (suite *ServiceSuite) updateSegmentDist(collection, node int64) {
suite.dist.SegmentDistManager.Update(node, metaSegments...)
}
func (suite *ServiceSuite) updateChannelDist(collection int64) {
func (suite *ServiceSuite) updateChannelDist(ctx context.Context, collection int64) {
channels := suite.channels[collection]
segments := lo.Flatten(lo.Values(suite.segments[collection]))
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
for _, replica := range replicas {
i := 0
for _, node := range suite.sortInt64(replica.GetNodes()) {
@ -2027,10 +2029,10 @@ func (suite *ServiceSuite) sortInt64(ints []int64) []int64 {
return ints
}
func (suite *ServiceSuite) updateChannelDistWithoutSegment(collection int64) {
func (suite *ServiceSuite) updateChannelDistWithoutSegment(ctx context.Context, collection int64) {
channels := suite.channels[collection]
replicas := suite.meta.ReplicaManager.GetByCollection(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
for _, replica := range replicas {
i := 0
for _, node := range suite.sortInt64(replica.GetNodes()) {
@ -2052,8 +2054,8 @@ func (suite *ServiceSuite) updateChannelDistWithoutSegment(collection int64) {
}
}
func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status querypb.LoadStatus) {
collection := suite.meta.GetCollection(collectionID)
func (suite *ServiceSuite) updateCollectionStatus(ctx context.Context, collectionID int64, status querypb.LoadStatus) {
collection := suite.meta.GetCollection(ctx, collectionID)
if collection != nil {
collection := collection.Clone()
collection.LoadPercentage = 0
@ -2061,9 +2063,9 @@ func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status que
collection.LoadPercentage = 100
}
collection.CollectionLoadInfo.Status = status
suite.meta.PutCollection(collection)
suite.meta.PutCollection(ctx, collection)
partitions := suite.meta.GetPartitionsByCollection(collectionID)
partitions := suite.meta.GetPartitionsByCollection(ctx, collectionID)
for _, partition := range partitions {
partition := partition.Clone()
partition.LoadPercentage = 0
@ -2071,7 +2073,7 @@ func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status que
partition.LoadPercentage = 100
}
partition.PartitionLoadInfo.Status = status
suite.meta.PutPartition(partition)
suite.meta.PutPartition(ctx, partition)
}
}
}

View File

@ -208,7 +208,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
)
// get segment's replica first, then get shard leader by replica
replica := ex.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
replica := ex.meta.ReplicaManager.GetByCollectionAndNode(ctx, task.CollectionID(), action.Node())
if replica == nil {
msg := "node doesn't belong to any replica"
err := merr.WrapErrNodeNotAvailable(action.Node())
@ -259,7 +259,7 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) {
dstNode := action.Node()
req := packReleaseSegmentRequest(task, action)
channel := ex.targetMgr.GetDmChannel(task.CollectionID(), task.Shard(), meta.CurrentTarget)
channel := ex.targetMgr.GetDmChannel(ctx, task.CollectionID(), task.Shard(), meta.CurrentTarget)
if channel != nil {
// if channel exists in current target, set cp to ReleaseSegmentRequest, need to use it as growing segment's exclude ts
req.Checkpoint = channel.GetSeekPosition()
@ -272,9 +272,9 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) {
} else {
req.Shard = task.shard
if ex.meta.CollectionManager.Exist(task.CollectionID()) {
if ex.meta.CollectionManager.Exist(ctx, task.CollectionID()) {
// get segment's replica first, then get shard leader by replica
replica := ex.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
replica := ex.meta.ReplicaManager.GetByCollectionAndNode(ctx, task.CollectionID(), action.Node())
if replica == nil {
msg := "node doesn't belong to any replica, try to send release to worker"
err := merr.WrapErrNodeNotAvailable(action.Node())
@ -344,8 +344,8 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
log.Warn("failed to get collection info")
return err
}
loadFields := ex.meta.GetLoadFields(task.CollectionID())
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID())
loadFields := ex.meta.GetLoadFields(ctx, task.CollectionID())
partitions, err := utils.GetPartitions(ctx, ex.meta.CollectionManager, task.CollectionID())
if err != nil {
log.Warn("failed to get partitions of collection")
return err
@ -356,7 +356,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
return err
}
loadMeta := packLoadMeta(
ex.meta.GetLoadType(task.CollectionID()),
ex.meta.GetLoadType(ctx, task.CollectionID()),
task.CollectionID(),
collectionInfo.GetDbName(),
task.ResourceGroup(),
@ -364,7 +364,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
partitions...,
)
dmChannel := ex.targetMgr.GetDmChannel(task.CollectionID(), action.ChannelName(), meta.NextTarget)
dmChannel := ex.targetMgr.GetDmChannel(ctx, task.CollectionID(), action.ChannelName(), meta.NextTarget)
if dmChannel == nil {
msg := "channel does not exist in next target, skip it"
log.Warn(msg, zap.String("channelName", action.ChannelName()))
@ -652,15 +652,15 @@ func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.Descr
log.Warn("failed to get collection info", zap.Error(err))
return nil, nil, nil, err
}
loadFields := ex.meta.GetLoadFields(task.CollectionID())
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, collectionID)
loadFields := ex.meta.GetLoadFields(ctx, task.CollectionID())
partitions, err := utils.GetPartitions(ctx, ex.meta.CollectionManager, collectionID)
if err != nil {
log.Warn("failed to get partitions of collection", zap.Error(err))
return nil, nil, nil, err
}
loadMeta := packLoadMeta(
ex.meta.GetLoadType(task.CollectionID()),
ex.meta.GetLoadType(ctx, task.CollectionID()),
task.CollectionID(),
collectionInfo.GetDbName(),
task.ResourceGroup(),
@ -669,7 +669,7 @@ func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.Descr
)
// get channel first, in case of target updated after segment info fetched
channel := ex.targetMgr.GetDmChannel(collectionID, shard, meta.NextTargetFirst)
channel := ex.targetMgr.GetDmChannel(ctx, collectionID, shard, meta.NextTargetFirst)
if channel == nil {
return nil, nil, nil, merr.WrapErrChannelNotAvailable(shard)
}

View File

@ -385,7 +385,7 @@ func (scheduler *taskScheduler) preAdd(task Task) error {
if taskType == TaskTypeGrow {
views := scheduler.distMgr.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(task.Channel()))
nodesWithChannel := lo.Map(views, func(v *meta.LeaderView, _ int) UniqueID { return v.ID })
replicaNodeMap := utils.GroupNodesByReplica(scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithChannel)
replicaNodeMap := utils.GroupNodesByReplica(task.ctx, scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithChannel)
if _, ok := replicaNodeMap[task.ReplicaID()]; ok {
return merr.WrapErrServiceInternal("channel subscribed, it can be only balanced")
}
@ -535,7 +535,7 @@ func (scheduler *taskScheduler) calculateTaskDelta(collectionID int64, targetAct
case *SegmentAction:
// skip growing segment's count, cause doesn't know realtime row number of growing segment
if action.Scope == querypb.DataScope_Historical {
segment := scheduler.targetMgr.GetSealedSegment(collectionID, action.SegmentID, meta.NextTargetFirst)
segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst)
if segment != nil {
sum += int(segment.GetNumOfRows()) * delta
}
@ -708,14 +708,14 @@ func (scheduler *taskScheduler) isRelated(task Task, node int64) bool {
taskType := GetTaskType(task)
var segment *datapb.SegmentInfo
if taskType == TaskTypeMove || taskType == TaskTypeUpdate {
segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTarget)
segment = scheduler.targetMgr.GetSealedSegment(task.ctx, task.CollectionID(), task.SegmentID(), meta.CurrentTarget)
} else {
segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.NextTarget)
segment = scheduler.targetMgr.GetSealedSegment(task.ctx, task.CollectionID(), task.SegmentID(), meta.NextTarget)
}
if segment == nil {
continue
}
replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.ctx, task.CollectionID(), action.Node())
if replica == nil {
continue
}
@ -851,7 +851,7 @@ func (scheduler *taskScheduler) remove(task Task) {
if errors.Is(task.Err(), merr.ErrSegmentNotFound) {
log.Info("segment in target has been cleaned, trigger force update next target", zap.Int64("collectionID", task.CollectionID()))
scheduler.targetMgr.UpdateCollectionNextTarget(task.CollectionID())
scheduler.targetMgr.UpdateCollectionNextTarget(task.Context(), task.CollectionID())
}
task.Cancel(nil)
@ -884,7 +884,7 @@ func (scheduler *taskScheduler) remove(task Task) {
scheduler.updateTaskMetrics()
log.Info("task removed")
if scheduler.meta.Exist(task.CollectionID()) {
if scheduler.meta.Exist(task.Context(), task.CollectionID()) {
metrics.QueryCoordTaskLatency.WithLabelValues(fmt.Sprint(task.CollectionID()),
scheduler.getTaskMetricsLabel(task), task.Shard()).Observe(float64(task.GetTaskLatency()))
}
@ -985,7 +985,7 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) error {
return merr.WrapErrNodeOffline(action.Node())
}
taskType := GetTaskType(task)
segment := scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTargetFirst)
segment := scheduler.targetMgr.GetSealedSegment(task.ctx, task.CollectionID(), task.SegmentID(), meta.CurrentTargetFirst)
if segment == nil {
log.Warn("task stale due to the segment to load not exists in targets",
zap.Int64("segment", task.segmentID),
@ -994,7 +994,7 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) error {
return merr.WrapErrSegmentReduplicate(task.SegmentID(), "target doesn't contain this segment")
}
replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.ctx, task.CollectionID(), action.Node())
if replica == nil {
log.Warn("task stale due to replica not found")
return merr.WrapErrReplicaNotFound(task.CollectionID(), "by collectionID")
@ -1027,7 +1027,7 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) error {
log.Warn("task stale due to node offline", zap.String("channel", task.Channel()))
return merr.WrapErrNodeOffline(action.Node())
}
if scheduler.targetMgr.GetDmChannel(task.collectionID, task.Channel(), meta.NextTargetFirst) == nil {
if scheduler.targetMgr.GetDmChannel(task.ctx, task.collectionID, task.Channel(), meta.NextTargetFirst) == nil {
log.Warn("the task is stale, the channel to subscribe not exists in targets",
zap.String("channel", task.Channel()))
return merr.WrapErrChannelReduplicate(task.Channel(), "target doesn't contain this channel")
@ -1058,7 +1058,7 @@ func (scheduler *taskScheduler) checkLeaderTaskStale(task *LeaderTask) error {
}
taskType := GetTaskType(task)
segment := scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTargetFirst)
segment := scheduler.targetMgr.GetSealedSegment(task.ctx, task.CollectionID(), task.SegmentID(), meta.CurrentTargetFirst)
if segment == nil {
log.Warn("task stale due to the segment to load not exists in targets",
zap.Int64("segment", task.segmentID),
@ -1067,7 +1067,7 @@ func (scheduler *taskScheduler) checkLeaderTaskStale(task *LeaderTask) error {
return merr.WrapErrSegmentReduplicate(task.SegmentID(), "target doesn't contain this segment")
}
replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.ctx, task.CollectionID(), action.Node())
if replica == nil {
log.Warn("task stale due to replica not found")
return merr.WrapErrReplicaNotFound(task.CollectionID(), "by collectionID")

View File

@ -86,6 +86,7 @@ type TaskSuite struct {
// Test object
scheduler *taskScheduler
ctx context.Context
}
func (suite *TaskSuite) SetupSuite() {
@ -133,6 +134,7 @@ func (suite *TaskSuite) SetupSuite() {
segments: typeutil.NewSet[int64](),
},
}
suite.ctx = context.Background()
}
func (suite *TaskSuite) TearDownSuite() {
@ -193,20 +195,20 @@ func (suite *TaskSuite) BeforeTest(suiteName, testName string) {
"TestLeaderTaskSet",
"TestLeaderTaskRemove",
"TestNoExecutor":
suite.meta.PutCollection(&meta.Collection{
suite.meta.PutCollection(suite.ctx, &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: suite.collection,
ReplicaNumber: 1,
Status: querypb.LoadStatus_Loading,
},
})
suite.meta.PutPartition(&meta.Partition{
suite.meta.PutPartition(suite.ctx, &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: suite.collection,
PartitionID: 1,
},
})
suite.meta.ReplicaManager.Put(utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{1, 2, 3}))
suite.meta.ReplicaManager.Put(suite.ctx, utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{1, 2, 3}))
}
}
@ -276,7 +278,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(dmChannels, nil, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
suite.AssertTaskNum(0, len(suite.subChannels), len(suite.subChannels), 0)
// Process tasks
@ -371,7 +373,7 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() {
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(dmChannels, nil, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
// Only first channel exists
suite.dist.LeaderViewManager.Update(targetNode, &meta.LeaderView{
@ -463,7 +465,7 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -564,7 +566,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskNotIndex() {
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -658,7 +660,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -874,8 +876,8 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
tasks = append(tasks, task)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionCurrentTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
suite.target.UpdateCollectionCurrentTarget(ctx, suite.collection)
suite.dist.SegmentDistManager.Update(sourceNode, segments...)
suite.dist.LeaderViewManager.Update(leader, view)
for _, task := range tasks {
@ -958,8 +960,8 @@ func (suite *TaskSuite) TestMoveSegmentTaskStale() {
tasks = append(tasks, task)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionCurrentTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
suite.target.UpdateCollectionCurrentTarget(ctx, suite.collection)
suite.dist.LeaderViewManager.Update(leader, view)
for _, task := range tasks {
err := suite.scheduler.Add(task)
@ -1039,8 +1041,8 @@ func (suite *TaskSuite) TestTaskCanceled() {
segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segmentInfos, nil)
suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collection, partition))
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(suite.collection, partition))
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
// Process tasks
suite.dispatchAndWait(targetNode)
@ -1100,7 +1102,7 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil)
// Test load segment task
suite.meta.ReplicaManager.Put(createReplica(suite.collection, targetNode))
suite.meta.ReplicaManager.Put(ctx, createReplica(suite.collection, targetNode))
suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: channel.ChannelName,
@ -1128,8 +1130,8 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil)
suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collection, partition))
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(suite.collection, partition))
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -1166,8 +1168,8 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0]
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil)
suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collection, 2))
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(suite.collection, 2))
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(0, 0, 0, 0)
@ -1306,7 +1308,7 @@ func (suite *TaskSuite) TestLeaderTaskSet() {
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -1452,7 +1454,7 @@ func (suite *TaskSuite) TestNoExecutor() {
ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test",
}
suite.meta.ReplicaManager.Put(utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{1, 2, 3, -1}))
suite.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{1, 2, 3, -1}))
// Test load segment task
suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{
@ -1479,7 +1481,7 @@ func (suite *TaskSuite) TestNoExecutor() {
suite.NoError(err)
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil)
suite.target.UpdateCollectionNextTarget(suite.collection)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -1625,6 +1627,7 @@ func createReplica(collection int64, nodes ...int64) *meta.Replica {
}
func (suite *TaskSuite) TestBalanceChannelTask() {
ctx := context.Background()
collectionID := int64(1)
partitionID := int64(1)
channel := "channel-1"
@ -1653,12 +1656,12 @@ func (suite *TaskSuite) TestBalanceChannelTask() {
InsertChannel: channel,
},
}
suite.meta.PutCollection(utils.CreateTestCollection(collectionID, 1), utils.CreateTestPartition(collectionID, 1))
suite.meta.PutCollection(ctx, utils.CreateTestCollection(collectionID, 1), utils.CreateTestPartition(collectionID, 1))
suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return([]*datapb.VchannelInfo{vchannel}, segments, nil)
suite.target.UpdateCollectionNextTarget(collectionID)
suite.target.UpdateCollectionCurrentTarget(collectionID)
suite.target.UpdateCollectionNextTarget(collectionID)
suite.target.UpdateCollectionNextTarget(ctx, collectionID)
suite.target.UpdateCollectionCurrentTarget(ctx, collectionID)
suite.target.UpdateCollectionNextTarget(ctx, collectionID)
suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{
ID: 2,
@ -1712,6 +1715,7 @@ func (suite *TaskSuite) TestBalanceChannelTask() {
}
func (suite *TaskSuite) TestBalanceChannelWithL0SegmentTask() {
ctx := context.Background()
collectionID := int64(1)
partitionID := int64(1)
channel := "channel-1"
@ -1743,12 +1747,12 @@ func (suite *TaskSuite) TestBalanceChannelWithL0SegmentTask() {
Level: datapb.SegmentLevel_L0,
},
}
suite.meta.PutCollection(utils.CreateTestCollection(collectionID, 1), utils.CreateTestPartition(collectionID, 1))
suite.meta.PutCollection(ctx, utils.CreateTestCollection(collectionID, 1), utils.CreateTestPartition(collectionID, 1))
suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return([]*datapb.VchannelInfo{vchannel}, segments, nil)
suite.target.UpdateCollectionNextTarget(collectionID)
suite.target.UpdateCollectionCurrentTarget(collectionID)
suite.target.UpdateCollectionNextTarget(collectionID)
suite.target.UpdateCollectionNextTarget(ctx, collectionID)
suite.target.UpdateCollectionCurrentTarget(ctx, collectionID)
suite.target.UpdateCollectionNextTarget(ctx, collectionID)
suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{
ID: 2,

View File

@ -17,6 +17,7 @@
package utils
import (
"context"
"strings"
"github.com/cockroachdb/errors"
@ -29,10 +30,10 @@ import (
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func GetPartitions(collectionMgr *meta.CollectionManager, collectionID int64) ([]int64, error) {
collection := collectionMgr.GetCollection(collectionID)
func GetPartitions(ctx context.Context, collectionMgr *meta.CollectionManager, collectionID int64) ([]int64, error) {
collection := collectionMgr.GetCollection(ctx, collectionID)
if collection != nil {
partitions := collectionMgr.GetPartitionsByCollection(collectionID)
partitions := collectionMgr.GetPartitionsByCollection(ctx, collectionID)
if partitions != nil {
return lo.Map(partitions, func(partition *meta.Partition, i int) int64 {
return partition.PartitionID
@ -45,9 +46,9 @@ func GetPartitions(collectionMgr *meta.CollectionManager, collectionID int64) ([
// GroupNodesByReplica groups nodes by replica,
// returns ReplicaID -> NodeIDs
func GroupNodesByReplica(replicaMgr *meta.ReplicaManager, collectionID int64, nodes []int64) map[int64][]int64 {
func GroupNodesByReplica(ctx context.Context, replicaMgr *meta.ReplicaManager, collectionID int64, nodes []int64) map[int64][]int64 {
ret := make(map[int64][]int64)
replicas := replicaMgr.GetByCollection(collectionID)
replicas := replicaMgr.GetByCollection(ctx, collectionID)
for _, replica := range replicas {
for _, node := range nodes {
if replica.Contains(node) {
@ -71,9 +72,9 @@ func GroupPartitionsByCollection(partitions []*meta.Partition) map[int64][]*meta
// GroupSegmentsByReplica groups segments by replica,
// returns ReplicaID -> Segments
func GroupSegmentsByReplica(replicaMgr *meta.ReplicaManager, collectionID int64, segments []*meta.Segment) map[int64][]*meta.Segment {
func GroupSegmentsByReplica(ctx context.Context, replicaMgr *meta.ReplicaManager, collectionID int64, segments []*meta.Segment) map[int64][]*meta.Segment {
ret := make(map[int64][]*meta.Segment)
replicas := replicaMgr.GetByCollection(collectionID)
replicas := replicaMgr.GetByCollection(ctx, collectionID)
for _, replica := range replicas {
for _, segment := range segments {
if replica.Contains(segment.Node) {
@ -85,32 +86,32 @@ func GroupSegmentsByReplica(replicaMgr *meta.ReplicaManager, collectionID int64,
}
// RecoverReplicaOfCollection recovers all replica of collection with latest resource group.
func RecoverReplicaOfCollection(m *meta.Meta, collectionID typeutil.UniqueID) {
func RecoverReplicaOfCollection(ctx context.Context, m *meta.Meta, collectionID typeutil.UniqueID) {
logger := log.With(zap.Int64("collectionID", collectionID))
rgNames := m.ReplicaManager.GetResourceGroupByCollection(collectionID)
rgNames := m.ReplicaManager.GetResourceGroupByCollection(ctx, collectionID)
if rgNames.Len() == 0 {
logger.Error("no resource group found for collection", zap.Int64("collectionID", collectionID))
return
}
rgs, err := m.ResourceManager.GetNodesOfMultiRG(rgNames.Collect())
rgs, err := m.ResourceManager.GetNodesOfMultiRG(ctx, rgNames.Collect())
if err != nil {
logger.Error("unreachable code as expected, fail to get resource group for replica", zap.Error(err))
return
}
if err := m.ReplicaManager.RecoverNodesInCollection(collectionID, rgs); err != nil {
if err := m.ReplicaManager.RecoverNodesInCollection(ctx, collectionID, rgs); err != nil {
logger.Warn("fail to set available nodes in replica", zap.Error(err))
}
}
// RecoverAllCollectionrecovers all replica of all collection in resource group.
func RecoverAllCollection(m *meta.Meta) {
for _, collection := range m.CollectionManager.GetAll() {
RecoverReplicaOfCollection(m, collection)
for _, collection := range m.CollectionManager.GetAll(context.TODO()) {
RecoverReplicaOfCollection(context.TODO(), m, collection)
}
}
func AssignReplica(m *meta.Meta, resourceGroups []string, replicaNumber int32, checkNodeNum bool) (map[string]int, error) {
func AssignReplica(ctx context.Context, m *meta.Meta, resourceGroups []string, replicaNumber int32, checkNodeNum bool) (map[string]int, error) {
if len(resourceGroups) != 0 && len(resourceGroups) != 1 && len(resourceGroups) != int(replicaNumber) {
return nil, errors.Errorf(
"replica=[%d] resource group=[%s], resource group num can only be 0, 1 or same as replica number", replicaNumber, strings.Join(resourceGroups, ","))
@ -135,10 +136,10 @@ func AssignReplica(m *meta.Meta, resourceGroups []string, replicaNumber int32, c
// 2. rg1 is removed.
// 3. replica1 spawn finished, but cannot find related resource group.
for rgName, num := range replicaNumInRG {
if !m.ContainResourceGroup(rgName) {
if !m.ContainResourceGroup(ctx, rgName) {
return nil, merr.WrapErrResourceGroupNotFound(rgName)
}
nodes, err := m.ResourceManager.GetNodes(rgName)
nodes, err := m.ResourceManager.GetNodes(ctx, rgName)
if err != nil {
return nil, err
}
@ -155,35 +156,36 @@ func AssignReplica(m *meta.Meta, resourceGroups []string, replicaNumber int32, c
}
// SpawnReplicasWithRG spawns replicas in rgs one by one for given collection.
func SpawnReplicasWithRG(m *meta.Meta, collection int64, resourceGroups []string, replicaNumber int32, channels []string) ([]*meta.Replica, error) {
replicaNumInRG, err := AssignReplica(m, resourceGroups, replicaNumber, true)
func SpawnReplicasWithRG(ctx context.Context, m *meta.Meta, collection int64, resourceGroups []string, replicaNumber int32, channels []string) ([]*meta.Replica, error) {
replicaNumInRG, err := AssignReplica(ctx, m, resourceGroups, replicaNumber, true)
if err != nil {
return nil, err
}
// Spawn it in replica manager.
replicas, err := m.ReplicaManager.Spawn(collection, replicaNumInRG, channels)
replicas, err := m.ReplicaManager.Spawn(ctx, collection, replicaNumInRG, channels)
if err != nil {
return nil, err
}
// Active recover it.
RecoverReplicaOfCollection(m, collection)
RecoverReplicaOfCollection(ctx, m, collection)
return replicas, nil
}
func ReassignReplicaToRG(
ctx context.Context,
m *meta.Meta,
collectionID int64,
newReplicaNumber int32,
newResourceGroups []string,
) (map[string]int, map[string][]*meta.Replica, []int64, error) {
// assign all replicas to newResourceGroups, got each rg's replica number
newAssignment, err := AssignReplica(m, newResourceGroups, newReplicaNumber, false)
newAssignment, err := AssignReplica(ctx, m, newResourceGroups, newReplicaNumber, false)
if err != nil {
return nil, nil, nil, err
}
replicas := m.ReplicaManager.GetByCollection(collectionID)
replicas := m.ReplicaManager.GetByCollection(context.TODO(), collectionID)
replicasInRG := lo.GroupBy(replicas, func(replica *meta.Replica) string {
return replica.GetResourceGroup()
})

View File

@ -17,6 +17,7 @@
package utils
import (
"context"
"testing"
"github.com/cockroachdb/errors"
@ -51,18 +52,19 @@ func TestSpawnReplicasWithRG(t *testing.T) {
require.NoError(t, err)
kv := etcdKV.NewEtcdKV(cli, config.MetaRootPath.GetValue())
ctx := context.Background()
store := querycoord.NewCatalog(kv)
nodeMgr := session.NewNodeManager()
m := meta.NewMeta(RandomIncrementIDAllocator(), store, nodeMgr)
m.ResourceManager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{
m.ResourceManager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 3},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 3},
})
m.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{
m.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 3},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 3},
})
m.ResourceManager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{
m.ResourceManager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 3},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 3},
})
@ -74,13 +76,13 @@ func TestSpawnReplicasWithRG(t *testing.T) {
Hostname: "localhost",
}))
if i%3 == 0 {
m.ResourceManager.HandleNodeUp(int64(i))
m.ResourceManager.HandleNodeUp(ctx, int64(i))
}
if i%3 == 1 {
m.ResourceManager.HandleNodeUp(int64(i))
m.ResourceManager.HandleNodeUp(ctx, int64(i))
}
if i%3 == 2 {
m.ResourceManager.HandleNodeUp(int64(i))
m.ResourceManager.HandleNodeUp(ctx, int64(i))
}
}
@ -120,7 +122,7 @@ func TestSpawnReplicasWithRG(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := SpawnReplicasWithRG(tt.args.m, tt.args.collection, tt.args.resourceGroups, tt.args.replicaNumber, nil)
got, err := SpawnReplicasWithRG(ctx, tt.args.m, tt.args.collection, tt.args.resourceGroups, tt.args.replicaNumber, nil)
if (err != nil) != tt.wantErr {
t.Errorf("SpawnReplicasWithRG() error = %v, wantErr %v", err, tt.wantErr)
return
@ -135,21 +137,22 @@ func TestSpawnReplicasWithRG(t *testing.T) {
func TestAddNodesToCollectionsInRGFailed(t *testing.T) {
paramtable.Init()
ctx := context.Background()
store := mocks.NewQueryCoordCatalog(t)
store.EXPECT().SaveCollection(mock.Anything).Return(nil)
store.EXPECT().SaveReplica(mock.Anything).Return(nil).Times(4)
store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil)
store.EXPECT().SaveCollection(mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil).Times(4)
store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(nil)
nodeMgr := session.NewNodeManager()
m := meta.NewMeta(RandomIncrementIDAllocator(), store, nodeMgr)
m.ResourceManager.AddResourceGroup("rg", &rgpb.ResourceGroupConfig{
m.ResourceManager.AddResourceGroup(ctx, "rg", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 0},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 0},
})
m.CollectionManager.PutCollection(CreateTestCollection(1, 2))
m.CollectionManager.PutCollection(CreateTestCollection(2, 2))
m.ReplicaManager.Put(meta.NewReplica(
m.CollectionManager.PutCollection(ctx, CreateTestCollection(1, 2))
m.CollectionManager.PutCollection(ctx, CreateTestCollection(2, 2))
m.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 1,
CollectionID: 1,
@ -159,7 +162,7 @@ func TestAddNodesToCollectionsInRGFailed(t *testing.T) {
typeutil.NewUniqueSet(),
))
m.ReplicaManager.Put(meta.NewReplica(
m.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 2,
CollectionID: 1,
@ -169,7 +172,7 @@ func TestAddNodesToCollectionsInRGFailed(t *testing.T) {
typeutil.NewUniqueSet(),
))
m.ReplicaManager.Put(meta.NewReplica(
m.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 3,
CollectionID: 2,
@ -179,7 +182,7 @@ func TestAddNodesToCollectionsInRGFailed(t *testing.T) {
typeutil.NewUniqueSet(),
))
m.ReplicaManager.Put(meta.NewReplica(
m.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 4,
CollectionID: 2,
@ -190,33 +193,34 @@ func TestAddNodesToCollectionsInRGFailed(t *testing.T) {
))
storeErr := errors.New("store error")
store.EXPECT().SaveReplica(mock.Anything).Return(storeErr)
store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(storeErr)
RecoverAllCollection(m)
assert.Len(t, m.ReplicaManager.Get(1).GetNodes(), 0)
assert.Len(t, m.ReplicaManager.Get(2).GetNodes(), 0)
assert.Len(t, m.ReplicaManager.Get(3).GetNodes(), 0)
assert.Len(t, m.ReplicaManager.Get(4).GetNodes(), 0)
assert.Len(t, m.ReplicaManager.Get(ctx, 1).GetNodes(), 0)
assert.Len(t, m.ReplicaManager.Get(ctx, 2).GetNodes(), 0)
assert.Len(t, m.ReplicaManager.Get(ctx, 3).GetNodes(), 0)
assert.Len(t, m.ReplicaManager.Get(ctx, 4).GetNodes(), 0)
}
func TestAddNodesToCollectionsInRG(t *testing.T) {
paramtable.Init()
ctx := context.Background()
store := mocks.NewQueryCoordCatalog(t)
store.EXPECT().SaveCollection(mock.Anything).Return(nil)
store.EXPECT().SaveReplica(mock.Anything).Return(nil)
store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil)
store.EXPECT().SaveCollection(mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveReplica(mock.Anything, mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveReplica(mock.Anything, mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil)
store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(nil)
nodeMgr := session.NewNodeManager()
m := meta.NewMeta(RandomIncrementIDAllocator(), store, nodeMgr)
m.ResourceManager.AddResourceGroup("rg", &rgpb.ResourceGroupConfig{
m.ResourceManager.AddResourceGroup(ctx, "rg", &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{NodeNum: 4},
Limits: &rgpb.ResourceGroupLimit{NodeNum: 4},
})
m.CollectionManager.PutCollection(CreateTestCollection(1, 2))
m.CollectionManager.PutCollection(CreateTestCollection(2, 2))
m.ReplicaManager.Put(meta.NewReplica(
m.CollectionManager.PutCollection(ctx, CreateTestCollection(1, 2))
m.CollectionManager.PutCollection(ctx, CreateTestCollection(2, 2))
m.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 1,
CollectionID: 1,
@ -226,7 +230,7 @@ func TestAddNodesToCollectionsInRG(t *testing.T) {
typeutil.NewUniqueSet(),
))
m.ReplicaManager.Put(meta.NewReplica(
m.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 2,
CollectionID: 1,
@ -236,7 +240,7 @@ func TestAddNodesToCollectionsInRG(t *testing.T) {
typeutil.NewUniqueSet(),
))
m.ReplicaManager.Put(meta.NewReplica(
m.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 3,
CollectionID: 2,
@ -246,7 +250,7 @@ func TestAddNodesToCollectionsInRG(t *testing.T) {
typeutil.NewUniqueSet(),
))
m.ReplicaManager.Put(meta.NewReplica(
m.ReplicaManager.Put(ctx, meta.NewReplica(
&querypb.Replica{
ID: 4,
CollectionID: 2,
@ -262,12 +266,12 @@ func TestAddNodesToCollectionsInRG(t *testing.T) {
Address: "127.0.0.1",
Hostname: "localhost",
}))
m.ResourceManager.HandleNodeUp(nodeID)
m.ResourceManager.HandleNodeUp(ctx, nodeID)
}
RecoverAllCollection(m)
assert.Len(t, m.ReplicaManager.Get(1).GetNodes(), 2)
assert.Len(t, m.ReplicaManager.Get(2).GetNodes(), 2)
assert.Len(t, m.ReplicaManager.Get(3).GetNodes(), 2)
assert.Len(t, m.ReplicaManager.Get(4).GetNodes(), 2)
assert.Len(t, m.ReplicaManager.Get(ctx, 1).GetNodes(), 2)
assert.Len(t, m.ReplicaManager.Get(ctx, 2).GetNodes(), 2)
assert.Len(t, m.ReplicaManager.Get(ctx, 3).GetNodes(), 2)
assert.Len(t, m.ReplicaManager.Get(ctx, 4).GetNodes(), 2)
}

View File

@ -68,7 +68,7 @@ func CheckDelegatorDataReady(nodeMgr *session.NodeManager, targetMgr meta.Target
return err
}
}
segmentDist := targetMgr.GetSealedSegmentsByChannel(leader.CollectionID, leader.Channel, scope)
segmentDist := targetMgr.GetSealedSegmentsByChannel(context.TODO(), leader.CollectionID, leader.Channel, scope)
// Check whether segments are fully loaded
for segmentID, info := range segmentDist {
_, exist := leader.Segments[segmentID]
@ -87,13 +87,13 @@ func CheckDelegatorDataReady(nodeMgr *session.NodeManager, targetMgr meta.Target
}
func checkLoadStatus(ctx context.Context, m *meta.Meta, collectionID int64) error {
percentage := m.CollectionManager.CalculateLoadPercentage(collectionID)
percentage := m.CollectionManager.CalculateLoadPercentage(ctx, collectionID)
if percentage < 0 {
err := merr.WrapErrCollectionNotLoaded(collectionID)
log.Ctx(ctx).Warn("failed to GetShardLeaders", zap.Error(err))
return err
}
collection := m.CollectionManager.GetCollection(collectionID)
collection := m.CollectionManager.GetCollection(ctx, collectionID)
if collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded {
// when collection is loaded, regard collection as readable, set percentage == 100
percentage = 100
@ -108,7 +108,7 @@ func checkLoadStatus(ctx context.Context, m *meta.Meta, collectionID int64) erro
return nil
}
func GetShardLeadersWithChannels(m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager,
func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager,
nodeMgr *session.NodeManager, collectionID int64, channels map[string]*meta.DmChannel,
) ([]*querypb.ShardLeadersList, error) {
ret := make([]*querypb.ShardLeadersList, 0)
@ -137,7 +137,7 @@ func GetShardLeadersWithChannels(m *meta.Meta, targetMgr meta.TargetManagerInter
return nil, err
}
readableLeaders = filterDupLeaders(m.ReplicaManager, readableLeaders)
readableLeaders = filterDupLeaders(ctx, m.ReplicaManager, readableLeaders)
ids := make([]int64, 0, len(leaders))
addrs := make([]string, 0, len(leaders))
for _, leader := range readableLeaders {
@ -174,26 +174,26 @@ func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetMan
return nil, err
}
channels := targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget)
channels := targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.CurrentTarget)
if len(channels) == 0 {
msg := "loaded collection do not found any channel in target, may be in recovery"
err := merr.WrapErrCollectionOnRecovering(collectionID, msg)
log.Ctx(ctx).Warn("failed to get channels", zap.Error(err))
return nil, err
}
return GetShardLeadersWithChannels(m, targetMgr, dist, nodeMgr, collectionID, channels)
return GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels)
}
// CheckCollectionsQueryable check all channels are watched and all segments are loaded for this collection
func CheckCollectionsQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager) error {
maxInterval := paramtable.Get().QueryCoordCfg.UpdateCollectionLoadStatusInterval.GetAsDuration(time.Minute)
for _, coll := range m.GetAllCollections() {
for _, coll := range m.GetAllCollections(ctx) {
err := checkCollectionQueryable(ctx, m, targetMgr, dist, nodeMgr, coll)
// the collection is not queryable, if meet following conditions:
// 1. Some segments are not loaded
// 2. Collection is not starting to release
// 3. The load percentage has not been updated in the last 5 minutes.
if err != nil && m.Exist(coll.CollectionID) && time.Since(coll.UpdatedAt) >= maxInterval {
if err != nil && m.Exist(ctx, coll.CollectionID) && time.Since(coll.UpdatedAt) >= maxInterval {
log.Ctx(ctx).Warn("collection not querable",
zap.Int64("collectionID", coll.CollectionID),
zap.Time("lastUpdated", coll.UpdatedAt),
@ -212,7 +212,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.
return err
}
channels := targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget)
channels := targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.CurrentTarget)
if len(channels) == 0 {
msg := "loaded collection do not found any channel in target, may be in recovery"
err := merr.WrapErrCollectionOnRecovering(collectionID, msg)
@ -220,7 +220,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.
return err
}
shardList, err := GetShardLeadersWithChannels(m, targetMgr, dist, nodeMgr, collectionID, channels)
shardList, err := GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels)
if err != nil {
return err
}
@ -232,7 +232,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.
return nil
}
func filterDupLeaders(replicaManager *meta.ReplicaManager, leaders map[int64]*meta.LeaderView) map[int64]*meta.LeaderView {
func filterDupLeaders(ctx context.Context, replicaManager *meta.ReplicaManager, leaders map[int64]*meta.LeaderView) map[int64]*meta.LeaderView {
type leaderID struct {
ReplicaID int64
Shard string
@ -240,7 +240,7 @@ func filterDupLeaders(replicaManager *meta.ReplicaManager, leaders map[int64]*me
newLeaders := make(map[leaderID]*meta.LeaderView)
for _, view := range leaders {
replica := replicaManager.GetByCollectionAndNode(view.CollectionID, view.ID)
replica := replicaManager.GetByCollectionAndNode(ctx, view.CollectionID, view.ID)
if replica == nil {
continue
}

View File

@ -59,13 +59,13 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliable() {
}
mockTargetManager := meta.NewMockTargetManager(suite.T())
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
2: {
ID: 2,
InsertChannel: "test",
},
}).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe()
suite.setNodeAvailable(1, 2)
err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget)
@ -81,13 +81,13 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliableFailed() {
TargetVersion: 1011,
}
mockTargetManager := meta.NewMockTargetManager(suite.T())
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
2: {
ID: 2,
InsertChannel: "test",
},
}).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe()
// leader nodeID=1 not available
suite.setNodeAvailable(2)
err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget)
@ -103,13 +103,13 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliableFailed() {
}
mockTargetManager := meta.NewMockTargetManager(suite.T())
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
2: {
ID: 2,
InsertChannel: "test",
},
}).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe()
// leader nodeID=2 not available
suite.setNodeAvailable(1)
err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget)
@ -124,14 +124,14 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliableFailed() {
TargetVersion: 1011,
}
mockTargetManager := meta.NewMockTargetManager(suite.T())
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
// target segmentID=1 not in leadView
1: {
ID: 1,
InsertChannel: "test",
},
}).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe()
suite.setNodeAvailable(1, 2)
err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget)
suite.Error(err)
@ -144,14 +144,14 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliableFailed() {
Segments: map[int64]*querypb.SegmentDist{2: {NodeID: 2}},
}
mockTargetManager := meta.NewMockTargetManager(suite.T())
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{
// target segmentID=1 not in leadView
1: {
ID: 1,
InsertChannel: "test",
},
}).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe()
mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe()
suite.setNodeAvailable(1, 2)
err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget)
suite.Error(err)