Read query node num from param table instead of hardcode

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
dragondriver 2020-11-25 16:17:33 +08:00 committed by yefu.chen
parent 4bcb460c98
commit a7b3efecd7
19 changed files with 616 additions and 63 deletions

View File

@ -30,10 +30,10 @@ msgChannel:
queryNodeSubNamePrefix: "queryNode"
writeNodeSubNamePrefix: "writeNode"
# default channel range [0, 0]
# default channel range [0, 1)
channelRange:
insert: [0, 15]
delete: [0, 15]
k2s: [0, 15]
search: [0, 0]
insert: [0, 1]
delete: [0, 1]
k2s: [0, 1]
search: [0, 1]
searchResult: [0, 1]

View File

@ -25,4 +25,6 @@ proxy:
pulsarBufSize: 1024 # pulsar chan buffer size
timeTick:
bufSize: 512
bufSize: 512
maxNameLength: 255

View File

@ -11,9 +11,9 @@
nodeID: # will be deprecated after v0.2
proxyIDList: [1, 2]
queryNodeIDList: [3, 4]
writeNodeIDList: [5, 6]
proxyIDList: [1]
queryNodeIDList: [2]
writeNodeIDList: [3]
etcd:
address: localhost

View File

@ -65,7 +65,8 @@ StructuredIndexFlat<T>::NotIn(const size_t n, const T* values) {
if (!is_built_) {
build();
}
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size(), true);
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
bitset->set();
for (size_t i = 0; i < n; ++i) {
for (const auto& index : data_) {
if (index->a_ == *(values + i)) {

View File

@ -120,7 +120,8 @@ StructuredIndexSort<T>::NotIn(const size_t n, const T* values) {
if (!is_built_) {
build();
}
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size(), true);
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
bitset->set();
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));

View File

@ -130,13 +130,7 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
}
case OpType::NotEqual: {
auto index_func = [val](Index* index) {
// Note: index->NotIn() is buggy, investigating
// this is a workaround
auto res = index->In(1, &val);
*res = ~std::move(*res);
return res;
};
auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); });
}

View File

@ -360,18 +360,33 @@ func (p *ParamTable) initInsertChannelNames() {
if err != nil {
log.Fatal(err)
}
id, err := p.Load("nodeID.queryNodeIDList")
channelRange, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
log.Panicf("load query node id list error, %s", err.Error())
panic(err)
}
ids := strings.Split(id, ",")
channels := make([]string, 0, len(ids))
for _, i := range ids {
_, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load query node id list error, %s", err.Error())
}
channels = append(channels, ch+"-"+i)
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
p.InsertChannelNames = channels
}

View File

@ -31,7 +31,7 @@ func TestParamTable_EtcdRootPath(t *testing.T) {
func TestParamTable_TopicNum(t *testing.T) {
Params.Init()
num := Params.TopicNum
assert.Equal(t, num, 15)
assert.Equal(t, num, 1)
}
func TestParamTable_SegmentSize(t *testing.T) {
@ -73,7 +73,7 @@ func TestParamTable_SegIDAssignExpiration(t *testing.T) {
func TestParamTable_QueryNodeNum(t *testing.T) {
Params.Init()
num := Params.QueryNodeNum
assert.Equal(t, num, 2)
assert.Equal(t, num, 1)
}
func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
@ -85,17 +85,15 @@ func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
func TestParamTable_ProxyIDList(t *testing.T) {
Params.Init()
ids := Params.ProxyIDList
assert.Equal(t, len(ids), 2)
assert.Equal(t, len(ids), 1)
assert.Equal(t, ids[0], int64(1))
assert.Equal(t, ids[1], int64(2))
}
func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) {
Params.Init()
names := Params.ProxyTimeTickChannelNames
assert.Equal(t, len(names), 2)
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "proxyTimeTick-1")
assert.Equal(t, names[1], "proxyTimeTick-2")
}
func TestParamTable_MsgChannelSubName(t *testing.T) {
@ -113,31 +111,27 @@ func TestParamTable_SoftTimeTickBarrierInterval(t *testing.T) {
func TestParamTable_WriteNodeIDList(t *testing.T) {
Params.Init()
ids := Params.WriteNodeIDList
assert.Equal(t, len(ids), 2)
assert.Equal(t, ids[0], int64(5))
assert.Equal(t, ids[1], int64(6))
assert.Equal(t, len(ids), 1)
assert.Equal(t, ids[0], int64(3))
}
func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) {
Params.Init()
names := Params.WriteNodeTimeTickChannelNames
assert.Equal(t, len(names), 2)
assert.Equal(t, names[0], "writeNodeTimeTick-5")
assert.Equal(t, names[1], "writeNodeTimeTick-6")
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "writeNodeTimeTick-3")
}
func TestParamTable_InsertChannelNames(t *testing.T) {
Params.Init()
names := Params.InsertChannelNames
assert.Equal(t, len(names), 2)
assert.Equal(t, names[0], "insert-3")
assert.Equal(t, names[1], "insert-4")
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "insert-0")
}
func TestParamTable_K2SChannelNames(t *testing.T) {
Params.Init()
names := Params.K2SChannelNames
assert.Equal(t, len(names), 2)
assert.Equal(t, names[0], "k2s-5")
assert.Equal(t, names[1], "k2s-6")
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "k2s-3")
}

View File

@ -82,9 +82,8 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc
Schema: &commonpb.Blob{},
},
masterClient: p.masterClient,
schema: req,
}
schemaBytes, _ := proto.Marshal(req)
cct.CreateCollectionRequest.Schema.Value = schemaBytes
var cancel func()
cct.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval)
defer cancel()
@ -125,6 +124,7 @@ func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.Qu
},
queryMsgStream: p.queryMsgStream,
resultBuf: make(chan []*internalpb.SearchResult),
query: req,
}
var cancel func()
qt.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval)

View File

@ -96,6 +96,27 @@ func (pt *ParamTable) ProxyIDList() []UniqueID {
return ret
}
func (pt *ParamTable) queryNodeNum() int {
return len(pt.queryNodeIDList())
}
func (pt *ParamTable) queryNodeIDList() []UniqueID {
queryNodeIDStr, err := pt.Load("nodeID.queryNodeIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
queryNodeIDs := strings.Split(queryNodeIDStr, ",")
for _, i := range queryNodeIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
func (pt *ParamTable) ProxyID() UniqueID {
proxyID, err := pt.Load("_proxyID")
if err != nil {
@ -322,3 +343,123 @@ func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 {
func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 {
return pt.parseInt64("proxy.msgStream.timeTick.bufSize")
}
func (pt *ParamTable) insertChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (pt *ParamTable) searchChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.search")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.search")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (pt *ParamTable) searchResultChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.searchResult")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.searchResult")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (pt *ParamTable) MaxNameLength() int64 {
str, err := pt.Load("proxy.maxNameLength")
if err != nil {
panic(err)
}
maxNameLength, err := strconv.ParseInt(str, 10, 64)
if err != nil {
panic(err)
}
return maxNameLength
}

View File

@ -55,12 +55,11 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
proxyLoopCancel: cancel,
}
// TODO: use config instead
pulsarAddress := Params.PulsarAddress()
p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize())
p.queryMsgStream.SetPulsarClient(pulsarAddress)
p.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames())
p.queryMsgStream.CreatePulsarProducers(Params.searchChannelNames())
masterAddr := Params.MasterAddress()
idAllocator, err := allocator.NewIDAllocator(p.proxyLoopCtx, masterAddr)
@ -84,7 +83,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
p.manipulationMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamInsertBufSize())
p.manipulationMsgStream.SetPulsarClient(pulsarAddress)
p.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames())
p.manipulationMsgStream.CreatePulsarProducers(Params.insertChannelNames())
repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return insertRepackFunc(tsMsgs, hashKeys, p.segAssigner, false)
}

View File

@ -5,11 +5,13 @@ import (
"errors"
"log"
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
@ -32,7 +34,6 @@ type BaseInsertTask = msgstream.InsertMsg
type InsertTask struct {
BaseInsertTask
Condition
ts Timestamp
result *servicepb.IntegerRangeResponse
manipulationMsgStream *msgstream.PulsarMsgStream
ctx context.Context
@ -44,15 +45,21 @@ func (it *InsertTask) SetID(uid UniqueID) {
}
func (it *InsertTask) SetTs(ts Timestamp) {
it.ts = ts
rowNum := len(it.RowData)
it.Timestamps = make([]uint64, rowNum)
for index := range it.Timestamps {
it.Timestamps[index] = ts
}
it.BeginTimestamp = ts
it.EndTimestamp = ts
}
func (it *InsertTask) BeginTs() Timestamp {
return it.ts
return it.BeginTimestamp
}
func (it *InsertTask) EndTs() Timestamp {
return it.ts
return it.EndTimestamp
}
func (it *InsertTask) ID() UniqueID {
@ -64,6 +71,15 @@ func (it *InsertTask) Type() internalpb.MsgType {
}
func (it *InsertTask) PreExecute() error {
collectionName := it.BaseInsertTask.CollectionName
if err := ValidateCollectionName(collectionName); err != nil {
return err
}
partitionTag := it.BaseInsertTask.PartitionTag
if err := ValidatePartitionTag(partitionTag, true); err != nil {
return err
}
return nil
}
@ -120,6 +136,7 @@ type CreateCollectionTask struct {
masterClient masterpb.MasterClient
result *commonpb.Status
ctx context.Context
schema *schemapb.CollectionSchema
}
func (cct *CreateCollectionTask) ID() UniqueID {
@ -147,10 +164,24 @@ func (cct *CreateCollectionTask) SetTs(ts Timestamp) {
}
func (cct *CreateCollectionTask) PreExecute() error {
// validate collection name
if err := ValidateCollectionName(cct.schema.Name); err != nil {
return err
}
// validate field name
for _, field := range cct.schema.Fields {
if err := ValidateFieldName(field.Name); err != nil {
return err
}
}
return nil
}
func (cct *CreateCollectionTask) Execute() error {
schemaBytes, _ := proto.Marshal(cct.schema)
cct.CreateCollectionRequest.Schema.Value = schemaBytes
resp, err := cct.masterClient.CreateCollection(cct.ctx, &cct.CreateCollectionRequest)
if err != nil {
log.Printf("create collection failed, error= %v", err)
@ -201,6 +232,9 @@ func (dct *DropCollectionTask) SetTs(ts Timestamp) {
}
func (dct *DropCollectionTask) PreExecute() error {
if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil {
return err
}
return nil
}
@ -229,6 +263,7 @@ type QueryTask struct {
resultBuf chan []*internalpb.SearchResult
result *servicepb.QueryResult
ctx context.Context
query *servicepb.Query
}
func (qt *QueryTask) ID() UniqueID {
@ -256,6 +291,15 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
}
func (qt *QueryTask) PreExecute() error {
if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
return err
}
for _, tag := range qt.query.PartitionTags {
if err := ValidatePartitionTag(tag, false); err != nil {
return err
}
}
return nil
}
@ -367,6 +411,9 @@ func (hct *HasCollectionTask) SetTs(ts Timestamp) {
}
func (hct *HasCollectionTask) PreExecute() error {
if err := ValidateCollectionName(hct.CollectionName.CollectionName); err != nil {
return err
}
return nil
}
@ -424,6 +471,9 @@ func (dct *DescribeCollectionTask) SetTs(ts Timestamp) {
}
func (dct *DescribeCollectionTask) PreExecute() error {
if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil {
return err
}
return nil
}
@ -532,6 +582,16 @@ func (cpt *CreatePartitionTask) SetTs(ts Timestamp) {
}
func (cpt *CreatePartitionTask) PreExecute() error {
collName, partitionTag := cpt.PartitionName.CollectionName, cpt.PartitionName.Tag
if err := ValidateCollectionName(collName); err != nil {
return err
}
if err := ValidatePartitionTag(partitionTag, true); err != nil {
return err
}
return nil
}
@ -577,6 +637,16 @@ func (dpt *DropPartitionTask) SetTs(ts Timestamp) {
}
func (dpt *DropPartitionTask) PreExecute() error {
collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag
if err := ValidateCollectionName(collName); err != nil {
return err
}
if err := ValidatePartitionTag(partitionTag, true); err != nil {
return err
}
return nil
}
@ -622,6 +692,15 @@ func (hpt *HasPartitionTask) SetTs(ts Timestamp) {
}
func (hpt *HasPartitionTask) PreExecute() error {
collName, partitionTag := hpt.PartitionName.CollectionName, hpt.PartitionName.Tag
if err := ValidateCollectionName(collName); err != nil {
return err
}
if err := ValidatePartitionTag(partitionTag, true); err != nil {
return err
}
return nil
}
@ -667,6 +746,15 @@ func (dpt *DescribePartitionTask) SetTs(ts Timestamp) {
}
func (dpt *DescribePartitionTask) PreExecute() error {
collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag
if err := ValidateCollectionName(collName); err != nil {
return err
}
if err := ValidatePartitionTag(partitionTag, true); err != nil {
return err
}
return nil
}
@ -712,6 +800,9 @@ func (spt *ShowPartitionsTask) SetTs(ts Timestamp) {
}
func (spt *ShowPartitionsTask) PreExecute() error {
if err := ValidateCollectionName(spt.CollectionName.CollectionName); err != nil {
return err
}
return nil
}

View File

@ -369,14 +369,14 @@ func (sched *TaskScheduler) queryLoop() {
func (sched *TaskScheduler) queryResultLoop() {
defer sched.wg.Done()
// TODO: use config instead
unmarshal := msgstream.NewUnmarshalDispatcher()
queryResultMsgStream := msgstream.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize())
queryResultMsgStream.SetPulsarClient(Params.PulsarAddress())
queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(),
queryResultMsgStream.CreatePulsarConsumers(Params.searchResultChannelNames(),
Params.ProxySubName(),
unmarshal,
Params.MsgStreamSearchResultPulsarBufSize())
queryNodeNum := Params.queryNodeNum()
queryResultMsgStream.Start()
defer queryResultMsgStream.Close()
@ -401,8 +401,7 @@ func (sched *TaskScheduler) queryResultLoop() {
queryResultBuf[reqID] = make([]*internalpb.SearchResult, 0)
}
queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResult)
if len(queryResultBuf[reqID]) == 4 {
// TODO: use the number of query node instead
if len(queryResultBuf[reqID]) == queryNodeNum {
t := sched.getTaskByReqID(reqID)
if t != nil {
qt, ok := t.(*QueryTask)

View File

@ -0,0 +1,118 @@
package proxy
import (
"strconv"
"strings"
"github.com/zilliztech/milvus-distributed/internal/errors"
)
func isAlpha(c uint8) bool {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') {
return false
}
return true
}
func isNumber(c uint8) bool {
if c < '0' || c > '9' {
return false
}
return true
}
func ValidateCollectionName(collName string) error {
collName = strings.TrimSpace(collName)
if collName == "" {
return errors.New("Collection name should not be empty")
}
invalidMsg := "Invalid collection name: " + collName + ". "
if int64(len(collName)) > Params.MaxNameLength() {
msg := invalidMsg + "The length of a collection name must be less than " +
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
return errors.New(msg)
}
firstChar := collName[0]
if firstChar != '_' && !isAlpha(firstChar) {
msg := invalidMsg + "The first character of a collection name must be an underscore or letter."
return errors.New(msg)
}
for i := 1; i < len(collName); i++ {
c := collName[i]
if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) {
msg := invalidMsg + "Collection name can only contain numbers, letters, dollars and underscores."
return errors.New(msg)
}
}
return nil
}
func ValidatePartitionTag(partitionTag string, strictCheck bool) error {
partitionTag = strings.TrimSpace(partitionTag)
invalidMsg := "Invalid partition tag: " + partitionTag + ". "
if partitionTag == "" {
msg := invalidMsg + "Partition tag should not be empty."
return errors.New(msg)
}
if int64(len(partitionTag)) > Params.MaxNameLength() {
msg := invalidMsg + "The length of a partition tag must be less than " +
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
return errors.New(msg)
}
if strictCheck {
firstChar := partitionTag[0]
if firstChar != '_' && !isAlpha(firstChar) {
msg := invalidMsg + "The first character of a partition tag must be an underscore or letter."
return errors.New(msg)
}
tagSize := len(partitionTag)
for i := 1; i < tagSize; i++ {
c := partitionTag[i]
if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) {
msg := invalidMsg + "Partition tag can only contain numbers, letters, dollars and underscores."
return errors.New(msg)
}
}
}
return nil
}
func ValidateFieldName(fieldName string) error {
fieldName = strings.TrimSpace(fieldName)
if fieldName == "" {
return errors.New("Field name should not be empty")
}
invalidMsg := "Invalid field name: " + fieldName + ". "
if int64(len(fieldName)) > Params.MaxNameLength() {
msg := invalidMsg + "The length of a field name must be less than " +
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
return errors.New(msg)
}
firstChar := fieldName[0]
if firstChar != '_' && !isAlpha(firstChar) {
msg := invalidMsg + "The first character of a field name must be an underscore or letter."
return errors.New(msg)
}
fieldNameSize := len(fieldName)
for i := 1; i < fieldNameSize; i++ {
c := fieldName[i]
if c != '_' && !isAlpha(c) && !isNumber(c) {
msg := invalidMsg + "Field name cannot only contain numbers, letters, and underscores."
return errors.New(msg)
}
}
return nil
}

View File

@ -0,0 +1,84 @@
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateCollectionName(t *testing.T) {
Params.Init()
assert.Nil(t, ValidateCollectionName("abc"))
assert.Nil(t, ValidateCollectionName("_123abc"))
assert.Nil(t, ValidateCollectionName("abc123_$"))
longName := make([]byte, 256)
for i := 0; i < len(longName); i++ {
longName[i] = 'a'
}
invalidNames := []string{
"123abc",
"$abc",
"_12 ac",
" ",
"",
string(longName),
"中文",
}
for _, name := range invalidNames {
assert.NotNil(t, ValidateCollectionName(name))
}
}
func TestValidatePartitionTag(t *testing.T) {
Params.Init()
assert.Nil(t, ValidatePartitionTag("abc", true))
assert.Nil(t, ValidatePartitionTag("_123abc", true))
assert.Nil(t, ValidatePartitionTag("abc123_$", true))
longName := make([]byte, 256)
for i := 0; i < len(longName); i++ {
longName[i] = 'a'
}
invalidNames := []string{
"123abc",
"$abc",
"_12 ac",
" ",
"",
string(longName),
"中文",
}
for _, name := range invalidNames {
assert.NotNil(t, ValidatePartitionTag(name, true))
}
assert.Nil(t, ValidatePartitionTag("ab cd", false))
assert.Nil(t, ValidatePartitionTag("ab*", false))
}
func TestValidateFieldName(t *testing.T) {
Params.Init()
assert.Nil(t, ValidateFieldName("abc"))
assert.Nil(t, ValidateFieldName("_123abc"))
longName := make([]byte, 256)
for i := 0; i < len(longName); i++ {
longName[i] = 'a'
}
invalidNames := []string{
"123abc",
"$abc",
"_12 ac",
" ",
"",
string(longName),
"中文",
}
for _, name := range invalidNames {
assert.NotNil(t, ValidateFieldName(name))
}
}

View File

@ -17,7 +17,7 @@ func newDmInputNode(ctx context.Context) *flowgraph.InputNode {
log.Fatal(err)
}
consumeChannels := []string{"insert"}
consumeChannels := Params.insertChannelNames()
consumeSubName := "insertSub"
insertStream := msgstream.NewPulsarTtMsgStream(ctx, receiveBufSize)

View File

@ -1,7 +1,9 @@
package reader
import (
"log"
"strconv"
"strings"
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
)
@ -18,6 +20,10 @@ func (p *ParamTable) Init() {
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/channel.yaml")
if err != nil {
panic(err)
}
}
func (p *ParamTable) pulsarAddress() (string, error) {
@ -193,3 +199,111 @@ func (p *ParamTable) etcdRootPath() string {
}
return etcdRootPath
}
func (p *ParamTable) insertChannelNames() []string {
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
log.Fatal(err)
}
channelRange, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (p *ParamTable) searchChannelNames() []string {
ch, err := p.Load("msgChannel.chanNamePrefix.search")
if err != nil {
log.Fatal(err)
}
channelRange, err := p.Load("msgChannel.channelRange.search")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (p *ParamTable) searchResultChannelNames() []string {
ch, err := p.Load("msgChannel.chanNamePrefix.searchResult")
if err != nil {
log.Fatal(err)
}
channelRange, err := p.Load("msgChannel.channelRange.searchResult")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}

View File

@ -6,7 +6,7 @@ import (
"time"
)
const ctxTimeInMillisecond = 2000
const ctxTimeInMillisecond = 200
const closeWithDeadline = true
// NOTE: start pulsar and etcd before test

View File

@ -42,7 +42,7 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
log.Fatal(err)
}
consumeChannels := []string{"search"}
consumeChannels := Params.searchChannelNames()
consumeSubName := "subSearch"
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarClient(msgStreamURL)
@ -50,7 +50,7 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize)
var inputStream msgstream.MsgStream = searchStream
producerChannels := []string{"searchResult"}
producerChannels := Params.searchResultChannelNames()
searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchResultStream.SetPulsarClient(msgStreamURL)
searchResultStream.CreatePulsarProducers(producerChannels)