Add unittest to insertChannelsMap in proxy

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
dragondriver 2021-04-08 15:41:28 +08:00 committed by yefu.chen
parent a250eb370f
commit 57831b9978
7 changed files with 319 additions and 43 deletions

View File

@ -1,6 +1,9 @@
package msgstream
import "sync"
import (
"context"
"sync"
)
type SimpleMsgStream struct {
msgChan chan *MsgPack
@ -28,12 +31,28 @@ func (ms *SimpleMsgStream) AsConsumer(channels []string, subName string) {
func (ms *SimpleMsgStream) SetRepackFunc(repackFunc RepackFunc) {
}
func (ms *SimpleMsgStream) Produce(pack *MsgPack) error {
func (ms *SimpleMsgStream) getMsgCount() int {
ms.msgCountMtx.RLock()
defer ms.msgCountMtx.RUnlock()
return ms.msgCount
}
func (ms *SimpleMsgStream) increaseMsgCount(delta int) {
ms.msgCountMtx.Lock()
defer ms.msgCountMtx.Unlock()
ms.msgCount += delta
}
func (ms *SimpleMsgStream) decreaseMsgCount(delta int) {
ms.increaseMsgCount(-delta)
}
func (ms *SimpleMsgStream) Produce(pack *MsgPack) error {
defer ms.increaseMsgCount(1)
ms.msgChan <- pack
ms.msgCount++
return nil
}
@ -43,13 +62,12 @@ func (ms *SimpleMsgStream) Broadcast(pack *MsgPack) error {
}
func (ms *SimpleMsgStream) Consume() *MsgPack {
ms.msgCountMtx.RLock()
defer ms.msgCountMtx.RUnlock()
if ms.msgCount <= 0 {
if ms.getMsgCount() <= 0 {
return nil
}
defer ms.decreaseMsgCount(1)
return <-ms.msgChan
}
@ -63,3 +81,26 @@ func NewSimpleMsgStream() *SimpleMsgStream {
msgCount: 0,
}
}
type SimpleMsgStreamFactory struct {
}
func (factory *SimpleMsgStreamFactory) SetParams(params map[string]interface{}) error {
return nil
}
func (factory *SimpleMsgStreamFactory) NewMsgStream(ctx context.Context) (MsgStream, error) {
return NewSimpleMsgStream(), nil
}
func (factory *SimpleMsgStreamFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) {
return NewSimpleMsgStream(), nil
}
func (factory *SimpleMsgStreamFactory) NewQueryMsgStream(ctx context.Context) (MsgStream, error) {
return NewSimpleMsgStream(), nil
}
func NewSimpleMsgStreamFactory() *SimpleMsgStreamFactory {
return &SimpleMsgStreamFactory{}
}

View File

@ -3,19 +3,19 @@ package proxynode
type IndexType = string
const (
IndexFaissIDMap = "FLAT"
IndexFaissIvfFlat = "IVF_FLAT"
IndexFaissIvfPQ = "IVF_PQ"
IndexFaissIvfSQ8 = "IVF_SQ8"
IndexFaissIvfSQ8H = "IVF_SQ8_HYBRID"
IndexFaissBinIDMap = "BIN_FLAT"
IndexFaissBinIvfFlat = "BIN_IVF_FLAT"
IndexNSG = "NSG"
IndexHNSW = "HNSW"
IndexRHNSWFlat = "RHNSW_FLAT"
IndexRHNSWPQ = "RHNSW_PQ"
IndexRHNSWSQ = "RHNSW_SQ"
IndexANNOY = "ANNOY"
IndexNGTPANNG = "NGT_PANNG"
IndexNGTONNG = "NGT_ONNG"
IndexFaissIDMap IndexType = "FLAT"
IndexFaissIvfFlat IndexType = "IVF_FLAT"
IndexFaissIvfPQ IndexType = "IVF_PQ"
IndexFaissIvfSQ8 IndexType = "IVF_SQ8"
IndexFaissIvfSQ8H IndexType = "IVF_SQ8_HYBRID"
IndexFaissBinIDMap IndexType = "BIN_FLAT"
IndexFaissBinIvfFlat IndexType = "BIN_IVF_FLAT"
IndexNSG IndexType = "NSG"
IndexHNSW IndexType = "HNSW"
IndexRHNSWFlat IndexType = "RHNSW_FLAT"
IndexRHNSWPQ IndexType = "RHNSW_PQ"
IndexRHNSWSQ IndexType = "RHNSW_SQ"
IndexANNOY IndexType = "ANNOY"
IndexNGTPANNG IndexType = "NGT_PANNG"
IndexNGTONNG IndexType = "NGT_ONNG"
)

View File

@ -14,18 +14,19 @@ import (
"go.uber.org/zap"
)
type InsertChannelsMap struct {
type insertChannelsMap struct {
collectionID2InsertChannels map[UniqueID]int // the value of map is the location of insertChannels & insertMsgStreams
insertChannels [][]string // it's a little confusing to use []string as the key of map
insertMsgStreams []msgstream.MsgStream // maybe there's a better way to implement Set, just agilely now
droppedBitMap []int // 0 -> normal, 1 -> dropped
usageHistogram []int // message stream can be closed only when the use count is zero
mtx sync.RWMutex
nodeInstance *ProxyNode
msFactory msgstream.Factory
// TODO: use fine grained lock
mtx sync.RWMutex
nodeInstance *ProxyNode
msFactory msgstream.Factory
}
func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []string) error {
func (m *insertChannelsMap) CreateInsertMsgStream(collID UniqueID, channels []string) error {
m.mtx.Lock()
defer m.mtx.Unlock()
@ -61,7 +62,7 @@ func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []st
return nil
}
func (m *InsertChannelsMap) closeInsertMsgStream(collID UniqueID) error {
func (m *insertChannelsMap) CloseInsertMsgStream(collID UniqueID) error {
m.mtx.Lock()
defer m.mtx.Unlock()
@ -80,13 +81,15 @@ func (m *InsertChannelsMap) closeInsertMsgStream(collID UniqueID) error {
if m.usageHistogram[loc] <= 0 {
m.insertMsgStreams[loc].Close()
m.droppedBitMap[loc] = 1
delete(m.collectionID2InsertChannels, collID)
log.Warn("close insert message stream ...")
}
delete(m.collectionID2InsertChannels, collID)
return nil
}
func (m *InsertChannelsMap) getInsertChannels(collID UniqueID) ([]string, error) {
func (m *insertChannelsMap) GetInsertChannels(collID UniqueID) ([]string, error) {
m.mtx.RLock()
defer m.mtx.RUnlock()
@ -102,7 +105,7 @@ func (m *InsertChannelsMap) getInsertChannels(collID UniqueID) ([]string, error)
return ret, nil
}
func (m *InsertChannelsMap) getInsertMsgStream(collID UniqueID) (msgstream.MsgStream, error) {
func (m *insertChannelsMap) GetInsertMsgStream(collID UniqueID) (msgstream.MsgStream, error) {
m.mtx.RLock()
defer m.mtx.RUnlock()
@ -118,7 +121,7 @@ func (m *InsertChannelsMap) getInsertMsgStream(collID UniqueID) (msgstream.MsgSt
return m.insertMsgStreams[loc], nil
}
func (m *InsertChannelsMap) closeAllMsgStream() {
func (m *insertChannelsMap) CloseAllMsgStream() {
m.mtx.Lock()
defer m.mtx.Unlock()
@ -135,8 +138,8 @@ func (m *InsertChannelsMap) closeAllMsgStream() {
m.usageHistogram = make([]int, 0)
}
func newInsertChannelsMap(node *ProxyNode) *InsertChannelsMap {
return &InsertChannelsMap{
func newInsertChannelsMap(node *ProxyNode) *insertChannelsMap {
return &insertChannelsMap{
collectionID2InsertChannels: make(map[UniqueID]int),
insertChannels: make([][]string, 0),
insertMsgStreams: make([]msgstream.MsgStream, 0),
@ -147,8 +150,12 @@ func newInsertChannelsMap(node *ProxyNode) *InsertChannelsMap {
}
}
var globalInsertChannelsMap *InsertChannelsMap
var globalInsertChannelsMap *insertChannelsMap
var initGlobalInsertChannelsMapOnce sync.Once
// change to singleton mode later? Such as GetInsertChannelsMapInstance like GetConfAdapterMgrInstance.
func initGlobalInsertChannelsMap(node *ProxyNode) {
globalInsertChannelsMap = newInsertChannelsMap(node)
initGlobalInsertChannelsMapOnce.Do(func() {
globalInsertChannelsMap = newInsertChannelsMap(node)
})
}

View File

@ -0,0 +1,228 @@
package proxynode
import (
"testing"
"github.com/zilliztech/milvus-distributed/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
)
func TestInsertChannelsMap_CreateInsertMsgStream(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &ProxyNode{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
err = m.CreateInsertMsgStream(1, []string{"1"})
assert.Equal(t, nil, err)
// duplicated
err = m.CreateInsertMsgStream(1, []string{"1"})
assert.NotEqual(t, nil, err)
// duplicated
err = m.CreateInsertMsgStream(1, []string{"1", "2"})
assert.NotEqual(t, nil, err)
// use same channels
err = m.CreateInsertMsgStream(2, []string{"1"})
assert.Equal(t, nil, err)
err = m.CreateInsertMsgStream(3, []string{"3"})
assert.Equal(t, nil, err)
}
func TestInsertChannelsMap_CloseInsertMsgStream(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &ProxyNode{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
_ = m.CreateInsertMsgStream(1, []string{"1"})
_ = m.CreateInsertMsgStream(2, []string{"1"})
_ = m.CreateInsertMsgStream(3, []string{"3"})
// don't exist
err = m.CloseInsertMsgStream(0)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(1)
assert.Equal(t, nil, err)
// close twice
err = m.CloseInsertMsgStream(1)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(2)
assert.Equal(t, nil, err)
// close twice
err = m.CloseInsertMsgStream(2)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(3)
assert.Equal(t, nil, err)
// close twice
err = m.CloseInsertMsgStream(3)
assert.NotEqual(t, nil, err)
}
func TestInsertChannelsMap_GetInsertChannels(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &ProxyNode{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
var channels []string
_ = m.CreateInsertMsgStream(1, []string{"1"})
_ = m.CreateInsertMsgStream(2, []string{"1"})
_ = m.CreateInsertMsgStream(3, []string{"3"})
// don't exist
channels, err = m.GetInsertChannels(0)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
channels, err = m.GetInsertChannels(1)
assert.Equal(t, nil, err)
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"1"}))
channels, err = m.GetInsertChannels(2)
assert.Equal(t, nil, err)
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"1"}))
channels, err = m.GetInsertChannels(3)
assert.Equal(t, nil, err)
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"3"}))
_ = m.CloseInsertMsgStream(1)
channels, err = m.GetInsertChannels(1)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
_ = m.CloseInsertMsgStream(2)
channels, err = m.GetInsertChannels(2)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
_ = m.CloseInsertMsgStream(3)
channels, err = m.GetInsertChannels(3)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
}
func TestInsertChannelsMap_GetInsertMsgStream(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &ProxyNode{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
var stream msgstream.MsgStream
_ = m.CreateInsertMsgStream(1, []string{"1"})
_ = m.CreateInsertMsgStream(2, []string{"1"})
_ = m.CreateInsertMsgStream(3, []string{"3"})
// don't exist
stream, err = m.GetInsertMsgStream(0)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
stream, err = m.GetInsertMsgStream(1)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, stream)
stream, err = m.GetInsertMsgStream(2)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, stream)
stream, err = m.GetInsertMsgStream(3)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, stream)
_ = m.CloseInsertMsgStream(1)
stream, err = m.GetInsertMsgStream(1)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
_ = m.CloseInsertMsgStream(2)
stream, err = m.GetInsertMsgStream(2)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
_ = m.CloseInsertMsgStream(3)
stream, err = m.GetInsertMsgStream(3)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
}
func TestInsertChannelsMap_CloseAllMsgStream(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &ProxyNode{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
var stream msgstream.MsgStream
var channels []string
_ = m.CreateInsertMsgStream(1, []string{"1"})
_ = m.CreateInsertMsgStream(2, []string{"1"})
_ = m.CreateInsertMsgStream(3, []string{"3"})
m.CloseAllMsgStream()
err = m.CloseInsertMsgStream(1)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(2)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(3)
assert.NotEqual(t, nil, err)
channels, err = m.GetInsertChannels(1)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
channels, err = m.GetInsertChannels(2)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
channels, err = m.GetInsertChannels(3)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
stream, err = m.GetInsertMsgStream(1)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
stream, err = m.GetInsertMsgStream(2)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
stream, err = m.GetInsertMsgStream(3)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
}

View File

@ -249,7 +249,7 @@ func (node *ProxyNode) Start() error {
func (node *ProxyNode) Stop() error {
node.cancel()
globalInsertChannelsMap.closeAllMsgStream()
globalInsertChannelsMap.CloseAllMsgStream()
node.tsoAllocator.Close()
node.idAllocator.Close()
node.segAssigner.Close()

View File

@ -71,7 +71,7 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
collID := insertRequest.CollectionID
if _, ok := channelNamesMap[collID]; !ok {
channelNames, err := globalInsertChannelsMap.getInsertChannels(collID)
channelNames, err := globalInsertChannelsMap.GetInsertChannels(collID)
if err != nil {
return nil, err
}

View File

@ -204,7 +204,7 @@ func (it *InsertTask) Execute(ctx context.Context) error {
msgPack.Msgs[0] = tsMsg
stream, err := globalInsertChannelsMap.getInsertMsgStream(collID)
stream, err := globalInsertChannelsMap.GetInsertMsgStream(collID)
if err != nil {
resp, _ := it.dataService.GetInsertChannels(ctx, &datapb.GetInsertChannelsRequest{
Base: &commonpb.MsgBase{
@ -222,12 +222,12 @@ func (it *InsertTask) Execute(ctx context.Context) error {
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(resp.Status.Reason)
}
err = globalInsertChannelsMap.createInsertMsgStream(collID, resp.Values)
err = globalInsertChannelsMap.CreateInsertMsgStream(collID, resp.Values)
if err != nil {
return err
}
}
stream, err = globalInsertChannelsMap.getInsertMsgStream(collID)
stream, err = globalInsertChannelsMap.GetInsertMsgStream(collID)
if err != nil {
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
@ -386,7 +386,7 @@ func (cct *CreateCollectionTask) Execute(ctx context.Context) error {
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(resp.Status.Reason)
}
err = globalInsertChannelsMap.createInsertMsgStream(collID, resp.Values)
err = globalInsertChannelsMap.CreateInsertMsgStream(collID, resp.Values)
if err != nil {
return err
}
@ -464,7 +464,7 @@ func (dct *DropCollectionTask) Execute(ctx context.Context) error {
return err
}
err = globalInsertChannelsMap.closeInsertMsgStream(collID)
err = globalInsertChannelsMap.CloseInsertMsgStream(collID)
if err != nil {
return err
}