mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 10:59:32 +08:00
Add unittest cases for proxy (#7364)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
e6de86a433
commit
84bddd0a03
@ -138,8 +138,11 @@ func (s *Server) init() error {
|
||||
|
||||
proxy.Params.Init()
|
||||
log.Debug("init params done ...")
|
||||
|
||||
// NetworkPort & IP don't matter here, NetworkAddress matters
|
||||
proxy.Params.NetworkPort = Params.Port
|
||||
proxy.Params.IP = Params.IP
|
||||
|
||||
proxy.Params.NetworkAddress = Params.Address
|
||||
// for purpose of ID Allocator
|
||||
proxy.Params.RootCoordAddress = Params.RootCoordAddress
|
||||
|
@ -24,9 +24,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type vChan = string
|
||||
type pChan = string
|
||||
|
||||
type channelsMgr interface {
|
||||
getChannels(collectionID UniqueID) ([]pChan, error)
|
||||
getVChannels(collectionID UniqueID) ([]vChan, error)
|
||||
@ -179,7 +176,12 @@ func (mgr *singleTypeChannelsMgr) getAllVIDs(collectionID UniqueID) ([]int, erro
|
||||
mgr.collMtx.RLock()
|
||||
defer mgr.collMtx.RUnlock()
|
||||
|
||||
return mgr.collectionID2VIDs[collectionID], nil
|
||||
ids, exist := mgr.collectionID2VIDs[collectionID]
|
||||
if !exist {
|
||||
return nil, fmt.Errorf("collection %d not found", collectionID)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) getVChansByVID(vid int) ([]vChan, error) {
|
||||
@ -339,10 +341,15 @@ func (mgr *singleTypeChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan,
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) error {
|
||||
channels, err := mgr.getChannelsFunc(collectionID)
|
||||
log.Debug("singleTypeChannelsMgr", zap.Any("createMsgStream.getChannels", channels))
|
||||
if err != nil {
|
||||
log.Warn("failed to create message stream",
|
||||
zap.Int64("collection_id", collectionID),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("singleTypeChannelsMgr",
|
||||
zap.Int64("collection_id", collectionID),
|
||||
zap.Any("createMsgStream.getChannels", channels))
|
||||
|
||||
mgr.updateChannels(channels)
|
||||
|
||||
@ -480,13 +487,13 @@ func (mgr *channelsMgrImpl) removeAllDMLStream() error {
|
||||
return mgr.dmlChannelsMgr.removeAllStream()
|
||||
}
|
||||
|
||||
func newChannelsMgr(
|
||||
func newChannelsMgrImpl(
|
||||
getDmlChannelsFunc getChannelsFuncType,
|
||||
dmlRepackFunc repackFuncType,
|
||||
getDqlChannelsFunc getChannelsFuncType,
|
||||
dqlRepackFunc repackFuncType,
|
||||
msgStreamFactory msgstream.Factory,
|
||||
) channelsMgr {
|
||||
) *channelsMgrImpl {
|
||||
return &channelsMgrImpl{
|
||||
dmlChannelsMgr: newSingleTypeChannelsMgr(getDmlChannelsFunc, msgStreamFactory, dmlRepackFunc, dmlStreamType),
|
||||
dqlChannelsMgr: newSingleTypeChannelsMgr(getDqlChannelsFunc, msgStreamFactory, dqlRepackFunc, dqlStreamType),
|
||||
|
@ -11,7 +11,6 @@
|
||||
|
||||
package proxy
|
||||
|
||||
/*
|
||||
import (
|
||||
"testing"
|
||||
|
||||
@ -38,7 +37,7 @@ func TestChannelsMgrImpl_getChannels(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(getUniqueIntGeneratorIns().get())
|
||||
@ -56,7 +55,7 @@ func TestChannelsMgrImpl_getVChannels(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(getUniqueIntGeneratorIns().get())
|
||||
@ -74,7 +73,7 @@ func TestChannelsMgrImpl_createDMLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(getUniqueIntGeneratorIns().get())
|
||||
@ -96,7 +95,7 @@ func TestChannelsMgrImpl_getDMLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(getUniqueIntGeneratorIns().get())
|
||||
@ -114,7 +113,7 @@ func TestChannelsMgrImpl_removeDMLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(getUniqueIntGeneratorIns().get())
|
||||
@ -141,7 +140,7 @@ func TestChannelsMgrImpl_removeAllDMLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
num := 10
|
||||
@ -156,21 +155,12 @@ func TestChannelsMgrImpl_createDQLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(getUniqueIntGeneratorIns().get())
|
||||
_, err := mgr.getChannels(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
_, err = mgr.getVChannels(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = mgr.createDQLStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
_, err = mgr.getChannels(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
_, err = mgr.getVChannels(collID)
|
||||
err := mgr.createDQLStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
@ -178,7 +168,7 @@ func TestChannelsMgrImpl_getDQLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(getUniqueIntGeneratorIns().get())
|
||||
@ -196,7 +186,7 @@ func TestChannelsMgrImpl_removeDQLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(getUniqueIntGeneratorIns().get())
|
||||
@ -223,7 +213,7 @@ func TestChannelsMgrImpl_removeAllDQLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
num := 10
|
||||
@ -233,4 +223,3 @@ func TestChannelsMgrImpl_removeAllDQLMsgStream(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
@ -21,11 +21,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type pChanStatistics struct {
|
||||
minTs Timestamp
|
||||
maxTs Timestamp
|
||||
}
|
||||
|
||||
// ticker can update ts only when the minTs greater than the ts of ticker, we can use maxTs to update current later
|
||||
type getPChanStatisticsFuncType func() (map[pChan]*pChanStatistics, error)
|
||||
|
||||
|
@ -11,6 +11,181 @@
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseDummyRequestType(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// not in json format
|
||||
notInJSONFormatStr := "not in json format string"
|
||||
_, err = parseDummyRequestType(notInJSONFormatStr)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// only contain other field, in json format
|
||||
otherField := "other_field"
|
||||
otherFieldValue := "not important"
|
||||
m1 := make(map[string]string)
|
||||
m1[otherField] = otherFieldValue
|
||||
bs1, err := json.Marshal(m1)
|
||||
assert.Nil(t, err)
|
||||
log.Info("Test_parseDummyRequestType",
|
||||
zap.String("json", string(bs1)))
|
||||
ret1, err := parseDummyRequestType(string(bs1))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(ret1.RequestType))
|
||||
|
||||
// normal case
|
||||
key := "request_type"
|
||||
value := "value"
|
||||
m2 := make(map[string]string)
|
||||
m2[key] = value
|
||||
bs2, err := json.Marshal(m2)
|
||||
assert.Nil(t, err)
|
||||
log.Info("Test_parseDummyRequestType",
|
||||
zap.String("json", string(bs2)))
|
||||
ret2, err := parseDummyRequestType(string(bs2))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, value, ret2.RequestType)
|
||||
|
||||
// contain other field and request_type
|
||||
m3 := make(map[string]string)
|
||||
m3[key] = value
|
||||
m3[otherField] = otherFieldValue
|
||||
bs3, err := json.Marshal(m3)
|
||||
assert.Nil(t, err)
|
||||
log.Info("Test_parseDummyRequestType",
|
||||
zap.String("json", string(bs3)))
|
||||
ret3, err := parseDummyRequestType(string(bs3))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, value, ret3.RequestType)
|
||||
}
|
||||
|
||||
func Test_parseDummyQueryRequest(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// not in json format
|
||||
notInJSONFormatStr := "not in json format string"
|
||||
_, err = parseDummyQueryRequest(notInJSONFormatStr)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// only contain other field, in json format
|
||||
otherField := "other_field"
|
||||
otherFieldValue := "not important"
|
||||
m1 := make(map[string]interface{})
|
||||
m1[otherField] = otherFieldValue
|
||||
bs1, err := json.Marshal(m1)
|
||||
log.Info("Test_parseDummyQueryRequest",
|
||||
zap.String("json", string(bs1)))
|
||||
assert.Nil(t, err)
|
||||
ret1, err := parseDummyQueryRequest(string(bs1))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(ret1.RequestType))
|
||||
assert.Equal(t, 0, len(ret1.DbName))
|
||||
assert.Equal(t, 0, len(ret1.CollectionName))
|
||||
assert.Equal(t, 0, len(ret1.PartitionNames))
|
||||
assert.Equal(t, 0, len(ret1.Expr))
|
||||
assert.Equal(t, 0, len(ret1.OutputFields))
|
||||
|
||||
requestTypeKey := "request_type"
|
||||
requestTypeValue := "request_type"
|
||||
dbNameKey := "dbname"
|
||||
dbNameValue := "dbname"
|
||||
collectionNameKey := "collection_name"
|
||||
collectionNameValue := "collection_name"
|
||||
partitionNamesKey := "partition_names"
|
||||
partitionNamesValue := []string{"partition_names"}
|
||||
exprKey := "expr"
|
||||
exprValue := "expr"
|
||||
outputFieldsKey := "output_fields"
|
||||
outputFieldsValue := []string{"output_fields"}
|
||||
|
||||
// all fields
|
||||
m2 := make(map[string]interface{})
|
||||
m2[requestTypeKey] = requestTypeValue
|
||||
m2[dbNameKey] = dbNameValue
|
||||
m2[collectionNameKey] = collectionNameValue
|
||||
m2[partitionNamesKey] = partitionNamesValue
|
||||
m2[exprKey] = exprValue
|
||||
m2[outputFieldsKey] = outputFieldsValue
|
||||
bs2, err := json.Marshal(m2)
|
||||
log.Info("Test_parseDummyQueryRequest",
|
||||
zap.String("json", string(bs2)))
|
||||
assert.Nil(t, err)
|
||||
ret2, err := parseDummyQueryRequest(string(bs2))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, requestTypeValue, ret2.RequestType)
|
||||
assert.Equal(t, dbNameValue, ret2.DbName)
|
||||
assert.Equal(t, collectionNameValue, ret2.CollectionName)
|
||||
assert.Equal(t, partitionNamesValue, ret2.PartitionNames)
|
||||
assert.Equal(t, exprValue, ret2.Expr)
|
||||
assert.Equal(t, outputFieldsValue, ret2.OutputFields)
|
||||
|
||||
// all fields and other field
|
||||
m3 := make(map[string]interface{})
|
||||
m3[otherField] = otherFieldValue
|
||||
m3[requestTypeKey] = requestTypeValue
|
||||
m3[dbNameKey] = dbNameValue
|
||||
m3[collectionNameKey] = collectionNameValue
|
||||
m3[partitionNamesKey] = partitionNamesValue
|
||||
m3[exprKey] = exprValue
|
||||
m3[outputFieldsKey] = outputFieldsValue
|
||||
bs3, err := json.Marshal(m3)
|
||||
log.Info("Test_parseDummyQueryRequest",
|
||||
zap.String("json", string(bs3)))
|
||||
assert.Nil(t, err)
|
||||
ret3, err := parseDummyQueryRequest(string(bs3))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, requestTypeValue, ret3.RequestType)
|
||||
assert.Equal(t, dbNameValue, ret3.DbName)
|
||||
assert.Equal(t, collectionNameValue, ret3.CollectionName)
|
||||
assert.Equal(t, partitionNamesValue, ret3.PartitionNames)
|
||||
assert.Equal(t, exprValue, ret3.Expr)
|
||||
assert.Equal(t, outputFieldsValue, ret3.OutputFields)
|
||||
|
||||
// partial fields
|
||||
m4 := make(map[string]interface{})
|
||||
m4[requestTypeKey] = requestTypeValue
|
||||
m4[dbNameKey] = dbNameValue
|
||||
bs4, err := json.Marshal(m4)
|
||||
log.Info("Test_parseDummyQueryRequest",
|
||||
zap.String("json", string(bs4)))
|
||||
assert.Nil(t, err)
|
||||
ret4, err := parseDummyQueryRequest(string(bs4))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, requestTypeValue, ret4.RequestType)
|
||||
assert.Equal(t, dbNameValue, ret4.DbName)
|
||||
assert.Equal(t, collectionNameValue, ret2.CollectionName)
|
||||
assert.Equal(t, partitionNamesValue, ret2.PartitionNames)
|
||||
assert.Equal(t, exprValue, ret2.Expr)
|
||||
assert.Equal(t, outputFieldsValue, ret2.OutputFields)
|
||||
|
||||
// partial fields and other field
|
||||
m5 := make(map[string]interface{})
|
||||
m5[otherField] = otherFieldValue
|
||||
m5[requestTypeKey] = requestTypeValue
|
||||
m5[dbNameKey] = dbNameValue
|
||||
bs5, err := json.Marshal(m5)
|
||||
log.Info("Test_parseDummyQueryRequest",
|
||||
zap.String("json", string(bs5)))
|
||||
assert.Nil(t, err)
|
||||
ret5, err := parseDummyQueryRequest(string(bs5))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, requestTypeValue, ret5.RequestType)
|
||||
assert.Equal(t, dbNameValue, ret5.DbName)
|
||||
assert.Equal(t, collectionNameValue, ret2.CollectionName)
|
||||
assert.Equal(t, partitionNamesValue, ret2.PartitionNames)
|
||||
assert.Equal(t, exprValue, ret2.Expr)
|
||||
assert.Equal(t, outputFieldsValue, ret2.OutputFields)
|
||||
}
|
||||
|
||||
// func TestParseDummyQueryRequest(t *testing.T) {
|
||||
// invalidStr := `{"request_type":"query"`
|
||||
// _, err := parseDummyQueryRequest(invalidStr)
|
||||
|
@ -14,28 +14,134 @@ package proxy
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestMsgProxyIsUnhealthy(t *testing.T) {
|
||||
func Test_errInvalidNumRows(t *testing.T) {
|
||||
invalidNumRowsList := []uint32{
|
||||
0,
|
||||
16384,
|
||||
}
|
||||
|
||||
for _, invalidNumRows := range invalidNumRowsList {
|
||||
log.Info("Test_errInvalidNumRows",
|
||||
zap.Error(errInvalidNumRows(invalidNumRows)))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errNumRowsLessThanOrEqualToZero(t *testing.T) {
|
||||
invalidNumRowsList := []uint32{
|
||||
0,
|
||||
16384,
|
||||
}
|
||||
|
||||
for _, invalidNumRows := range invalidNumRowsList {
|
||||
log.Info("Test_errNumRowsLessThanOrEqualToZero",
|
||||
zap.Error(errNumRowsLessThanOrEqualToZero(invalidNumRows)))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errEmptyFieldData(t *testing.T) {
|
||||
log.Info("Test_errEmptyFieldData",
|
||||
zap.Error(errEmptyFieldData))
|
||||
}
|
||||
|
||||
func Test_errFieldsLessThanNeeded(t *testing.T) {
|
||||
cases := []struct {
|
||||
fieldsNum int
|
||||
neededNum int
|
||||
}{
|
||||
{0, 1},
|
||||
{1, 2},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
log.Info("Test_errFieldsLessThanNeeded",
|
||||
zap.Error(errFieldsLessThanNeeded(test.fieldsNum, test.neededNum)))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errUnsupportedDataType(t *testing.T) {
|
||||
unsupportedDTypes := []schemapb.DataType{
|
||||
schemapb.DataType_None,
|
||||
}
|
||||
|
||||
for _, dType := range unsupportedDTypes {
|
||||
log.Info("Test_errUnsupportedDataType",
|
||||
zap.Error(errUnsupportedDataType(dType)))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errUnsupportedDType(t *testing.T) {
|
||||
unsupportedDTypes := []string{
|
||||
"bytes",
|
||||
"None",
|
||||
}
|
||||
|
||||
for _, dType := range unsupportedDTypes {
|
||||
log.Info("Test_errUnsupportedDType",
|
||||
zap.Error(errUnsupportedDType(dType)))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errInvalidDim(t *testing.T) {
|
||||
invalidDimList := []int{
|
||||
0,
|
||||
-1,
|
||||
}
|
||||
|
||||
for _, invalidDim := range invalidDimList {
|
||||
log.Info("Test_errInvalidDim",
|
||||
zap.Error(errInvalidDim(invalidDim)))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errDimLessThanOrEqualToZero(t *testing.T) {
|
||||
invalidDimList := []int{
|
||||
0,
|
||||
-1,
|
||||
}
|
||||
|
||||
for _, invalidDim := range invalidDimList {
|
||||
log.Info("Test_errDimLessThanOrEqualToZero",
|
||||
zap.Error(errDimLessThanOrEqualToZero(invalidDim)))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errDimShouldDivide8(t *testing.T) {
|
||||
invalidDimList := []int{
|
||||
0,
|
||||
1,
|
||||
7,
|
||||
}
|
||||
|
||||
for _, invalidDim := range invalidDimList {
|
||||
log.Info("Test_errDimShouldDivide8",
|
||||
zap.Error(errDimShouldDivide8(invalidDim)))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_msgProxyIsUnhealthy(t *testing.T) {
|
||||
ids := []UniqueID{
|
||||
1,
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
log.Info("TestMsgProxyIsUnhealthy",
|
||||
log.Info("Test_msgProxyIsUnhealthy",
|
||||
zap.String("msg", msgProxyIsUnhealthy(id)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrProxyIsUnhealthy(t *testing.T) {
|
||||
func Test_errProxyIsUnhealthy(t *testing.T) {
|
||||
ids := []UniqueID{
|
||||
1,
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
log.Info("TestErrProxyIsUnhealthy",
|
||||
log.Info("Test_errProxyIsUnhealthy",
|
||||
zap.Error(errProxyIsUnhealthy(id)))
|
||||
}
|
||||
}
|
||||
|
@ -1,169 +0,0 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type insertChannelsMap struct {
|
||||
collectionID2InsertChannels map[UniqueID]int // the value of map is the location of insertChannels & insertMsgStreams
|
||||
insertChannels [][]string // it's a little confusing to use []string as the key of map
|
||||
insertMsgStreams []msgstream.MsgStream // maybe there's a better way to implement Set, just agilely now
|
||||
droppedBitMap []int // 0 -> normal, 1 -> dropped
|
||||
usageHistogram []int // message stream can be closed only when the use count is zero
|
||||
// TODO: use fine grained lock
|
||||
mtx sync.RWMutex
|
||||
nodeInstance *Proxy
|
||||
msFactory msgstream.Factory
|
||||
}
|
||||
|
||||
func (m *insertChannelsMap) CreateInsertMsgStream(collID UniqueID, channels []string) error {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
_, ok := m.collectionID2InsertChannels[collID]
|
||||
if ok {
|
||||
return errors.New("impossible and forbidden to create message stream twice")
|
||||
}
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i] <= channels[j]
|
||||
})
|
||||
for loc, existedChannels := range m.insertChannels {
|
||||
if m.droppedBitMap[loc] == 0 && funcutil.SortedSliceEqual(existedChannels, channels) {
|
||||
m.collectionID2InsertChannels[collID] = loc
|
||||
m.usageHistogram[loc]++
|
||||
return nil
|
||||
}
|
||||
}
|
||||
m.insertChannels = append(m.insertChannels, channels)
|
||||
m.collectionID2InsertChannels[collID] = len(m.insertChannels) - 1
|
||||
|
||||
stream, _ := m.msFactory.NewMsgStream(context.Background())
|
||||
stream.AsProducer(channels)
|
||||
log.Debug("proxy", zap.Strings("proxy AsProducer: ", channels))
|
||||
stream.SetRepackFunc(insertRepackFunc)
|
||||
stream.Start()
|
||||
m.insertMsgStreams = append(m.insertMsgStreams, stream)
|
||||
m.droppedBitMap = append(m.droppedBitMap, 0)
|
||||
m.usageHistogram = append(m.usageHistogram, 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *insertChannelsMap) CloseInsertMsgStream(collID UniqueID) error {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
loc, ok := m.collectionID2InsertChannels[collID]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find collection with id %d", collID)
|
||||
}
|
||||
if m.droppedBitMap[loc] != 0 {
|
||||
return errors.New("insert message stream already closed")
|
||||
}
|
||||
if m.usageHistogram[loc] <= 0 {
|
||||
return errors.New("insert message stream already closed")
|
||||
}
|
||||
|
||||
m.usageHistogram[loc]--
|
||||
if m.usageHistogram[loc] <= 0 {
|
||||
m.insertMsgStreams[loc].Close()
|
||||
m.droppedBitMap[loc] = 1
|
||||
log.Warn("close insert message stream ...")
|
||||
}
|
||||
|
||||
delete(m.collectionID2InsertChannels, collID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *insertChannelsMap) GetInsertChannels(collID UniqueID) ([]string, error) {
|
||||
m.mtx.RLock()
|
||||
defer m.mtx.RUnlock()
|
||||
|
||||
loc, ok := m.collectionID2InsertChannels[collID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot find collection with id: %d", collID)
|
||||
}
|
||||
|
||||
if m.droppedBitMap[loc] != 0 {
|
||||
return nil, errors.New("insert message stream already closed")
|
||||
}
|
||||
ret := append([]string(nil), m.insertChannels[loc]...)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (m *insertChannelsMap) GetInsertMsgStream(collID UniqueID) (msgstream.MsgStream, error) {
|
||||
m.mtx.RLock()
|
||||
defer m.mtx.RUnlock()
|
||||
|
||||
loc, ok := m.collectionID2InsertChannels[collID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot find collection with id: %d", collID)
|
||||
}
|
||||
|
||||
if m.droppedBitMap[loc] != 0 {
|
||||
return nil, errors.New("insert message stream already closed")
|
||||
}
|
||||
|
||||
return m.insertMsgStreams[loc], nil
|
||||
}
|
||||
|
||||
func (m *insertChannelsMap) CloseAllMsgStream() {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
for loc, stream := range m.insertMsgStreams {
|
||||
if m.droppedBitMap[loc] == 0 && m.usageHistogram[loc] >= 1 {
|
||||
stream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
m.collectionID2InsertChannels = make(map[UniqueID]int)
|
||||
m.insertChannels = make([][]string, 0)
|
||||
m.insertMsgStreams = make([]msgstream.MsgStream, 0)
|
||||
m.droppedBitMap = make([]int, 0)
|
||||
m.usageHistogram = make([]int, 0)
|
||||
}
|
||||
|
||||
func newInsertChannelsMap(node *Proxy) *insertChannelsMap {
|
||||
return &insertChannelsMap{
|
||||
collectionID2InsertChannels: make(map[UniqueID]int),
|
||||
insertChannels: make([][]string, 0),
|
||||
insertMsgStreams: make([]msgstream.MsgStream, 0),
|
||||
droppedBitMap: make([]int, 0),
|
||||
usageHistogram: make([]int, 0),
|
||||
nodeInstance: node,
|
||||
msFactory: node.msFactory,
|
||||
}
|
||||
}
|
||||
|
||||
var globalInsertChannelsMap *insertChannelsMap
|
||||
var initGlobalInsertChannelsMapOnce sync.Once
|
||||
|
||||
// change to singleton mode later? Such as GetInsertChannelsMapInstance like GetConfAdapterMgrInstance.
|
||||
func initGlobalInsertChannelsMap(node *Proxy) {
|
||||
initGlobalInsertChannelsMapOnce.Do(func() {
|
||||
globalInsertChannelsMap = newInsertChannelsMap(node)
|
||||
})
|
||||
}
|
@ -1,239 +0,0 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
)
|
||||
|
||||
func TestInsertChannelsMap_CreateInsertMsgStream(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &Proxy{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
|
||||
err = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// duplicated
|
||||
err = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// duplicated
|
||||
err = m.CreateInsertMsgStream(1, []string{"1", "2"})
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// use same channels
|
||||
err = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
err = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func TestInsertChannelsMap_CloseInsertMsgStream(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &Proxy{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
|
||||
_ = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
|
||||
// don't exist
|
||||
err = m.CloseInsertMsgStream(0)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(1)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// close twice
|
||||
err = m.CloseInsertMsgStream(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(2)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// close twice
|
||||
err = m.CloseInsertMsgStream(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(3)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// close twice
|
||||
err = m.CloseInsertMsgStream(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestInsertChannelsMap_GetInsertChannels(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &Proxy{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
var channels []string
|
||||
|
||||
_ = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
|
||||
// don't exist
|
||||
channels, err = m.GetInsertChannels(0)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
channels, err = m.GetInsertChannels(1)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"1"}))
|
||||
|
||||
channels, err = m.GetInsertChannels(2)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"1"}))
|
||||
|
||||
channels, err = m.GetInsertChannels(3)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"3"}))
|
||||
|
||||
_ = m.CloseInsertMsgStream(1)
|
||||
channels, err = m.GetInsertChannels(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
_ = m.CloseInsertMsgStream(2)
|
||||
channels, err = m.GetInsertChannels(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
_ = m.CloseInsertMsgStream(3)
|
||||
channels, err = m.GetInsertChannels(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
}
|
||||
|
||||
func TestInsertChannelsMap_GetInsertMsgStream(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &Proxy{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
var stream msgstream.MsgStream
|
||||
|
||||
_ = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
|
||||
// don't exist
|
||||
stream, err = m.GetInsertMsgStream(0)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(1)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(2)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(3)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, stream)
|
||||
|
||||
_ = m.CloseInsertMsgStream(1)
|
||||
stream, err = m.GetInsertMsgStream(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
_ = m.CloseInsertMsgStream(2)
|
||||
stream, err = m.GetInsertMsgStream(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
_ = m.CloseInsertMsgStream(3)
|
||||
stream, err = m.GetInsertMsgStream(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
}
|
||||
|
||||
func TestInsertChannelsMap_CloseAllMsgStream(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &Proxy{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
var stream msgstream.MsgStream
|
||||
var channels []string
|
||||
|
||||
_ = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
|
||||
m.CloseAllMsgStream()
|
||||
|
||||
err = m.CloseInsertMsgStream(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
channels, err = m.GetInsertChannels(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
channels, err = m.GetInsertChannels(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
channels, err = m.GetInsertChannels(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
stream, err = m.GetInsertMsgStream(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
}
|
@ -31,16 +31,21 @@ const (
|
||||
type ParamTable struct {
|
||||
paramtable.BaseTable
|
||||
|
||||
NetworkPort int
|
||||
IP string
|
||||
// NetworkPort & IP are not used
|
||||
NetworkPort int
|
||||
IP string
|
||||
|
||||
NetworkAddress string
|
||||
Alias string
|
||||
|
||||
// TODO(dragondriver): maybe using the Proxy + ProxyID as the alias is more reasonable
|
||||
Alias string
|
||||
|
||||
EtcdEndpoints []string
|
||||
MetaRootPath string
|
||||
RootCoordAddress string
|
||||
PulsarAddress string
|
||||
RocksmqPath string
|
||||
|
||||
RocksmqPath string // not used in Proxy
|
||||
|
||||
ProxyID UniqueID
|
||||
TimeTickInterval time.Duration
|
||||
|
78
internal/proxy/paramtable_test.go
Normal file
78
internal/proxy/paramtable_test.go
Normal file
@ -0,0 +1,78 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParamTable(t *testing.T) {
|
||||
Params.Init()
|
||||
|
||||
t.Run("EtcdEndPoints", func(t *testing.T) {
|
||||
t.Logf("EtcdEndPoints: %v", Params.EtcdEndpoints)
|
||||
})
|
||||
|
||||
t.Run("MetaRootPath", func(t *testing.T) {
|
||||
t.Logf("MetaRootPath: %s", Params.MetaRootPath)
|
||||
})
|
||||
|
||||
t.Run("PulsarAddress", func(t *testing.T) {
|
||||
t.Logf("PulsarAddress: %s", Params.PulsarAddress)
|
||||
})
|
||||
|
||||
t.Run("RocksmqPath", func(t *testing.T) {
|
||||
t.Logf("RocksmqPath: %s", Params.RocksmqPath)
|
||||
})
|
||||
|
||||
t.Run("TimeTickInterval", func(t *testing.T) {
|
||||
t.Logf("TimeTickInterval: %v", Params.TimeTickInterval)
|
||||
})
|
||||
|
||||
t.Run("ProxySubName", func(t *testing.T) {
|
||||
t.Logf("ProxySubName: %s", Params.ProxySubName)
|
||||
})
|
||||
|
||||
t.Run("ProxyTimeTickChannelNames", func(t *testing.T) {
|
||||
t.Logf("ProxyTimeTickChannelNames: %v", Params.ProxyTimeTickChannelNames)
|
||||
})
|
||||
|
||||
t.Run("MsgStreamTimeTickBufSize", func(t *testing.T) {
|
||||
t.Logf("MsgStreamTimeTickBufSize: %d", Params.MsgStreamTimeTickBufSize)
|
||||
})
|
||||
|
||||
t.Run("MaxNameLength", func(t *testing.T) {
|
||||
t.Logf("MaxNameLength: %d", Params.MaxNameLength)
|
||||
})
|
||||
|
||||
t.Run("MaxFieldNum", func(t *testing.T) {
|
||||
t.Logf("MaxFieldNum: %d", Params.MaxFieldNum)
|
||||
})
|
||||
|
||||
t.Run("MaxDimension", func(t *testing.T) {
|
||||
t.Logf("MaxDimension: %d", Params.MaxDimension)
|
||||
})
|
||||
|
||||
t.Run("DefaultPartitionName", func(t *testing.T) {
|
||||
t.Logf("DefaultPartitionName: %s", Params.DefaultPartitionName)
|
||||
})
|
||||
|
||||
t.Run("DefaultIndexName", func(t *testing.T) {
|
||||
t.Logf("DefaultIndexName: %s", Params.DefaultIndexName)
|
||||
})
|
||||
|
||||
t.Run("PulsarMaxMessageSize", func(t *testing.T) {
|
||||
t.Logf("PulsarMaxMessageSize: %d", Params.PulsarMaxMessageSize)
|
||||
})
|
||||
|
||||
t.Run("RoleName", func(t *testing.T) {
|
||||
t.Logf("RoleName: %s", Params.RoleName)
|
||||
})
|
||||
}
|
@ -250,7 +250,7 @@ func (node *Proxy) Init() error {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
chMgr := newChannelsMgr(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory)
|
||||
chMgr := newChannelsMgrImpl(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory)
|
||||
node.chMgr = chMgr
|
||||
|
||||
node.sched, err = NewTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)
|
||||
|
@ -12,11 +12,23 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
)
|
||||
|
||||
func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
||||
hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
func insertRepackFunc(
|
||||
tsMsgs []msgstream.TsMsg,
|
||||
hashKeys [][]int32,
|
||||
) (map[int32]*msgstream.MsgPack, error) {
|
||||
|
||||
if len(hashKeys) < len(tsMsgs) {
|
||||
return nil, fmt.Errorf(
|
||||
"the length of hash keys (%d) is less than the length of messages (%d)",
|
||||
len(hashKeys),
|
||||
len(tsMsgs),
|
||||
)
|
||||
}
|
||||
|
||||
result := make(map[int32]*msgstream.MsgPack)
|
||||
for i, request := range tsMsgs {
|
||||
@ -28,17 +40,32 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
||||
result[key] = &msgstream.MsgPack{}
|
||||
}
|
||||
result[key].Msgs = append(result[key].Msgs, request)
|
||||
} else {
|
||||
return nil, fmt.Errorf("no hash key for %dth message", i)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func defaultInsertRepackFunc(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
func defaultInsertRepackFunc(
|
||||
tsMsgs []msgstream.TsMsg,
|
||||
hashKeys [][]int32,
|
||||
) (map[int32]*msgstream.MsgPack, error) {
|
||||
|
||||
if len(hashKeys) < len(tsMsgs) {
|
||||
return nil, fmt.Errorf(
|
||||
"the length of hash keys (%d) is less than the length of messages (%d)",
|
||||
len(hashKeys),
|
||||
len(tsMsgs),
|
||||
)
|
||||
}
|
||||
|
||||
// after assigning segment id to msg, tsMsgs was already re-bucketed
|
||||
pack := make(map[int32]*msgstream.MsgPack)
|
||||
for idx, msg := range tsMsgs {
|
||||
if len(hashKeys[idx]) <= 0 {
|
||||
continue
|
||||
return nil, fmt.Errorf("no hash key for %dth message", idx)
|
||||
}
|
||||
key := hashKeys[idx][0]
|
||||
_, ok := pack[key]
|
||||
|
191
internal/proxy/repack_func_test.go
Normal file
191
internal/proxy/repack_func_test.go
Normal file
@ -0,0 +1,191 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_insertRepackFunc(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// tsMsgs is empty
|
||||
ret1, err := insertRepackFunc(nil, [][]int32{{1, 2}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(ret1))
|
||||
|
||||
// hashKeys is empty
|
||||
tsMsgs2 := []msgstream.TsMsg{
|
||||
&msgstream.InsertMsg{}, // not important
|
||||
&msgstream.InsertMsg{}, // not important
|
||||
}
|
||||
ret2, err := insertRepackFunc(tsMsgs2, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, ret2)
|
||||
|
||||
// len(hashKeys) < len(tsMsgs), 1 < 2
|
||||
ret2, err = insertRepackFunc(tsMsgs2, [][]int32{{1, 2}})
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, ret2)
|
||||
|
||||
// both tsMsgs & hashKeys are empty
|
||||
ret3, err := insertRepackFunc(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(ret3))
|
||||
|
||||
num := rand.Int()%100 + 1
|
||||
tsMsgs4 := make([]msgstream.TsMsg, 0)
|
||||
for i := 0; i < num; i++ {
|
||||
tsMsgs4 = append(tsMsgs4, &msgstream.InsertMsg{
|
||||
// not important
|
||||
})
|
||||
}
|
||||
|
||||
// len(hashKeys) = len(tsMsgs), but no hash key
|
||||
hashKeys1 := make([][]int32, num)
|
||||
ret4, err := insertRepackFunc(tsMsgs4, hashKeys1)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, ret4)
|
||||
|
||||
// all messages are shuffled to same bucket
|
||||
hashKeys2 := make([][]int32, num)
|
||||
key := int32(0)
|
||||
for i := 0; i < num; i++ {
|
||||
hashKeys2[i] = []int32{key}
|
||||
}
|
||||
ret5, err := insertRepackFunc(tsMsgs4, hashKeys2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(ret5))
|
||||
assert.Equal(t, num, len(ret5[key].Msgs))
|
||||
|
||||
// evenly shuffle
|
||||
hashKeys3 := make([][]int32, num)
|
||||
for i := 0; i < num; i++ {
|
||||
hashKeys3[i] = []int32{int32(i)}
|
||||
}
|
||||
ret6, err := insertRepackFunc(tsMsgs4, hashKeys3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, num, len(ret6))
|
||||
for key := range ret6 {
|
||||
assert.Equal(t, 1, len(ret6[key].Msgs))
|
||||
}
|
||||
|
||||
// randomly shuffle
|
||||
histogram := make(map[int32]int) // key -> key num
|
||||
hashKeys4 := make([][]int32, num)
|
||||
for i := 0; i < num; i++ {
|
||||
k := int32(rand.Uint32())
|
||||
hashKeys4[i] = []int32{k}
|
||||
_, exist := histogram[k]
|
||||
if exist {
|
||||
histogram[k]++
|
||||
} else {
|
||||
histogram[k] = 1
|
||||
}
|
||||
}
|
||||
ret7, err := insertRepackFunc(tsMsgs4, hashKeys4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(histogram), len(ret7))
|
||||
for key := range ret7 {
|
||||
assert.Equal(t, histogram[key], len(ret7[key].Msgs))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_defaultInsertRepackFunc(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// tsMsgs is empty
|
||||
ret1, err := defaultInsertRepackFunc(nil, [][]int32{{1, 2}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(ret1))
|
||||
|
||||
// hashKeys is empty
|
||||
tsMsgs2 := []msgstream.TsMsg{
|
||||
&msgstream.InsertMsg{}, // not important
|
||||
&msgstream.InsertMsg{}, // not important
|
||||
}
|
||||
ret2, err := defaultInsertRepackFunc(tsMsgs2, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, ret2)
|
||||
|
||||
// len(hashKeys) < len(tsMsgs), 1 < 2
|
||||
ret2, err = defaultInsertRepackFunc(tsMsgs2, [][]int32{{1, 2}})
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, ret2)
|
||||
|
||||
// both tsMsgs & hashKeys are empty
|
||||
ret3, err := defaultInsertRepackFunc(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(ret3))
|
||||
|
||||
num := rand.Int()%100 + 1
|
||||
tsMsgs4 := make([]msgstream.TsMsg, 0)
|
||||
for i := 0; i < num; i++ {
|
||||
tsMsgs4 = append(tsMsgs4, &msgstream.InsertMsg{
|
||||
// not important
|
||||
})
|
||||
}
|
||||
|
||||
// len(hashKeys) = len(tsMsgs), but no hash key
|
||||
hashKeys1 := make([][]int32, num)
|
||||
ret4, err := defaultInsertRepackFunc(tsMsgs4, hashKeys1)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, ret4)
|
||||
|
||||
// all messages are shuffled to same bucket
|
||||
hashKeys2 := make([][]int32, num)
|
||||
key := int32(0)
|
||||
for i := 0; i < num; i++ {
|
||||
hashKeys2[i] = []int32{key}
|
||||
}
|
||||
ret5, err := defaultInsertRepackFunc(tsMsgs4, hashKeys2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(ret5))
|
||||
assert.Equal(t, num, len(ret5[key].Msgs))
|
||||
|
||||
// evenly shuffle
|
||||
hashKeys3 := make([][]int32, num)
|
||||
for i := 0; i < num; i++ {
|
||||
hashKeys3[i] = []int32{int32(i)}
|
||||
}
|
||||
ret6, err := defaultInsertRepackFunc(tsMsgs4, hashKeys3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, num, len(ret6))
|
||||
for key := range ret6 {
|
||||
assert.Equal(t, 1, len(ret6[key].Msgs))
|
||||
}
|
||||
|
||||
// randomly shuffle
|
||||
histogram := make(map[int32]int) // key -> key num
|
||||
hashKeys4 := make([][]int32, num)
|
||||
for i := 0; i < num; i++ {
|
||||
k := int32(rand.Uint32())
|
||||
hashKeys4[i] = []int32{k}
|
||||
_, exist := histogram[k]
|
||||
if exist {
|
||||
histogram[k]++
|
||||
} else {
|
||||
histogram[k] = 1
|
||||
}
|
||||
}
|
||||
ret7, err := defaultInsertRepackFunc(tsMsgs4, hashKeys4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(histogram), len(ret7))
|
||||
for key := range ret7 {
|
||||
assert.Equal(t, histogram[key], len(ret7[key].Msgs))
|
||||
}
|
||||
}
|
20
internal/proxy/type_def.go
Normal file
20
internal/proxy/type_def.go
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
type vChan = string
|
||||
type pChan = string
|
||||
|
||||
type pChanStatistics struct {
|
||||
minTs Timestamp
|
||||
maxTs Timestamp
|
||||
}
|
Loading…
Reference in New Issue
Block a user