Refactor segment allocate policy on querycoord (#11181)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2021-11-11 12:56:42 +08:00 committed by GitHub
parent feb6e866c8
commit dba0ae4421
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 588 additions and 340 deletions

View File

@ -0,0 +1,76 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querycoord
import (
"context"
"errors"
"sort"
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
func defaultChannelAllocatePolicy() ChannelAllocatePolicy {
return shuffleChannelsToQueryNode
}
// ChannelAllocatePolicy helper function definition to allocate dmChannel to queryNode
type ChannelAllocatePolicy func(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error
func shuffleChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error {
for {
availableNodes, err := cluster.onlineNodes()
if err != nil {
log.Debug(err.Error())
if !wait {
return err
}
time.Sleep(1 * time.Second)
continue
}
for _, id := range excludeNodeIDs {
delete(availableNodes, id)
}
nodeID2NumChannels := make(map[int64]int)
for nodeID := range availableNodes {
numChannels, err := cluster.getNumDmChannels(nodeID)
if err != nil {
delete(availableNodes, nodeID)
continue
}
nodeID2NumChannels[nodeID] = numChannels
}
if len(availableNodes) > 0 {
nodeIDSlice := make([]int64, 0)
for nodeID := range availableNodes {
nodeIDSlice = append(nodeIDSlice, nodeID)
}
for _, req := range reqs {
sort.Slice(nodeIDSlice, func(i, j int) bool {
return nodeID2NumChannels[nodeIDSlice[i]] < nodeID2NumChannels[nodeIDSlice[j]]
})
req.NodeID = nodeIDSlice[0]
nodeID2NumChannels[nodeIDSlice[0]]++
}
return nil
}
if !wait {
return errors.New("no queryNode to allocate")
}
}
}

View File

@ -0,0 +1,84 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querycoord
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestShuffleChannelsToQueryNode(t *testing.T) {
refreshParams()
baseCtx, cancel := context.WithCancel(context.Background())
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
clusterSession := sessionutil.NewSession(context.Background(), Params.MetaRootPath, Params.EtcdEndpoints)
clusterSession.Init(typeutil.QueryCoordRole, Params.Address, true)
meta, err := newMeta(baseCtx, kv, nil, nil)
assert.Nil(t, err)
cluster := &queryNodeCluster{
ctx: baseCtx,
cancel: cancel,
client: kv,
clusterMeta: meta,
nodes: make(map[int64]Node),
newNodeFn: newQueryNodeTest,
session: clusterSession,
}
firstReq := &querypb.WatchDmChannelsRequest{
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
Infos: []*datapb.VchannelInfo{
{
ChannelName: "test1",
},
},
}
secondReq := &querypb.WatchDmChannelsRequest{
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
Infos: []*datapb.VchannelInfo{
{
ChannelName: "test2",
},
},
}
reqs := []*querypb.WatchDmChannelsRequest{firstReq, secondReq}
err = shuffleChannelsToQueryNode(baseCtx, reqs, cluster, false, nil)
assert.NotNil(t, err)
node, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
nodeSession := node.session
nodeID := node.queryNodeID
cluster.registerNode(baseCtx, nodeSession, nodeID, disConnect)
waitQueryNodeOnline(cluster, nodeID)
err = shuffleChannelsToQueryNode(baseCtx, reqs, cluster, false, nil)
assert.Nil(t, err)
assert.Equal(t, nodeID, firstReq.NodeID)
assert.Equal(t, nodeID, secondReq.NodeID)
err = removeAllSession()
assert.Nil(t, err)
}

View File

@ -69,6 +69,9 @@ type Cluster interface {
offlineNodes() (map[int64]Node, error)
hasNode(nodeID int64) bool
allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64) error
allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error
getSessionVersion() int64
getMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) []queryNodeGetMetricsResponse
@ -93,22 +96,26 @@ type queryNodeCluster struct {
sessionVersion int64
sync.RWMutex
clusterMeta Meta
nodes map[int64]Node
newNodeFn newQueryNodeFn
clusterMeta Meta
nodes map[int64]Node
newNodeFn newQueryNodeFn
segmentAllocator SegmentAllocatePolicy
channelAllocator ChannelAllocatePolicy
}
func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn, session *sessionutil.Session) (Cluster, error) {
childCtx, cancel := context.WithCancel(ctx)
nodes := make(map[int64]Node)
c := &queryNodeCluster{
ctx: childCtx,
cancel: cancel,
client: kv,
session: session,
clusterMeta: clusterMeta,
nodes: nodes,
newNodeFn: newNodeFn,
ctx: childCtx,
cancel: cancel,
client: kv,
session: session,
clusterMeta: clusterMeta,
nodes: nodes,
newNodeFn: newNodeFn,
segmentAllocator: defaultSegAllocatePolicy(),
channelAllocator: defaultChannelAllocatePolicy(),
}
err := c.reloadFromKV()
if err != nil {
@ -642,3 +649,11 @@ func (c *queryNodeCluster) getCollectionInfosByID(ctx context.Context, nodeID in
return nil
}
func (c *queryNodeCluster) allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64) error {
return c.segmentAllocator(ctx, reqs, c, wait, excludeNodeIDs)
}
func (c *queryNodeCluster) allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error {
return c.channelAllocator(ctx, reqs, c, wait, excludeNodeIDs)
}

View File

@ -38,11 +38,12 @@ import (
)
const (
defaultCollectionID = UniqueID(2021)
defaultPartitionID = UniqueID(2021)
defaultSegmentID = UniqueID(2021)
defaultQueryNodeID = int64(100)
defaultChannelNum = 2
defaultCollectionID = UniqueID(2021)
defaultPartitionID = UniqueID(2021)
defaultSegmentID = UniqueID(2021)
defaultQueryNodeID = int64(100)
defaultChannelNum = 2
defaultNumRowPerSegment = 10000
)
func genCollectionSchema(collectionID UniqueID, isBinary bool) *schemapb.CollectionSchema {
@ -347,6 +348,7 @@ func (data *dataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetR
segmentBinlog := &datapb.SegmentBinlogs{
SegmentID: segmentID,
FieldBinlogs: fieldBinlogs,
NumOfRows: defaultNumRowPerSegment,
}
data.Segment2Binlog[segmentID] = segmentBinlog
}

View File

@ -0,0 +1,183 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querycoord
import (
"context"
"errors"
"sort"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const MaxMemUsagePerNode = 0.9
func defaultSegAllocatePolicy() SegmentAllocatePolicy {
return shuffleSegmentsToQueryNodeV2
}
// SegmentAllocatePolicy helper function definition to allocate Segment to queryNode
type SegmentAllocatePolicy func(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error
// shuffleSegmentsToQueryNode shuffle segments to online nodes
// returned are noded id for each segment, which satisfies:
// len(returnedNodeIds) == len(segmentIDs) && segmentIDs[i] is assigned to returnedNodeIds[i]
func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error {
if len(reqs) == 0 {
return nil
}
for {
availableNodes, err := cluster.onlineNodes()
if err != nil {
log.Debug(err.Error())
if !wait {
return err
}
time.Sleep(1 * time.Second)
continue
}
for _, id := range excludeNodeIDs {
delete(availableNodes, id)
}
nodeID2NumSegemnt := make(map[int64]int)
for nodeID := range availableNodes {
numSegments, err := cluster.getNumSegments(nodeID)
if err != nil {
delete(availableNodes, nodeID)
continue
}
nodeID2NumSegemnt[nodeID] = numSegments
}
if len(availableNodes) > 0 {
nodeIDSlice := make([]int64, 0)
for nodeID := range availableNodes {
nodeIDSlice = append(nodeIDSlice, nodeID)
}
for _, req := range reqs {
sort.Slice(nodeIDSlice, func(i, j int) bool {
return nodeID2NumSegemnt[nodeIDSlice[i]] < nodeID2NumSegemnt[nodeIDSlice[j]]
})
req.DstNodeID = nodeIDSlice[0]
nodeID2NumSegemnt[nodeIDSlice[0]]++
}
return nil
}
if !wait {
return errors.New("no queryNode to allocate")
}
}
}
func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error {
// key = offset, value = segmentSize
if len(reqs) == 0 {
return nil
}
dataSizePerReq := make([]int64, 0)
for _, req := range reqs {
sizePerRecord, err := typeutil.EstimateSizePerRecord(req.Schema)
if err != nil {
return err
}
sizeOfReq := int64(0)
for _, loadInfo := range req.Infos {
sizeOfReq += int64(sizePerRecord) * loadInfo.NumOfRows
}
dataSizePerReq = append(dataSizePerReq, sizeOfReq)
}
for {
// online nodes map and totalMem, usedMem, memUsage of every node
totalMem := make(map[int64]uint64)
memUsage := make(map[int64]uint64)
memUsageRate := make(map[int64]float64)
availableNodes, err := cluster.onlineNodes()
if err != nil && !wait {
return errors.New("no online queryNode to allocate")
}
for _, id := range excludeNodeIDs {
delete(availableNodes, id)
}
for nodeID := range availableNodes {
// statistic nodeInfo, used memory, memory usage of every query node
nodeInfo, err := cluster.getNodeInfoByID(nodeID)
if err != nil {
log.Debug("shuffleSegmentsToQueryNodeV2: getNodeInfoByID failed", zap.Error(err))
delete(availableNodes, nodeID)
continue
}
queryNodeInfo := nodeInfo.(*queryNode)
// avoid allocate segment to node which memUsageRate is high
if queryNodeInfo.memUsageRate >= MaxMemUsagePerNode {
log.Debug("shuffleSegmentsToQueryNodeV2: queryNode memUsageRate large than MaxMemUsagePerNode", zap.Int64("nodeID", nodeID), zap.Float64("current rate", queryNodeInfo.memUsageRate))
delete(availableNodes, nodeID)
continue
}
// update totalMem, memUsage, memUsageRate
totalMem[nodeID], memUsage[nodeID], memUsageRate[nodeID] = queryNodeInfo.totalMem, queryNodeInfo.memUsage, queryNodeInfo.memUsageRate
}
if len(availableNodes) > 0 {
nodeIDSlice := make([]int64, 0, len(availableNodes))
for nodeID := range availableNodes {
nodeIDSlice = append(nodeIDSlice, nodeID)
}
allocateSegmentsDone := true
for offset, sizeOfReq := range dataSizePerReq {
// sort nodes by memUsageRate, low to high
sort.Slice(nodeIDSlice, func(i, j int) bool {
return memUsageRate[nodeIDSlice[i]] < memUsageRate[nodeIDSlice[j]]
})
findNodeToAllocate := false
// assign load segment request to query node which has least memUsageRate
for _, nodeID := range nodeIDSlice {
memUsageAfterLoad := memUsage[nodeID] + uint64(sizeOfReq)
memUsageRateAfterLoad := float64(memUsageAfterLoad) / float64(totalMem[nodeID])
if memUsageRateAfterLoad > MaxMemUsagePerNode {
continue
}
reqs[offset].DstNodeID = nodeID
memUsage[nodeID] = memUsageAfterLoad
memUsageRate[nodeID] = memUsageRateAfterLoad
findNodeToAllocate = true
break
}
// the load segment request can't be allocated to any query node
if !findNodeToAllocate {
allocateSegmentsDone = false
break
}
}
if allocateSegmentsDone {
return nil
}
}
if wait {
time.Sleep(1 * time.Second)
continue
} else {
return errors.New("no queryNode to allocate")
}
}
}

View File

@ -0,0 +1,110 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querycoord
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestShuffleSegmentsToQueryNode(t *testing.T) {
refreshParams()
baseCtx, cancel := context.WithCancel(context.Background())
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
clusterSession := sessionutil.NewSession(context.Background(), Params.MetaRootPath, Params.EtcdEndpoints)
clusterSession.Init(typeutil.QueryCoordRole, Params.Address, true)
meta, err := newMeta(baseCtx, kv, nil, nil)
assert.Nil(t, err)
cluster := &queryNodeCluster{
ctx: baseCtx,
cancel: cancel,
client: kv,
clusterMeta: meta,
nodes: make(map[int64]Node),
newNodeFn: newQueryNodeTest,
session: clusterSession,
}
schema := genCollectionSchema(defaultCollectionID, false)
firstReq := &querypb.LoadSegmentsRequest{
CollectionID: defaultCollectionID,
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
NumOfRows: defaultNumRowPerSegment,
},
},
}
secondReq := &querypb.LoadSegmentsRequest{
CollectionID: defaultCollectionID,
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID + 1,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
NumOfRows: defaultNumRowPerSegment,
},
},
}
reqs := []*querypb.LoadSegmentsRequest{firstReq, secondReq}
t.Run("Test shuffleSegmentsWithoutQueryNode", func(t *testing.T) {
err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil)
assert.NotNil(t, err)
})
node1, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
node1Session := node1.session
node1ID := node1.queryNodeID
cluster.registerNode(baseCtx, node1Session, node1ID, disConnect)
waitQueryNodeOnline(cluster, node1ID)
t.Run("Test shuffleSegmentsToQueryNode", func(t *testing.T) {
err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil)
assert.Nil(t, err)
assert.Equal(t, node1ID, firstReq.DstNodeID)
assert.Equal(t, node1ID, secondReq.DstNodeID)
})
node2, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
node2Session := node2.session
node2ID := node2.queryNodeID
cluster.registerNode(baseCtx, node2Session, node2ID, disConnect)
waitQueryNodeOnline(cluster, node2ID)
cluster.stopNode(node1ID)
t.Run("Test shuffleSegmentsToQueryNodeV2", func(t *testing.T) {
err = shuffleSegmentsToQueryNodeV2(baseCtx, reqs, cluster, false, nil)
assert.Nil(t, err)
assert.Equal(t, node2ID, firstReq.DstNodeID)
assert.Equal(t, node2ID, secondReq.DstNodeID)
})
err = removeAllSession()
assert.Nil(t, err)
}

View File

@ -16,7 +16,6 @@ import (
"errors"
"fmt"
"sync"
"time"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
@ -384,6 +383,7 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error {
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
Schema: lct.Schema,
LoadCondition: querypb.TriggerCondition_grpcRequest,
CollectionID: collectionID,
}
segmentsToLoad = append(segmentsToLoad, segmentID)
@ -453,12 +453,16 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error {
}
err = assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, false, nil)
internalTasks, err := assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, false, nil)
if err != nil {
log.Warn("loadCollectionTask: assign child task failed", zap.Int64("collectionID", collectionID))
lct.setResultInfo(err)
return err
}
for _, internalTask := range internalTasks {
lct.addChildTask(internalTask)
log.Debug("loadCollectionTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Int64("collectionID", collectionID), zap.Any("task", internalTask))
}
log.Debug("loadCollectionTask: assign child task done", zap.Int64("collectionID", collectionID))
log.Debug("LoadCollection execute done",
@ -735,6 +739,7 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error {
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
Schema: lpt.Schema,
LoadCondition: querypb.TriggerCondition_grpcRequest,
CollectionID: collectionID,
}
segmentsToLoad = append(segmentsToLoad, segmentID)
loadSegmentReqs = append(loadSegmentReqs, loadSegmentReq)
@ -778,12 +783,16 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error {
}
}
err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, watchDeltaReqs, false, nil)
internalTasks, err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, watchDeltaReqs, false, nil)
if err != nil {
log.Warn("loadPartitionTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
lpt.setResultInfo(err)
return err
}
for _, internalTask := range internalTasks {
lpt.addChildTask(internalTask)
log.Debug("loadPartitionTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Int64("collectionID", collectionID), zap.Any("task", internalTask))
}
log.Debug("loadPartitionTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
log.Debug("loadPartitionTask Execute done",
@ -1053,78 +1062,33 @@ func (lst *loadSegmentTask) postExecute(context.Context) error {
}
func (lst *loadSegmentTask) reschedule(ctx context.Context) ([]task, error) {
segmentIDs := make([]UniqueID, 0)
collectionID := lst.Infos[0].CollectionID
reScheduledTask := make([]task, 0)
loadSegmentReqs := make([]*querypb.LoadSegmentsRequest, 0)
collectionID := lst.CollectionID
for _, info := range lst.Infos {
segmentIDs = append(segmentIDs, info.SegmentID)
msgBase := proto.Clone(lst.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_LoadSegments
req := &querypb.LoadSegmentsRequest{
Base: msgBase,
Infos: []*querypb.SegmentLoadInfo{info},
Schema: lst.Schema,
LoadCondition: lst.triggerCondition,
SourceNodeID: lst.SourceNodeID,
CollectionID: lst.CollectionID,
}
loadSegmentReqs = append(loadSegmentReqs, req)
}
if lst.excludeNodeIDs == nil {
lst.excludeNodeIDs = []int64{}
}
lst.excludeNodeIDs = append(lst.excludeNodeIDs, lst.DstNodeID)
segment2Nodes, err := shuffleSegmentsToQueryNode(segmentIDs, lst.cluster, false, lst.excludeNodeIDs)
//TODO:: wait or not according msgType
reScheduledTasks, err := assignInternalTask(ctx, collectionID, lst.getParentTask(), lst.meta, lst.cluster, loadSegmentReqs, nil, nil, false, lst.excludeNodeIDs)
if err != nil {
log.Error("loadSegment reschedule failed", zap.Int64s("excludeNodes", lst.excludeNodeIDs), zap.Error(err))
return nil, err
}
node2segmentInfos := make(map[int64][]*querypb.SegmentLoadInfo)
for index, info := range lst.Infos {
nodeID := segment2Nodes[index]
if _, ok := node2segmentInfos[nodeID]; !ok {
node2segmentInfos[nodeID] = make([]*querypb.SegmentLoadInfo, 0)
}
node2segmentInfos[nodeID] = append(node2segmentInfos[nodeID], info)
}
for nodeID, infos := range node2segmentInfos {
loadSegmentBaseTask := newBaseTask(ctx, lst.getTriggerCondition())
loadSegmentBaseTask.setParentTask(lst.getParentTask())
loadSegmentTask := &loadSegmentTask{
baseTask: loadSegmentBaseTask,
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
Base: lst.Base,
DstNodeID: nodeID,
Infos: infos,
Schema: lst.Schema,
LoadCondition: lst.LoadCondition,
},
meta: lst.meta,
cluster: lst.cluster,
excludeNodeIDs: lst.excludeNodeIDs,
}
reScheduledTask = append(reScheduledTask, loadSegmentTask)
log.Debug("loadSegmentTask: add a loadSegmentTask to RescheduleTasks", zap.Any("task", loadSegmentTask))
hasWatchQueryChannel := lst.cluster.hasWatchedQueryChannel(lst.ctx, nodeID, collectionID)
if !hasWatchQueryChannel {
queryChannelInfo, err := lst.meta.getQueryChannelInfoByID(collectionID)
if err != nil {
return nil, err
}
msgBase := proto.Clone(lst.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchQueryChannels
addQueryChannelRequest := &querypb.AddQueryChannelRequest{
Base: msgBase,
NodeID: nodeID,
CollectionID: collectionID,
RequestChannelID: queryChannelInfo.QueryChannelID,
ResultChannelID: queryChannelInfo.QueryResultChannelID,
GlobalSealedSegments: queryChannelInfo.GlobalSealedSegments,
SeekPosition: queryChannelInfo.SeekPosition,
}
watchQueryChannelBaseTask := newBaseTask(ctx, lst.getTriggerCondition())
watchQueryChannelBaseTask.setParentTask(lst.getParentTask())
watchQueryChannelTask := &watchQueryChannelTask{
baseTask: watchQueryChannelBaseTask,
AddQueryChannelRequest: addQueryChannelRequest,
cluster: lst.cluster,
}
reScheduledTask = append(reScheduledTask, watchQueryChannelTask)
log.Debug("loadSegmentTask: add a watchQueryChannelTask to RescheduleTasks", zap.Any("task", watchQueryChannelTask))
}
}
return reScheduledTask, nil
return reScheduledTasks, nil
}
type releaseSegmentTask struct {
@ -1273,79 +1237,33 @@ func (wdt *watchDmChannelTask) postExecute(context.Context) error {
func (wdt *watchDmChannelTask) reschedule(ctx context.Context) ([]task, error) {
collectionID := wdt.CollectionID
channelIDs := make([]string, 0)
reScheduledTask := make([]task, 0)
watchDmChannelReqs := make([]*querypb.WatchDmChannelsRequest, 0)
for _, info := range wdt.Infos {
channelIDs = append(channelIDs, info.ChannelName)
msgBase := proto.Clone(wdt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchDmChannels
req := &querypb.WatchDmChannelsRequest{
Base: msgBase,
CollectionID: collectionID,
PartitionID: wdt.PartitionID,
Infos: []*datapb.VchannelInfo{info},
Schema: wdt.Schema,
ExcludeInfos: wdt.ExcludeInfos,
}
watchDmChannelReqs = append(watchDmChannelReqs, req)
}
if wdt.excludeNodeIDs == nil {
wdt.excludeNodeIDs = []int64{}
}
wdt.excludeNodeIDs = append(wdt.excludeNodeIDs, wdt.NodeID)
channel2Nodes, err := shuffleChannelsToQueryNode(channelIDs, wdt.cluster, false, wdt.excludeNodeIDs)
//TODO:: wait or not according msgType
reScheduledTasks, err := assignInternalTask(ctx, collectionID, wdt.parentTask, wdt.meta, wdt.cluster, nil, watchDmChannelReqs, nil, false, wdt.excludeNodeIDs)
if err != nil {
log.Error("watchDmChannel reschedule failed", zap.Int64s("excludeNodes", wdt.excludeNodeIDs), zap.Error(err))
return nil, err
}
node2channelInfos := make(map[int64][]*datapb.VchannelInfo)
for index, info := range wdt.Infos {
nodeID := channel2Nodes[index]
if _, ok := node2channelInfos[nodeID]; !ok {
node2channelInfos[nodeID] = make([]*datapb.VchannelInfo, 0)
}
node2channelInfos[nodeID] = append(node2channelInfos[nodeID], info)
}
for nodeID, infos := range node2channelInfos {
watchDmChannelBaseTask := newBaseTask(ctx, wdt.getTriggerCondition())
watchDmChannelBaseTask.setParentTask(wdt.getParentTask())
watchDmChannelTask := &watchDmChannelTask{
baseTask: watchDmChannelBaseTask,
WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{
Base: wdt.Base,
NodeID: nodeID,
CollectionID: wdt.CollectionID,
PartitionID: wdt.PartitionID,
Infos: infos,
Schema: wdt.Schema,
ExcludeInfos: wdt.ExcludeInfos,
},
meta: wdt.meta,
cluster: wdt.cluster,
excludeNodeIDs: wdt.excludeNodeIDs,
}
reScheduledTask = append(reScheduledTask, watchDmChannelTask)
log.Debug("watchDmChannelTask: add a watchDmChannelTask to RescheduleTasks", zap.Any("task", watchDmChannelTask))
hasWatchQueryChannel := wdt.cluster.hasWatchedQueryChannel(wdt.ctx, nodeID, collectionID)
if !hasWatchQueryChannel {
queryChannelInfo, err := wdt.meta.getQueryChannelInfoByID(collectionID)
if err != nil {
return nil, err
}
msgBase := proto.Clone(wdt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchQueryChannels
addQueryChannelRequest := &querypb.AddQueryChannelRequest{
Base: msgBase,
NodeID: nodeID,
CollectionID: collectionID,
RequestChannelID: queryChannelInfo.QueryChannelID,
ResultChannelID: queryChannelInfo.QueryResultChannelID,
GlobalSealedSegments: queryChannelInfo.GlobalSealedSegments,
SeekPosition: queryChannelInfo.SeekPosition,
}
watchQueryChannelBaseTask := newBaseTask(ctx, wdt.getTriggerCondition())
watchQueryChannelBaseTask.setParentTask(wdt.getParentTask())
watchQueryChannelTask := &watchQueryChannelTask{
baseTask: watchQueryChannelBaseTask,
AddQueryChannelRequest: addQueryChannelRequest,
cluster: wdt.cluster,
}
reScheduledTask = append(reScheduledTask, watchQueryChannelTask)
log.Debug("watchDmChannelTask: add a watchQueryChannelTask to RescheduleTasks", zap.Any("task", watchQueryChannelTask))
}
}
return reScheduledTask, nil
return reScheduledTasks, nil
}
type watchDeltaChannelTask struct {
@ -1639,12 +1557,16 @@ func (ht *handoffTask) execute(ctx context.Context) error {
ht.setResultInfo(err)
return err
}
err = assignInternalTask(ctx, collectionID, ht, ht.meta, ht.cluster, []*querypb.LoadSegmentsRequest{loadSegmentReq}, nil, watchDeltaChannelReqs, true, nil)
internalTasks, err := assignInternalTask(ctx, collectionID, ht, ht.meta, ht.cluster, []*querypb.LoadSegmentsRequest{loadSegmentReq}, nil, watchDeltaChannelReqs, true, nil)
if err != nil {
log.Error("handoffTask: assign child task failed", zap.Any("segmentInfo", segmentInfo))
ht.setResultInfo(err)
return err
}
for _, internalTask := range internalTasks {
ht.addChildTask(internalTask)
log.Debug("handoffTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Int64("segmentID", segmentID), zap.Any("task", internalTask))
}
} else {
err = fmt.Errorf("sealed segment has been exist on query node, segmentID is %d", segmentID)
log.Error("handoffTask: sealed segment has been exist on query node", zap.Int64("segmentID", segmentID))
@ -1851,12 +1773,17 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error {
}
}
}
err = assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, true, lbt.SourceNodeIDs)
internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, true, lbt.SourceNodeIDs)
if err != nil {
log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
lbt.setResultInfo(err)
return err
}
for _, internalTask := range internalTasks {
lbt.addChildTask(internalTask)
log.Debug("loadBalanceTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Any("task", internalTask))
}
log.Debug("loadBalanceTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
}
}
@ -1998,12 +1925,16 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error {
}
// TODO:: assignInternalTask with multi collection
err = assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, watchDeltaChannelReqs, false, lbt.SourceNodeIDs)
internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, watchDeltaChannelReqs, false, lbt.SourceNodeIDs)
if err != nil {
log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID))
log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
lbt.setResultInfo(err)
return err
}
for _, internalTask := range internalTasks {
lbt.addChildTask(internalTask)
log.Debug("loadBalanceTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Any("task", internalTask))
}
}
log.Debug("loadBalanceTask: assign child task done", zap.Any("balance request", lbt.LoadBalanceRequest))
}
@ -2038,143 +1969,6 @@ func (lbt *loadBalanceTask) postExecute(context.Context) error {
return nil
}
func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster, wait bool, excludeNodeIDs []int64) ([]int64, error) {
maxNumChannels := 0
nodes := make(map[int64]Node)
var err error
for {
nodes, err = cluster.onlineNodes()
if err != nil {
log.Debug(err.Error())
if !wait {
return nil, err
}
time.Sleep(1 * time.Second)
continue
}
for _, id := range excludeNodeIDs {
delete(nodes, id)
}
if len(nodes) > 0 {
break
}
if !wait {
return nil, errors.New("no queryNode to allocate")
}
}
for nodeID := range nodes {
numChannels, _ := cluster.getNumDmChannels(nodeID)
if numChannels > maxNumChannels {
maxNumChannels = numChannels
}
}
res := make([]int64, 0)
if len(dmChannels) == 0 {
return res, nil
}
offset := 0
loopAll := false
for {
lastOffset := offset
if !loopAll {
for nodeID := range nodes {
numSegments, _ := cluster.getNumSegments(nodeID)
if numSegments >= maxNumChannels {
continue
}
res = append(res, nodeID)
offset++
if offset == len(dmChannels) {
return res, nil
}
}
} else {
for nodeID := range nodes {
res = append(res, nodeID)
offset++
if offset == len(dmChannels) {
return res, nil
}
}
}
if lastOffset == offset {
loopAll = true
}
}
}
// shuffleSegmentsToQueryNode shuffle segments to online nodes
// returned are noded id for each segment, which satisfies:
// len(returnedNodeIds) == len(segmentIDs) && segmentIDs[i] is assigned to returnedNodeIds[i]
func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster, wait bool, excludeNodeIDs []int64) ([]int64, error) {
maxNumSegments := 0
nodes := make(map[int64]Node)
var err error
for {
nodes, err = cluster.onlineNodes()
if err != nil {
log.Debug(err.Error())
if !wait {
return nil, err
}
time.Sleep(1 * time.Second)
continue
}
for _, id := range excludeNodeIDs {
delete(nodes, id)
}
if len(nodes) > 0 {
break
}
if !wait {
return nil, errors.New("no queryNode to allocate")
}
}
for nodeID := range nodes {
numSegments, _ := cluster.getNumSegments(nodeID)
if numSegments > maxNumSegments {
maxNumSegments = numSegments
}
}
res := make([]int64, 0)
if len(segmentIDs) == 0 {
return res, nil
}
offset := 0
loopAll := false
for {
lastOffset := offset
if !loopAll {
for nodeID := range nodes {
numSegments, _ := cluster.getNumSegments(nodeID)
if numSegments >= maxNumSegments {
continue
}
res = append(res, nodeID)
offset++
if offset == len(segmentIDs) {
return res, nil
}
}
} else {
for nodeID := range nodes {
res = append(res, nodeID)
offset++
if offset == len(segmentIDs) {
return res, nil
}
}
}
if lastOffset == offset {
loopAll = true
}
}
}
func mergeVChannelInfo(info1 *datapb.VchannelInfo, info2 *datapb.VchannelInfo) *datapb.VchannelInfo {
collectionID := info1.CollectionID
channelName := info1.ChannelName
@ -2208,53 +2002,45 @@ func mergeVChannelInfo(info1 *datapb.VchannelInfo, info2 *datapb.VchannelInfo) *
}
func assignInternalTask(ctx context.Context,
collectionID UniqueID,
parentTask task,
meta Meta,
cluster Cluster,
collectionID UniqueID, parentTask task, meta Meta, cluster Cluster,
loadSegmentRequests []*querypb.LoadSegmentsRequest,
watchDmChannelRequests []*querypb.WatchDmChannelsRequest,
watchDeltaChannelRequests []*querypb.WatchDeltaChannelsRequest,
wait bool, excludeNodeIDs []int64) error {
wait bool, excludeNodeIDs []int64) ([]task, error) {
sp, _ := trace.StartSpanFromContext(ctx)
defer sp.Finish()
segmentsToLoad := make([]UniqueID, 0)
for _, req := range loadSegmentRequests {
segmentsToLoad = append(segmentsToLoad, req.Infos[0].SegmentID)
}
channelsToWatch := make([]string, 0)
for _, req := range watchDmChannelRequests {
channelsToWatch = append(channelsToWatch, req.Infos[0].ChannelName)
}
segment2Nodes, err := shuffleSegmentsToQueryNode(segmentsToLoad, cluster, wait, excludeNodeIDs)
internalTasks := make([]task, 0)
err := cluster.allocateSegmentsToQueryNode(ctx, loadSegmentRequests, wait, excludeNodeIDs)
if err != nil {
log.Error("assignInternalTask: segment to node failed", zap.Any("segments map", segment2Nodes), zap.Int64("collectionID", collectionID))
return err
log.Error("assignInternalTask: assign segment to node failed", zap.Any("load segments requests", loadSegmentRequests))
return nil, err
}
log.Debug("assignInternalTask: segment to node", zap.Any("segments map", segment2Nodes), zap.Int64("collectionID", collectionID))
watchRequest2Nodes, err := shuffleChannelsToQueryNode(channelsToWatch, cluster, wait, excludeNodeIDs)
log.Debug("assignInternalTask: assign segment to node success", zap.Any("load segments requests", loadSegmentRequests))
err = cluster.allocateChannelsToQueryNode(ctx, watchDmChannelRequests, wait, excludeNodeIDs)
if err != nil {
log.Error("assignInternalTask: watch request to node failed", zap.Any("request map", watchRequest2Nodes), zap.Int64("collectionID", collectionID))
return err
log.Error("assignInternalTask: assign dmChannel to node failed", zap.Any("watch dmChannel requests", watchDmChannelRequests))
return nil, err
}
log.Debug("assignInternalTask: watch request to node", zap.Any("request map", watchRequest2Nodes), zap.Int64("collectionID", collectionID))
log.Debug("assignInternalTask: assign dmChannel to node success", zap.Any("watch dmChannel requests", watchDmChannelRequests))
watchQueryChannelInfo := make(map[int64]bool)
node2Segments := make(map[int64][]*querypb.LoadSegmentsRequest)
sizeCounts := make(map[int64]int)
for index, nodeID := range segment2Nodes {
sizeOfReq := getSizeOfLoadSegmentReq(loadSegmentRequests[index])
for _, req := range loadSegmentRequests {
nodeID := req.DstNodeID
sizeOfReq := getSizeOfLoadSegmentReq(req)
if _, ok := node2Segments[nodeID]; !ok {
node2Segments[nodeID] = make([]*querypb.LoadSegmentsRequest, 0)
node2Segments[nodeID] = append(node2Segments[nodeID], loadSegmentRequests[index])
node2Segments[nodeID] = append(node2Segments[nodeID], req)
sizeCounts[nodeID] = sizeOfReq
} else {
if sizeCounts[nodeID]+sizeOfReq > MaxSendSizeToEtcd {
node2Segments[nodeID] = append(node2Segments[nodeID], loadSegmentRequests[index])
node2Segments[nodeID] = append(node2Segments[nodeID], req)
sizeCounts[nodeID] = sizeOfReq
} else {
lastReq := node2Segments[nodeID][len(node2Segments[nodeID])-1]
lastReq.Infos = append(lastReq.Infos, loadSegmentRequests[index].Infos...)
lastReq.Infos = append(lastReq.Infos, req.Infos...)
sizeCounts[nodeID] += sizeOfReq
}
}
@ -2265,18 +2051,10 @@ func assignInternalTask(ctx context.Context,
}
watchQueryChannelInfo[nodeID] = false
}
for _, nodeID := range watchRequest2Nodes {
if cluster.hasWatchedQueryChannel(parentTask.traceCtx(), nodeID, collectionID) {
watchQueryChannelInfo[nodeID] = true
continue
}
watchQueryChannelInfo[nodeID] = false
}
for nodeID, loadSegmentsReqs := range node2Segments {
for _, req := range loadSegmentsReqs {
ctx = opentracing.ContextWithSpan(context.Background(), sp)
req.DstNodeID = nodeID
baseTask := newBaseTask(ctx, parentTask.getTriggerCondition())
baseTask.setParentTask(parentTask)
loadSegmentTask := &loadSegmentTask{
@ -2284,10 +2062,9 @@ func assignInternalTask(ctx context.Context,
LoadSegmentsRequest: req,
meta: meta,
cluster: cluster,
excludeNodeIDs: []int64{},
excludeNodeIDs: excludeNodeIDs,
}
parentTask.addChildTask(loadSegmentTask)
log.Debug("assignInternalTask: add a loadSegmentTask childTask", zap.Any("task", loadSegmentTask))
internalTasks = append(internalTasks, loadSegmentTask)
}
for _, req := range watchDeltaChannelRequests {
@ -2303,27 +2080,29 @@ func assignInternalTask(ctx context.Context,
cluster: cluster,
excludeNodeIDs: []int64{},
}
parentTask.addChildTask(watchDeltaTask)
log.Debug("assignInternalTask: add a watchDeltaChannelTask childTask", zap.Any("task", watchDeltaTask))
internalTasks = append(internalTasks, watchDeltaTask)
}
}
for index, nodeID := range watchRequest2Nodes {
for _, req := range watchDmChannelRequests {
nodeID := req.NodeID
ctx = opentracing.ContextWithSpan(context.Background(), sp)
watchDmChannelReq := watchDmChannelRequests[index]
watchDmChannelReq.NodeID = nodeID
baseTask := newBaseTask(ctx, parentTask.getTriggerCondition())
baseTask.setParentTask(parentTask)
watchDmChannelTask := &watchDmChannelTask{
baseTask: baseTask,
WatchDmChannelsRequest: watchDmChannelReq,
WatchDmChannelsRequest: req,
meta: meta,
cluster: cluster,
excludeNodeIDs: []int64{},
excludeNodeIDs: excludeNodeIDs,
}
parentTask.addChildTask(watchDmChannelTask)
log.Debug("assignInternalTask: add a watchDmChannelTask childTask", zap.Any("task", watchDmChannelTask))
internalTasks = append(internalTasks, watchDmChannelTask)
if cluster.hasWatchedQueryChannel(parentTask.traceCtx(), nodeID, collectionID) {
watchQueryChannelInfo[nodeID] = true
continue
}
watchQueryChannelInfo[nodeID] = false
}
for nodeID, watched := range watchQueryChannelInfo {
@ -2331,7 +2110,7 @@ func assignInternalTask(ctx context.Context,
ctx = opentracing.ContextWithSpan(context.Background(), sp)
queryChannelInfo, err := meta.getQueryChannelInfoByID(collectionID)
if err != nil {
return err
return nil, err
}
msgBase := proto.Clone(parentTask.msgBase()).(*commonpb.MsgBase)
@ -2353,11 +2132,10 @@ func assignInternalTask(ctx context.Context,
AddQueryChannelRequest: addQueryChannelRequest,
cluster: cluster,
}
parentTask.addChildTask(watchQueryChannelTask)
log.Debug("assignInternalTask: add a watchQueryChannelTask childTask", zap.Any("task", watchQueryChannelTask))
internalTasks = append(internalTasks, watchQueryChannelTask)
}
}
return nil
return internalTasks, nil
}
func getSizeOfLoadSegmentReq(req *querypb.LoadSegmentsRequest) int {

View File

@ -694,10 +694,10 @@ func Test_AssignInternalTask(t *testing.T) {
loadSegmentRequests = append(loadSegmentRequests, req)
}
err = assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, nil, false, nil)
internalTasks, err := assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, nil, false, nil)
assert.Nil(t, err)
assert.NotEqual(t, 1, len(loadCollectionTask.getChildTask()))
assert.NotEqual(t, 1, len(internalTasks))
queryCoord.Stop()
err = removeAllSession()