diff --git a/internal/datanode/channel_checkpoint_updater.go b/internal/datanode/channel_checkpoint_updater.go index 16a3dc71fa..93c4cd5cb4 100644 --- a/internal/datanode/channel_checkpoint_updater.go +++ b/internal/datanode/channel_checkpoint_updater.go @@ -25,6 +25,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -41,7 +42,7 @@ type channelCPUpdateTask struct { } type channelCheckpointUpdater struct { - dn *DataNode + broker broker.Broker mu sync.RWMutex tasks map[string]*channelCPUpdateTask @@ -51,9 +52,9 @@ type channelCheckpointUpdater struct { closeOnce sync.Once } -func newChannelCheckpointUpdater(dn *DataNode) *channelCheckpointUpdater { +func newChannelCheckpointUpdater(broker broker.Broker) *channelCheckpointUpdater { return &channelCheckpointUpdater{ - dn: dn, + broker: broker, tasks: make(map[string]*channelCPUpdateTask), closeCh: make(chan struct{}), notifyChan: make(chan struct{}, 1), @@ -124,7 +125,7 @@ func (ccu *channelCheckpointUpdater) updateCheckpoints(tasks []*channelCPUpdateT channelCPs := lo.Map(tasks, func(t *channelCPUpdateTask, _ int) *msgpb.MsgPosition { return t.pos }) - err := ccu.dn.broker.UpdateChannelCheckpoint(ctx, channelCPs) + err := ccu.broker.UpdateChannelCheckpoint(ctx, channelCPs) if err != nil { log.Warn("update channel checkpoint failed", zap.Error(err)) return diff --git a/internal/datanode/channel_checkpoint_updater_test.go b/internal/datanode/channel_checkpoint_updater_test.go index 5eedecee7f..9fc52b3cb9 100644 --- a/internal/datanode/channel_checkpoint_updater_test.go +++ b/internal/datanode/channel_checkpoint_updater_test.go @@ -35,23 +35,23 @@ import ( type ChannelCPUpdaterSuite struct { suite.Suite + broker *broker.MockBroker updater *channelCheckpointUpdater } func (s *ChannelCPUpdaterSuite) SetupTest() { - s.updater = newChannelCheckpointUpdater(&DataNode{}) + s.broker = broker.NewMockBroker(s.T()) + s.updater = newChannelCheckpointUpdater(s.broker) } func (s *ChannelCPUpdaterSuite) TestUpdate() { paramtable.Get().Save(paramtable.Get().DataNodeCfg.ChannelCheckpointUpdateTickInSeconds.Key, "0.01") defer paramtable.Get().Save(paramtable.Get().DataNodeCfg.ChannelCheckpointUpdateTickInSeconds.Key, "10") - b := broker.NewMockBroker(s.T()) - b.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, positions []*msgpb.MsgPosition) error { + s.broker.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, positions []*msgpb.MsgPosition) error { time.Sleep(10 * time.Millisecond) return nil }) - s.updater.dn.broker = b go s.updater.start() defer s.updater.close() @@ -75,10 +75,10 @@ func (s *ChannelCPUpdaterSuite) TestUpdate() { } } }() + wg.Wait() s.Eventually(func() bool { return counter.Load() == int64(tasksNum) }, time.Second*10, time.Millisecond*100) - wg.Wait() } func TestChannelCPUpdater(t *testing.T) { diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 820fd21e84..628ab9198e 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -289,7 +289,7 @@ func (node *DataNode) Init() error { node.importTaskMgr = importv2.NewTaskManager() node.importScheduler = importv2.NewScheduler(node.importTaskMgr, node.syncMgr, node.chunkManager) - node.channelCheckpointUpdater = newChannelCheckpointUpdater(node) + node.channelCheckpointUpdater = newChannelCheckpointUpdater(node.broker) node.flowgraphManager = newFlowgraphManager() if paramtable.Get().DataCoordCfg.EnableBalanceChannelWithRPC.GetAsBool() { diff --git a/internal/datanode/data_sync_service_test.go b/internal/datanode/data_sync_service_test.go index 55c9f58fac..c73e3257dd 100644 --- a/internal/datanode/data_sync_service_test.go +++ b/internal/datanode/data_sync_service_test.go @@ -397,7 +397,7 @@ func (s *DataSyncServiceSuite) SetupTest() { }, } s.node.ctx = context.Background() - s.node.channelCheckpointUpdater = newChannelCheckpointUpdater(s.node) + s.node.channelCheckpointUpdater = newChannelCheckpointUpdater(s.node.broker) paramtable.Get().Save(paramtable.Get().DataNodeCfg.ChannelCheckpointUpdateTickInSeconds.Key, "0.01") defer paramtable.Get().Save(paramtable.Get().DataNodeCfg.ChannelCheckpointUpdateTickInSeconds.Key, "10") go s.node.channelCheckpointUpdater.start()