fix watchQueryChannelTask's info can't be deleted from etcd (#6568)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2021-07-16 10:21:55 +08:00 committed by GitHub
parent d2cbcb92ec
commit 2a42244ff6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 254 additions and 32 deletions

View File

@ -497,6 +497,9 @@ func (c *queryNodeCluster) isOnService(nodeID int64) (bool, error) {
}
func (c *queryNodeCluster) printMeta() {
c.Lock()
defer c.Unlock()
for id, node := range c.nodes {
if node.isOnService() {
for collectionID, info := range node.collectionInfos {

View File

@ -381,6 +381,13 @@ type queryNodeServerMock struct {
queryNode *qn.QueryNode
grpcErrChan chan error
grpcServer *grpc.Server
addQueryChannels func() (*commonpb.Status, error)
watchDmChannels func() (*commonpb.Status, error)
loadSegment func() (*commonpb.Status, error)
releaseCollection func() (*commonpb.Status, error)
releasePartition func() (*commonpb.Status, error)
releaseSegment func() (*commonpb.Status, error)
}
func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
@ -392,6 +399,13 @@ func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
cancel: cancel,
queryNode: qn.NewQueryNode(ctx1, factory),
grpcErrChan: make(chan error),
addQueryChannels: returnSuccessResult,
watchDmChannels: returnSuccessResult,
loadSegment: returnSuccessResult,
releaseCollection: returnSuccessResult,
releasePartition: returnSuccessResult,
releaseSegment: returnSuccessResult,
}
}
@ -475,39 +489,27 @@ func (qs *queryNodeServerMock) run() error {
}
func (qs *queryNodeServerMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
return qs.addQueryChannels()
}
func (qs *queryNodeServerMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
return qs.watchDmChannels()
}
func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
return qs.loadSegment()
}
func (qs *queryNodeServerMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
return qs.releaseCollection()
}
func (qs *queryNodeServerMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
return qs.releasePartition()
}
func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
return qs.releaseSegment()
}
func startQueryNodeServer(ctx context.Context) (*queryNodeServerMock, error) {
@ -519,3 +521,15 @@ func startQueryNodeServer(ctx context.Context) (*queryNodeServerMock, error) {
return node, nil
}
func returnSuccessResult() (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
func returnFailedResult() (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}, errors.New("query node do task failed")
}

View File

@ -17,17 +17,18 @@ import (
"os"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/msgstream"
)
var metaRootPath string
func setup() {
Params.Init()
metaRootPath = Params.MetaRootPath
rand.Seed(time.Now().UnixNano())
suffix := "-test-query-Coord" + strconv.FormatInt(rand.Int63(), 10)
Params.MetaRootPath = Params.MetaRootPath + suffix
}
func refreshChannelNames() {

View File

@ -40,6 +40,7 @@ type queryNode struct {
collectionInfos map[UniqueID]*querypb.CollectionInfo
watchedQueryChannels map[UniqueID]*querypb.QueryChannelInfo
onService bool
serviceLock sync.Mutex
}
func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) *queryNode {
@ -73,12 +74,16 @@ func (qn *queryNode) start() error {
}
qn.client = client
qn.serviceLock.Lock()
qn.onService = true
qn.serviceLock.Unlock()
log.Debug("queryNode client start success", zap.Int64("nodeID", qn.id), zap.String("address", qn.address))
return nil
}
func (qn *queryNode) stop() {
qn.serviceLock.Lock()
defer qn.serviceLock.Unlock()
qn.onService = false
if qn.client != nil {
qn.client.Stop()
@ -344,15 +349,15 @@ func (qn *queryNode) clearNodeInfo() error {
}
func (qn *queryNode) setNodeState(onService bool) {
qn.Lock()
defer qn.Unlock()
qn.serviceLock.Lock()
defer qn.serviceLock.Unlock()
qn.onService = onService
}
func (qn *queryNode) isOnService() bool {
qn.Lock()
defer qn.Unlock()
qn.serviceLock.Lock()
defer qn.serviceLock.Unlock()
return qn.onService
}

View File

@ -13,13 +13,12 @@ package querycoord
import (
"context"
"math/rand"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
@ -27,9 +26,7 @@ import (
func startQueryCoord(ctx context.Context) (*QueryCoord, error) {
factory := msgstream.NewPmsFactory()
rand.Seed(time.Now().UnixNano())
suffix := "-test-query-Coord" + strconv.FormatInt(rand.Int63(), 10)
Params.MetaRootPath = metaRootPath + suffix
coord, err := NewQueryCoord(ctx, factory)
if err != nil {
return nil, err
@ -105,7 +102,25 @@ func TestQueryNode_MultiNode_stop(t *testing.T) {
})
assert.Nil(t, err)
time.Sleep(2 * time.Second)
nodes, err := queryCoord.cluster.onServiceNodes()
assert.Nil(t, err)
queryNode5.stop()
for {
allOffline := true
for nodeID := range nodes {
_, err = queryCoord.cluster.getNodeByID(nodeID)
if err == nil {
allOffline = false
time.Sleep(time.Second)
break
}
}
if allOffline {
break
}
log.Debug("wait all queryNode offline")
}
queryCoord.Stop()
}
@ -146,9 +161,26 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) {
CollectionID: defaultCollectionID,
})
assert.Nil(t, err)
nodes, err := queryCoord.cluster.onServiceNodes()
assert.Nil(t, err)
queryNode3.stop()
queryNode4.stop()
queryNode5.stop()
time.Sleep(2 * time.Second)
for {
allOffline := true
for nodeID := range nodes {
_, err = queryCoord.cluster.getNodeByID(nodeID)
if err == nil {
allOffline = false
time.Sleep(time.Second)
break
}
}
if allOffline {
break
}
log.Debug("wait all queryNode offline")
}
queryCoord.Stop()
}

View File

@ -471,7 +471,7 @@ func (scheduler *TaskScheduler) processTask(t task) error {
childTask.SetID(id)
kvs := make(map[string]string)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, childTask.ID())
kvs[taskKey] = t.Marshal()
kvs[taskKey] = childTask.Marshal()
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, childTask.ID())
kvs[stateKey] = strconv.Itoa(int(taskUndo))
err = scheduler.client.MultiSave(kvs)
@ -610,6 +610,16 @@ func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task)
scheduler.activateTaskChan <- t
wg.Add(1)
go scheduler.waitActivateTaskDone(wg, t)
} else {
removes := make([]string, 0)
taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID())
removes = append(removes, taskKey)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID())
removes = append(removes, stateKey)
err = scheduler.client.MultiRemove(removes)
if err != nil {
log.Error("waitActivateTaskDone: error when remove task from etcd")
}
}
}

View File

@ -0,0 +1,157 @@
package querycoord
import (
"context"
"testing"
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/stretchr/testify/assert"
)
type testTask struct {
BaseTask
baseMsg *commonpb.MsgBase
cluster *queryNodeCluster
meta *meta
nodeID int64
}
func (tt *testTask) MsgBase() *commonpb.MsgBase {
return tt.baseMsg
}
func (tt *testTask) Marshal() string {
return ""
}
func (tt *testTask) Type() commonpb.MsgType {
return tt.baseMsg.MsgType
}
func (tt *testTask) Timestamp() Timestamp {
return tt.baseMsg.Timestamp
}
func (tt *testTask) PreExecute(ctx context.Context) error {
log.Debug("test task preExecute...")
return nil
}
func (tt *testTask) Execute(ctx context.Context) error {
log.Debug("test task execute...")
switch tt.baseMsg.MsgType {
case commonpb.MsgType_LoadSegments:
childTask := &LoadSegmentTask{
BaseTask: BaseTask{
ctx: tt.ctx,
Condition: NewTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
},
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
},
NodeID: tt.nodeID,
},
meta: tt.meta,
cluster: tt.cluster,
}
tt.AddChildTask(childTask)
case commonpb.MsgType_WatchDmChannels:
childTask := &WatchDmChannelTask{
BaseTask: BaseTask{
ctx: tt.ctx,
Condition: NewTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
},
WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
},
NodeID: tt.nodeID,
},
cluster: tt.cluster,
meta: tt.meta,
}
tt.AddChildTask(childTask)
case commonpb.MsgType_WatchQueryChannels:
childTask := &WatchQueryChannelTask{
BaseTask: BaseTask{
ctx: tt.ctx,
Condition: NewTaskCondition(tt.ctx),
triggerCondition: tt.triggerCondition,
},
AddQueryChannelRequest: &querypb.AddQueryChannelRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
},
NodeID: tt.nodeID,
},
cluster: tt.cluster,
}
tt.AddChildTask(childTask)
}
return nil
}
func (tt *testTask) PostExecute(ctx context.Context) error {
log.Debug("test task postExecute...")
return nil
}
func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) {
baseCtx := context.Background()
queryCoord, err := startQueryCoord(baseCtx)
assert.Nil(t, err)
activeTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix)
assert.Nil(t, err)
queryNode, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
queryNode.addQueryChannels = returnFailedResult
time.Sleep(time.Second)
nodes, err := queryCoord.cluster.onServiceNodes()
assert.Nil(t, err)
assert.Equal(t, len(nodes), 1)
var nodeID int64
for id := range nodes {
nodeID = id
break
}
testTask := &testTask{
BaseTask: BaseTask{
ctx: baseCtx,
Condition: NewTaskCondition(baseCtx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
baseMsg: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
},
cluster: queryCoord.cluster,
meta: queryCoord.meta,
nodeID: nodeID,
}
queryCoord.scheduler.Enqueue([]task{testTask})
time.Sleep(time.Second)
queryNode.stop()
for {
_, err = queryCoord.cluster.getNodeByID(nodeID)
if err == nil {
time.Sleep(time.Second)
break
}
}
time.Sleep(time.Second)
newActiveTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix)
assert.Nil(t, err)
assert.Equal(t, len(newActiveTaskIDKeys), len(activeTaskIDKeys))
queryCoord.Stop()
}