diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index c30586bca9..361f60fe1a 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -497,6 +497,9 @@ func (c *queryNodeCluster) isOnService(nodeID int64) (bool, error) { } func (c *queryNodeCluster) printMeta() { + c.Lock() + defer c.Unlock() + for id, node := range c.nodes { if node.isOnService() { for collectionID, info := range node.collectionInfos { diff --git a/internal/querycoord/mock_test.go b/internal/querycoord/mock_test.go index 4b39974d46..10b333e611 100644 --- a/internal/querycoord/mock_test.go +++ b/internal/querycoord/mock_test.go @@ -381,6 +381,13 @@ type queryNodeServerMock struct { queryNode *qn.QueryNode grpcErrChan chan error grpcServer *grpc.Server + + addQueryChannels func() (*commonpb.Status, error) + watchDmChannels func() (*commonpb.Status, error) + loadSegment func() (*commonpb.Status, error) + releaseCollection func() (*commonpb.Status, error) + releasePartition func() (*commonpb.Status, error) + releaseSegment func() (*commonpb.Status, error) } func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock { @@ -392,6 +399,13 @@ func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock { cancel: cancel, queryNode: qn.NewQueryNode(ctx1, factory), grpcErrChan: make(chan error), + + addQueryChannels: returnSuccessResult, + watchDmChannels: returnSuccessResult, + loadSegment: returnSuccessResult, + releaseCollection: returnSuccessResult, + releasePartition: returnSuccessResult, + releaseSegment: returnSuccessResult, } } @@ -475,39 +489,27 @@ func (qs *queryNodeServerMock) run() error { } func (qs *queryNodeServerMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return qs.addQueryChannels() } func (qs *queryNodeServerMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return qs.watchDmChannels() } func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return qs.loadSegment() } func (qs *queryNodeServerMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return qs.releaseCollection() } func (qs *queryNodeServerMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return qs.releasePartition() } func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return qs.releaseSegment() } func startQueryNodeServer(ctx context.Context) (*queryNodeServerMock, error) { @@ -519,3 +521,15 @@ func startQueryNodeServer(ctx context.Context) (*queryNodeServerMock, error) { return node, nil } + +func returnSuccessResult() (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil +} + +func returnFailedResult() (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, errors.New("query node do task failed") +} diff --git a/internal/querycoord/query_coord_test.go b/internal/querycoord/query_coord_test.go index ce62c1fa3a..70fc7fa538 100644 --- a/internal/querycoord/query_coord_test.go +++ b/internal/querycoord/query_coord_test.go @@ -17,17 +17,18 @@ import ( "os" "strconv" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/msgstream" ) -var metaRootPath string - func setup() { Params.Init() - metaRootPath = Params.MetaRootPath + rand.Seed(time.Now().UnixNano()) + suffix := "-test-query-Coord" + strconv.FormatInt(rand.Int63(), 10) + Params.MetaRootPath = Params.MetaRootPath + suffix } func refreshChannelNames() { diff --git a/internal/querycoord/querynode.go b/internal/querycoord/querynode.go index 799807995f..9a5b203ee0 100644 --- a/internal/querycoord/querynode.go +++ b/internal/querycoord/querynode.go @@ -40,6 +40,7 @@ type queryNode struct { collectionInfos map[UniqueID]*querypb.CollectionInfo watchedQueryChannels map[UniqueID]*querypb.QueryChannelInfo onService bool + serviceLock sync.Mutex } func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) *queryNode { @@ -73,12 +74,16 @@ func (qn *queryNode) start() error { } qn.client = client + qn.serviceLock.Lock() qn.onService = true + qn.serviceLock.Unlock() log.Debug("queryNode client start success", zap.Int64("nodeID", qn.id), zap.String("address", qn.address)) return nil } func (qn *queryNode) stop() { + qn.serviceLock.Lock() + defer qn.serviceLock.Unlock() qn.onService = false if qn.client != nil { qn.client.Stop() @@ -344,15 +349,15 @@ func (qn *queryNode) clearNodeInfo() error { } func (qn *queryNode) setNodeState(onService bool) { - qn.Lock() - defer qn.Unlock() + qn.serviceLock.Lock() + defer qn.serviceLock.Unlock() qn.onService = onService } func (qn *queryNode) isOnService() bool { - qn.Lock() - defer qn.Unlock() + qn.serviceLock.Lock() + defer qn.serviceLock.Unlock() return qn.onService } diff --git a/internal/querycoord/querynode_test.go b/internal/querycoord/querynode_test.go index 8b2a2e4576..1ca8c74cbc 100644 --- a/internal/querycoord/querynode_test.go +++ b/internal/querycoord/querynode_test.go @@ -13,13 +13,12 @@ package querycoord import ( "context" - "math/rand" - "strconv" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -27,9 +26,7 @@ import ( func startQueryCoord(ctx context.Context) (*QueryCoord, error) { factory := msgstream.NewPmsFactory() - rand.Seed(time.Now().UnixNano()) - suffix := "-test-query-Coord" + strconv.FormatInt(rand.Int63(), 10) - Params.MetaRootPath = metaRootPath + suffix + coord, err := NewQueryCoord(ctx, factory) if err != nil { return nil, err @@ -105,7 +102,25 @@ func TestQueryNode_MultiNode_stop(t *testing.T) { }) assert.Nil(t, err) time.Sleep(2 * time.Second) + nodes, err := queryCoord.cluster.onServiceNodes() + assert.Nil(t, err) queryNode5.stop() + + for { + allOffline := true + for nodeID := range nodes { + _, err = queryCoord.cluster.getNodeByID(nodeID) + if err == nil { + allOffline = false + time.Sleep(time.Second) + break + } + } + if allOffline { + break + } + log.Debug("wait all queryNode offline") + } queryCoord.Stop() } @@ -146,9 +161,26 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) { CollectionID: defaultCollectionID, }) assert.Nil(t, err) + nodes, err := queryCoord.cluster.onServiceNodes() + assert.Nil(t, err) queryNode3.stop() queryNode4.stop() queryNode5.stop() - time.Sleep(2 * time.Second) + + for { + allOffline := true + for nodeID := range nodes { + _, err = queryCoord.cluster.getNodeByID(nodeID) + if err == nil { + allOffline = false + time.Sleep(time.Second) + break + } + } + if allOffline { + break + } + log.Debug("wait all queryNode offline") + } queryCoord.Stop() } diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index 361c54d7e2..0a36272270 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -471,7 +471,7 @@ func (scheduler *TaskScheduler) processTask(t task) error { childTask.SetID(id) kvs := make(map[string]string) taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, childTask.ID()) - kvs[taskKey] = t.Marshal() + kvs[taskKey] = childTask.Marshal() stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, childTask.ID()) kvs[stateKey] = strconv.Itoa(int(taskUndo)) err = scheduler.client.MultiSave(kvs) @@ -610,6 +610,16 @@ func (scheduler *TaskScheduler) waitActivateTaskDone(wg *sync.WaitGroup, t task) scheduler.activateTaskChan <- t wg.Add(1) go scheduler.waitActivateTaskDone(wg, t) + } else { + removes := make([]string, 0) + taskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, t.ID()) + removes = append(removes, taskKey) + stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, t.ID()) + removes = append(removes, stateKey) + err = scheduler.client.MultiRemove(removes) + if err != nil { + log.Error("waitActivateTaskDone: error when remove task from etcd") + } } } diff --git a/internal/querycoord/task_scheduler_test.go b/internal/querycoord/task_scheduler_test.go new file mode 100644 index 0000000000..d8972fef85 --- /dev/null +++ b/internal/querycoord/task_scheduler_test.go @@ -0,0 +1,157 @@ +package querycoord + +import ( + "context" + "testing" + "time" + + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/stretchr/testify/assert" +) + +type testTask struct { + BaseTask + baseMsg *commonpb.MsgBase + cluster *queryNodeCluster + meta *meta + nodeID int64 +} + +func (tt *testTask) MsgBase() *commonpb.MsgBase { + return tt.baseMsg +} + +func (tt *testTask) Marshal() string { + return "" +} + +func (tt *testTask) Type() commonpb.MsgType { + return tt.baseMsg.MsgType +} + +func (tt *testTask) Timestamp() Timestamp { + return tt.baseMsg.Timestamp +} + +func (tt *testTask) PreExecute(ctx context.Context) error { + log.Debug("test task preExecute...") + return nil +} + +func (tt *testTask) Execute(ctx context.Context) error { + log.Debug("test task execute...") + + switch tt.baseMsg.MsgType { + case commonpb.MsgType_LoadSegments: + childTask := &LoadSegmentTask{ + BaseTask: BaseTask{ + ctx: tt.ctx, + Condition: NewTaskCondition(tt.ctx), + triggerCondition: tt.triggerCondition, + }, + LoadSegmentsRequest: &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadSegments, + }, + NodeID: tt.nodeID, + }, + meta: tt.meta, + cluster: tt.cluster, + } + tt.AddChildTask(childTask) + case commonpb.MsgType_WatchDmChannels: + childTask := &WatchDmChannelTask{ + BaseTask: BaseTask{ + ctx: tt.ctx, + Condition: NewTaskCondition(tt.ctx), + triggerCondition: tt.triggerCondition, + }, + WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDmChannels, + }, + NodeID: tt.nodeID, + }, + cluster: tt.cluster, + meta: tt.meta, + } + tt.AddChildTask(childTask) + case commonpb.MsgType_WatchQueryChannels: + childTask := &WatchQueryChannelTask{ + BaseTask: BaseTask{ + ctx: tt.ctx, + Condition: NewTaskCondition(tt.ctx), + triggerCondition: tt.triggerCondition, + }, + AddQueryChannelRequest: &querypb.AddQueryChannelRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchQueryChannels, + }, + NodeID: tt.nodeID, + }, + cluster: tt.cluster, + } + tt.AddChildTask(childTask) + } + + return nil +} + +func (tt *testTask) PostExecute(ctx context.Context) error { + log.Debug("test task postExecute...") + return nil +} + +func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) { + baseCtx := context.Background() + queryCoord, err := startQueryCoord(baseCtx) + assert.Nil(t, err) + activeTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix) + assert.Nil(t, err) + queryNode, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + queryNode.addQueryChannels = returnFailedResult + + time.Sleep(time.Second) + nodes, err := queryCoord.cluster.onServiceNodes() + assert.Nil(t, err) + assert.Equal(t, len(nodes), 1) + var nodeID int64 + for id := range nodes { + nodeID = id + break + } + testTask := &testTask{ + BaseTask: BaseTask{ + ctx: baseCtx, + Condition: NewTaskCondition(baseCtx), + triggerCondition: querypb.TriggerCondition_grpcRequest, + }, + baseMsg: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchQueryChannels, + }, + cluster: queryCoord.cluster, + meta: queryCoord.meta, + nodeID: nodeID, + } + queryCoord.scheduler.Enqueue([]task{testTask}) + + time.Sleep(time.Second) + queryNode.stop() + + for { + _, err = queryCoord.cluster.getNodeByID(nodeID) + if err == nil { + time.Sleep(time.Second) + break + } + } + + time.Sleep(time.Second) + newActiveTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix) + assert.Nil(t, err) + assert.Equal(t, len(newActiveTaskIDKeys), len(activeTaskIDKeys)) + queryCoord.Stop() +}