mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 19:39:21 +08:00
Add unittest to insertChannelsMap in proxy
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
a250eb370f
commit
57831b9978
@ -1,6 +1,9 @@
|
||||
package msgstream
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type SimpleMsgStream struct {
|
||||
msgChan chan *MsgPack
|
||||
@ -28,12 +31,28 @@ func (ms *SimpleMsgStream) AsConsumer(channels []string, subName string) {
|
||||
func (ms *SimpleMsgStream) SetRepackFunc(repackFunc RepackFunc) {
|
||||
}
|
||||
|
||||
func (ms *SimpleMsgStream) Produce(pack *MsgPack) error {
|
||||
func (ms *SimpleMsgStream) getMsgCount() int {
|
||||
ms.msgCountMtx.RLock()
|
||||
defer ms.msgCountMtx.RUnlock()
|
||||
|
||||
return ms.msgCount
|
||||
}
|
||||
|
||||
func (ms *SimpleMsgStream) increaseMsgCount(delta int) {
|
||||
ms.msgCountMtx.Lock()
|
||||
defer ms.msgCountMtx.Unlock()
|
||||
|
||||
ms.msgCount += delta
|
||||
}
|
||||
|
||||
func (ms *SimpleMsgStream) decreaseMsgCount(delta int) {
|
||||
ms.increaseMsgCount(-delta)
|
||||
}
|
||||
|
||||
func (ms *SimpleMsgStream) Produce(pack *MsgPack) error {
|
||||
defer ms.increaseMsgCount(1)
|
||||
|
||||
ms.msgChan <- pack
|
||||
ms.msgCount++
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -43,13 +62,12 @@ func (ms *SimpleMsgStream) Broadcast(pack *MsgPack) error {
|
||||
}
|
||||
|
||||
func (ms *SimpleMsgStream) Consume() *MsgPack {
|
||||
ms.msgCountMtx.RLock()
|
||||
defer ms.msgCountMtx.RUnlock()
|
||||
|
||||
if ms.msgCount <= 0 {
|
||||
if ms.getMsgCount() <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer ms.decreaseMsgCount(1)
|
||||
|
||||
return <-ms.msgChan
|
||||
}
|
||||
|
||||
@ -63,3 +81,26 @@ func NewSimpleMsgStream() *SimpleMsgStream {
|
||||
msgCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
type SimpleMsgStreamFactory struct {
|
||||
}
|
||||
|
||||
func (factory *SimpleMsgStreamFactory) SetParams(params map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (factory *SimpleMsgStreamFactory) NewMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
return NewSimpleMsgStream(), nil
|
||||
}
|
||||
|
||||
func (factory *SimpleMsgStreamFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
return NewSimpleMsgStream(), nil
|
||||
}
|
||||
|
||||
func (factory *SimpleMsgStreamFactory) NewQueryMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
return NewSimpleMsgStream(), nil
|
||||
}
|
||||
|
||||
func NewSimpleMsgStreamFactory() *SimpleMsgStreamFactory {
|
||||
return &SimpleMsgStreamFactory{}
|
||||
}
|
||||
|
@ -3,19 +3,19 @@ package proxynode
|
||||
type IndexType = string
|
||||
|
||||
const (
|
||||
IndexFaissIDMap = "FLAT"
|
||||
IndexFaissIvfFlat = "IVF_FLAT"
|
||||
IndexFaissIvfPQ = "IVF_PQ"
|
||||
IndexFaissIvfSQ8 = "IVF_SQ8"
|
||||
IndexFaissIvfSQ8H = "IVF_SQ8_HYBRID"
|
||||
IndexFaissBinIDMap = "BIN_FLAT"
|
||||
IndexFaissBinIvfFlat = "BIN_IVF_FLAT"
|
||||
IndexNSG = "NSG"
|
||||
IndexHNSW = "HNSW"
|
||||
IndexRHNSWFlat = "RHNSW_FLAT"
|
||||
IndexRHNSWPQ = "RHNSW_PQ"
|
||||
IndexRHNSWSQ = "RHNSW_SQ"
|
||||
IndexANNOY = "ANNOY"
|
||||
IndexNGTPANNG = "NGT_PANNG"
|
||||
IndexNGTONNG = "NGT_ONNG"
|
||||
IndexFaissIDMap IndexType = "FLAT"
|
||||
IndexFaissIvfFlat IndexType = "IVF_FLAT"
|
||||
IndexFaissIvfPQ IndexType = "IVF_PQ"
|
||||
IndexFaissIvfSQ8 IndexType = "IVF_SQ8"
|
||||
IndexFaissIvfSQ8H IndexType = "IVF_SQ8_HYBRID"
|
||||
IndexFaissBinIDMap IndexType = "BIN_FLAT"
|
||||
IndexFaissBinIvfFlat IndexType = "BIN_IVF_FLAT"
|
||||
IndexNSG IndexType = "NSG"
|
||||
IndexHNSW IndexType = "HNSW"
|
||||
IndexRHNSWFlat IndexType = "RHNSW_FLAT"
|
||||
IndexRHNSWPQ IndexType = "RHNSW_PQ"
|
||||
IndexRHNSWSQ IndexType = "RHNSW_SQ"
|
||||
IndexANNOY IndexType = "ANNOY"
|
||||
IndexNGTPANNG IndexType = "NGT_PANNG"
|
||||
IndexNGTONNG IndexType = "NGT_ONNG"
|
||||
)
|
||||
|
@ -14,18 +14,19 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type InsertChannelsMap struct {
|
||||
type insertChannelsMap struct {
|
||||
collectionID2InsertChannels map[UniqueID]int // the value of map is the location of insertChannels & insertMsgStreams
|
||||
insertChannels [][]string // it's a little confusing to use []string as the key of map
|
||||
insertMsgStreams []msgstream.MsgStream // maybe there's a better way to implement Set, just agilely now
|
||||
droppedBitMap []int // 0 -> normal, 1 -> dropped
|
||||
usageHistogram []int // message stream can be closed only when the use count is zero
|
||||
mtx sync.RWMutex
|
||||
nodeInstance *ProxyNode
|
||||
msFactory msgstream.Factory
|
||||
// TODO: use fine grained lock
|
||||
mtx sync.RWMutex
|
||||
nodeInstance *ProxyNode
|
||||
msFactory msgstream.Factory
|
||||
}
|
||||
|
||||
func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []string) error {
|
||||
func (m *insertChannelsMap) CreateInsertMsgStream(collID UniqueID, channels []string) error {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
@ -61,7 +62,7 @@ func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []st
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InsertChannelsMap) closeInsertMsgStream(collID UniqueID) error {
|
||||
func (m *insertChannelsMap) CloseInsertMsgStream(collID UniqueID) error {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
@ -80,13 +81,15 @@ func (m *InsertChannelsMap) closeInsertMsgStream(collID UniqueID) error {
|
||||
if m.usageHistogram[loc] <= 0 {
|
||||
m.insertMsgStreams[loc].Close()
|
||||
m.droppedBitMap[loc] = 1
|
||||
delete(m.collectionID2InsertChannels, collID)
|
||||
log.Warn("close insert message stream ...")
|
||||
}
|
||||
|
||||
delete(m.collectionID2InsertChannels, collID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InsertChannelsMap) getInsertChannels(collID UniqueID) ([]string, error) {
|
||||
func (m *insertChannelsMap) GetInsertChannels(collID UniqueID) ([]string, error) {
|
||||
m.mtx.RLock()
|
||||
defer m.mtx.RUnlock()
|
||||
|
||||
@ -102,7 +105,7 @@ func (m *InsertChannelsMap) getInsertChannels(collID UniqueID) ([]string, error)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (m *InsertChannelsMap) getInsertMsgStream(collID UniqueID) (msgstream.MsgStream, error) {
|
||||
func (m *insertChannelsMap) GetInsertMsgStream(collID UniqueID) (msgstream.MsgStream, error) {
|
||||
m.mtx.RLock()
|
||||
defer m.mtx.RUnlock()
|
||||
|
||||
@ -118,7 +121,7 @@ func (m *InsertChannelsMap) getInsertMsgStream(collID UniqueID) (msgstream.MsgSt
|
||||
return m.insertMsgStreams[loc], nil
|
||||
}
|
||||
|
||||
func (m *InsertChannelsMap) closeAllMsgStream() {
|
||||
func (m *insertChannelsMap) CloseAllMsgStream() {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
@ -135,8 +138,8 @@ func (m *InsertChannelsMap) closeAllMsgStream() {
|
||||
m.usageHistogram = make([]int, 0)
|
||||
}
|
||||
|
||||
func newInsertChannelsMap(node *ProxyNode) *InsertChannelsMap {
|
||||
return &InsertChannelsMap{
|
||||
func newInsertChannelsMap(node *ProxyNode) *insertChannelsMap {
|
||||
return &insertChannelsMap{
|
||||
collectionID2InsertChannels: make(map[UniqueID]int),
|
||||
insertChannels: make([][]string, 0),
|
||||
insertMsgStreams: make([]msgstream.MsgStream, 0),
|
||||
@ -147,8 +150,12 @@ func newInsertChannelsMap(node *ProxyNode) *InsertChannelsMap {
|
||||
}
|
||||
}
|
||||
|
||||
var globalInsertChannelsMap *InsertChannelsMap
|
||||
var globalInsertChannelsMap *insertChannelsMap
|
||||
var initGlobalInsertChannelsMapOnce sync.Once
|
||||
|
||||
// change to singleton mode later? Such as GetInsertChannelsMapInstance like GetConfAdapterMgrInstance.
|
||||
func initGlobalInsertChannelsMap(node *ProxyNode) {
|
||||
globalInsertChannelsMap = newInsertChannelsMap(node)
|
||||
initGlobalInsertChannelsMapOnce.Do(func() {
|
||||
globalInsertChannelsMap = newInsertChannelsMap(node)
|
||||
})
|
||||
}
|
||||
|
228
internal/proxynode/insert_channels_test.go
Normal file
228
internal/proxynode/insert_channels_test.go
Normal file
@ -0,0 +1,228 @@
|
||||
package proxynode
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/funcutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
)
|
||||
|
||||
func TestInsertChannelsMap_CreateInsertMsgStream(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &ProxyNode{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
|
||||
err = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// duplicated
|
||||
err = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// duplicated
|
||||
err = m.CreateInsertMsgStream(1, []string{"1", "2"})
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// use same channels
|
||||
err = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
err = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func TestInsertChannelsMap_CloseInsertMsgStream(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &ProxyNode{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
|
||||
_ = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
|
||||
// don't exist
|
||||
err = m.CloseInsertMsgStream(0)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(1)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// close twice
|
||||
err = m.CloseInsertMsgStream(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(2)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// close twice
|
||||
err = m.CloseInsertMsgStream(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(3)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// close twice
|
||||
err = m.CloseInsertMsgStream(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestInsertChannelsMap_GetInsertChannels(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &ProxyNode{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
var channels []string
|
||||
|
||||
_ = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
|
||||
// don't exist
|
||||
channels, err = m.GetInsertChannels(0)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
channels, err = m.GetInsertChannels(1)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"1"}))
|
||||
|
||||
channels, err = m.GetInsertChannels(2)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"1"}))
|
||||
|
||||
channels, err = m.GetInsertChannels(3)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"3"}))
|
||||
|
||||
_ = m.CloseInsertMsgStream(1)
|
||||
channels, err = m.GetInsertChannels(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
_ = m.CloseInsertMsgStream(2)
|
||||
channels, err = m.GetInsertChannels(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
_ = m.CloseInsertMsgStream(3)
|
||||
channels, err = m.GetInsertChannels(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
}
|
||||
|
||||
func TestInsertChannelsMap_GetInsertMsgStream(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &ProxyNode{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
var stream msgstream.MsgStream
|
||||
|
||||
_ = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
|
||||
// don't exist
|
||||
stream, err = m.GetInsertMsgStream(0)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(1)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(2)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(3)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, stream)
|
||||
|
||||
_ = m.CloseInsertMsgStream(1)
|
||||
stream, err = m.GetInsertMsgStream(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
_ = m.CloseInsertMsgStream(2)
|
||||
stream, err = m.GetInsertMsgStream(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
_ = m.CloseInsertMsgStream(3)
|
||||
stream, err = m.GetInsertMsgStream(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
}
|
||||
|
||||
func TestInsertChannelsMap_CloseAllMsgStream(t *testing.T) {
|
||||
msFactory := msgstream.NewSimpleMsgStreamFactory()
|
||||
node := &ProxyNode{
|
||||
segAssigner: nil,
|
||||
msFactory: msFactory,
|
||||
}
|
||||
m := newInsertChannelsMap(node)
|
||||
|
||||
var err error
|
||||
var stream msgstream.MsgStream
|
||||
var channels []string
|
||||
|
||||
_ = m.CreateInsertMsgStream(1, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(2, []string{"1"})
|
||||
_ = m.CreateInsertMsgStream(3, []string{"3"})
|
||||
|
||||
m.CloseAllMsgStream()
|
||||
|
||||
err = m.CloseInsertMsgStream(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = m.CloseInsertMsgStream(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
channels, err = m.GetInsertChannels(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
channels, err = m.GetInsertChannels(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
channels, err = m.GetInsertChannels(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, 0, len(channels))
|
||||
|
||||
stream, err = m.GetInsertMsgStream(1)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(2)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
|
||||
stream, err = m.GetInsertMsgStream(3)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, stream)
|
||||
}
|
@ -249,7 +249,7 @@ func (node *ProxyNode) Start() error {
|
||||
func (node *ProxyNode) Stop() error {
|
||||
node.cancel()
|
||||
|
||||
globalInsertChannelsMap.closeAllMsgStream()
|
||||
globalInsertChannelsMap.CloseAllMsgStream()
|
||||
node.tsoAllocator.Close()
|
||||
node.idAllocator.Close()
|
||||
node.segAssigner.Close()
|
||||
|
@ -71,7 +71,7 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
||||
|
||||
collID := insertRequest.CollectionID
|
||||
if _, ok := channelNamesMap[collID]; !ok {
|
||||
channelNames, err := globalInsertChannelsMap.getInsertChannels(collID)
|
||||
channelNames, err := globalInsertChannelsMap.GetInsertChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -204,7 +204,7 @@ func (it *InsertTask) Execute(ctx context.Context) error {
|
||||
|
||||
msgPack.Msgs[0] = tsMsg
|
||||
|
||||
stream, err := globalInsertChannelsMap.getInsertMsgStream(collID)
|
||||
stream, err := globalInsertChannelsMap.GetInsertMsgStream(collID)
|
||||
if err != nil {
|
||||
resp, _ := it.dataService.GetInsertChannels(ctx, &datapb.GetInsertChannelsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -222,12 +222,12 @@ func (it *InsertTask) Execute(ctx context.Context) error {
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return errors.New(resp.Status.Reason)
|
||||
}
|
||||
err = globalInsertChannelsMap.createInsertMsgStream(collID, resp.Values)
|
||||
err = globalInsertChannelsMap.CreateInsertMsgStream(collID, resp.Values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
stream, err = globalInsertChannelsMap.getInsertMsgStream(collID)
|
||||
stream, err = globalInsertChannelsMap.GetInsertMsgStream(collID)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
@ -386,7 +386,7 @@ func (cct *CreateCollectionTask) Execute(ctx context.Context) error {
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return errors.New(resp.Status.Reason)
|
||||
}
|
||||
err = globalInsertChannelsMap.createInsertMsgStream(collID, resp.Values)
|
||||
err = globalInsertChannelsMap.CreateInsertMsgStream(collID, resp.Values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -464,7 +464,7 @@ func (dct *DropCollectionTask) Execute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = globalInsertChannelsMap.closeInsertMsgStream(collID)
|
||||
err = globalInsertChannelsMap.CloseInsertMsgStream(collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user