mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
Add collectionName/partitionTag/fieldName validation.
Signed-off-by: sunby <bingyi.sun@zilliz.com>
This commit is contained in:
parent
a38f539b9b
commit
24b29bec30
@ -30,10 +30,10 @@ msgChannel:
|
||||
queryNodeSubNamePrefix: "queryNode"
|
||||
writeNodeSubNamePrefix: "writeNode"
|
||||
|
||||
# default channel range [0, 0]
|
||||
# default channel range [0, 1)
|
||||
channelRange:
|
||||
insert: [0, 15]
|
||||
delete: [0, 15]
|
||||
k2s: [0, 15]
|
||||
search: [0, 0]
|
||||
insert: [0, 1]
|
||||
delete: [0, 1]
|
||||
k2s: [0, 1]
|
||||
search: [0, 1]
|
||||
searchResult: [0, 1]
|
@ -25,4 +25,6 @@ proxy:
|
||||
pulsarBufSize: 1024 # pulsar chan buffer size
|
||||
|
||||
timeTick:
|
||||
bufSize: 512
|
||||
bufSize: 512
|
||||
|
||||
maxNameLength: 255
|
||||
|
@ -11,9 +11,9 @@
|
||||
|
||||
|
||||
nodeID: # will be deprecated after v0.2
|
||||
proxyIDList: [1, 2]
|
||||
queryNodeIDList: [3, 4]
|
||||
writeNodeIDList: [5, 6]
|
||||
proxyIDList: [1]
|
||||
queryNodeIDList: [2]
|
||||
writeNodeIDList: [3]
|
||||
|
||||
etcd:
|
||||
address: localhost
|
||||
|
@ -360,18 +360,33 @@ func (p *ParamTable) initInsertChannelNames() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
id, err := p.Load("nodeID.queryNodeIDList")
|
||||
channelRange, err := p.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
log.Panicf("load query node id list error, %s", err.Error())
|
||||
panic(err)
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
channels := make([]string, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
_, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load query node id list error, %s", err.Error())
|
||||
}
|
||||
channels = append(channels, ch+"-"+i)
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
p.InsertChannelNames = channels
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func TestParamTable_EtcdRootPath(t *testing.T) {
|
||||
func TestParamTable_TopicNum(t *testing.T) {
|
||||
Params.Init()
|
||||
num := Params.TopicNum
|
||||
assert.Equal(t, num, 15)
|
||||
assert.Equal(t, num, 1)
|
||||
}
|
||||
|
||||
func TestParamTable_SegmentSize(t *testing.T) {
|
||||
@ -73,7 +73,7 @@ func TestParamTable_SegIDAssignExpiration(t *testing.T) {
|
||||
func TestParamTable_QueryNodeNum(t *testing.T) {
|
||||
Params.Init()
|
||||
num := Params.QueryNodeNum
|
||||
assert.Equal(t, num, 2)
|
||||
assert.Equal(t, num, 1)
|
||||
}
|
||||
|
||||
func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
|
||||
@ -85,17 +85,15 @@ func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
|
||||
func TestParamTable_ProxyIDList(t *testing.T) {
|
||||
Params.Init()
|
||||
ids := Params.ProxyIDList
|
||||
assert.Equal(t, len(ids), 2)
|
||||
assert.Equal(t, len(ids), 1)
|
||||
assert.Equal(t, ids[0], int64(1))
|
||||
assert.Equal(t, ids[1], int64(2))
|
||||
}
|
||||
|
||||
func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.ProxyTimeTickChannelNames
|
||||
assert.Equal(t, len(names), 2)
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "proxyTimeTick-1")
|
||||
assert.Equal(t, names[1], "proxyTimeTick-2")
|
||||
}
|
||||
|
||||
func TestParamTable_MsgChannelSubName(t *testing.T) {
|
||||
@ -113,31 +111,27 @@ func TestParamTable_SoftTimeTickBarrierInterval(t *testing.T) {
|
||||
func TestParamTable_WriteNodeIDList(t *testing.T) {
|
||||
Params.Init()
|
||||
ids := Params.WriteNodeIDList
|
||||
assert.Equal(t, len(ids), 2)
|
||||
assert.Equal(t, ids[0], int64(5))
|
||||
assert.Equal(t, ids[1], int64(6))
|
||||
assert.Equal(t, len(ids), 1)
|
||||
assert.Equal(t, ids[0], int64(3))
|
||||
}
|
||||
|
||||
func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.WriteNodeTimeTickChannelNames
|
||||
assert.Equal(t, len(names), 2)
|
||||
assert.Equal(t, names[0], "writeNodeTimeTick-5")
|
||||
assert.Equal(t, names[1], "writeNodeTimeTick-6")
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "writeNodeTimeTick-3")
|
||||
}
|
||||
|
||||
func TestParamTable_InsertChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.InsertChannelNames
|
||||
assert.Equal(t, len(names), 2)
|
||||
assert.Equal(t, names[0], "insert-3")
|
||||
assert.Equal(t, names[1], "insert-4")
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "insert-0")
|
||||
}
|
||||
|
||||
func TestParamTable_K2SChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.K2SChannelNames
|
||||
assert.Equal(t, len(names), 2)
|
||||
assert.Equal(t, names[0], "k2s-5")
|
||||
assert.Equal(t, names[1], "k2s-6")
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "k2s-3")
|
||||
}
|
||||
|
@ -82,9 +82,8 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc
|
||||
Schema: &commonpb.Blob{},
|
||||
},
|
||||
masterClient: p.masterClient,
|
||||
schema: req,
|
||||
}
|
||||
schemaBytes, _ := proto.Marshal(req)
|
||||
cct.CreateCollectionRequest.Schema.Value = schemaBytes
|
||||
var cancel func()
|
||||
cct.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval)
|
||||
defer cancel()
|
||||
@ -125,6 +124,7 @@ func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.Qu
|
||||
},
|
||||
queryMsgStream: p.queryMsgStream,
|
||||
resultBuf: make(chan []*internalpb.SearchResult),
|
||||
query: req,
|
||||
}
|
||||
var cancel func()
|
||||
qt.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval)
|
||||
|
@ -322,3 +322,123 @@ func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 {
|
||||
func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 {
|
||||
return pt.parseInt64("proxy.msgStream.timeTick.bufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) insertChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (pt *ParamTable) searchChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.search")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (pt *ParamTable) searchResultChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.search")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MaxNameLength() int64 {
|
||||
str, err := pt.Load("proxy.maxNameLength")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
maxNameLength, err := strconv.ParseInt(str, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return maxNameLength
|
||||
}
|
||||
|
@ -5,11 +5,13 @@ import (
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/zilliztech/milvus-distributed/internal/allocator"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
@ -64,6 +66,15 @@ func (it *InsertTask) Type() internalpb.MsgType {
|
||||
}
|
||||
|
||||
func (it *InsertTask) PreExecute() error {
|
||||
collectionName := it.BaseInsertTask.CollectionName
|
||||
if err := ValidateCollectionName(collectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
partitionTag := it.BaseInsertTask.PartitionTag
|
||||
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -120,6 +131,7 @@ type CreateCollectionTask struct {
|
||||
masterClient masterpb.MasterClient
|
||||
result *commonpb.Status
|
||||
ctx context.Context
|
||||
schema *schemapb.CollectionSchema
|
||||
}
|
||||
|
||||
func (cct *CreateCollectionTask) ID() UniqueID {
|
||||
@ -147,10 +159,24 @@ func (cct *CreateCollectionTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (cct *CreateCollectionTask) PreExecute() error {
|
||||
// validate collection name
|
||||
if err := ValidateCollectionName(cct.schema.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// validate field name
|
||||
for _, field := range cct.schema.Fields {
|
||||
if err := ValidateFieldName(field.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cct *CreateCollectionTask) Execute() error {
|
||||
schemaBytes, _ := proto.Marshal(cct.schema)
|
||||
cct.CreateCollectionRequest.Schema.Value = schemaBytes
|
||||
resp, err := cct.masterClient.CreateCollection(cct.ctx, &cct.CreateCollectionRequest)
|
||||
if err != nil {
|
||||
log.Printf("create collection failed, error= %v", err)
|
||||
@ -201,6 +227,9 @@ func (dct *DropCollectionTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (dct *DropCollectionTask) PreExecute() error {
|
||||
if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -229,6 +258,7 @@ type QueryTask struct {
|
||||
resultBuf chan []*internalpb.SearchResult
|
||||
result *servicepb.QueryResult
|
||||
ctx context.Context
|
||||
query *servicepb.Query
|
||||
}
|
||||
|
||||
func (qt *QueryTask) ID() UniqueID {
|
||||
@ -256,6 +286,15 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (qt *QueryTask) PreExecute() error {
|
||||
if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, tag := range qt.query.PartitionTags {
|
||||
if err := ValidatePartitionTag(tag, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -367,6 +406,9 @@ func (hct *HasCollectionTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (hct *HasCollectionTask) PreExecute() error {
|
||||
if err := ValidateCollectionName(hct.CollectionName.CollectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -424,6 +466,9 @@ func (dct *DescribeCollectionTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (dct *DescribeCollectionTask) PreExecute() error {
|
||||
if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -532,6 +577,16 @@ func (cpt *CreatePartitionTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (cpt *CreatePartitionTask) PreExecute() error {
|
||||
collName, partitionTag := cpt.PartitionName.CollectionName, cpt.PartitionName.Tag
|
||||
|
||||
if err := ValidateCollectionName(collName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -577,6 +632,16 @@ func (dpt *DropPartitionTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (dpt *DropPartitionTask) PreExecute() error {
|
||||
collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag
|
||||
|
||||
if err := ValidateCollectionName(collName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -622,6 +687,15 @@ func (hpt *HasPartitionTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (hpt *HasPartitionTask) PreExecute() error {
|
||||
collName, partitionTag := hpt.PartitionName.CollectionName, hpt.PartitionName.Tag
|
||||
|
||||
if err := ValidateCollectionName(collName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -667,6 +741,15 @@ func (dpt *DescribePartitionTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (dpt *DescribePartitionTask) PreExecute() error {
|
||||
collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag
|
||||
|
||||
if err := ValidateCollectionName(collName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -712,6 +795,9 @@ func (spt *ShowPartitionsTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (spt *ShowPartitionsTask) PreExecute() error {
|
||||
if err := ValidateCollectionName(spt.CollectionName.CollectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
118
internal/proxy/validate_util.go
Normal file
118
internal/proxy/validate_util.go
Normal file
@ -0,0 +1,118 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
)
|
||||
|
||||
func isAlpha(c uint8) bool {
|
||||
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isNumber(c uint8) bool {
|
||||
if c < '0' || c > '9' {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ValidateCollectionName(collName string) error {
|
||||
collName = strings.TrimSpace(collName)
|
||||
|
||||
if collName == "" {
|
||||
return errors.New("Collection name should not be empty")
|
||||
}
|
||||
|
||||
invalidMsg := "Invalid collection name: " + collName + ". "
|
||||
if int64(len(collName)) > Params.MaxNameLength() {
|
||||
msg := invalidMsg + "The length of a collection name must be less than " +
|
||||
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
firstChar := collName[0]
|
||||
if firstChar != '_' && !isAlpha(firstChar) {
|
||||
msg := invalidMsg + "The first character of a collection name must be an underscore or letter."
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
for i := 1; i < len(collName); i++ {
|
||||
c := collName[i]
|
||||
if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) {
|
||||
msg := invalidMsg + "Collection name can only contain numbers, letters, dollars and underscores."
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidatePartitionTag(partitionTag string, strictCheck bool) error {
|
||||
partitionTag = strings.TrimSpace(partitionTag)
|
||||
|
||||
invalidMsg := "Invalid partition tag: " + partitionTag + ". "
|
||||
if partitionTag == "" {
|
||||
msg := invalidMsg + "Partition tag should not be empty."
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
if int64(len(partitionTag)) > Params.MaxNameLength() {
|
||||
msg := invalidMsg + "The length of a partition tag must be less than " +
|
||||
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
if strictCheck {
|
||||
firstChar := partitionTag[0]
|
||||
if firstChar != '_' && !isAlpha(firstChar) {
|
||||
msg := invalidMsg + "The first character of a partition tag must be an underscore or letter."
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
tagSize := len(partitionTag)
|
||||
for i := 1; i < tagSize; i++ {
|
||||
c := partitionTag[i]
|
||||
if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) {
|
||||
msg := invalidMsg + "Partition tag can only contain numbers, letters, dollars and underscores."
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateFieldName(fieldName string) error {
|
||||
fieldName = strings.TrimSpace(fieldName)
|
||||
|
||||
if fieldName == "" {
|
||||
return errors.New("Field name should not be empty")
|
||||
}
|
||||
|
||||
invalidMsg := "Invalid field name: " + fieldName + ". "
|
||||
if int64(len(fieldName)) > Params.MaxNameLength() {
|
||||
msg := invalidMsg + "The length of a field name must be less than " +
|
||||
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
firstChar := fieldName[0]
|
||||
if firstChar != '_' && !isAlpha(firstChar) {
|
||||
msg := invalidMsg + "The first character of a field name must be an underscore or letter."
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
fieldNameSize := len(fieldName)
|
||||
for i := 1; i < fieldNameSize; i++ {
|
||||
c := fieldName[i]
|
||||
if c != '_' && !isAlpha(c) && !isNumber(c) {
|
||||
msg := invalidMsg + "Field name cannot only contain numbers, letters, and underscores."
|
||||
return errors.New(msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
84
internal/proxy/validate_util_test.go
Normal file
84
internal/proxy/validate_util_test.go
Normal file
@ -0,0 +1,84 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidateCollectionName(t *testing.T) {
|
||||
Params.Init()
|
||||
assert.Nil(t, ValidateCollectionName("abc"))
|
||||
assert.Nil(t, ValidateCollectionName("_123abc"))
|
||||
assert.Nil(t, ValidateCollectionName("abc123_$"))
|
||||
|
||||
longName := make([]byte, 256)
|
||||
for i := 0; i < len(longName); i++ {
|
||||
longName[i] = 'a'
|
||||
}
|
||||
invalidNames := []string{
|
||||
"123abc",
|
||||
"$abc",
|
||||
"_12 ac",
|
||||
" ",
|
||||
"",
|
||||
string(longName),
|
||||
"中文",
|
||||
}
|
||||
|
||||
for _, name := range invalidNames {
|
||||
assert.NotNil(t, ValidateCollectionName(name))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePartitionTag(t *testing.T) {
|
||||
Params.Init()
|
||||
assert.Nil(t, ValidatePartitionTag("abc", true))
|
||||
assert.Nil(t, ValidatePartitionTag("_123abc", true))
|
||||
assert.Nil(t, ValidatePartitionTag("abc123_$", true))
|
||||
|
||||
longName := make([]byte, 256)
|
||||
for i := 0; i < len(longName); i++ {
|
||||
longName[i] = 'a'
|
||||
}
|
||||
invalidNames := []string{
|
||||
"123abc",
|
||||
"$abc",
|
||||
"_12 ac",
|
||||
" ",
|
||||
"",
|
||||
string(longName),
|
||||
"中文",
|
||||
}
|
||||
|
||||
for _, name := range invalidNames {
|
||||
assert.NotNil(t, ValidatePartitionTag(name, true))
|
||||
}
|
||||
|
||||
assert.Nil(t, ValidatePartitionTag("ab cd", false))
|
||||
assert.Nil(t, ValidatePartitionTag("ab*", false))
|
||||
}
|
||||
|
||||
func TestValidateFieldName(t *testing.T) {
|
||||
Params.Init()
|
||||
assert.Nil(t, ValidateFieldName("abc"))
|
||||
assert.Nil(t, ValidateFieldName("_123abc"))
|
||||
|
||||
longName := make([]byte, 256)
|
||||
for i := 0; i < len(longName); i++ {
|
||||
longName[i] = 'a'
|
||||
}
|
||||
invalidNames := []string{
|
||||
"123abc",
|
||||
"$abc",
|
||||
"_12 ac",
|
||||
" ",
|
||||
"",
|
||||
string(longName),
|
||||
"中文",
|
||||
}
|
||||
|
||||
for _, name := range invalidNames {
|
||||
assert.NotNil(t, ValidateFieldName(name))
|
||||
}
|
||||
}
|
@ -17,7 +17,7 @@ func newDmInputNode(ctx context.Context) *flowgraph.InputNode {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
consumeChannels := []string{"insert"}
|
||||
consumeChannels := Params.insertChannelNames()
|
||||
consumeSubName := "insertSub"
|
||||
|
||||
insertStream := msgstream.NewPulsarTtMsgStream(ctx, receiveBufSize)
|
||||
|
@ -1,7 +1,9 @@
|
||||
package reader
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
|
||||
)
|
||||
@ -18,6 +20,10 @@ func (p *ParamTable) Init() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = p.LoadYaml("advanced/channel.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ParamTable) pulsarAddress() (string, error) {
|
||||
@ -193,3 +199,111 @@ func (p *ParamTable) etcdRootPath() string {
|
||||
}
|
||||
return etcdRootPath
|
||||
}
|
||||
|
||||
func (p *ParamTable) insertChannelNames() []string {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := p.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchChannelNames() []string {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.search")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := p.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchResultChannelNames() []string {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.search")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := p.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const ctxTimeInMillisecond = 2000
|
||||
const ctxTimeInMillisecond = 200
|
||||
const closeWithDeadline = true
|
||||
|
||||
// NOTE: start pulsar and etcd before test
|
||||
|
@ -42,7 +42,7 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
consumeChannels := []string{"search"}
|
||||
consumeChannels := Params.searchChannelNames()
|
||||
consumeSubName := "subSearch"
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(msgStreamURL)
|
||||
@ -50,7 +50,7 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
|
||||
searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize)
|
||||
var inputStream msgstream.MsgStream = searchStream
|
||||
|
||||
producerChannels := []string{"searchResult"}
|
||||
producerChannels := Params.searchResultChannelNames()
|
||||
searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchResultStream.SetPulsarClient(msgStreamURL)
|
||||
searchResultStream.CreatePulsarProducers(producerChannels)
|
||||
|
Loading…
Reference in New Issue
Block a user