milvus/internal/datacoord/stats_task_meta_test.go
cai.zhang 7bf40694fd
fix: Record the nodeID before assigning tasks (#36371)
issue: #33744

---------

Signed-off-by: Cai Zhang <cai.zhang@zilliz.com>
2024-09-28 17:21:15 +08:00

290 lines
7.8 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/workerpb"
)
type statsTaskMetaSuite struct {
suite.Suite
collectionID int64
partitionID int64
segmentID int64
}
func (s *statsTaskMetaSuite) SetupSuite() {}
func (s *statsTaskMetaSuite) TearDownSuite() {}
func (s *statsTaskMetaSuite) SetupTest() {
s.collectionID = 100
s.partitionID = 101
s.segmentID = 102
}
func (s *statsTaskMetaSuite) Test_Method() {
s.Run("newStatsTaskMeta", func() {
s.Run("normal case", func() {
catalog := mocks.NewDataCoordCatalog(s.T())
catalog.EXPECT().ListStatsTasks(mock.Anything).Return([]*indexpb.StatsTask{
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
SegmentID: 10000,
InsertChannel: "ch1",
TaskID: 10001,
Version: 1,
NodeID: 0,
State: indexpb.JobState_JobStateFinished,
FailReason: "",
TargetSegmentID: 10002,
SubJobType: indexpb.StatsSubJob_Sort,
CanRecycle: true,
},
}, nil)
m, err := newStatsTaskMeta(context.Background(), catalog)
s.NoError(err)
s.NotNil(m)
})
s.Run("failed case", func() {
catalog := mocks.NewDataCoordCatalog(s.T())
catalog.EXPECT().ListStatsTasks(mock.Anything).Return(nil, fmt.Errorf("mock error"))
m, err := newStatsTaskMeta(context.Background(), catalog)
s.Error(err)
s.Nil(m)
})
})
catalog := mocks.NewDataCoordCatalog(s.T())
catalog.EXPECT().ListStatsTasks(mock.Anything).Return(nil, nil)
m, err := newStatsTaskMeta(context.Background(), catalog)
s.NoError(err)
t := &indexpb.StatsTask{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
SegmentID: s.segmentID,
InsertChannel: "ch1",
TaskID: 1,
Version: 0,
NodeID: 0,
State: indexpb.JobState_JobStateInit,
FailReason: "",
SubJobType: indexpb.StatsSubJob_Sort,
}
s.Run("AddStatsTask", func() {
s.Run("failed case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.Error(m.AddStatsTask(t))
_, ok := m.tasks[1]
s.False(ok)
})
s.Run("normal case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(m.AddStatsTask(t))
_, ok := m.tasks[1]
s.True(ok)
})
s.Run("already exist", func() {
s.Error(m.AddStatsTask(t))
_, ok := m.tasks[1]
s.True(ok)
})
})
s.Run("UpdateVersion", func() {
s.Run("normal case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(m.UpdateVersion(1, 1180))
task, ok := m.tasks[1]
s.True(ok)
s.Equal(int64(1), task.GetVersion())
})
s.Run("task not exist", func() {
_, ok := m.tasks[100]
s.False(ok)
s.Error(m.UpdateVersion(100, 1180))
})
s.Run("failed case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.Error(m.UpdateVersion(1, 1180))
task, ok := m.tasks[1]
s.True(ok)
// still 1
s.Equal(int64(1), task.GetVersion())
})
})
s.Run("UpdateBuildingTask", func() {
s.Run("failed case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.Error(m.UpdateBuildingTask(1))
task, ok := m.tasks[1]
s.True(ok)
s.Equal(indexpb.JobState_JobStateInit, task.GetState())
s.Equal(int64(1180), task.GetNodeID())
})
s.Run("normal case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(m.UpdateBuildingTask(1))
task, ok := m.tasks[1]
s.True(ok)
s.Equal(indexpb.JobState_JobStateInProgress, task.GetState())
s.Equal(int64(1180), task.GetNodeID())
})
s.Run("task not exist", func() {
_, ok := m.tasks[100]
s.False(ok)
s.Error(m.UpdateBuildingTask(100))
})
})
s.Run("FinishTask", func() {
result := &workerpb.StatsResult{
TaskID: 1,
State: indexpb.JobState_JobStateFinished,
FailReason: "",
CollectionID: s.collectionID,
PartitionID: s.partitionID,
SegmentID: s.segmentID,
Channel: "ch1",
InsertLogs: []*datapb.FieldBinlog{
{FieldID: 0, Binlogs: []*datapb.Binlog{{LogID: 1}, {LogID: 5}}},
{FieldID: 1, Binlogs: []*datapb.Binlog{{LogID: 2}, {LogID: 6}}},
{FieldID: 100, Binlogs: []*datapb.Binlog{{LogID: 3}, {LogID: 7}}},
{FieldID: 101, Binlogs: []*datapb.Binlog{{LogID: 4}, {LogID: 8}}},
},
StatsLogs: []*datapb.FieldBinlog{
{FieldID: 100, Binlogs: []*datapb.Binlog{{LogID: 9}}},
},
TextStatsLogs: map[int64]*datapb.TextIndexStats{
100: {
FieldID: 100,
Version: 1,
Files: []string{"file1", "file2", "file3"},
LogSize: 100,
MemorySize: 100,
},
},
NumRows: 2048,
}
s.Run("failed case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.Error(m.FinishTask(1, result))
task, ok := m.tasks[1]
s.True(ok)
s.Equal(indexpb.JobState_JobStateInProgress, task.GetState())
})
s.Run("normal case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(m.FinishTask(1, result))
task, ok := m.tasks[1]
s.True(ok)
s.Equal(indexpb.JobState_JobStateFinished, task.GetState())
})
s.Run("task not exist", func() {
s.Error(m.FinishTask(100, result))
})
})
s.Run("GetStatsTaskState", func() {
s.Run("task not exist", func() {
state := m.GetStatsTaskState(100)
s.Equal(indexpb.JobState_JobStateNone, state)
})
s.Run("normal case", func() {
state := m.GetStatsTaskState(1)
s.Equal(indexpb.JobState_JobStateFinished, state)
})
})
s.Run("GetStatsTaskStateBySegmentID", func() {
s.Run("task not exist", func() {
state := m.GetStatsTaskStateBySegmentID(100, indexpb.StatsSubJob_Sort)
s.Equal(indexpb.JobState_JobStateNone, state)
state = m.GetStatsTaskStateBySegmentID(s.segmentID, indexpb.StatsSubJob_BM25Job)
s.Equal(indexpb.JobState_JobStateNone, state)
})
s.Run("normal case", func() {
state := m.GetStatsTaskStateBySegmentID(s.segmentID, indexpb.StatsSubJob_Sort)
s.Equal(indexpb.JobState_JobStateFinished, state)
})
})
s.Run("DropStatsTask", func() {
s.Run("failed case", func() {
catalog.EXPECT().DropStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.Error(m.DropStatsTask(1))
_, ok := m.tasks[1]
s.True(ok)
})
s.Run("normal case", func() {
catalog.EXPECT().DropStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(m.DropStatsTask(1))
_, ok := m.tasks[1]
s.False(ok)
s.NoError(m.DropStatsTask(1000))
})
})
}
func Test_statsTaskMeta(t *testing.T) {
suite.Run(t, new(statsTaskMetaSuite))
}