diff --git a/configs/advanced/channel.yaml b/configs/advanced/channel.yaml index b166abdf68..13684e4bf4 100644 --- a/configs/advanced/channel.yaml +++ b/configs/advanced/channel.yaml @@ -30,10 +30,10 @@ msgChannel: queryNodeSubNamePrefix: "queryNode" writeNodeSubNamePrefix: "writeNode" - # default channel range [0, 0] + # default channel range [0, 1) channelRange: - insert: [0, 15] - delete: [0, 15] - k2s: [0, 15] - search: [0, 0] + insert: [0, 1] + delete: [0, 1] + k2s: [0, 1] + search: [0, 1] searchResult: [0, 1] \ No newline at end of file diff --git a/configs/advanced/proxy.yaml b/configs/advanced/proxy.yaml index 71cb3006c2..cc98ad4d85 100644 --- a/configs/advanced/proxy.yaml +++ b/configs/advanced/proxy.yaml @@ -25,4 +25,6 @@ proxy: pulsarBufSize: 1024 # pulsar chan buffer size timeTick: - bufSize: 512 \ No newline at end of file + bufSize: 512 + + maxNameLength: 255 diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 505272da2d..f15a5db9f5 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -11,9 +11,9 @@ nodeID: # will be deprecated after v0.2 - proxyIDList: [1, 2] - queryNodeIDList: [3, 4] - writeNodeIDList: [5, 6] + proxyIDList: [1] + queryNodeIDList: [2] + writeNodeIDList: [3] etcd: address: localhost diff --git a/internal/master/param_table.go b/internal/master/param_table.go index 4c03e45c9c..872bdab34e 100644 --- a/internal/master/param_table.go +++ b/internal/master/param_table.go @@ -360,18 +360,33 @@ func (p *ParamTable) initInsertChannelNames() { if err != nil { log.Fatal(err) } - id, err := p.Load("nodeID.queryNodeIDList") + channelRange, err := p.Load("msgChannel.channelRange.insert") if err != nil { - log.Panicf("load query node id list error, %s", err.Error()) + panic(err) } - ids := strings.Split(id, ",") - channels := make([]string, 0, len(ids)) - for _, i := range ids { - _, err := strconv.ParseInt(i, 10, 64) - if err != nil { - log.Panicf("load query node id list error, %s", err.Error()) - } - channels = append(channels, ch+"-"+i) + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) } p.InsertChannelNames = channels } diff --git a/internal/master/param_table_test.go b/internal/master/param_table_test.go index 8940e961c0..c35213e626 100644 --- a/internal/master/param_table_test.go +++ b/internal/master/param_table_test.go @@ -31,7 +31,7 @@ func TestParamTable_EtcdRootPath(t *testing.T) { func TestParamTable_TopicNum(t *testing.T) { Params.Init() num := Params.TopicNum - assert.Equal(t, num, 15) + assert.Equal(t, num, 1) } func TestParamTable_SegmentSize(t *testing.T) { @@ -73,7 +73,7 @@ func TestParamTable_SegIDAssignExpiration(t *testing.T) { func TestParamTable_QueryNodeNum(t *testing.T) { Params.Init() num := Params.QueryNodeNum - assert.Equal(t, num, 2) + assert.Equal(t, num, 1) } func TestParamTable_QueryNodeStatsChannelName(t *testing.T) { @@ -85,17 +85,15 @@ func TestParamTable_QueryNodeStatsChannelName(t *testing.T) { func TestParamTable_ProxyIDList(t *testing.T) { Params.Init() ids := Params.ProxyIDList - assert.Equal(t, len(ids), 2) + assert.Equal(t, len(ids), 1) assert.Equal(t, ids[0], int64(1)) - assert.Equal(t, ids[1], int64(2)) } func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) { Params.Init() names := Params.ProxyTimeTickChannelNames - assert.Equal(t, len(names), 2) + assert.Equal(t, len(names), 1) assert.Equal(t, names[0], "proxyTimeTick-1") - assert.Equal(t, names[1], "proxyTimeTick-2") } func TestParamTable_MsgChannelSubName(t *testing.T) { @@ -113,31 +111,27 @@ func TestParamTable_SoftTimeTickBarrierInterval(t *testing.T) { func TestParamTable_WriteNodeIDList(t *testing.T) { Params.Init() ids := Params.WriteNodeIDList - assert.Equal(t, len(ids), 2) - assert.Equal(t, ids[0], int64(5)) - assert.Equal(t, ids[1], int64(6)) + assert.Equal(t, len(ids), 1) + assert.Equal(t, ids[0], int64(3)) } func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) { Params.Init() names := Params.WriteNodeTimeTickChannelNames - assert.Equal(t, len(names), 2) - assert.Equal(t, names[0], "writeNodeTimeTick-5") - assert.Equal(t, names[1], "writeNodeTimeTick-6") + assert.Equal(t, len(names), 1) + assert.Equal(t, names[0], "writeNodeTimeTick-3") } func TestParamTable_InsertChannelNames(t *testing.T) { Params.Init() names := Params.InsertChannelNames - assert.Equal(t, len(names), 2) - assert.Equal(t, names[0], "insert-3") - assert.Equal(t, names[1], "insert-4") + assert.Equal(t, len(names), 1) + assert.Equal(t, names[0], "insert-0") } func TestParamTable_K2SChannelNames(t *testing.T) { Params.Init() names := Params.K2SChannelNames - assert.Equal(t, len(names), 2) - assert.Equal(t, names[0], "k2s-5") - assert.Equal(t, names[1], "k2s-6") + assert.Equal(t, len(names), 1) + assert.Equal(t, names[0], "k2s-3") } diff --git a/internal/proxy/grpc_service.go b/internal/proxy/grpc_service.go index ba0f92dc39..dbccf5ac44 100644 --- a/internal/proxy/grpc_service.go +++ b/internal/proxy/grpc_service.go @@ -82,9 +82,8 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc Schema: &commonpb.Blob{}, }, masterClient: p.masterClient, + schema: req, } - schemaBytes, _ := proto.Marshal(req) - cct.CreateCollectionRequest.Schema.Value = schemaBytes var cancel func() cct.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval) defer cancel() @@ -125,6 +124,7 @@ func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.Qu }, queryMsgStream: p.queryMsgStream, resultBuf: make(chan []*internalpb.SearchResult), + query: req, } var cancel func() qt.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval) diff --git a/internal/proxy/paramtable.go b/internal/proxy/paramtable.go index 0ad0775a2c..ba60abe634 100644 --- a/internal/proxy/paramtable.go +++ b/internal/proxy/paramtable.go @@ -322,3 +322,123 @@ func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 { func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 { return pt.parseInt64("proxy.msgStream.timeTick.bufSize") } + +func (pt *ParamTable) insertChannelNames() []string { + ch, err := pt.Load("msgChannel.chanNamePrefix.insert") + if err != nil { + log.Fatal(err) + } + channelRange, err := pt.Load("msgChannel.channelRange.insert") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels +} + +func (pt *ParamTable) searchChannelNames() []string { + ch, err := pt.Load("msgChannel.chanNamePrefix.search") + if err != nil { + log.Fatal(err) + } + channelRange, err := pt.Load("msgChannel.channelRange.search") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels +} + +func (pt *ParamTable) searchResultChannelNames() []string { + ch, err := pt.Load("msgChannel.chanNamePrefix.search") + if err != nil { + log.Fatal(err) + } + channelRange, err := pt.Load("msgChannel.channelRange.search") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels +} + +func (pt *ParamTable) MaxNameLength() int64 { + str, err := pt.Load("proxy.maxNameLength") + if err != nil { + panic(err) + } + maxNameLength, err := strconv.ParseInt(str, 10, 64) + if err != nil { + panic(err) + } + return maxNameLength +} diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 538fd05326..e2f83e50b9 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -5,11 +5,13 @@ import ( "errors" "log" + "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/allocator" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) @@ -64,6 +66,15 @@ func (it *InsertTask) Type() internalpb.MsgType { } func (it *InsertTask) PreExecute() error { + collectionName := it.BaseInsertTask.CollectionName + if err := ValidateCollectionName(collectionName); err != nil { + return err + } + partitionTag := it.BaseInsertTask.PartitionTag + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } + return nil } @@ -120,6 +131,7 @@ type CreateCollectionTask struct { masterClient masterpb.MasterClient result *commonpb.Status ctx context.Context + schema *schemapb.CollectionSchema } func (cct *CreateCollectionTask) ID() UniqueID { @@ -147,10 +159,24 @@ func (cct *CreateCollectionTask) SetTs(ts Timestamp) { } func (cct *CreateCollectionTask) PreExecute() error { + // validate collection name + if err := ValidateCollectionName(cct.schema.Name); err != nil { + return err + } + + // validate field name + for _, field := range cct.schema.Fields { + if err := ValidateFieldName(field.Name); err != nil { + return err + } + } + return nil } func (cct *CreateCollectionTask) Execute() error { + schemaBytes, _ := proto.Marshal(cct.schema) + cct.CreateCollectionRequest.Schema.Value = schemaBytes resp, err := cct.masterClient.CreateCollection(cct.ctx, &cct.CreateCollectionRequest) if err != nil { log.Printf("create collection failed, error= %v", err) @@ -201,6 +227,9 @@ func (dct *DropCollectionTask) SetTs(ts Timestamp) { } func (dct *DropCollectionTask) PreExecute() error { + if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil { + return err + } return nil } @@ -229,6 +258,7 @@ type QueryTask struct { resultBuf chan []*internalpb.SearchResult result *servicepb.QueryResult ctx context.Context + query *servicepb.Query } func (qt *QueryTask) ID() UniqueID { @@ -256,6 +286,15 @@ func (qt *QueryTask) SetTs(ts Timestamp) { } func (qt *QueryTask) PreExecute() error { + if err := ValidateCollectionName(qt.query.CollectionName); err != nil { + return err + } + + for _, tag := range qt.query.PartitionTags { + if err := ValidatePartitionTag(tag, false); err != nil { + return err + } + } return nil } @@ -367,6 +406,9 @@ func (hct *HasCollectionTask) SetTs(ts Timestamp) { } func (hct *HasCollectionTask) PreExecute() error { + if err := ValidateCollectionName(hct.CollectionName.CollectionName); err != nil { + return err + } return nil } @@ -424,6 +466,9 @@ func (dct *DescribeCollectionTask) SetTs(ts Timestamp) { } func (dct *DescribeCollectionTask) PreExecute() error { + if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil { + return err + } return nil } @@ -532,6 +577,16 @@ func (cpt *CreatePartitionTask) SetTs(ts Timestamp) { } func (cpt *CreatePartitionTask) PreExecute() error { + collName, partitionTag := cpt.PartitionName.CollectionName, cpt.PartitionName.Tag + + if err := ValidateCollectionName(collName); err != nil { + return err + } + + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } + return nil } @@ -577,6 +632,16 @@ func (dpt *DropPartitionTask) SetTs(ts Timestamp) { } func (dpt *DropPartitionTask) PreExecute() error { + collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag + + if err := ValidateCollectionName(collName); err != nil { + return err + } + + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } + return nil } @@ -622,6 +687,15 @@ func (hpt *HasPartitionTask) SetTs(ts Timestamp) { } func (hpt *HasPartitionTask) PreExecute() error { + collName, partitionTag := hpt.PartitionName.CollectionName, hpt.PartitionName.Tag + + if err := ValidateCollectionName(collName); err != nil { + return err + } + + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } return nil } @@ -667,6 +741,15 @@ func (dpt *DescribePartitionTask) SetTs(ts Timestamp) { } func (dpt *DescribePartitionTask) PreExecute() error { + collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag + + if err := ValidateCollectionName(collName); err != nil { + return err + } + + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } return nil } @@ -712,6 +795,9 @@ func (spt *ShowPartitionsTask) SetTs(ts Timestamp) { } func (spt *ShowPartitionsTask) PreExecute() error { + if err := ValidateCollectionName(spt.CollectionName.CollectionName); err != nil { + return err + } return nil } diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go new file mode 100644 index 0000000000..8049595f28 --- /dev/null +++ b/internal/proxy/validate_util.go @@ -0,0 +1,118 @@ +package proxy + +import ( + "strconv" + "strings" + + "github.com/zilliztech/milvus-distributed/internal/errors" +) + +func isAlpha(c uint8) bool { + if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') { + return false + } + return true +} + +func isNumber(c uint8) bool { + if c < '0' || c > '9' { + return false + } + return true +} + +func ValidateCollectionName(collName string) error { + collName = strings.TrimSpace(collName) + + if collName == "" { + return errors.New("Collection name should not be empty") + } + + invalidMsg := "Invalid collection name: " + collName + ". " + if int64(len(collName)) > Params.MaxNameLength() { + msg := invalidMsg + "The length of a collection name must be less than " + + strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + return errors.New(msg) + } + + firstChar := collName[0] + if firstChar != '_' && !isAlpha(firstChar) { + msg := invalidMsg + "The first character of a collection name must be an underscore or letter." + return errors.New(msg) + } + + for i := 1; i < len(collName); i++ { + c := collName[i] + if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) { + msg := invalidMsg + "Collection name can only contain numbers, letters, dollars and underscores." + return errors.New(msg) + } + } + return nil +} + +func ValidatePartitionTag(partitionTag string, strictCheck bool) error { + partitionTag = strings.TrimSpace(partitionTag) + + invalidMsg := "Invalid partition tag: " + partitionTag + ". " + if partitionTag == "" { + msg := invalidMsg + "Partition tag should not be empty." + return errors.New(msg) + } + + if int64(len(partitionTag)) > Params.MaxNameLength() { + msg := invalidMsg + "The length of a partition tag must be less than " + + strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + return errors.New(msg) + } + + if strictCheck { + firstChar := partitionTag[0] + if firstChar != '_' && !isAlpha(firstChar) { + msg := invalidMsg + "The first character of a partition tag must be an underscore or letter." + return errors.New(msg) + } + + tagSize := len(partitionTag) + for i := 1; i < tagSize; i++ { + c := partitionTag[i] + if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) { + msg := invalidMsg + "Partition tag can only contain numbers, letters, dollars and underscores." + return errors.New(msg) + } + } + } + + return nil +} + +func ValidateFieldName(fieldName string) error { + fieldName = strings.TrimSpace(fieldName) + + if fieldName == "" { + return errors.New("Field name should not be empty") + } + + invalidMsg := "Invalid field name: " + fieldName + ". " + if int64(len(fieldName)) > Params.MaxNameLength() { + msg := invalidMsg + "The length of a field name must be less than " + + strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + return errors.New(msg) + } + + firstChar := fieldName[0] + if firstChar != '_' && !isAlpha(firstChar) { + msg := invalidMsg + "The first character of a field name must be an underscore or letter." + return errors.New(msg) + } + + fieldNameSize := len(fieldName) + for i := 1; i < fieldNameSize; i++ { + c := fieldName[i] + if c != '_' && !isAlpha(c) && !isNumber(c) { + msg := invalidMsg + "Field name cannot only contain numbers, letters, and underscores." + return errors.New(msg) + } + } + return nil +} diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go new file mode 100644 index 0000000000..336425f9de --- /dev/null +++ b/internal/proxy/validate_util_test.go @@ -0,0 +1,84 @@ +package proxy + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateCollectionName(t *testing.T) { + Params.Init() + assert.Nil(t, ValidateCollectionName("abc")) + assert.Nil(t, ValidateCollectionName("_123abc")) + assert.Nil(t, ValidateCollectionName("abc123_$")) + + longName := make([]byte, 256) + for i := 0; i < len(longName); i++ { + longName[i] = 'a' + } + invalidNames := []string{ + "123abc", + "$abc", + "_12 ac", + " ", + "", + string(longName), + "中文", + } + + for _, name := range invalidNames { + assert.NotNil(t, ValidateCollectionName(name)) + } +} + +func TestValidatePartitionTag(t *testing.T) { + Params.Init() + assert.Nil(t, ValidatePartitionTag("abc", true)) + assert.Nil(t, ValidatePartitionTag("_123abc", true)) + assert.Nil(t, ValidatePartitionTag("abc123_$", true)) + + longName := make([]byte, 256) + for i := 0; i < len(longName); i++ { + longName[i] = 'a' + } + invalidNames := []string{ + "123abc", + "$abc", + "_12 ac", + " ", + "", + string(longName), + "中文", + } + + for _, name := range invalidNames { + assert.NotNil(t, ValidatePartitionTag(name, true)) + } + + assert.Nil(t, ValidatePartitionTag("ab cd", false)) + assert.Nil(t, ValidatePartitionTag("ab*", false)) +} + +func TestValidateFieldName(t *testing.T) { + Params.Init() + assert.Nil(t, ValidateFieldName("abc")) + assert.Nil(t, ValidateFieldName("_123abc")) + + longName := make([]byte, 256) + for i := 0; i < len(longName); i++ { + longName[i] = 'a' + } + invalidNames := []string{ + "123abc", + "$abc", + "_12 ac", + " ", + "", + string(longName), + "中文", + } + + for _, name := range invalidNames { + assert.NotNil(t, ValidateFieldName(name)) + } +} diff --git a/internal/reader/flow_graph_msg_stream_input_nodes.go b/internal/reader/flow_graph_msg_stream_input_nodes.go index b5ee6a581b..a1c08e951a 100644 --- a/internal/reader/flow_graph_msg_stream_input_nodes.go +++ b/internal/reader/flow_graph_msg_stream_input_nodes.go @@ -17,7 +17,7 @@ func newDmInputNode(ctx context.Context) *flowgraph.InputNode { log.Fatal(err) } - consumeChannels := []string{"insert"} + consumeChannels := Params.insertChannelNames() consumeSubName := "insertSub" insertStream := msgstream.NewPulsarTtMsgStream(ctx, receiveBufSize) diff --git a/internal/reader/param_table.go b/internal/reader/param_table.go index 81be65b28a..c104130afb 100644 --- a/internal/reader/param_table.go +++ b/internal/reader/param_table.go @@ -1,7 +1,9 @@ package reader import ( + "log" "strconv" + "strings" "github.com/zilliztech/milvus-distributed/internal/util/paramtable" ) @@ -18,6 +20,10 @@ func (p *ParamTable) Init() { if err != nil { panic(err) } + err = p.LoadYaml("advanced/channel.yaml") + if err != nil { + panic(err) + } } func (p *ParamTable) pulsarAddress() (string, error) { @@ -193,3 +199,111 @@ func (p *ParamTable) etcdRootPath() string { } return etcdRootPath } + +func (p *ParamTable) insertChannelNames() []string { + ch, err := p.Load("msgChannel.chanNamePrefix.insert") + if err != nil { + log.Fatal(err) + } + channelRange, err := p.Load("msgChannel.channelRange.insert") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels +} + +func (p *ParamTable) searchChannelNames() []string { + ch, err := p.Load("msgChannel.chanNamePrefix.search") + if err != nil { + log.Fatal(err) + } + channelRange, err := p.Load("msgChannel.channelRange.search") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels +} + +func (p *ParamTable) searchResultChannelNames() []string { + ch, err := p.Load("msgChannel.chanNamePrefix.search") + if err != nil { + log.Fatal(err) + } + channelRange, err := p.Load("msgChannel.channelRange.search") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels +} diff --git a/internal/reader/query_node_test.go b/internal/reader/query_node_test.go index 92b82ef865..af7153cda5 100644 --- a/internal/reader/query_node_test.go +++ b/internal/reader/query_node_test.go @@ -6,7 +6,7 @@ import ( "time" ) -const ctxTimeInMillisecond = 2000 +const ctxTimeInMillisecond = 200 const closeWithDeadline = true // NOTE: start pulsar and etcd before test diff --git a/internal/reader/search_service.go b/internal/reader/search_service.go index 464d7952eb..b34fe3fd72 100644 --- a/internal/reader/search_service.go +++ b/internal/reader/search_service.go @@ -42,7 +42,7 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe log.Fatal(err) } - consumeChannels := []string{"search"} + consumeChannels := Params.searchChannelNames() consumeSubName := "subSearch" searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) searchStream.SetPulsarClient(msgStreamURL) @@ -50,7 +50,7 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize) var inputStream msgstream.MsgStream = searchStream - producerChannels := []string{"searchResult"} + producerChannels := Params.searchResultChannelNames() searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) searchResultStream.SetPulsarClient(msgStreamURL) searchResultStream.CreatePulsarProducers(producerChannels)