diff --git a/internal/datacoord/channel_checker.go b/internal/datacoord/channel_checker.go index f1f03a27cb..6ef7dd4176 100644 --- a/internal/datacoord/channel_checker.go +++ b/internal/datacoord/channel_checker.go @@ -20,7 +20,6 @@ import ( "fmt" "path" "strconv" - "sync" "time" "github.com/golang/protobuf/proto" @@ -31,15 +30,17 @@ import ( "github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type channelStateTimer struct { watchkv kv.WatchKV - runningTimers sync.Map - runningTimerStops sync.Map // channel name to timer stop channels - etcdWatcher clientv3.WatchChan - timeoutWatcher chan *ackEvent + runningTimers *typeutil.ConcurrentMap[string, *time.Timer] + runningTimerStops *typeutil.ConcurrentMap[string, chan struct{}] // channel name to timer stop channels + + etcdWatcher clientv3.WatchChan + timeoutWatcher chan *ackEvent //Modifies afterwards must guarantee that runningTimerCount is updated synchronized with runningTimers //in order to keep consistency runningTimerCount atomic.Int32 @@ -47,8 +48,10 @@ type channelStateTimer struct { func newChannelStateTimer(kv kv.WatchKV) *channelStateTimer { return &channelStateTimer{ - watchkv: kv, - timeoutWatcher: make(chan *ackEvent, 20), + watchkv: kv, + timeoutWatcher: make(chan *ackEvent, 20), + runningTimers: typeutil.NewConcurrentMap[string, *time.Timer](), + runningTimerStops: typeutil.NewConcurrentMap[string, chan struct{}](), } } @@ -103,8 +106,8 @@ func (c *channelStateTimer) startOne(watchState datapb.ChannelWatchState, channe stop := make(chan struct{}) ticker := time.NewTimer(timeout) c.removeTimers([]string{channelName}) - c.runningTimerStops.Store(channelName, stop) - c.runningTimers.Store(channelName, ticker) + c.runningTimerStops.Insert(channelName, stop) + c.runningTimers.Insert(channelName, ticker) c.runningTimerCount.Inc() go func() { log.Info("timer started", @@ -145,9 +148,9 @@ func (c *channelStateTimer) notifyTimeoutWatcher(e *ackEvent) { func (c *channelStateTimer) removeTimers(channels []string) { for _, channel := range channels { - if stop, ok := c.runningTimerStops.LoadAndDelete(channel); ok { - close(stop.(chan struct{})) - c.runningTimers.Delete(channel) + if stop, ok := c.runningTimerStops.GetAndRemove(channel); ok { + close(stop) + c.runningTimers.GetAndRemove(channel) c.runningTimerCount.Dec() log.Info("remove timer for channel", zap.String("channel", channel), zap.Int32("timerCount", c.runningTimerCount.Load())) @@ -156,10 +159,10 @@ func (c *channelStateTimer) removeTimers(channels []string) { } func (c *channelStateTimer) stopIfExist(e *ackEvent) { - stop, ok := c.runningTimerStops.LoadAndDelete(e.channelName) + stop, ok := c.runningTimerStops.GetAndRemove(e.channelName) if ok && e.ackType != watchTimeoutAck && e.ackType != releaseTimeoutAck { - close(stop.(chan struct{})) - c.runningTimers.Delete(e.channelName) + close(stop) + c.runningTimers.GetAndRemove(e.channelName) c.runningTimerCount.Dec() log.Info("stop timer for channel", zap.String("channel", e.channelName), zap.Int32("timerCount", c.runningTimerCount.Load())) @@ -167,8 +170,7 @@ func (c *channelStateTimer) stopIfExist(e *ackEvent) { } func (c *channelStateTimer) resetIfExist(channel string, interval time.Duration) { - if value, ok := c.runningTimers.Load(channel); ok { - timer := value.(*time.Timer) + if timer, ok := c.runningTimers.Get(channel); ok { timer.Reset(interval) } } diff --git a/internal/datacoord/channel_checker_test.go b/internal/datacoord/channel_checker_test.go index 1fdc05bfbf..a15cd4bd5f 100644 --- a/internal/datacoord/channel_checker_test.go +++ b/internal/datacoord/channel_checker_test.go @@ -132,18 +132,18 @@ func TestChannelStateTimer(t *testing.T) { timer := newChannelStateTimer(kv) timer.startOne(datapb.ChannelWatchState_ToRelease, "channel-1", 1, 20*time.Second) - stop, ok := timer.runningTimerStops.Load("channel-1") + stop, ok := timer.runningTimerStops.Get("channel-1") require.True(t, ok) timer.startOne(datapb.ChannelWatchState_ToWatch, "channel-1", 1, 20*time.Second) - _, ok = <-stop.(chan struct{}) + _, ok = <-stop assert.False(t, ok) - stop2, ok := timer.runningTimerStops.Load("channel-1") + stop2, ok := timer.runningTimerStops.Get("channel-1") assert.True(t, ok) timer.removeTimers([]string{"channel-1"}) - _, ok = <-stop2.(chan struct{}) + _, ok = <-stop2 assert.False(t, ok) }) } diff --git a/internal/datacoord/channel_manager_test.go b/internal/datacoord/channel_manager_test.go index be5a536837..799308b3b2 100644 --- a/internal/datacoord/channel_manager_test.go +++ b/internal/datacoord/channel_manager_test.go @@ -127,7 +127,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { waitAndCheckState(t, watchkv, datapb.ChannelWatchState_WatchSuccess, nodeID, cName, collectionID) assert.Eventually(t, func() bool { - _, loaded := chManager.stateTimer.runningTimerStops.Load(cName) + loaded := chManager.stateTimer.runningTimerStops.Contain(cName) return !loaded }, waitFor, tick) @@ -157,7 +157,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToRelease, nodeID, cName, collectionID) assert.Eventually(t, func() bool { - _, loaded := chManager.stateTimer.runningTimerStops.Load(cName) + loaded := chManager.stateTimer.runningTimerStops.Contain(cName) return loaded }, waitFor, tick) @@ -193,7 +193,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToRelease, nodeID, cName, collectionID) assert.Eventually(t, func() bool { - _, loaded := chManager.stateTimer.runningTimerStops.Load(cName) + loaded := chManager.stateTimer.runningTimerStops.Contain(cName) return loaded }, waitFor, tick) @@ -242,7 +242,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { assert.Error(t, err) assert.Empty(t, w) - _, loaded := chManager.stateTimer.runningTimerStops.Load(cName) + loaded := chManager.stateTimer.runningTimerStops.Contain(cName) assert.True(t, loaded) chManager.stateTimer.removeTimers([]string{cName}) }) @@ -279,7 +279,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, nodeID, cName, collectionID) assert.Eventually(t, func() bool { - _, loaded := chManager.stateTimer.runningTimerStops.Load(cName) + loaded := chManager.stateTimer.runningTimerStops.Contain(cName) return loaded }, waitFor, tick) cancel() @@ -331,7 +331,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { assert.Error(t, err) assert.Empty(t, w) - _, loaded := chManager.stateTimer.runningTimerStops.Load(cName) + loaded := chManager.stateTimer.runningTimerStops.Contain(cName) assert.True(t, loaded) chManager.stateTimer.removeTimers([]string{cName}) }) @@ -370,7 +370,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, nodeID, cName, collectionID) assert.Eventually(t, func() bool { - _, loaded := chManager.stateTimer.runningTimerStops.Load(cName) + loaded := chManager.stateTimer.runningTimerStops.Contain(cName) return loaded }, waitFor, tick) @@ -811,7 +811,7 @@ func TestChannelManager_Reload(t *testing.T) { require.NoError(t, err) chManager.checkOldNodes([]UniqueID{nodeID}) - _, ok := chManager.stateTimer.runningTimerStops.Load(channelName) + ok := chManager.stateTimer.runningTimerStops.Contain(channelName) assert.True(t, ok) chManager.stateTimer.removeTimers([]string{channelName}) }) @@ -827,7 +827,7 @@ func TestChannelManager_Reload(t *testing.T) { err = chManager.checkOldNodes([]UniqueID{nodeID}) assert.NoError(t, err) - _, ok := chManager.stateTimer.runningTimerStops.Load(channelName) + ok := chManager.stateTimer.runningTimerStops.Contain(channelName) assert.True(t, ok) chManager.stateTimer.removeTimers([]string{channelName}) }) diff --git a/internal/datacoord/session_manager.go b/internal/datacoord/session_manager.go index 39e12b086c..78d023f383 100644 --- a/internal/datacoord/session_manager.go +++ b/internal/datacoord/session_manager.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/typeutil" "go.uber.org/zap" ) @@ -251,7 +252,7 @@ func (c *SessionManager) GetCompactionState() map[int64]*datapb.CompactionStateR wg := sync.WaitGroup{} ctx := context.Background() - plans := sync.Map{} + plans := typeutil.NewConcurrentMap[int64, *datapb.CompactionStateResult]() c.sessions.RLock() for nodeID, s := range c.sessions.data { wg.Add(1) @@ -280,7 +281,7 @@ func (c *SessionManager) GetCompactionState() map[int64]*datapb.CompactionStateR return } for _, rst := range resp.GetResults() { - plans.Store(rst.PlanID, rst) + plans.Insert(rst.PlanID, rst) } }(nodeID, s) } @@ -288,8 +289,8 @@ func (c *SessionManager) GetCompactionState() map[int64]*datapb.CompactionStateR wg.Wait() rst := make(map[int64]*datapb.CompactionStateResult) - plans.Range(func(key, value any) bool { - rst[key.(int64)] = value.(*datapb.CompactionStateResult) + plans.Range(func(planID int64, result *datapb.CompactionStateResult) bool { + rst[planID] = result return true }) diff --git a/internal/datanode/cache.go b/internal/datanode/cache.go index 6649c49802..fde4b7e0be 100644 --- a/internal/datanode/cache.go +++ b/internal/datanode/cache.go @@ -16,9 +16,7 @@ package datanode -import ( - "sync" -) +import "github.com/milvus-io/milvus/pkg/util/typeutil" // Cache stores flushing segments' ids to prevent flushing the same segment again and again. // @@ -28,37 +26,33 @@ import ( // After the flush procedure, whether the segment successfully flushed or not, // it'll be removed from the cache. So if flush failed, the secondary flush can be triggered. type Cache struct { - cacheMap sync.Map + *typeutil.ConcurrentSet[UniqueID] } // newCache returns a new Cache func newCache() *Cache { return &Cache{ - cacheMap: sync.Map{}, + ConcurrentSet: typeutil.NewConcurrentSet[UniqueID](), } } // checkIfCached returns whether unique id is in cache func (c *Cache) checkIfCached(key UniqueID) bool { - _, ok := c.cacheMap.Load(key) - return ok + return c.Contain(key) } // Cache caches a specific ID into the cache func (c *Cache) Cache(ID UniqueID) { - c.cacheMap.Store(ID, struct{}{}) + c.Insert(ID) } // checkOrCache returns true if `key` is present. // Otherwise, it returns false and stores `key` into cache. func (c *Cache) checkOrCache(key UniqueID) bool { - _, exist := c.cacheMap.LoadOrStore(key, struct{}{}) - return exist + return !c.Insert(key) } // Remove removes a set of IDs from the cache func (c *Cache) Remove(IDs ...UniqueID) { - for _, id := range IDs { - c.cacheMap.Delete(id) - } + c.ConcurrentSet.Remove(IDs...) } diff --git a/internal/datanode/compaction_executor.go b/internal/datanode/compaction_executor.go index 1d854c33d8..272edef586 100644 --- a/internal/datanode/compaction_executor.go +++ b/internal/datanode/compaction_executor.go @@ -18,12 +18,12 @@ package datanode import ( "context" - "sync" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( @@ -31,17 +31,20 @@ const ( ) type compactionExecutor struct { - executing sync.Map // planID to compactor - completedCompactor sync.Map // planID to compactor - completed sync.Map // planID to CompactionResult + executing *typeutil.ConcurrentMap[int64, compactor] // planID to compactor + completedCompactor *typeutil.ConcurrentMap[int64, compactor] // planID to compactor + completed *typeutil.ConcurrentMap[int64, *datapb.CompactionResult] // planID to CompactionResult taskCh chan compactor - dropped sync.Map // vchannel dropped + dropped *typeutil.ConcurrentSet[string] // vchannel dropped } func newCompactionExecutor() *compactionExecutor { return &compactionExecutor{ - executing: sync.Map{}, - taskCh: make(chan compactor, maxTaskNum), + executing: typeutil.NewConcurrentMap[int64, compactor](), + completedCompactor: typeutil.NewConcurrentMap[int64, compactor](), + completed: typeutil.NewConcurrentMap[int64, *datapb.CompactionResult](), + taskCh: make(chan compactor, maxTaskNum), + dropped: typeutil.NewConcurrentSet[string](), } } @@ -51,19 +54,19 @@ func (c *compactionExecutor) execute(task compactor) { } func (c *compactionExecutor) toExecutingState(task compactor) { - c.executing.Store(task.getPlanID(), task) + c.executing.Insert(task.getPlanID(), task) } func (c *compactionExecutor) toCompleteState(task compactor) { task.complete() - c.executing.Delete(task.getPlanID()) + c.executing.GetAndRemove(task.getPlanID()) } func (c *compactionExecutor) injectDone(planID UniqueID, success bool) { - c.completed.Delete(planID) - task, loaded := c.completedCompactor.LoadAndDelete(planID) + c.completed.GetAndRemove(planID) + task, loaded := c.completedCompactor.GetAndRemove(planID) if loaded { - task.(compactor).injectDone(success) + task.injectDone(success) } } @@ -97,42 +100,41 @@ func (c *compactionExecutor) executeTask(task compactor) { zap.Error(err), ) } else { - c.completed.Store(task.getPlanID(), result) - c.completedCompactor.Store(task.getPlanID(), task) + c.completed.Insert(task.getPlanID(), result) + c.completedCompactor.Insert(task.getPlanID(), task) } log.Info("end to execute compaction", zap.Int64("planID", task.getPlanID())) } func (c *compactionExecutor) stopTask(planID UniqueID) { - task, loaded := c.executing.LoadAndDelete(planID) + task, loaded := c.executing.GetAndRemove(planID) if loaded { - log.Warn("compaction executor stop task", zap.Int64("planID", planID), zap.String("vChannelName", task.(compactor).getChannelName())) - task.(compactor).stop() + log.Warn("compaction executor stop task", zap.Int64("planID", planID), zap.String("vChannelName", task.getChannelName())) + task.stop() } } func (c *compactionExecutor) channelValidateForCompaction(vChannelName string) bool { // if vchannel marked dropped, compaction should not proceed - _, loaded := c.dropped.Load(vChannelName) - return !loaded + return !c.dropped.Contain(vChannelName) } func (c *compactionExecutor) stopExecutingtaskByVChannelName(vChannelName string) { - c.dropped.Store(vChannelName, struct{}{}) - c.executing.Range(func(key interface{}, value interface{}) bool { - if value.(compactor).getChannelName() == vChannelName { - c.stopTask(key.(UniqueID)) + c.dropped.Insert(vChannelName) + c.executing.Range(func(planID int64, task compactor) bool { + if task.getChannelName() == vChannelName { + c.stopTask(planID) } return true }) // remove all completed plans for vChannelName - c.completed.Range(func(key interface{}, value interface{}) bool { - if value.(*datapb.CompactionResult).GetChannel() == vChannelName { - c.injectDone(key.(UniqueID), true) + c.completed.Range(func(planID int64, result *datapb.CompactionResult) bool { + if result.GetChannel() == vChannelName { + c.injectDone(planID, true) log.Info("remove compaction results for dropped channel", zap.String("channel", vChannelName), - zap.Int64("planID", key.(UniqueID))) + zap.Int64("planID", planID)) } return true }) diff --git a/internal/datanode/compaction_executor_test.go b/internal/datanode/compaction_executor_test.go index a19ad9a767..b3b92a7702 100644 --- a/internal/datanode/compaction_executor_test.go +++ b/internal/datanode/compaction_executor_test.go @@ -103,7 +103,7 @@ func TestCompactionExecutor(t *testing.T) { // wait for task enqueued found := false for !found { - _, found = ex.executing.Load(mc.getPlanID()) + found = ex.executing.Contain(mc.getPlanID()) } ex.stopExecutingtaskByVChannelName("mock") diff --git a/internal/datanode/flow_graph_insert_buffer_node.go b/internal/datanode/flow_graph_insert_buffer_node.go index 1e5579b782..b5190dbf9f 100644 --- a/internal/datanode/flow_graph_insert_buffer_node.go +++ b/internal/datanode/flow_graph_insert_buffer_node.go @@ -21,7 +21,6 @@ import ( "fmt" "math" "reflect" - "sync" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" @@ -54,7 +53,6 @@ type insertBufferNode struct { channel Channel idAllocator allocator.Allocator - flushMap sync.Map flushChan <-chan flushMsg resendTTChan <-chan resendTTMsg flushingSegCache *Cache @@ -700,7 +698,6 @@ func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *De ctx: ctx, BaseNode: baseNode, - flushMap: sync.Map{}, flushChan: flushCh, resendTTChan: resendTTCh, flushingSegCache: flushingSegCache, @@ -767,7 +764,6 @@ func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *De BaseNode: baseNode, timeTickStream: wTtMsgStream, - flushMap: sync.Map{}, flushChan: flushCh, resendTTChan: resendTTCh, flushingSegCache: flushingSegCache, diff --git a/internal/datanode/flush_manager.go b/internal/datanode/flush_manager.go index 404cb2365f..f45a18347b 100644 --- a/internal/datanode/flush_manager.go +++ b/internal/datanode/flush_manager.go @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // flushManager defines a flush manager signature @@ -97,7 +98,7 @@ type orderFlushQueue struct { injectCh chan *taskInjection // MsgID => flushTask - working sync.Map + working *typeutil.ConcurrentMap[string, *flushTaskRunner] notifyFunc notifyMetaFunc tailMut sync.Mutex @@ -114,6 +115,7 @@ func newOrderFlushQueue(segID UniqueID, f notifyMetaFunc) *orderFlushQueue { segmentID: segID, notifyFunc: f, injectCh: make(chan *taskInjection, 100), + working: typeutil.NewConcurrentMap[string, *flushTaskRunner](), } return q } @@ -128,8 +130,7 @@ func (q *orderFlushQueue) init() { } func (q *orderFlushQueue) getFlushTaskRunner(pos *msgpb.MsgPosition) *flushTaskRunner { - actual, loaded := q.working.LoadOrStore(getSyncTaskID(pos), newFlushTaskRunner(q.segmentID, q.injectCh)) - t := actual.(*flushTaskRunner) + t, loaded := q.working.GetOrInsert(getSyncTaskID(pos), newFlushTaskRunner(q.segmentID, q.injectCh)) // not loaded means the task runner is new, do initializtion if !loaded { // take over injection if task queue is handling it @@ -152,7 +153,7 @@ func (q *orderFlushQueue) getFlushTaskRunner(pos *msgpb.MsgPosition) *flushTaskR // postTask handles clean up work after a task is done func (q *orderFlushQueue) postTask(pack *segmentFlushPack, postInjection postInjectionFunc) { // delete task from working map - q.working.Delete(getSyncTaskID(pack.pos)) + q.working.GetAndRemove(getSyncTaskID(pack.pos)) // after descreasing working count, check whether flush queue is empty q.injectMut.Lock() q.runningTasks-- @@ -271,7 +272,7 @@ type rendezvousFlushManager struct { Channel // segment id => flush queue - dispatcher sync.Map + dispatcher *typeutil.ConcurrentMap[int64, *orderFlushQueue] notifyFunc notifyMetaFunc dropping atomic.Bool @@ -281,9 +282,7 @@ type rendezvousFlushManager struct { // getFlushQueue gets or creates an orderFlushQueue for segment id if not found func (m *rendezvousFlushManager) getFlushQueue(segmentID UniqueID) *orderFlushQueue { newQueue := newOrderFlushQueue(segmentID, m.notifyFunc) - actual, _ := m.dispatcher.LoadOrStore(segmentID, newQueue) - // all operation on dispatcher is private, assertion ok guaranteed - queue := actual.(*orderFlushQueue) + queue, _ := m.dispatcher.GetOrInsert(segmentID, newQueue) queue.init() return queue } @@ -321,7 +320,7 @@ func (m *rendezvousFlushManager) handleDeleteTask(segmentID UniqueID, task flush if m.dropping.Load() { // preventing separate delete, check position exists in queue first q := m.getFlushQueue(segmentID) - _, ok := q.working.Load(getSyncTaskID(pos)) + _, ok := q.working.Get(getSyncTaskID(pos)) // if ok, means position insert data already in queue, just handle task in normal mode // if not ok, means the insert buf should be handle in drop mode if !ok { @@ -422,12 +421,8 @@ func (m *rendezvousFlushManager) serializePkStatsLog(segmentID int64, flushed bo // isFull return true if the task pool is full func (m *rendezvousFlushManager) isFull() bool { var num int - m.dispatcher.Range(func(_, q any) bool { - queue := q.(*orderFlushQueue) - queue.working.Range(func(_, _ any) bool { - num++ - return true - }) + m.dispatcher.Range(func(_ int64, queue *orderFlushQueue) bool { + num += queue.working.Len() return true }) return num >= Params.DataNodeCfg.MaxParallelSyncTaskNum.GetAsInt() @@ -605,8 +600,7 @@ func (m *rendezvousFlushManager) getSegmentMeta(segmentID UniqueID, pos *msgpb.M // waitForAllTaskQueue waits for all flush queues in dispatcher become empty func (m *rendezvousFlushManager) waitForAllFlushQueue() { var wg sync.WaitGroup - m.dispatcher.Range(func(k, v interface{}) bool { - queue := v.(*orderFlushQueue) + m.dispatcher.Range(func(segmentID int64, queue *orderFlushQueue) bool { wg.Add(1) go func() { <-queue.tailCh @@ -652,9 +646,8 @@ func getSyncTaskID(pos *msgpb.MsgPosition) string { // close cleans up all the left members func (m *rendezvousFlushManager) close() { - m.dispatcher.Range(func(k, v interface{}) bool { + m.dispatcher.Range(func(segmentID int64, queue *orderFlushQueue) bool { //assertion ok - queue := v.(*orderFlushQueue) queue.injectMut.Lock() for i := 0; i < len(queue.injectCh); i++ { go queue.handleInject(<-queue.injectCh) @@ -721,6 +714,7 @@ func NewRendezvousFlushManager(allocator allocator.Allocator, cm storage.ChunkMa dropHandler: dropHandler{ flushAndDrop: drop, }, + dispatcher: typeutil.NewConcurrentMap[int64, *orderFlushQueue](), } // start with normal mode fm.dropping.Store(false) diff --git a/internal/datanode/services.go b/internal/datanode/services.go index 82369952b4..b1a0e7d6ee 100644 --- a/internal/datanode/services.go +++ b/internal/datanode/services.go @@ -309,18 +309,18 @@ func (node *DataNode) GetCompactionState(ctx context.Context, req *datapb.Compac }, nil } results := make([]*datapb.CompactionStateResult, 0) - node.compactionExecutor.executing.Range(func(k, v any) bool { + node.compactionExecutor.executing.Range(func(planID int64, task compactor) bool { results = append(results, &datapb.CompactionStateResult{ State: commonpb.CompactionState_Executing, - PlanID: k.(UniqueID), + PlanID: planID, }) return true }) - node.compactionExecutor.completed.Range(func(k, v any) bool { + node.compactionExecutor.completed.Range(func(planID int64, result *datapb.CompactionResult) bool { results = append(results, &datapb.CompactionStateResult{ State: commonpb.CompactionState_Completed, - PlanID: k.(UniqueID), - Result: v.(*datapb.CompactionResult), + PlanID: planID, + Result: result, }) return true }) diff --git a/internal/datanode/services_test.go b/internal/datanode/services_test.go index 8fc61be8df..6834a6212e 100644 --- a/internal/datanode/services_test.go +++ b/internal/datanode/services_test.go @@ -146,9 +146,9 @@ func (s *DataNodeServicesSuite) TestGetComponentStates() { func (s *DataNodeServicesSuite) TestGetCompactionState() { s.Run("success", func() { - s.node.compactionExecutor.executing.Store(int64(3), 0) - s.node.compactionExecutor.executing.Store(int64(2), 0) - s.node.compactionExecutor.completed.Store(int64(1), &datapb.CompactionResult{ + s.node.compactionExecutor.executing.Insert(int64(3), newMockCompactor(true)) + s.node.compactionExecutor.executing.Insert(int64(2), newMockCompactor(true)) + s.node.compactionExecutor.completed.Insert(int64(1), &datapb.CompactionResult{ PlanID: 1, SegmentID: 10, }) @@ -169,16 +169,7 @@ func (s *DataNodeServicesSuite) TestGetCompactionState() { s.Assert().Equal(1, cnt) mu.Unlock() - mu.Lock() - cnt = 0 - mu.Unlock() - s.node.compactionExecutor.completed.Range(func(k, v any) bool { - mu.Lock() - cnt++ - mu.Unlock() - return true - }) - s.Assert().Equal(1, cnt) + s.Assert().Equal(1, s.node.compactionExecutor.completed.Len()) }) s.Run("unhealthy", func() { diff --git a/internal/indexnode/chunk_mgr_factory.go b/internal/indexnode/chunk_mgr_factory.go index 8d6bad3f0c..17ae59385a 100644 --- a/internal/indexnode/chunk_mgr_factory.go +++ b/internal/indexnode/chunk_mgr_factory.go @@ -3,25 +3,31 @@ package indexnode import ( "context" "fmt" - "sync" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type StorageFactory interface { NewChunkManager(ctx context.Context, config *indexpb.StorageConfig) (storage.ChunkManager, error) } -type chunkMgr struct { - cached sync.Map +type chunkMgrFactory struct { + cached *typeutil.ConcurrentMap[string, storage.ChunkManager] } -func (m *chunkMgr) NewChunkManager(ctx context.Context, config *indexpb.StorageConfig) (storage.ChunkManager, error) { +func NewChunkMgrFactory() *chunkMgrFactory { + return &chunkMgrFactory{ + cached: typeutil.NewConcurrentMap[string, storage.ChunkManager](), + } +} + +func (m *chunkMgrFactory) NewChunkManager(ctx context.Context, config *indexpb.StorageConfig) (storage.ChunkManager, error) { key := m.cacheKey(config.StorageType, config.BucketName, config.Address) - if v, ok := m.cached.Load(key); ok { - return v.(storage.ChunkManager), nil + if v, ok := m.cached.Get(key); ok { + return v, nil } chunkManagerFactory := storage.NewChunkManagerFactoryWithParam(Params) @@ -29,11 +35,11 @@ func (m *chunkMgr) NewChunkManager(ctx context.Context, config *indexpb.StorageC if err != nil { return nil, err } - v, _ := m.cached.LoadOrStore(key, mgr) + v, _ := m.cached.GetOrInsert(key, mgr) log.Ctx(ctx).Info("index node successfully init chunk manager") - return v.(storage.ChunkManager), nil + return v, nil } -func (m *chunkMgr) cacheKey(storageType, bucket, address string) string { +func (m *chunkMgrFactory) cacheKey(storageType, bucket, address string) string { return fmt.Sprintf("%s/%s/%s", storageType, bucket, address) } diff --git a/internal/indexnode/indexnode.go b/internal/indexnode/indexnode.go index 911aa6c979..1ceed43b57 100644 --- a/internal/indexnode/indexnode.go +++ b/internal/indexnode/indexnode.go @@ -111,7 +111,7 @@ func NewIndexNode(ctx context.Context, factory dependency.Factory) *IndexNode { loopCtx: ctx1, loopCancel: cancel, factory: factory, - storageFactory: &chunkMgr{}, + storageFactory: NewChunkMgrFactory(), tasks: map[taskKey]*taskInfo{}, lifetime: lifetime.NewLifetime(commonpb.StateCode_Abnormal), } diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go index 183d538a0a..cecce0eaae 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go @@ -417,7 +417,7 @@ func (rmq *rocksmq) CreateTopic(topicName string) error { rmq.retentionInfo.mutex.Lock() defer rmq.retentionInfo.mutex.Unlock() - rmq.retentionInfo.topicRetetionTime.Store(topicName, time.Now().Unix()) + rmq.retentionInfo.topicRetetionTime.Insert(topicName, time.Now().Unix()) log.Debug("Rocksmq create topic successfully ", zap.String("topic", topicName), zap.Int64("elapsed", time.Since(start).Milliseconds())) return nil } @@ -480,7 +480,7 @@ func (rmq *rocksmq) DestroyTopic(topicName string) error { // clean up retention info topicMu.Delete(topicName) - rmq.retentionInfo.topicRetetionTime.Delete(topicName) + rmq.retentionInfo.topicRetetionTime.GetAndRemove(topicName) log.Debug("Rocksmq destroy topic successfully ", zap.String("topic", topicName), zap.Int64("elapsed", time.Since(start).Milliseconds())) return nil diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq_retention.go b/internal/mq/mqimpl/rocksmq/server/rocksmq_retention.go index bbe6ac3fbd..80ebec395d 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq_retention.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq_retention.go @@ -34,7 +34,7 @@ const ( type retentionInfo struct { // key is topic name, value is last retention time - topicRetetionTime sync.Map + topicRetetionTime *typeutil.ConcurrentMap[string, int64] mutex sync.RWMutex kv *rocksdbkv.RocksdbKV @@ -47,7 +47,7 @@ type retentionInfo struct { func initRetentionInfo(kv *rocksdbkv.RocksdbKV, db *gorocksdb.DB) (*retentionInfo, error) { ri := &retentionInfo{ - topicRetetionTime: sync.Map{}, + topicRetetionTime: typeutil.NewConcurrentMap[string, int64](), mutex: sync.RWMutex{}, kv: kv, db: db, @@ -61,7 +61,7 @@ func initRetentionInfo(kv *rocksdbkv.RocksdbKV, db *gorocksdb.DB) (*retentionInf } for _, key := range topicKeys { topic := key[len(TopicIDTitle):] - ri.topicRetetionTime.Store(topic, time.Now().Unix()) + ri.topicRetetionTime.Insert(topic, time.Now().Unix()) topicMu.Store(topic, new(sync.Mutex)) } return ri, nil @@ -99,19 +99,13 @@ func (ri *retentionInfo) retention() error { timeNow := t.Unix() checkTime := int64(params.RocksmqCfg.RetentionTimeInMinutes.GetAsFloat() * 60 / 10) ri.mutex.RLock() - ri.topicRetetionTime.Range(func(k, v interface{}) bool { - topic, _ := k.(string) - lastRetentionTs, ok := v.(int64) - if !ok { - log.Warn("Can't parse lastRetention to int64", zap.String("topic", topic), zap.Any("value", v)) - return true - } + ri.topicRetetionTime.Range(func(topic string, lastRetentionTs int64) bool { if lastRetentionTs+checkTime < timeNow { err := ri.expiredCleanUp(topic) if err != nil { - log.Warn("Retention expired clean failed", zap.Any("error", err)) + log.Warn("Retention expired clean failed", zap.Error(err)) } - ri.topicRetetionTime.Store(topic, timeNow) + ri.topicRetetionTime.Insert(topic, timeNow) } return true }) diff --git a/internal/proxy/msg_pack.go b/internal/proxy/msg_pack.go index 7be5521712..6f801c616f 100644 --- a/internal/proxy/msg_pack.go +++ b/internal/proxy/msg_pack.go @@ -18,7 +18,6 @@ package proxy import ( "context" - "sync" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -247,7 +246,7 @@ func repackInsertDataWithPartitionKey(ctx context.Context, } errGroup, _ := errgroup.WithContext(ctx) - partition2Msgs := sync.Map{} + partition2Msgs := typeutil.NewConcurrentMap[string, []msgstream.TsMsg]() for partitionName, offsets := range partition2RowOffsets { partitionName := partitionName offsets := offsets @@ -257,7 +256,7 @@ func repackInsertDataWithPartitionKey(ctx context.Context, return err } - partition2Msgs.Store(partitionName, msgs) + partition2Msgs.Insert(partitionName, msgs) return nil }) } @@ -271,8 +270,7 @@ func repackInsertDataWithPartitionKey(ctx context.Context, return nil, err } - partition2Msgs.Range(func(k, v interface{}) bool { - msgs := v.([]msgstream.TsMsg) + partition2Msgs.Range(func(name string, msgs []msgstream.TsMsg) bool { msgPack.Msgs = append(msgPack.Msgs, msgs...) return true }) diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 55b8856449..cdf71a3307 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -22,6 +22,7 @@ import ( "time" "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" "go.uber.org/atomic" "go.uber.org/zap" @@ -48,7 +49,7 @@ type Executor struct { // Merge load segment requests merger *Merger[segmentIndex, *querypb.LoadSegmentsRequest] - executingTasks sync.Map + executingTasks *typeutil.ConcurrentSet[int64] // taskID executingTaskNum atomic.Int32 } @@ -68,7 +69,7 @@ func NewExecutor(meta *meta.Meta, nodeMgr: nodeMgr, merger: NewMerger[segmentIndex, *querypb.LoadSegmentsRequest](), - executingTasks: sync.Map{}, + executingTasks: typeutil.NewConcurrentSet[int64](), } } @@ -86,12 +87,12 @@ func (ex *Executor) Stop() { // does nothing and returns false if the action is already committed, // returns true otherwise. func (ex *Executor) Execute(task Task, step int) bool { - _, exist := ex.executingTasks.LoadOrStore(task.ID(), struct{}{}) + exist := !ex.executingTasks.Insert(task.ID()) if exist { return false } if ex.executingTaskNum.Inc() > Params.QueryCoordCfg.TaskExecutionCap.GetAsInt32() { - ex.executingTasks.Delete(task.ID()) + ex.executingTasks.Remove(task.ID()) ex.executingTaskNum.Dec() return false } @@ -119,8 +120,7 @@ func (ex *Executor) Execute(task Task, step int) bool { } func (ex *Executor) Exist(taskID int64) bool { - _, ok := ex.executingTasks.Load(taskID) - return ok + return ex.executingTasks.Contain(taskID) } func (ex *Executor) scheduleRequests() { @@ -207,7 +207,7 @@ func (ex *Executor) removeTask(task Task, step int) { zap.Error(task.Err())) } - ex.executingTasks.Delete(task.ID()) + ex.executingTasks.Remove(task.ID()) ex.executingTaskNum.Dec() } diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 66ef47d921..dfc64a2992 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -1191,8 +1191,8 @@ func (suite *TaskSuite) dispatchAndWait(node int64) { keys = make([]any, 0) for _, executor := range suite.scheduler.executors { - executor.executingTasks.Range(func(key, value any) bool { - keys = append(keys, key) + executor.executingTasks.Range(func(taskID int64) bool { + keys = append(keys, taskID) count++ return true }) diff --git a/internal/rootcoord/dml_channels.go b/internal/rootcoord/dml_channels.go index 9be083f4af..af9df06300 100644 --- a/internal/rootcoord/dml_channels.go +++ b/internal/rootcoord/dml_channels.go @@ -142,7 +142,7 @@ type dmlChannels struct { namePrefix string capacity int64 // pool maintains channelName => dmlMsgStream mapping, stable - pool sync.Map + pool *typeutil.ConcurrentMap[string, *dmlMsgStream] // mut protects channelsHeap only mut sync.Mutex // channelsHeap is the heap to pop next dms for use @@ -174,6 +174,7 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref namePrefix: chanNamePrefix, capacity: chanNum, channelsHeap: make([]*dmlMsgStream, 0, chanNum), + pool: typeutil.NewConcurrentMap[string, *dmlMsgStream](), } for i, name := range names { @@ -206,7 +207,7 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref idx: int64(i), pos: i, } - d.pool.Store(name, dms) + d.pool.Insert(name, dms) d.channelsHeap = append(d.channelsHeap, dms) } @@ -247,8 +248,7 @@ func (d *dmlChannels) listChannels() []string { var chanNames []string d.pool.Range( - func(k, v interface{}) bool { - dms := v.(*dmlMsgStream) + func(channel string, dms *dmlMsgStream) bool { if dms.RefCnt() > 0 { chanNames = append(chanNames, getChannelName(d.namePrefix, dms.idx)) } @@ -262,12 +262,12 @@ func (d *dmlChannels) getChannelNum() int { } func (d *dmlChannels) getMsgStreamByName(chanName string) (*dmlMsgStream, error) { - v, ok := d.pool.Load(chanName) + dms, ok := d.pool.Get(chanName) if !ok { log.Error("invalid channelName", zap.String("chanName", chanName)) return nil, errors.Newf("invalid channel name: %s", chanName) } - return v.(*dmlMsgStream), nil + return dms, nil } func (d *dmlChannels) broadcast(chanNames []string, pack *msgstream.MsgPack) error { diff --git a/internal/rootcoord/timeticksync.go b/internal/rootcoord/timeticksync.go index 177bd9b33a..698069cd9c 100644 --- a/internal/rootcoord/timeticksync.go +++ b/internal/rootcoord/timeticksync.go @@ -47,28 +47,31 @@ var ( ) type ttHistogram struct { - sync.Map + *typeutil.ConcurrentMap[string, Timestamp] } func newTtHistogram() *ttHistogram { - return &ttHistogram{} + return &ttHistogram{ + ConcurrentMap: typeutil.NewConcurrentMap[string, Timestamp](), + } } func (h *ttHistogram) update(channel string, ts Timestamp) { - h.Store(channel, ts) + h.Insert(channel, ts) + } func (h *ttHistogram) get(channel string) Timestamp { - ts, ok := h.Load(channel) + ts, ok := h.Get(channel) if !ok { return typeutil.ZeroTimestamp } - return ts.(Timestamp) + return ts } func (h *ttHistogram) remove(channels ...string) { for _, channel := range channels { - h.Delete(channel) + h.GetAndRemove(channel) } } diff --git a/pkg/mq/msgdispatcher/client.go b/pkg/mq/msgdispatcher/client.go index 7bea3d6d7f..48c2455ff0 100644 --- a/pkg/mq/msgdispatcher/client.go +++ b/pkg/mq/msgdispatcher/client.go @@ -17,8 +17,6 @@ package msgdispatcher import ( - "sync" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "go.uber.org/zap" @@ -26,6 +24,7 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type ( @@ -45,15 +44,16 @@ var _ Client = (*client)(nil) type client struct { role string nodeID int64 - managers sync.Map // pchannel -> DispatcherManager + managers *typeutil.ConcurrentMap[string, DispatcherManager] factory msgstream.Factory } func NewClient(factory msgstream.Factory, role string, nodeID int64) Client { return &client{ - role: role, - nodeID: nodeID, - factory: factory, + role: role, + nodeID: nodeID, + factory: factory, + managers: typeutil.NewConcurrentMap[string, DispatcherManager](), } } @@ -62,20 +62,17 @@ func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *Msg zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) pchannel := funcutil.ToPhysicalChannel(vchannel) var manager DispatcherManager - res, ok := c.managers.Load(pchannel) + manager, ok := c.managers.Get(pchannel) if !ok { manager = NewDispatcherManager(pchannel, c.role, c.nodeID, c.factory) - c.managers.Store(pchannel, manager) + c.managers.Insert(pchannel, manager) go manager.Run() - } else { - manager, _ = res.(DispatcherManager) } - ch, err := manager.Add(vchannel, pos, subPos) if err != nil { if manager.Num() == 0 { manager.Close() - c.managers.Delete(pchannel) + c.managers.GetAndRemove(pchannel) } log.Error("register failed", zap.Error(err)) return nil, err @@ -86,12 +83,11 @@ func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *Msg func (c *client) Deregister(vchannel string) { pchannel := funcutil.ToPhysicalChannel(vchannel) - if res, ok := c.managers.Load(pchannel); ok { - manager, _ := res.(DispatcherManager) + if manager, ok := c.managers.Get(pchannel); ok { manager.Remove(vchannel) if manager.Num() == 0 { manager.Close() - c.managers.Delete(pchannel) + c.managers.GetAndRemove(pchannel) } log.Info("deregister done", zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) @@ -101,11 +97,9 @@ func (c *client) Deregister(vchannel string) { func (c *client) Close() { log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID)) - c.managers.Range(func(key, value any) bool { - pchannel := key.(string) - manager := value.(DispatcherManager) + c.managers.Range(func(pchannel string, manager DispatcherManager) bool { log.Info("close manager", zap.String("channel", pchannel)) - c.managers.Delete(pchannel) + c.managers.GetAndRemove(pchannel) manager.Close() return true }) diff --git a/pkg/mq/msgdispatcher/client_test.go b/pkg/mq/msgdispatcher/client_test.go index 70a9851342..eee7901e08 100644 --- a/pkg/mq/msgdispatcher/client_test.go +++ b/pkg/mq/msgdispatcher/client_test.go @@ -61,10 +61,6 @@ func TestClient_Concurrency(t *testing.T) { wg.Wait() expected := int(total - deregisterCount.Load()) - var n int - client1.(*client).managers.Range(func(_, _ any) bool { - n++ - return true - }) + n := client1.(*client).managers.Len() assert.Equal(t, expected, n) } diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index 2268f9e6fc..455d609aad 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -69,7 +69,7 @@ type Dispatcher struct { curTs atomic.Uint64 lagNotifyChan chan struct{} - lagTargets *sync.Map // vchannel -> *target + lagTargets *typeutil.ConcurrentMap[string, *target] // vchannel -> *target // vchannel -> *target, lock free since we guarantee that // it's modified only after dispatcher paused or terminated @@ -85,7 +85,7 @@ func NewDispatcher(factory msgstream.Factory, subName string, subPos SubPos, lagNotifyChan chan struct{}, - lagTargets *sync.Map, + lagTargets *typeutil.ConcurrentMap[string, *target], ) (*Dispatcher, error) { log := log.With(zap.String("pchannel", pchannel), zap.String("subName", subName), zap.Bool("isMain", isMain)) @@ -227,7 +227,7 @@ func (d *Dispatcher) work() { t.pos = pack.StartPositions[0] // replace the pChannel with vChannel t.pos.ChannelName = t.vchannel - d.lagTargets.LoadOrStore(t.vchannel, t) + d.lagTargets.Insert(t.vchannel, t) d.nonBlockingNotify() delete(d.targets, vchannel) log.Warn("lag target notified", zap.Error(err)) diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index fe849eaa20..4546dc6bf9 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -54,7 +54,7 @@ type dispatcherManager struct { pchannel string lagNotifyChan chan struct{} - lagTargets *sync.Map // vchannel -> *target + lagTargets *typeutil.ConcurrentMap[string, *target] // vchannel -> *target mu sync.RWMutex // guards mainDispatcher and soloDispatchers mainDispatcher *Dispatcher @@ -73,7 +73,7 @@ func NewDispatcherManager(pchannel string, role string, nodeID int64, factory ms nodeID: nodeID, pchannel: pchannel, lagNotifyChan: make(chan struct{}, 1), - lagTargets: &sync.Map{}, + lagTargets: typeutil.NewConcurrentMap[string, *target](), soloDispatchers: make(map[string]*Dispatcher), factory: factory, closeChan: make(chan struct{}), @@ -132,7 +132,7 @@ func (c *dispatcherManager) Remove(vchannel string) { c.deleteMetric(vchannel) log.Info("remove soloDispatcher done") } - c.lagTargets.Delete(vchannel) + c.lagTargets.GetAndRemove(vchannel) } func (c *dispatcherManager) Num() int { @@ -170,9 +170,9 @@ func (c *dispatcherManager) Run() { c.tryMerge() case <-c.lagNotifyChan: c.mu.Lock() - c.lagTargets.Range(func(vchannel, t any) bool { - c.split(t.(*target)) - c.lagTargets.Delete(vchannel) + c.lagTargets.Range(func(vchannel string, t *target) bool { + c.split(t) + c.lagTargets.GetAndRemove(vchannel) return true }) c.mu.Unlock() diff --git a/pkg/util/cgoconverter/bytes_converter.go b/pkg/util/cgoconverter/bytes_converter.go index 2d1741f7bf..0cd7469563 100644 --- a/pkg/util/cgoconverter/bytes_converter.go +++ b/pkg/util/cgoconverter/bytes_converter.go @@ -7,9 +7,10 @@ import "C" import ( "math" - "sync" "sync/atomic" "unsafe" + + "github.com/milvus-io/milvus/pkg/util/typeutil" ) const maxByteArrayLen = math.MaxInt32 @@ -17,20 +18,20 @@ const maxByteArrayLen = math.MaxInt32 var globalConverter = NewBytesConverter() type BytesConverter struct { - pointers sync.Map // leaseId -> unsafe.Pointer + pointers *typeutil.ConcurrentMap[int32, unsafe.Pointer] // leaseId -> unsafe.Pointer nextLease int32 } func NewBytesConverter() *BytesConverter { return &BytesConverter{ - pointers: sync.Map{}, + pointers: typeutil.NewConcurrentMap[int32, unsafe.Pointer](), nextLease: 0, } } func (converter *BytesConverter) add(p unsafe.Pointer) int32 { lease := atomic.AddInt32(&converter.nextLease, 1) - converter.pointers.Store(lease, p) + converter.pointers.Insert(lease, p) return lease } @@ -63,26 +64,19 @@ func (converter *BytesConverter) Release(lease int32) { } func (converter *BytesConverter) Extract(lease int32) unsafe.Pointer { - pI, ok := converter.pointers.LoadAndDelete(lease) + p, ok := converter.pointers.GetAndRemove(lease) if !ok { panic("try to release the resource that doesn't exist") } - p, ok := pI.(unsafe.Pointer) - if !ok { - panic("incorrect value type") - } - return p } // Make sure only the caller own the converter // or this would release someone's memory func (converter *BytesConverter) ReleaseAll() { - converter.pointers.Range(func(key, value interface{}) bool { - pointer := value.(unsafe.Pointer) - - converter.pointers.Delete(key) + converter.pointers.Range(func(lease int32, pointer unsafe.Pointer) bool { + converter.pointers.GetAndRemove(lease) C.free(pointer) return true diff --git a/pkg/util/timerecord/group_checker.go b/pkg/util/timerecord/group_checker.go index a5fd89a28b..d8502884d7 100644 --- a/pkg/util/timerecord/group_checker.go +++ b/pkg/util/timerecord/group_checker.go @@ -19,20 +19,22 @@ package timerecord import ( "sync" "time" + + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // groups maintains string to GroupChecker -var groups sync.Map +var groups = typeutil.NewConcurrentMap[string, *GroupChecker]() // GroupChecker checks members in same group silent for certain period of time // print warning msg if there are item(s) that not reported type GroupChecker struct { groupName string - d time.Duration // check duration - t *time.Ticker // internal ticker - ch chan struct{} // closing signal - lastest sync.Map // map member name => lastest report time + d time.Duration // check duration + t *time.Ticker // internal ticker + ch chan struct{} // closing signal + lastest *typeutil.ConcurrentMap[string, time.Time] // map member name => lastest report time initOnce sync.Once stopOnce sync.Once @@ -52,8 +54,6 @@ func (gc *GroupChecker) init() { func (gc *GroupChecker) work() { gc.t = time.NewTicker(gc.d) defer gc.t.Stop() - var name string - var ts time.Time for { select { @@ -63,9 +63,7 @@ func (gc *GroupChecker) work() { } var list []string - gc.lastest.Range(func(k, v interface{}) bool { - name = k.(string) - ts = v.(time.Time) + gc.lastest.Range(func(name string, ts time.Time) bool { if time.Since(ts) > gc.d { list = append(list, name) } @@ -79,19 +77,19 @@ func (gc *GroupChecker) work() { // Check updates the latest timestamp for provided name func (gc *GroupChecker) Check(name string) { - gc.lastest.Store(name, time.Now()) + gc.lastest.Insert(name, time.Now()) } // Remove deletes name from watch list func (gc *GroupChecker) Remove(name string) { - gc.lastest.Delete(name) + gc.lastest.GetAndRemove(name) } // Stop closes the GroupChecker func (gc *GroupChecker) Stop() { gc.stopOnce.Do(func() { close(gc.ch) - groups.Delete(gc.groupName) + groups.GetAndRemove(gc.groupName) }) } @@ -103,12 +101,12 @@ func GetGroupChecker(groupName string, duration time.Duration, fn func([]string) groupName: groupName, d: duration, fn: fn, + lastest: typeutil.NewConcurrentMap[string, time.Time](), } - actual, loaded := groups.LoadOrStore(groupName, gc) + gc, loaded := groups.GetOrInsert(groupName, gc) if !loaded { gc.init() } - gc = actual.(*GroupChecker) return gc }