enhance: Save collection targets by batches (#31616)

See also #28491 #31240

When colleciton number is large, querycoord saves collection target one
by one, which is slow and may block querycoord exits.

In local run, 500 collections scenario may lead to about 40 seconds
saving collection targets.

This PR changes the `SaveCollectionTarget` interface into batch one and
organizes the collection in 16 per bundle batches to accelerate this
procedure.

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2024-03-27 00:09:08 +08:00 committed by GitHub
parent 248c923e59
commit 8e5865f630
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 54 deletions

View File

@ -165,7 +165,7 @@ type QueryCoordCatalog interface {
RemoveResourceGroup(rgName string) error
GetResourceGroups() ([]*querypb.ResourceGroup, error)
SaveCollectionTarget(target *querypb.CollectionTarget) error
SaveCollectionTargets(target ...*querypb.CollectionTarget) error
RemoveCollectionTarget(collectionID int64) error
GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error)
}

View File

@ -241,16 +241,21 @@ func (s Catalog) ReleaseReplica(collection, replica int64) error {
return s.cli.Remove(key)
}
func (s Catalog) SaveCollectionTarget(target *querypb.CollectionTarget) error {
k := encodeCollectionTargetKey(target.GetCollectionID())
v, err := proto.Marshal(target)
if err != nil {
return err
func (s Catalog) SaveCollectionTargets(targets ...*querypb.CollectionTarget) error {
kvs := make(map[string]string)
for _, target := range targets {
k := encodeCollectionTargetKey(target.GetCollectionID())
v, err := proto.Marshal(target)
if err != nil {
return err
}
var compressed bytes.Buffer
compressor.ZstdCompress(bytes.NewReader(v), io.Writer(&compressed), zstd.WithEncoderLevel(zstd.SpeedBetterCompression))
kvs[k] = compressed.String()
}
// to reduce the target size, we do compress before write to etcd
var compressed bytes.Buffer
compressor.ZstdCompress(bytes.NewReader(v), io.Writer(&compressed), zstd.WithEncoderLevel(zstd.SpeedBetterCompression))
err = s.cli.Save(k, compressed.String())
err := s.cli.MultiSave(kvs)
if err != nil {
return err
}

View File

@ -203,22 +203,22 @@ func (suite *CatalogTestSuite) TestResourceGroup() {
}
func (suite *CatalogTestSuite) TestCollectionTarget() {
suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{
suite.catalog.SaveCollectionTargets(&querypb.CollectionTarget{
CollectionID: 1,
Version: 1,
})
suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{
CollectionID: 2,
Version: 2,
})
suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{
CollectionID: 3,
Version: 3,
})
suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{
CollectionID: 1,
Version: 4,
})
},
&querypb.CollectionTarget{
CollectionID: 2,
Version: 2,
},
&querypb.CollectionTarget{
CollectionID: 3,
Version: 3,
},
&querypb.CollectionTarget{
CollectionID: 1,
Version: 4,
})
suite.catalog.RemoveCollectionTarget(2)
targets, err := suite.catalog.GetCollectionTargets()
@ -230,18 +230,18 @@ func (suite *CatalogTestSuite) TestCollectionTarget() {
// test access meta store failed
mockStore := mocks.NewMetaKv(suite.T())
mockErr := errors.New("failed to access etcd")
mockStore.EXPECT().Save(mock.Anything, mock.Anything).Return(mockErr)
mockStore.EXPECT().MultiSave(mock.Anything).Return(mockErr)
mockStore.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, mockErr)
suite.catalog.cli = mockStore
err = suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{})
err = suite.catalog.SaveCollectionTargets(&querypb.CollectionTarget{})
suite.ErrorIs(err, mockErr)
_, err = suite.catalog.GetCollectionTargets()
suite.ErrorIs(err, mockErr)
// test invalid message
err = suite.catalog.SaveCollectionTarget(nil)
err = suite.catalog.SaveCollectionTargets(nil)
suite.Error(err)
}

View File

@ -610,13 +610,19 @@ func (_c *QueryCoordCatalog_SaveCollection_Call) RunAndReturn(run func(*querypb.
return _c
}
// SaveCollectionTarget provides a mock function with given fields: target
func (_m *QueryCoordCatalog) SaveCollectionTarget(target *querypb.CollectionTarget) error {
ret := _m.Called(target)
// SaveCollectionTargets provides a mock function with given fields: target
func (_m *QueryCoordCatalog) SaveCollectionTargets(target ...*querypb.CollectionTarget) error {
_va := make([]interface{}, len(target))
for _i := range target {
_va[_i] = target[_i]
}
var _ca []interface{}
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 error
if rf, ok := ret.Get(0).(func(*querypb.CollectionTarget) error); ok {
r0 = rf(target)
if rf, ok := ret.Get(0).(func(...*querypb.CollectionTarget) error); ok {
r0 = rf(target...)
} else {
r0 = ret.Error(0)
}
@ -624,30 +630,37 @@ func (_m *QueryCoordCatalog) SaveCollectionTarget(target *querypb.CollectionTarg
return r0
}
// QueryCoordCatalog_SaveCollectionTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCollectionTarget'
type QueryCoordCatalog_SaveCollectionTarget_Call struct {
// QueryCoordCatalog_SaveCollectionTargets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCollectionTargets'
type QueryCoordCatalog_SaveCollectionTargets_Call struct {
*mock.Call
}
// SaveCollectionTarget is a helper method to define mock.On call
// - target *querypb.CollectionTarget
func (_e *QueryCoordCatalog_Expecter) SaveCollectionTarget(target interface{}) *QueryCoordCatalog_SaveCollectionTarget_Call {
return &QueryCoordCatalog_SaveCollectionTarget_Call{Call: _e.mock.On("SaveCollectionTarget", target)}
// SaveCollectionTargets is a helper method to define mock.On call
// - target ...*querypb.CollectionTarget
func (_e *QueryCoordCatalog_Expecter) SaveCollectionTargets(target ...interface{}) *QueryCoordCatalog_SaveCollectionTargets_Call {
return &QueryCoordCatalog_SaveCollectionTargets_Call{Call: _e.mock.On("SaveCollectionTargets",
append([]interface{}{}, target...)...)}
}
func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) Run(run func(target *querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTarget_Call {
func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Run(run func(target ...*querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTargets_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*querypb.CollectionTarget))
variadicArgs := make([]*querypb.CollectionTarget, len(args)-0)
for i, a := range args[0:] {
if a != nil {
variadicArgs[i] = a.(*querypb.CollectionTarget)
}
}
run(variadicArgs...)
})
return _c
}
func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) Return(_a0 error) *QueryCoordCatalog_SaveCollectionTarget_Call {
func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Return(_a0 error) *QueryCoordCatalog_SaveCollectionTargets_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) RunAndReturn(run func(*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTarget_Call {
func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) RunAndReturn(run func(...*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTargets_Call {
_c.Call.Return(run)
return _c
}

View File

@ -19,6 +19,7 @@ package meta
import (
"context"
"fmt"
"runtime"
"sync"
"github.com/cockroachdb/errors"
@ -28,9 +29,11 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/retry"
@ -594,13 +597,38 @@ func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog)
mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock()
if mgr.current != nil {
// use pool here to control maximal writer used by save target
pool := conc.NewPool[any](runtime.GOMAXPROCS(0) * 2)
// use batch write in case of the number of collections is large
batchSize := 16
var wg sync.WaitGroup
submit := func(tasks []typeutil.Pair[int64, *querypb.CollectionTarget]) {
wg.Add(1)
pool.Submit(func() (any, error) {
defer wg.Done()
ids := lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) int64 { return p.A })
if err := catalog.SaveCollectionTargets(lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) *querypb.CollectionTarget {
return p.B
})...); err != nil {
log.Warn("failed to save current target for collection", zap.Int64s("collectionIDs", ids), zap.Error(err))
} else {
log.Info("succeed to save current target for collection", zap.Int64s("collectionIDs", ids))
}
return nil, nil
})
}
tasks := make([]typeutil.Pair[int64, *querypb.CollectionTarget], 0, batchSize)
for id, target := range mgr.current.collectionTargetMap {
if err := catalog.SaveCollectionTarget(target.toPbMsg()); err != nil {
log.Warn("failed to save current target for collection", zap.Int64("collectionID", id), zap.Error(err))
} else {
log.Warn("succeed to save current target for collection", zap.Int64("collectionID", id))
tasks = append(tasks, typeutil.NewPair(id, target.toPbMsg()))
if len(tasks) >= batchSize {
submit(tasks)
tasks = make([]typeutil.Pair[int64, *querypb.CollectionTarget], 0, batchSize)
}
}
if len(tasks) > 0 {
submit(tasks)
}
wg.Wait()
}
}

View File

@ -26,6 +26,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -146,8 +147,8 @@ func (s *CoordSwitchSuite) checkCollections() bool {
TimeStamp: 0, // means now
}
resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req)
s.NoError(err)
s.Equal(len(resp.CollectionIds), numCollections)
s.Require().NoError(merr.CheckRPCCall(resp, err))
s.Require().Equal(len(resp.CollectionIds), numCollections)
notLoaded := 0
loaded := 0
for _, name := range resp.CollectionNames {
@ -181,7 +182,7 @@ func (s *CoordSwitchSuite) search(collectionName string, dim int) {
GuaranteeTimestamp: 0,
}
queryResult, err := c.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
s.Require().NoError(merr.CheckRPCCall(queryResult, err))
s.Equal(len(queryResult.FieldsData), 1)
numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]
s.Equal(numEntities, int64(rowsPerCollection))
@ -198,10 +199,9 @@ func (s *CoordSwitchSuite) search(collectionName string, dim int) {
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(context.TODO(), searchReq)
searchResult, err := c.Proxy.Search(context.TODO(), searchReq)
err = merr.Error(searchResult.GetStatus())
s.NoError(err)
s.NoError(merr.CheckRPCCall(searchResult, err))
}
func (s *CoordSwitchSuite) insertBatchCollections(prefix string, collectionBatchSize, idxStart, dim int, wg *sync.WaitGroup) {
@ -229,7 +229,7 @@ func (s *CoordSwitchSuite) setupData() {
}
wg.Wait()
log.Info("=========================Data injection finished=========================")
s.checkCollections()
s.Require().True(s.checkCollections())
log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName))
s.search(searchName, Dim)
log.Info("=========================Search finished=========================")
@ -238,11 +238,13 @@ func (s *CoordSwitchSuite) setupData() {
func (s *CoordSwitchSuite) switchCoord() float64 {
var err error
c := s.Cluster
start := time.Now()
log.Info("=========================Stopping Coordinators========================")
c.RootCoord.Stop()
c.DataCoord.Stop()
c.QueryCoord.Stop()
log.Info("=========================Coordinators stopped=========================")
start := time.Now()
log.Info("=========================Coordinators stopped=========================", zap.Duration("elapsed", time.Since(start)))
start = time.Now()
c.RootCoord, err = grpcrootcoord.NewServer(context.TODO(), c.GetFactory())
s.NoError(err)