milvus/internal/querycoordv2/task/executor.go

721 lines
21 KiB
Go
Raw Normal View History

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package task
import (
"context"
"sync"
"time"
"github.com/blang/semver/v4"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
. "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// segmentsVersion is used for the flushed segments should not be included in the watch dm channel request
var segmentsVersion = semver.Version{
Major: 2,
Minor: 3,
Patch: 4,
}
type Executor struct {
doneCh chan struct{}
wg sync.WaitGroup
meta *meta.Meta
dist *meta.DistributionManager
broker meta.Broker
targetMgr meta.TargetManagerInterface
cluster session.Cluster
nodeMgr *session.NodeManager
executingTasks *typeutil.ConcurrentSet[string] // task index
executingTaskNum atomic.Int32
executedFlag chan struct{}
}
func NewExecutor(meta *meta.Meta,
dist *meta.DistributionManager,
broker meta.Broker,
targetMgr meta.TargetManagerInterface,
cluster session.Cluster,
nodeMgr *session.NodeManager,
) *Executor {
return &Executor{
doneCh: make(chan struct{}),
meta: meta,
dist: dist,
broker: broker,
targetMgr: targetMgr,
cluster: cluster,
nodeMgr: nodeMgr,
executingTasks: typeutil.NewConcurrentSet[string](),
executedFlag: make(chan struct{}, 1),
}
}
func (ex *Executor) Start(ctx context.Context) {
}
func (ex *Executor) Stop() {
ex.wg.Wait()
}
// Execute executes the given action,
// does nothing and returns false if the action is already committed,
// returns true otherwise.
func (ex *Executor) Execute(task Task, step int) bool {
exist := !ex.executingTasks.Insert(task.Index())
if exist {
return false
}
if ex.executingTaskNum.Inc() > Params.QueryCoordCfg.TaskExecutionCap.GetAsInt32() {
ex.executingTasks.Remove(task.Index())
ex.executingTaskNum.Dec()
return false
}
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int("step", step),
zap.String("source", task.Source().String()),
)
go func() {
log.Info("execute the action of task")
switch task.Actions()[step].(type) {
case *SegmentAction:
ex.executeSegmentAction(task.(*SegmentTask), step)
case *ChannelAction:
ex.executeDmChannelAction(task.(*ChannelTask), step)
case *LeaderAction:
ex.executeLeaderAction(task.(*LeaderTask), step)
}
}()
return true
}
func (ex *Executor) GetExecutedFlag() <-chan struct{} {
return ex.executedFlag
}
func (ex *Executor) removeTask(task Task, step int) {
if task.Err() != nil {
log.Info("execute action done, remove it",
zap.Int64("taskID", task.ID()),
zap.Int("step", step),
zap.Error(task.Err()))
} else {
select {
case ex.executedFlag <- struct{}{}:
default:
}
}
ex.executingTasks.Remove(task.Index())
ex.executingTaskNum.Dec()
}
func (ex *Executor) executeSegmentAction(task *SegmentTask, step int) {
switch task.Actions()[step].Type() {
case ActionTypeGrow, ActionTypeUpdate:
ex.loadSegment(task, step)
case ActionTypeReduce:
ex.releaseSegment(task, step)
}
}
// loadSegment commits the request to merger,
// not really executes the request
func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
action := task.Actions()[step].(*SegmentAction)
defer action.rpcReturned.Store(true)
ctx := task.Context()
log := log.Ctx(ctx).With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
var err error
defer func() {
if err != nil {
task.Fail(err)
}
ex.removeTask(task, step)
}()
collectionInfo, loadMeta, channel, err := ex.getMetaInfo(ctx, task)
if err != nil {
return err
}
loadInfo, indexInfos, err := ex.getLoadInfo(ctx, task.CollectionID(), action.SegmentID(), channel)
if err != nil {
return err
}
req := packLoadSegmentRequest(
task,
action,
collectionInfo.GetSchema(),
collectionInfo.GetProperties(),
loadMeta,
loadInfo,
indexInfos,
)
// get segment's replica first, then get shard leader by replica
replica := ex.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
if replica == nil {
msg := "node doesn't belong to any replica"
err := merr.WrapErrNodeNotAvailable(action.Node())
log.Warn(msg, zap.Error(err))
return err
}
view := ex.dist.LeaderViewManager.GetLatestShardLeaderByFilter(meta.WithReplica2LeaderView(replica), meta.WithChannelName2LeaderView(action.Shard()))
if view == nil {
msg := "no shard leader for the segment to execute loading"
err = merr.WrapErrChannelNotFound(task.Shard(), "shard delegator not found")
log.Warn(msg, zap.Error(err))
return err
}
log = log.With(zap.Int64("shardLeader", view.ID))
startTs := time.Now()
log.Info("load segments...")
status, err := ex.cluster.LoadSegments(task.Context(), view.ID, req)
err = merr.CheckRPCCall(status, err)
if err != nil {
log.Warn("failed to load segment", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("load segments done", zap.Duration("elapsed", elapsed))
return nil
}
func (ex *Executor) releaseSegment(task *SegmentTask, step int) {
defer ex.removeTask(task, step)
startTs := time.Now()
action := task.Actions()[step].(*SegmentAction)
defer action.rpcReturned.Store(true)
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
ctx := task.Context()
dstNode := action.Node()
req := packReleaseSegmentRequest(task, action)
channel := ex.targetMgr.GetDmChannel(task.CollectionID(), task.Shard(), meta.CurrentTarget)
if channel != nil {
// if channel exists in current target, set cp to ReleaseSegmentRequest, need to use it as growing segment's exclude ts
req.Checkpoint = channel.GetSeekPosition()
}
if action.Scope() == querypb.DataScope_Streaming {
// Any modification to the segment distribution have to set NeedTransfer true,
// to protect the version, which serves search/query
req.NeedTransfer = true
} else {
req.Shard = task.shard
if ex.meta.CollectionManager.Exist(task.CollectionID()) {
// get segment's replica first, then get shard leader by replica
replica := ex.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
if replica == nil {
msg := "node doesn't belong to any replica"
err := merr.WrapErrNodeNotAvailable(action.Node())
log.Warn(msg, zap.Error(err))
return
}
view := ex.dist.LeaderViewManager.GetLatestShardLeaderByFilter(meta.WithReplica2LeaderView(replica), meta.WithChannelName2LeaderView(action.Shard()))
if view == nil {
msg := "no shard leader for the segment to execute releasing"
err := merr.WrapErrChannelNotFound(task.Shard(), "shard delegator not found")
log.Warn(msg, zap.Error(err))
return
}
dstNode = view.ID
log = log.With(zap.Int64("shardLeader", view.ID))
req.NeedTransfer = true
}
}
log.Info("release segment...")
status, err := ex.cluster.ReleaseSegments(ctx, dstNode, req)
err = merr.CheckRPCCall(status, err)
if err != nil {
log.Warn("failed to release segment", zap.Error(err))
return
}
elapsed := time.Since(startTs)
log.Info("release segment done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed))
}
func (ex *Executor) executeDmChannelAction(task *ChannelTask, step int) {
switch task.Actions()[step].Type() {
case ActionTypeGrow:
ex.subscribeChannel(task, step)
case ActionTypeReduce:
ex.unsubscribeChannel(task, step)
}
}
func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
defer ex.removeTask(task, step)
startTs := time.Now()
action := task.Actions()[step].(*ChannelAction)
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.String("channel", task.Channel()),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
var err error
defer func() {
if err != nil {
task.Fail(err)
}
}()
ctx := task.Context()
collectionInfo, err := ex.broker.DescribeCollection(ctx, task.CollectionID())
if err != nil {
log.Warn("failed to get collection info")
return err
}
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID())
if err != nil {
log.Warn("failed to get partitions of collection")
return err
}
indexInfo, err := ex.broker.ListIndexes(ctx, task.CollectionID())
if err != nil {
log.Warn("fail to get index meta of collection")
return err
}
loadMeta := packLoadMeta(
ex.meta.GetLoadType(task.CollectionID()),
task.CollectionID(),
collectionInfo.GetDbName(),
task.ResourceGroup(),
partitions...,
)
dmChannel := ex.targetMgr.GetDmChannel(task.CollectionID(), action.ChannelName(), meta.NextTarget)
if dmChannel == nil {
msg := "channel does not exist in next target, skip it"
log.Warn(msg, zap.String("channelName", action.ChannelName()))
return merr.WrapErrChannelReduplicate(action.ChannelName())
}
req := packSubChannelRequest(
task,
action,
collectionInfo.GetSchema(),
loadMeta,
dmChannel,
indexInfo,
)
err = fillSubChannelRequest(ctx, req, ex.broker, ex.shouldIncludeFlushedSegmentInfo(action.Node()))
if err != nil {
log.Warn("failed to subscribe channel, failed to fill the request with segments",
zap.Error(err))
return err
}
ts := dmChannel.GetSeekPosition().GetTimestamp()
log.Info("subscribe channel...",
zap.Uint64("checkpoint", ts),
zap.Duration("sinceCheckpoint", time.Since(tsoutil.PhysicalTime(ts))),
)
status, err := ex.cluster.WatchDmChannels(ctx, action.Node(), req)
if err != nil {
log.Warn("failed to subscribe channel, it may be a false failure", zap.Error(err))
return err
}
if !merr.Ok(status) {
err = merr.Error(status)
log.Warn("failed to subscribe channel", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("subscribe channel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed))
return nil
}
func (ex *Executor) shouldIncludeFlushedSegmentInfo(nodeID int64) bool {
node := ex.nodeMgr.Get(nodeID)
if node == nil {
return false
}
return node.Version().LT(segmentsVersion)
}
func (ex *Executor) unsubscribeChannel(task *ChannelTask, step int) error {
defer ex.removeTask(task, step)
startTs := time.Now()
action := task.Actions()[step].(*ChannelAction)
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.String("channel", task.Channel()),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
var err error
defer func() {
if err != nil {
task.Fail(err)
}
}()
ctx := task.Context()
req := packUnsubDmChannelRequest(task, action)
log.Info("unsubscribe channel...")
status, err := ex.cluster.UnsubDmChannel(ctx, action.Node(), req)
if err != nil {
log.Warn("failed to unsubscribe channel, it may be a false failure", zap.Error(err))
return err
}
if !merr.Ok(status) {
err = merr.Error(status)
log.Warn("failed to unsubscribe channel", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("unsubscribe channel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed))
return nil
}
func (ex *Executor) executeLeaderAction(task *LeaderTask, step int) {
switch task.Actions()[step].Type() {
case ActionTypeGrow:
ex.setDistribution(task, step)
case ActionTypeReduce:
ex.removeDistribution(task, step)
case ActionTypeUpdate:
ex.updatePartStatsVersions(task, step)
}
}
func (ex *Executor) updatePartStatsVersions(task *LeaderTask, step int) error {
action := task.Actions()[step].(*LeaderAction)
defer action.rpcReturned.Store(true)
ctx := task.Context()
log := log.Ctx(ctx).With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("leader", action.leaderID),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
var err error
defer func() {
if err != nil {
task.Fail(err)
}
ex.removeTask(task, step)
}()
req := &querypb.SyncDistributionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_SyncDistribution),
commonpbutil.WithMsgID(task.ID()),
),
CollectionID: task.collectionID,
Channel: task.Shard(),
ReplicaID: task.ReplicaID(),
Actions: []*querypb.SyncAction{
{
Type: querypb.SyncType_UpdatePartitionStats,
SegmentID: action.SegmentID(),
NodeID: action.Node(),
Version: action.Version(),
PartitionStatsVersions: action.partStatsVersions,
},
},
}
startTs := time.Now()
log.Debug("Update partition stats versions...")
status, err := ex.cluster.SyncDistribution(task.Context(), task.leaderID, req)
err = merr.CheckRPCCall(status, err)
if err != nil {
log.Warn("failed to update partition stats versions", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Debug("update partition stats done", zap.Duration("elapsed", elapsed))
return nil
}
func (ex *Executor) setDistribution(task *LeaderTask, step int) error {
action := task.Actions()[step].(*LeaderAction)
defer action.rpcReturned.Store(true)
ctx := task.Context()
log := log.Ctx(ctx).With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("leader", action.leaderID),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
var err error
defer func() {
if err != nil {
task.Fail(err)
}
ex.removeTask(task, step)
}()
collectionInfo, loadMeta, channel, err := ex.getMetaInfo(ctx, task)
if err != nil {
return err
}
loadInfo, indexInfo, err := ex.getLoadInfo(ctx, task.CollectionID(), action.SegmentID(), channel)
if err != nil {
return err
}
req := &querypb.SyncDistributionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments),
commonpbutil.WithMsgID(task.ID()),
),
CollectionID: task.collectionID,
Channel: task.Shard(),
Schema: collectionInfo.GetSchema(),
LoadMeta: loadMeta,
ReplicaID: task.ReplicaID(),
Actions: []*querypb.SyncAction{
{
Type: querypb.SyncType_Set,
PartitionID: loadInfo.GetPartitionID(),
SegmentID: action.SegmentID(),
NodeID: action.Node(),
Info: loadInfo,
Version: action.Version(),
},
},
IndexInfoList: indexInfo,
}
startTs := time.Now()
log.Info("Sync Distribution...")
status, err := ex.cluster.SyncDistribution(task.Context(), task.leaderID, req)
err = merr.CheckRPCCall(status, err)
if err != nil {
log.Warn("failed to sync distribution", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("sync distribution done", zap.Duration("elapsed", elapsed))
return nil
}
func (ex *Executor) removeDistribution(task *LeaderTask, step int) error {
action := task.Actions()[step].(*LeaderAction)
defer action.rpcReturned.Store(true)
ctx := task.Context()
log := log.Ctx(ctx).With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("segmentID", task.segmentID),
zap.Int64("leader", action.leaderID),
zap.Int64("node", action.Node()),
zap.String("source", task.Source().String()),
)
var err error
defer func() {
if err != nil {
task.Fail(err)
}
ex.removeTask(task, step)
}()
req := &querypb.SyncDistributionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_SyncDistribution),
commonpbutil.WithMsgID(task.ID()),
),
CollectionID: task.collectionID,
Channel: task.Shard(),
ReplicaID: task.ReplicaID(),
Actions: []*querypb.SyncAction{
{
Type: querypb.SyncType_Remove,
SegmentID: action.SegmentID(),
NodeID: action.Node(),
},
},
}
startTs := time.Now()
log.Info("Remove Distribution...")
status, err := ex.cluster.SyncDistribution(task.Context(), task.leaderID, req)
err = merr.CheckRPCCall(status, err)
if err != nil {
log.Warn("failed to remove distribution", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("remove distribution done", zap.Duration("elapsed", elapsed))
return nil
}
func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.DescribeCollectionResponse, *querypb.LoadMetaInfo, *meta.DmChannel, error) {
collectionID := task.CollectionID()
shard := task.Shard()
log := log.Ctx(ctx)
collectionInfo, err := ex.broker.DescribeCollection(ctx, collectionID)
if err != nil {
log.Warn("failed to get collection info", zap.Error(err))
return nil, nil, nil, err
}
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, collectionID)
if err != nil {
log.Warn("failed to get partitions of collection", zap.Error(err))
return nil, nil, nil, err
}
loadMeta := packLoadMeta(
ex.meta.GetLoadType(task.CollectionID()),
task.CollectionID(),
collectionInfo.GetDbName(),
task.ResourceGroup(),
partitions...,
)
// get channel first, in case of target updated after segment info fetched
channel := ex.targetMgr.GetDmChannel(collectionID, shard, meta.NextTargetFirst)
if channel == nil {
return nil, nil, nil, merr.WrapErrChannelNotAvailable(shard)
}
return collectionInfo, loadMeta, channel, nil
}
func (ex *Executor) getLoadInfo(ctx context.Context, collectionID, segmentID int64, channel *meta.DmChannel) (*querypb.SegmentLoadInfo, []*indexpb.IndexInfo, error) {
log := log.Ctx(ctx)
segmentInfos, err := ex.broker.GetSegmentInfo(ctx, segmentID)
if err != nil || len(segmentInfos) == 0 {
log.Warn("failed to get segment info from DataCoord", zap.Error(err))
return nil, nil, err
}
segment := segmentInfos[0]
log = log.With(zap.String("level", segment.GetLevel().String()))
indexes, err := ex.broker.GetIndexInfo(ctx, collectionID, segment.GetID())
if err != nil {
if !errors.Is(err, merr.ErrIndexNotFound) {
log.Warn("failed to get index of segment", zap.Error(err))
return nil, nil, err
}
indexes = nil
}
// Get collection index info
indexInfos, err := ex.broker.ListIndexes(ctx, collectionID)
if err != nil {
log.Warn("fail to get index meta of collection", zap.Error(err))
return nil, nil, err
}
// update the field index params
for _, segmentIndex := range indexes {
index, found := lo.Find(indexInfos, func(indexInfo *indexpb.IndexInfo) bool {
return indexInfo.IndexID == segmentIndex.IndexID
})
if !found {
log.Warn("no collection index info for the given segment index", zap.String("indexName", segmentIndex.GetIndexName()))
}
params := funcutil.KeyValuePair2Map(segmentIndex.GetIndexParams())
for _, kv := range index.GetUserIndexParams() {
if indexparams.IsConfigableIndexParam(kv.GetKey()) {
params[kv.GetKey()] = kv.GetValue()
}
}
segmentIndex.IndexParams = funcutil.Map2KeyValuePair(params)
}
loadInfo := utils.PackSegmentLoadInfo(segment, channel.GetSeekPosition(), indexes)
return loadInfo, indexInfos, nil
}