diff --git a/internal/rootcoord/dml_channels.go b/internal/rootcoord/dml_channels.go index e744ad42f9..10e51f1936 100644 --- a/internal/rootcoord/dml_channels.go +++ b/internal/rootcoord/dml_channels.go @@ -158,7 +158,7 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref // if topic created, use the existed topic if params.PreCreatedTopicEnabled.GetAsBool() { chanNamePrefix = "" - chanNum = int64(len(params.TopicNames.GetAsStrings())) + chanNumDefault + chanNum = int64(len(params.TopicNames.GetAsStrings())) names = params.TopicNames.GetAsStrings() } else { chanNamePrefix = chanNamePrefixDefault @@ -361,31 +361,62 @@ func genChannelNames(prefix string, num int64) []string { return results } -func parseChannelNameIndex(channeName string) int { - index := strings.LastIndex(channeName, "_") +func parseChannelNameIndex(channelName string) int { + index := strings.LastIndex(channelName, "_") if index < 0 { - log.Error("invalid channel name", zap.String("chanName", channeName)) - panic("invalid channel name: " + channeName) + log.Error("invalid channel name", zap.String("chanName", channelName)) + panic("invalid channel name: " + channelName) } - index, err := strconv.Atoi(channeName[index+1:]) + index, err := strconv.Atoi(channelName[index+1:]) if err != nil { - log.Error("invalid channel name", zap.String("chanName", channeName), zap.Error(err)) - panic("invalid channel name: " + channeName) + log.Error("invalid channel name", zap.String("chanName", channelName), zap.Error(err)) + panic("invalid channel name: " + channelName) } return index } func getNeedChanNum(setNum int, chanMap map[typeutil.UniqueID][]string) int { // find the largest number of current channel usage - maxChanUsed := setNum - for _, chanNames := range chanMap { - for _, chanName := range chanNames { - index := parseChannelNameIndex(chanName) - if maxChanUsed < index+1 { - maxChanUsed = index + 1 + maxChanUsed := 0 + isPreCreatedTopicEnabled := paramtable.Get().CommonCfg.PreCreatedTopicEnabled.GetAsBool() + chanNameSet := typeutil.NewSet[string]() + + if isPreCreatedTopicEnabled { + // can only use the topic in the list when preCreatedTopicEnabled + topics := paramtable.Get().CommonCfg.TopicNames.GetAsStrings() + + if len(topics) == 0 { + panic("no topic were specified when pre-created") + } + for _, topic := range topics { + if len(topic) == 0 { + panic("topic were empty") + } + if chanNameSet.Contain(topic) { + log.Error("duplicate topics are pre-created", zap.String("topic", topic)) + panic("duplicate topic: " + topic) + } + chanNameSet.Insert(topic) + } + + for _, chanNames := range chanMap { + for _, chanName := range chanNames { + if !chanNameSet.Contain(chanName) { + log.Error("invalid channel that is not in the list when pre-created topic", zap.String("chanName", chanName)) + panic("invalid chanName: " + chanName) + } + } + } + } else { + maxChanUsed = setNum + for _, chanNames := range chanMap { + for _, chanName := range chanNames { + index := parseChannelNameIndex(chanName) + if maxChanUsed < index+1 { + maxChanUsed = index + 1 + } } } } - return maxChanUsed } diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index b4b13c1b5d..51d382f122 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -196,6 +196,65 @@ func TestDmChannelsFailure(t *testing.T) { wg.Wait() } +func TestGetNeedChanNum(t *testing.T) { + paramtable.Get().Save(Params.CommonCfg.PreCreatedTopicEnabled.Key, "true") + defer paramtable.Get().Reset(Params.CommonCfg.PreCreatedTopicEnabled.Key) + chans := map[UniqueID][]string{} + + var wg sync.WaitGroup + wg.Add(1) + t.Run("topic were empty", func(t *testing.T) { + defer wg.Done() + paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "") + defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key) + assert.Panics(t, func() { + getNeedChanNum(10, chans) + }) + }) + + wg.Add(1) + t.Run("duplicated topics", func(t *testing.T) { + defer wg.Done() + paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "topic1,topic1") + defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key) + assert.Panics(t, func() { + getNeedChanNum(10, chans) + }) + }) + + wg.Add(1) + t.Run("invalid channel channel that not in the list", func(t *testing.T) { + defer wg.Done() + paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "topic1,topic2") + defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key) + chans[UniqueID(100)] = []string{"rootcoord-dml_0"} + assert.Panics(t, func() { + getNeedChanNum(10, chans) + }) + }) + + wg.Add(1) + t.Run("normal case when pre-created topic", func(t *testing.T) { + defer wg.Done() + paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "topic1,topic2") + defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key) + chans[UniqueID(100)] = []string{"topic1"} + assert.Equal(t, getNeedChanNum(10, chans), 0) + }) + + wg.Add(1) + t.Run("normal case", func(t *testing.T) { + defer wg.Done() + paramtable.Get().Save(Params.CommonCfg.PreCreatedTopicEnabled.Key, "false") + paramtable.Get().Save(Params.CommonCfg.RootCoordDml.Key, "rootcoord-dml") + defer paramtable.Get().Reset(Params.CommonCfg.RootCoordDml.Key) + chans[UniqueID(100)] = []string{"rootcoord-dml_99"} + assert.Equal(t, getNeedChanNum(10, chans), 100) + }) + + wg.Wait() +} + // FailMessageStreamFactory mock MessageStreamFactory failure type FailMessageStreamFactory struct { msgstream.Factory