milvus/internal/querynode/shard_cluster.go
congqixia 3a6db2faeb
Fix handling segment change logic (#16695)
Dispatch segmentChangeInfo to ShardCluster leader
Hold segment remove before search is done

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
2022-04-27 22:23:46 +08:00

740 lines
22 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynode
import (
"context"
"fmt"
"sync"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/errorutil"
"go.uber.org/atomic"
"go.uber.org/zap"
)
type shardClusterState int32
const (
available shardClusterState = 1
unavailable shardClusterState = 2
)
type nodeEventType int32
const (
nodeAdd nodeEventType = 1
nodeDel nodeEventType = 2
)
type segmentEventType int32
const (
segmentAdd segmentEventType = 1
segmentDel segmentEventType = 2
)
type segmentState int32
const (
segmentStateNone segmentState = 0
segmentStateOffline segmentState = 1
segmentStateLoading segmentState = 2
segmentStateLoaded segmentState = 3
)
type nodeEvent struct {
eventType nodeEventType
nodeID int64
nodeAddr string
}
type segmentEvent struct {
eventType segmentEventType
segmentID int64
partitionID int64
nodeID int64
state segmentState
}
type shardQueryNode interface {
Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)
Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
Stop() error
}
type shardNode struct {
nodeID int64
nodeAddr string
client shardQueryNode
}
type shardSegmentInfo struct {
segmentID int64
partitionID int64
nodeID int64
state segmentState
inUse int32
}
// ShardNodeDetector provides method to detect node events
type ShardNodeDetector interface {
watchNodes(collectionID int64, replicaID int64, vchannelName string) ([]nodeEvent, <-chan nodeEvent)
}
// ShardSegmentDetector provides method to detect segment events
type ShardSegmentDetector interface {
watchSegments(collectionID int64, replicaID int64, vchannelName string) ([]segmentEvent, <-chan segmentEvent)
}
// ShardNodeBuilder function type to build types.QueryNode from addr and id
type ShardNodeBuilder func(nodeID int64, addr string) shardQueryNode
// ShardCluster maintains the ShardCluster information and perform shard level operations
type ShardCluster struct {
state *atomic.Int32
collectionID int64
replicaID int64
vchannelName string
nodeDetector ShardNodeDetector
segmentDetector ShardSegmentDetector
nodeBuilder ShardNodeBuilder
mut sync.RWMutex
nodes map[int64]*shardNode // online nodes
segments map[int64]*shardSegmentInfo // shard segments
handoffs map[int32]*querypb.SegmentChangeInfo // current pending handoff
lastToken *atomic.Int32 // last token used for segment change info
segmentCond *sync.Cond // segment state change condition
rcCond *sync.Cond // segment rc change condition
closeOnce sync.Once
closeCh chan struct{}
}
// NewShardCluster create a ShardCluster with provided information.
func NewShardCluster(collectionID int64, replicaID int64, vchannelName string,
nodeDetector ShardNodeDetector, segmentDetector ShardSegmentDetector, nodeBuilder ShardNodeBuilder) *ShardCluster {
sc := &ShardCluster{
state: atomic.NewInt32(int32(unavailable)),
collectionID: collectionID,
replicaID: replicaID,
vchannelName: vchannelName,
nodeDetector: nodeDetector,
segmentDetector: segmentDetector,
nodeBuilder: nodeBuilder,
nodes: make(map[int64]*shardNode),
segments: make(map[int64]*shardSegmentInfo),
handoffs: make(map[int32]*querypb.SegmentChangeInfo),
lastToken: atomic.NewInt32(0),
closeCh: make(chan struct{}),
}
m := sync.Mutex{}
sc.segmentCond = sync.NewCond(&m)
m2 := sync.Mutex{}
sc.rcCond = sync.NewCond(&m2)
sc.init()
return sc
}
func (sc *ShardCluster) Close() {
sc.closeOnce.Do(func() {
sc.state.Store(int32(unavailable))
close(sc.closeCh)
})
}
// addNode add a node into cluster
func (sc *ShardCluster) addNode(evt nodeEvent) {
log.Debug("ShardCluster add node", zap.Int64("nodeID", evt.nodeID))
sc.mut.Lock()
defer sc.mut.Unlock()
oldNode, ok := sc.nodes[evt.nodeID]
if ok {
if oldNode.nodeAddr == evt.nodeAddr {
log.Warn("ShardCluster add same node, skip", zap.Int64("nodeID", evt.nodeID), zap.String("addr", evt.nodeAddr))
return
}
defer oldNode.client.Stop()
}
sc.nodes[evt.nodeID] = &shardNode{
nodeID: evt.nodeID,
nodeAddr: evt.nodeAddr,
client: sc.nodeBuilder(evt.nodeID, evt.nodeAddr),
}
}
// removeNode handles node offline and setup related segments
func (sc *ShardCluster) removeNode(evt nodeEvent) {
sc.mut.Lock()
defer sc.mut.Unlock()
old, ok := sc.nodes[evt.nodeID]
if !ok {
log.Warn("ShardCluster removeNode does not belong to it", zap.Int64("nodeID", evt.nodeID), zap.String("addr", evt.nodeAddr))
return
}
defer old.client.Stop()
delete(sc.nodes, evt.nodeID)
for _, segment := range sc.segments {
if segment.nodeID == evt.nodeID {
segment.state = segmentStateOffline
sc.state.Store(int32(unavailable))
}
}
}
// updateSegment apply segment change to shard cluster
func (sc *ShardCluster) updateSegment(evt segmentEvent) {
log.Debug("ShardCluster update segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
// notify handoff wait online if any
defer func() {
sc.segmentCond.L.Lock()
sc.segmentCond.Broadcast()
sc.segmentCond.L.Unlock()
}()
sc.mut.Lock()
defer sc.mut.Unlock()
old, ok := sc.segments[evt.segmentID]
if !ok { // newly add
sc.segments[evt.segmentID] = &shardSegmentInfo{
nodeID: evt.nodeID,
partitionID: evt.partitionID,
segmentID: evt.segmentID,
state: evt.state,
}
return
}
sc.transferSegment(old, evt)
}
// SyncSegments synchronize segment distribution in batch
func (sc *ShardCluster) SyncSegments(distribution []*querypb.ReplicaSegmentsInfo, state segmentState) {
// notify handoff wait online if any
defer func() {
sc.segmentCond.L.Lock()
sc.segmentCond.Broadcast()
sc.segmentCond.L.Unlock()
}()
sc.mut.Lock()
defer sc.mut.Unlock()
for _, line := range distribution {
for _, segmentID := range line.GetSegmentIds() {
old, ok := sc.segments[segmentID]
if !ok { // newly add
sc.segments[segmentID] = &shardSegmentInfo{
nodeID: line.GetNodeId(),
partitionID: line.GetPartitionId(),
segmentID: segmentID,
state: state,
}
continue
}
sc.transferSegment(old, segmentEvent{
eventType: segmentAdd,
nodeID: line.GetNodeId(),
partitionID: line.GetPartitionId(),
segmentID: segmentID,
state: state,
})
}
}
}
// transferSegment apply segment state transition.
func (sc *ShardCluster) transferSegment(old *shardSegmentInfo, evt segmentEvent) {
switch old.state {
case segmentStateOffline: // safe to update nodeID and state
old.nodeID = evt.nodeID
old.state = evt.state
if evt.state == segmentStateLoaded {
sc.healthCheck()
}
case segmentStateLoading: // to Loaded only when nodeID equal
if evt.state == segmentStateLoaded && evt.nodeID != old.nodeID {
log.Warn("transferSegment to loaded failed, nodeID not match", zap.Int64("segmentID", evt.segmentID), zap.Int64("nodeID", old.nodeID), zap.Int64("evtNodeID", evt.nodeID))
return
}
old.nodeID = evt.nodeID
old.state = evt.state
if evt.state == segmentStateLoaded {
sc.healthCheck()
}
case segmentStateLoaded:
old.nodeID = evt.nodeID
old.state = evt.state
if evt.state != segmentStateLoaded {
sc.healthCheck()
}
}
}
// removeSegment removes segment from cluster
// should only applied in hand-off or load balance procedure
func (sc *ShardCluster) removeSegment(evt segmentEvent) {
log.Debug("ShardCluster remove segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
sc.mut.Lock()
defer sc.mut.Unlock()
old, ok := sc.segments[evt.segmentID]
if !ok {
log.Warn("ShardCluster removeSegment does not belong to it", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID))
return
}
if old.nodeID != evt.nodeID {
log.Warn("ShardCluster removeSegment found node not match", zap.Int64("segmentID", evt.segmentID), zap.Int64("nodeID", old.nodeID), zap.Int64("evtNodeID", evt.nodeID))
return
}
delete(sc.segments, evt.segmentID)
}
// init list all nodes and semgent states ant start watching
func (sc *ShardCluster) init() {
// list nodes
nodes, nodeEvtCh := sc.nodeDetector.watchNodes(sc.collectionID, sc.replicaID, sc.vchannelName)
for _, node := range nodes {
sc.addNode(node)
}
go sc.watchNodes(nodeEvtCh)
// list segments
segments, segmentEvtCh := sc.segmentDetector.watchSegments(sc.collectionID, sc.replicaID, sc.vchannelName)
for _, segment := range segments {
sc.updateSegment(segment)
}
go sc.watchSegments(segmentEvtCh)
sc.healthCheck()
}
// healthCheck iterate all segments to to check cluster could provide service.
func (sc *ShardCluster) healthCheck() {
for _, segment := range sc.segments {
if segment.state != segmentStateLoaded { // TODO check hand-off or load balance
sc.state.Store(int32(unavailable))
return
}
}
sc.state.Store(int32(available))
}
// watchNodes handles node events.
func (sc *ShardCluster) watchNodes(evtCh <-chan nodeEvent) {
for {
select {
case evt, ok := <-evtCh:
if !ok {
log.Warn("ShardCluster node channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
return
}
switch evt.eventType {
case nodeAdd:
sc.addNode(evt)
case nodeDel:
sc.removeNode(evt)
}
case <-sc.closeCh:
log.Info("ShardCluster watchNode quit", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.String("vchannelName", sc.vchannelName))
return
}
}
}
// watchSegments handles segment events.
func (sc *ShardCluster) watchSegments(evtCh <-chan segmentEvent) {
for {
select {
case evt, ok := <-evtCh:
if !ok {
log.Warn("ShardCluster segment channel closed", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID))
return
}
switch evt.eventType {
case segmentAdd:
sc.updateSegment(evt)
case segmentDel:
sc.removeSegment(evt)
}
case <-sc.closeCh:
log.Info("ShardCluster watchSegments quit", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.String("vchannelName", sc.vchannelName))
return
}
}
}
// getNode returns shallow copy of shardNode
func (sc *ShardCluster) getNode(nodeID int64) (*shardNode, bool) {
sc.mut.RLock()
defer sc.mut.RUnlock()
node, ok := sc.nodes[nodeID]
if !ok {
return nil, false
}
return &shardNode{
nodeID: node.nodeID,
nodeAddr: node.nodeAddr,
client: node.client, // shallow copy
}, true
}
// getSegment returns copy of shardSegmentInfo
func (sc *ShardCluster) getSegment(segmentID int64) (*shardSegmentInfo, bool) {
sc.mut.RLock()
defer sc.mut.RUnlock()
segment, ok := sc.segments[segmentID]
if !ok {
return nil, false
}
return &shardSegmentInfo{
segmentID: segment.segmentID,
nodeID: segment.nodeID,
partitionID: segment.partitionID,
state: segment.state,
}, true
}
// segmentAllocations returns node to segments mappings.
// calling this function also increases the reference count of related segments.
func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) map[int64][]int64 {
result := make(map[int64][]int64) // nodeID => segmentIDs
sc.mut.Lock()
defer sc.mut.Unlock()
for _, segment := range sc.segments {
if len(partitionIDs) > 0 && !inList(partitionIDs, segment.partitionID) {
continue
}
if sc.inHandoffOffline(segment.segmentID) {
log.Debug("segment ignore in pending offline list", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.Int64("segmentID", segment.segmentID))
continue
}
// reference count ++
segment.inUse++
result[segment.nodeID] = append(result[segment.nodeID], segment.segmentID)
}
return result
}
// inHandoffOffline checks whether segment is pending handoff offline list
// Note that sc.mut Lock is assumed to be hold outside of this function!
func (sc *ShardCluster) inHandoffOffline(segmentID int64) bool {
for _, handoff := range sc.handoffs {
for _, offlineSegment := range handoff.OfflineSegments {
if segmentID == offlineSegment.GetSegmentID() {
return true
}
}
}
return false
}
// finishUsage decreases the inUse count of provided segments
func (sc *ShardCluster) finishUsage(allocs map[int64][]int64) {
defer func() {
sc.rcCond.L.Lock()
sc.rcCond.Broadcast()
sc.rcCond.L.Unlock()
}()
sc.mut.Lock()
defer sc.mut.Unlock()
for _, segments := range allocs {
for _, segmentID := range segments {
segment, ok := sc.segments[segmentID]
if !ok || segment == nil {
// this shall not happen, since removing segment without decreasing rc to zero is illegal
log.Error("finishUsage with non-existing segment", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.String("vchannel", sc.vchannelName), zap.Int64("segmentID", segmentID))
continue
}
// decrease the reference count
segment.inUse--
}
}
}
// HandoffSegments processes the handoff/load balance segments update procedure.
func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
// wait for all OnlineSegment is loaded
onlineSegments := make([]int64, 0, len(info.OnlineSegments))
for _, seg := range info.OnlineSegments {
// filter out segments not maintained in this cluster
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
continue
}
onlineSegments = append(onlineSegments, seg.GetSegmentID())
}
sc.waitSegmentsOnline(onlineSegments)
// add segmentChangeInfo to pending list
token := sc.appendHandoff(info)
// wait for all OfflineSegments is not in use
offlineSegments := make([]int64, 0, len(info.OfflineSegments))
for _, seg := range info.OfflineSegments {
offlineSegments = append(offlineSegments, seg.GetSegmentID())
}
sc.waitSegmentsNotInUse(offlineSegments)
removes := make(map[int64][]int64) // nodeID => []segmentIDs
// remove offline segments record
for _, seg := range info.OfflineSegments {
// filter out segments not maintained in this cluster
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
continue
}
sc.removeSegment(segmentEvent{segmentID: seg.GetSegmentID(), nodeID: seg.GetNodeID()})
removes[seg.GetNodeID()] = append(removes[seg.GetNodeID()], seg.SegmentID)
}
var errs errorutil.ErrorList
// notify querynode(s) to release segments
for nodeID, segmentIDs := range removes {
node, ok := sc.getNode(nodeID)
if !ok {
log.Warn("node not in cluster", zap.Int64("nodeID", nodeID), zap.Int64("collectionID", sc.collectionID), zap.String("vchannel", sc.vchannelName))
errs = append(errs, fmt.Errorf("node not in cluster nodeID %d", nodeID))
continue
}
state, err := node.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
CollectionID: sc.collectionID,
SegmentIDs: segmentIDs,
})
if err != nil {
errs = append(errs, err)
continue
}
if state.GetErrorCode() != commonpb.ErrorCode_Success {
errs = append(errs, fmt.Errorf("Release segments failed with reason: %s", state.GetReason()))
}
}
// finish handoff and remove it from pending list
sc.finishHandoff(token)
if len(errs) > 0 {
return errs
}
return nil
}
// appendHandoff adds the change info into pending list and returns the token.
func (sc *ShardCluster) appendHandoff(info *querypb.SegmentChangeInfo) int32 {
sc.mut.Lock()
defer sc.mut.Unlock()
token := sc.lastToken.Add(1)
sc.handoffs[token] = info
return token
}
// finishHandoff removes the handoff related to the token.
func (sc *ShardCluster) finishHandoff(token int32) {
sc.mut.Lock()
defer sc.mut.Unlock()
delete(sc.handoffs, token)
}
// waitSegmentsOnline waits until all provided segments is loaded.
func (sc *ShardCluster) waitSegmentsOnline(segments []int64) {
sc.segmentCond.L.Lock()
for !sc.segmentsOnline(segments) {
sc.segmentCond.Wait()
}
sc.segmentCond.L.Unlock()
}
// waitSegmentsNotInUse waits until all provided segments is not in use.
func (sc *ShardCluster) waitSegmentsNotInUse(segments []int64) {
sc.rcCond.L.Lock()
for sc.segmentsInUse(segments) {
sc.rcCond.Wait()
}
sc.rcCond.L.Unlock()
}
// checkOnline checks whether all segment ids provided in online state.
func (sc *ShardCluster) segmentsOnline(segmentIDs []int64) bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
for _, segID := range segmentIDs {
segment, ok := sc.segments[segID]
if !ok || segment.state != segmentStateLoaded {
return false
}
}
return true
}
// segmentsInUse checks whether all segment ids provided still in use.
func (sc *ShardCluster) segmentsInUse(segmentIDs []int64) bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
for _, segID := range segmentIDs {
segment, ok := sc.segments[segID]
if !ok {
// ignore missing segments, since they might be in streaming
continue
}
if segment.inUse > 0 {
return true
}
}
return false
}
// Search preforms search operation on shard cluster.
func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
if sc.state.Load() != int32(available) {
return nil, fmt.Errorf("ShardCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
}
if sc.vchannelName != req.GetDmlChannel() {
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
}
// get node allocation and maintains the inUse reference count
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
defer sc.finishUsage(segAllocs)
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)))
for nodeID, segmentIDs := range segAllocs {
log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs))
}
// concurrent visiting nodes
var wg sync.WaitGroup
reqCtx, cancel := context.WithCancel(ctx)
defer cancel()
var err error
var resultMut sync.Mutex
results := make([]*internalpb.SearchResults, 0, len(segAllocs)+1) // count(nodes) + 1(growing)
for nodeID, segments := range segAllocs {
nodeReq := proto.Clone(req).(*querypb.SearchRequest)
nodeReq.SegmentIDs = segments
node, ok := sc.getNode(nodeID)
if !ok { // meta dismatch, report error
return nil, fmt.Errorf("ShardCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
}
wg.Add(1)
go func() {
defer wg.Done()
partialResult, nodeErr := node.client.Search(reqCtx, nodeReq)
resultMut.Lock()
defer resultMut.Unlock()
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
cancel()
err = fmt.Errorf("Search %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
return
}
results = append(results, partialResult)
}()
}
wg.Wait()
if err != nil {
return nil, err
}
return results, nil
}
// Query performs query operation on shard cluster.
func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
if sc.state.Load() != int32(available) {
return nil, fmt.Errorf("ShardCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
}
// handles only the dml channel part, segment ids is dispatch by cluster itself
if sc.vchannelName != req.GetDmlChannel() {
return nil, fmt.Errorf("ShardCluster for %s does not match to request channel :%s", sc.vchannelName, req.GetDmlChannel())
}
// get node allocation and maintains the inUse reference count
segAllocs := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
defer sc.finishUsage(segAllocs)
// concurrent visiting nodes
var wg sync.WaitGroup
reqCtx, cancel := context.WithCancel(ctx)
defer cancel()
var err error
var resultMut sync.Mutex
results := make([]*internalpb.RetrieveResults, 0, len(segAllocs)+1) // count(nodes) + 1(growing)
for nodeID, segments := range segAllocs {
nodeReq := proto.Clone(req).(*querypb.QueryRequest)
nodeReq.SegmentIDs = segments
node, ok := sc.getNode(nodeID)
if !ok { // meta dismatch, report error
return nil, fmt.Errorf("SharcCluster for %s replicaID %d is no available", sc.vchannelName, sc.replicaID)
}
wg.Add(1)
go func() {
defer wg.Done()
partialResult, nodeErr := node.client.Query(reqCtx, nodeReq)
resultMut.Lock()
defer resultMut.Unlock()
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
cancel()
err = fmt.Errorf("Query %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
return
}
results = append(results, partialResult)
}()
}
wg.Wait()
if err != nil {
return nil, err
}
return results, nil
}