Refine channels management in Proxy. (#17334)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2022-06-02 15:34:04 +08:00 committed by GitHub
parent 5fdbe23779
commit adf3b14027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 858 additions and 644 deletions

View File

@ -18,9 +18,9 @@ package proxy
import (
"context"
"errors"
"fmt"
"runtime"
"sort"
"strconv"
"sync"
@ -30,7 +30,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
"go.uber.org/zap"
)
@ -45,60 +44,70 @@ type channelsMgr interface {
removeAllDMLStream() error
}
type channelInfos struct {
// It seems that there is no need to maintain relationships between vchans & pchans.
vchans []vChan
pchans []pChan
}
type streamInfos struct {
channelInfos channelInfos
stream msgstream.MsgStream
}
func removeDuplicate(ss []string) []string {
m := make(map[string]struct{})
filtered := make([]string, 0, len(ss))
for _, s := range ss {
if _, ok := m[s]; !ok {
filtered = append(filtered, s)
m[s] = struct{}{}
}
}
return filtered
}
func newChannels(vchans []vChan, pchans []pChan) (channelInfos, error) {
if len(vchans) != len(pchans) {
err := fmt.Errorf("physical channels mismatch virtual channels, len(VirtualChannelNames): %v, len(PhysicalChannelNames): %v", len(vchans), len(pchans))
log.Error(err.Error())
return channelInfos{}, err
}
/*
// remove duplicate physical channels.
return channelInfos{vchans: vchans, pchans: removeDuplicate(pchans)}, nil
*/
return channelInfos{vchans: vchans, pchans: pchans}, nil
}
// getChannelsFuncType returns the channel information according to the collection id.
type getChannelsFuncType = func(collectionID UniqueID) (map[vChan]pChan, error)
type getChannelsFuncType = func(collectionID UniqueID) (channelInfos, error)
// repackFuncType repacks message into message pack.
type repackFuncType = func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error)
// getDmlChannelsFunc returns a function about how to get dml channels of a collection.
func getDmlChannelsFunc(ctx context.Context, rc types.RootCoord) getChannelsFuncType {
return func(collectionID UniqueID) (map[vChan]pChan, error) {
return func(collectionID UniqueID) (channelInfos, error) {
req := &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: 0, // todo
Timestamp: 0, // todo
SourceID: 0, // todo
},
DbName: "", // todo
CollectionName: "", // todo
CollectionID: collectionID,
TimeStamp: 0, // todo
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DescribeCollection},
CollectionID: collectionID,
}
resp, err := rc.DescribeCollection(ctx, req)
if err != nil {
log.Warn("failed to describe collection", zap.Error(err), zap.Int64("collection", collectionID))
return nil, err
}
if resp.Status.ErrorCode != 0 {
log.Warn("DescribeCollection",
zap.Any("ErrorCode", resp.Status.ErrorCode),
zap.Any("Reason", resp.Status.Reason))
return nil, err
}
if len(resp.VirtualChannelNames) != len(resp.PhysicalChannelNames) {
err := fmt.Errorf(
"len(VirtualChannelNames): %v, len(PhysicalChannelNames): %v",
len(resp.VirtualChannelNames),
len(resp.PhysicalChannelNames))
log.Warn("GetDmlChannels", zap.Error(err))
return nil, err
log.Error("failed to describe collection", zap.Error(err), zap.Int64("collection", collectionID))
return channelInfos{}, err
}
ret := make(map[vChan]pChan)
for idx, name := range resp.VirtualChannelNames {
if _, ok := ret[name]; ok {
err := fmt.Errorf(
"duplicated virtual channel found, vchan: %v, pchan: %v",
name,
resp.PhysicalChannelNames[idx])
return nil, err
}
ret[name] = resp.PhysicalChannelNames[idx]
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Error("failed to describe collection",
zap.String("error_code", resp.GetStatus().GetErrorCode().String()),
zap.String("reason", resp.GetStatus().GetReason()))
return channelInfos{}, errors.New(resp.GetStatus().GetReason())
}
return ret, nil
return newChannels(resp.GetVirtualChannelNames(), resp.GetPhysicalChannelNames())
}
}
@ -111,323 +120,195 @@ const (
)
type singleTypeChannelsMgr struct {
collectionID2VIDs map[UniqueID][]int // id are sorted
collMtx sync.RWMutex
id2vchans map[int][]vChan
id2vchansMtx sync.RWMutex
id2Stream map[int]msgstream.MsgStream
id2UsageHistogramOfStream map[int]int
streamMtx sync.RWMutex
vchans2pchans map[vChan]pChan
vchans2pchansMtx sync.RWMutex
getChannelsFunc getChannelsFuncType
repackFunc repackFuncType
infos map[UniqueID]streamInfos // collection id -> stream infos
mu sync.RWMutex
getChannelsFunc getChannelsFuncType
repackFunc repackFuncType
singleStreamType streamType
msgStreamFactory msgstream.Factory
}
func getAllKeys(m map[vChan]pChan) []vChan {
keys := make([]vChan, 0, len(m))
for key := range m {
keys = append(keys, key)
}
return keys
}
func (mgr *singleTypeChannelsMgr) getAllChannels(collectionID UniqueID) (channelInfos, error) {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
func getAllValues(m map[vChan]pChan) []pChan {
values := make([]pChan, 0, len(m))
for _, value := range m {
values = append(values, value)
}
return values
}
func (mgr *singleTypeChannelsMgr) getLatestVID(collectionID UniqueID) (int, error) {
mgr.collMtx.RLock()
defer mgr.collMtx.RUnlock()
ids, ok := mgr.collectionID2VIDs[collectionID]
if !ok || len(ids) <= 0 {
return 0, fmt.Errorf("v-channel ID is not found for collection %d", collectionID)
infos, ok := mgr.infos[collectionID]
if ok {
return infos.channelInfos, nil
}
return ids[len(ids)-1], nil
return channelInfos{}, fmt.Errorf("collection not found in channels manager: %d", collectionID)
}
func (mgr *singleTypeChannelsMgr) getAllVIDs(collectionID UniqueID) ([]int, error) {
mgr.collMtx.RLock()
defer mgr.collMtx.RUnlock()
ids, exist := mgr.collectionID2VIDs[collectionID]
if !exist {
return nil, fmt.Errorf("collection %d not found", collectionID)
func (mgr *singleTypeChannelsMgr) getPChans(collectionID UniqueID) ([]pChan, error) {
channelInfos, err := mgr.getChannelsFunc(collectionID)
if err != nil {
return nil, err
}
return ids, nil
return channelInfos.pchans, nil
}
func (mgr *singleTypeChannelsMgr) getVChansByVID(vid int) ([]vChan, error) {
mgr.id2vchansMtx.RLock()
defer mgr.id2vchansMtx.RUnlock()
vchans, ok := mgr.id2vchans[vid]
if !ok {
return nil, fmt.Errorf("vid %d not found", vid)
}
return vchans, nil
}
// getPChansByVChans converts virtual channel names to physical channel names
func (mgr *singleTypeChannelsMgr) getPChansByVChans(vchans []vChan) ([]pChan, error) {
mgr.vchans2pchansMtx.RLock()
defer mgr.vchans2pchansMtx.RUnlock()
pchans := make([]pChan, 0, len(vchans))
for _, vchan := range vchans {
pchan, ok := mgr.vchans2pchans[vchan]
if !ok {
return nil, fmt.Errorf("vchan %v not found", vchan)
}
pchans = append(pchans, pchan)
}
return pchans, nil
}
func (mgr *singleTypeChannelsMgr) updateVChans(vid int, vchans []vChan) {
mgr.id2vchansMtx.Lock()
defer mgr.id2vchansMtx.Unlock()
mgr.id2vchans[vid] = vchans
}
func (mgr *singleTypeChannelsMgr) deleteVChansByVID(vid int) {
mgr.id2vchansMtx.Lock()
defer mgr.id2vchansMtx.Unlock()
delete(mgr.id2vchans, vid)
}
func (mgr *singleTypeChannelsMgr) deleteVChansByVIDs(vids []int) {
mgr.id2vchansMtx.Lock()
defer mgr.id2vchansMtx.Unlock()
for _, vid := range vids {
delete(mgr.id2vchans, vid)
}
}
func (mgr *singleTypeChannelsMgr) deleteStreamByVID(vid int) {
mgr.streamMtx.Lock()
defer mgr.streamMtx.Unlock()
delete(mgr.id2Stream, vid)
}
func (mgr *singleTypeChannelsMgr) deleteStreamByVIDs(vids []int) {
mgr.streamMtx.Lock()
defer mgr.streamMtx.Unlock()
for _, vid := range vids {
delete(mgr.id2Stream, vid)
}
}
func (mgr *singleTypeChannelsMgr) updateChannels(channels map[vChan]pChan) {
mgr.vchans2pchansMtx.Lock()
defer mgr.vchans2pchansMtx.Unlock()
for vchan, pchan := range channels {
mgr.vchans2pchans[vchan] = pchan
}
}
func (mgr *singleTypeChannelsMgr) deleteAllChannels() {
mgr.vchans2pchansMtx.Lock()
defer mgr.vchans2pchansMtx.Unlock()
mgr.vchans2pchans = nil
}
func (mgr *singleTypeChannelsMgr) deleteAllStream() {
mgr.id2vchansMtx.Lock()
defer mgr.id2vchansMtx.Unlock()
mgr.id2UsageHistogramOfStream = nil
mgr.id2Stream = nil
}
func (mgr *singleTypeChannelsMgr) deleteAllVChans() {
mgr.id2vchansMtx.Lock()
defer mgr.id2vchansMtx.Unlock()
mgr.id2vchans = nil
}
func (mgr *singleTypeChannelsMgr) deleteAllCollection() {
mgr.collMtx.Lock()
defer mgr.collMtx.Unlock()
mgr.collectionID2VIDs = nil
}
func (mgr *singleTypeChannelsMgr) addStream(vid int, stream msgstream.MsgStream) {
mgr.streamMtx.Lock()
defer mgr.streamMtx.Unlock()
mgr.id2Stream[vid] = stream
mgr.id2UsageHistogramOfStream[vid] = 0
}
func (mgr *singleTypeChannelsMgr) updateCollection(collectionID UniqueID, id int) {
mgr.collMtx.Lock()
defer mgr.collMtx.Unlock()
vids, ok := mgr.collectionID2VIDs[collectionID]
if !ok {
mgr.collectionID2VIDs[collectionID] = make([]int, 1)
mgr.collectionID2VIDs[collectionID][0] = id
} else {
vids = append(vids, id)
sort.Slice(vids, func(i, j int) bool {
return vids[i] < vids[j]
})
mgr.collectionID2VIDs[collectionID] = vids
func (mgr *singleTypeChannelsMgr) getVChans(collectionID UniqueID) ([]vChan, error) {
channelInfos, err := mgr.getChannelsFunc(collectionID)
if err != nil {
return nil, err
}
return channelInfos.vchans, nil
}
// getChannels returns the physical channels.
func (mgr *singleTypeChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) {
id, err := mgr.getLatestVID(collectionID)
if err == nil {
vchans, err := mgr.getVChansByVID(id)
if err != nil {
return nil, err
}
return mgr.getPChansByVChans(vchans)
var channelInfos channelInfos
channelInfos, err := mgr.getAllChannels(collectionID)
if err != nil {
return mgr.getPChans(collectionID)
}
// TODO(dragondriver): return error or update channel information from master?
return nil, err
return channelInfos.pchans, nil
}
// getVChannels returns the virtual channels.
func (mgr *singleTypeChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) {
id, err := mgr.getLatestVID(collectionID)
if err == nil {
return mgr.getVChansByVID(id)
var channelInfos channelInfos
channelInfos, err := mgr.getAllChannels(collectionID)
if err != nil {
return mgr.getVChans(collectionID)
}
return channelInfos.vchans, nil
}
// TODO(dragondriver): return error or update channel information from master?
return nil, err
func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool {
streamInfos, ok := mgr.infos[collectionID]
return ok && streamInfos.stream != nil
}
func (mgr *singleTypeChannelsMgr) streamExist(collectionID UniqueID) bool {
stream, err := mgr.getStream(collectionID)
return err == nil && stream != nil
mgr.mu.RLock()
defer mgr.mu.RUnlock()
return mgr.streamExistPrivate(collectionID)
}
func createStream(factory msgstream.Factory, streamType streamType, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
var stream msgstream.MsgStream
var err error
if streamType == dqlStreamType {
stream, err = factory.NewQueryMsgStream(context.Background())
} else {
stream, err = factory.NewMsgStream(context.Background())
}
if err != nil {
return nil, err
}
stream.AsProducer(pchans)
if repack != nil {
stream.SetRepackFunc(repack)
}
runtime.SetFinalizer(stream, func(stream msgstream.MsgStream) {
stream.Close()
})
return stream, nil
}
func (mgr *singleTypeChannelsMgr) updateCollection(collectionID UniqueID, channelInfos channelInfos, stream msgstream.MsgStream) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
if !mgr.streamExistPrivate(collectionID) {
mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream}
}
}
func incPChansMetrics(pchans []pChan) {
for _, pc := range pchans {
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Inc()
}
}
func decPChanMetrics(pchans []pChan) {
for _, pc := range pchans {
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Dec()
}
}
// createMsgStream create message stream for specified collection. Idempotent.
// If stream already exists, directly return nil.
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) error {
if mgr.streamExist(collectionID) {
log.Info("stream already exist, no need to re-create", zap.Int64("collection_id", collectionID))
return nil
}
channels, err := mgr.getChannelsFunc(collectionID)
channelInfos, err := mgr.getChannelsFunc(collectionID)
if err != nil {
log.Warn("failed to create message stream",
zap.Int64("collection_id", collectionID),
zap.Error(err))
log.Error("failed to get channels", zap.Error(err), zap.Int64("collection", collectionID))
return err
}
log.Debug("singleTypeChannelsMgr",
stream, err := createStream(mgr.msgStreamFactory, mgr.singleStreamType, channelInfos.pchans, mgr.repackFunc)
if err != nil {
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
return err
}
mgr.updateCollection(collectionID, channelInfos, stream)
log.Info("create message stream",
zap.Int64("collection_id", collectionID),
zap.Any("createMsgStream.getChannels", channels))
zap.Strings("virtual_channels", channelInfos.vchans),
zap.Strings("physical_channels", channelInfos.pchans))
mgr.updateChannels(channels)
incPChansMetrics(channelInfos.pchans)
id := uniquegenerator.GetUniqueIntGeneratorIns().GetInt()
vchans, pchans := make([]string, 0, len(channels)), make([]string, 0, len(channels))
for k, v := range channels {
vchans = append(vchans, k)
pchans = append(pchans, v)
}
mgr.updateVChans(id, vchans)
var stream msgstream.MsgStream
if mgr.singleStreamType == dqlStreamType {
stream, err = mgr.msgStreamFactory.NewQueryMsgStream(context.Background())
} else {
stream, err = mgr.msgStreamFactory.NewMsgStream(context.Background())
}
if err != nil {
return err
}
stream.AsProducer(pchans)
if mgr.repackFunc != nil {
stream.SetRepackFunc(mgr.repackFunc)
}
runtime.SetFinalizer(stream, func(stream msgstream.MsgStream) {
stream.Close()
})
mgr.addStream(id, stream)
mgr.updateCollection(collectionID, id)
for _, pc := range pchans {
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Inc()
}
return nil
}
func (mgr *singleTypeChannelsMgr) getStream(collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.streamMtx.RLock()
defer mgr.streamMtx.RUnlock()
func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
streamInfos, ok := mgr.infos[collectionID]
if ok {
return streamInfos.stream, nil
}
return nil, fmt.Errorf("collection not found: %d", collectionID)
}
vid, err := mgr.getLatestVID(collectionID)
if err != nil {
// getStream get message stream of specified collection.
// If stream don't exists, call createMsgStream to create for it.
func (mgr *singleTypeChannelsMgr) getStream(collectionID UniqueID) (msgstream.MsgStream, error) {
if stream, err := mgr.lockGetStream(collectionID); err == nil {
return stream, nil
}
if err := mgr.createMsgStream(collectionID); err != nil {
return nil, err
}
stream, ok := mgr.id2Stream[vid]
if !ok {
return nil, fmt.Errorf("no dml stream for collection %v", collectionID)
}
return stream, nil
return mgr.lockGetStream(collectionID)
}
// removeStream remove the corresponding stream of the specified collection. Idempotent.
// If stream already exists, remove it, otherwise do nothing.
func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) error {
channels, err := mgr.getChannels(collectionID)
if err != nil {
return err
}
ids, err2 := mgr.getAllVIDs(collectionID)
if err2 != nil {
return err2
}
mgr.deleteVChansByVIDs(ids)
mgr.deleteStreamByVIDs(ids)
for _, pc := range channels {
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Dec()
mgr.mu.Lock()
defer mgr.mu.Unlock()
if info, ok := mgr.infos[collectionID]; ok {
decPChanMetrics(info.channelInfos.pchans)
delete(mgr.infos, collectionID)
}
return nil
}
// removeAllStream remove all message stream.
func (mgr *singleTypeChannelsMgr) removeAllStream() error {
mgr.deleteAllChannels()
mgr.deleteAllStream()
mgr.deleteAllVChans()
mgr.deleteAllCollection()
mgr.mu.Lock()
defer mgr.mu.Unlock()
for _, info := range mgr.infos {
decPChanMetrics(info.channelInfos.pchans)
}
mgr.infos = make(map[UniqueID]streamInfos)
return nil
}
@ -438,15 +319,11 @@ func newSingleTypeChannelsMgr(
singleStreamType streamType,
) *singleTypeChannelsMgr {
return &singleTypeChannelsMgr{
collectionID2VIDs: make(map[UniqueID][]int),
id2vchans: make(map[int][]vChan),
id2Stream: make(map[int]msgstream.MsgStream),
id2UsageHistogramOfStream: make(map[int]int),
vchans2pchans: make(map[vChan]pChan),
getChannelsFunc: getChannelsFunc,
repackFunc: repackFunc,
singleStreamType: singleStreamType,
msgStreamFactory: msgStreamFactory,
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: getChannelsFunc,
repackFunc: repackFunc,
singleStreamType: singleStreamType,
msgStreamFactory: msgStreamFactory,
}
}
@ -486,7 +363,6 @@ func (mgr *channelsMgrImpl) removeAllDMLStream() error {
func newChannelsMgrImpl(
getDmlChannelsFunc getChannelsFuncType,
dmlRepackFunc repackFuncType,
dqlRepackFunc repackFuncType,
msgStreamFactory msgstream.Factory,
) *channelsMgrImpl {
return &channelsMgrImpl{

View File

@ -17,153 +17,398 @@
package proxy
import (
"sync"
"context"
"errors"
"testing"
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/stretchr/testify/assert"
)
func TestChannelsMgrImpl_getChannels(t *testing.T) {
master := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
_, err := mgr.getChannels(collID)
assert.NotEqual(t, nil, err)
err = mgr.createDMLMsgStream(collID)
assert.Equal(t, nil, err)
_, err = mgr.getChannels(collID)
assert.Equal(t, nil, err)
func Test_removeDuplicate(t *testing.T) {
s1 := []string{"11", "11"}
filtered1 := removeDuplicate(s1)
assert.ElementsMatch(t, filtered1, []string{"11"})
}
func TestChannelsMgrImpl_getVChannels(t *testing.T) {
master := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
_, err := mgr.getVChannels(collID)
assert.NotEqual(t, nil, err)
err = mgr.createDMLMsgStream(collID)
assert.Equal(t, nil, err)
_, err = mgr.getVChannels(collID)
assert.Equal(t, nil, err)
}
func TestChannelsMgrImpl_createDMLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
_, err := mgr.getChannels(collID)
assert.NotEqual(t, nil, err)
_, err = mgr.getVChannels(collID)
assert.NotEqual(t, nil, err)
err = mgr.createDMLMsgStream(collID)
assert.Equal(t, nil, err)
// re-create message stream.
err = mgr.createDMLMsgStream(collID)
assert.Equal(t, nil, err)
_, err = mgr.getChannels(collID)
assert.Equal(t, nil, err)
_, err = mgr.getVChannels(collID)
assert.Equal(t, nil, err)
}
func TestChannelsMgrImpl_getDMLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
_, err := mgr.getDMLStream(collID)
assert.NotEqual(t, nil, err)
err = mgr.createDMLMsgStream(collID)
assert.Equal(t, nil, err)
_, err = mgr.getDMLStream(collID)
assert.Equal(t, nil, err)
}
func TestChannelsMgrImpl_removeDMLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
_, err := mgr.getDMLStream(collID)
assert.NotEqual(t, nil, err)
err = mgr.removeDMLStream(collID)
assert.NotEqual(t, nil, err)
err = mgr.createDMLMsgStream(collID)
assert.Equal(t, nil, err)
_, err = mgr.getDMLStream(collID)
assert.Equal(t, nil, err)
err = mgr.removeDMLStream(collID)
assert.Equal(t, nil, err)
_, err = mgr.getDMLStream(collID)
assert.NotEqual(t, nil, err)
}
func TestChannelsMgrImpl_removeAllDMLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
defer mgr.removeAllDMLStream()
num := 10
for i := 0; i < num; i++ {
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
err := mgr.createDMLMsgStream(collID)
assert.Equal(t, nil, err)
}
}
func TestGetAllKeysAndGetAllValues(t *testing.T) {
chanMapping := make(map[vChan]pChan)
chanMapping["v1"] = "p1"
chanMapping["v2"] = "p2"
t.Run("getAllKeys", func(t *testing.T) {
vChans := getAllKeys(chanMapping)
assert.Equal(t, 2, len(vChans))
func Test_newChannels(t *testing.T) {
t.Run("length mismatch", func(t *testing.T) {
_, err := newChannels([]string{"111", "222"}, []string{"111"})
assert.Error(t, err)
})
t.Run("getAllValues", func(t *testing.T) {
pChans := getAllValues(chanMapping)
assert.Equal(t, 2, len(pChans))
t.Run("normal case", func(t *testing.T) {
got, err := newChannels([]string{"111", "222"}, []string{"111", "111"})
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"111", "222"}, got.vchans)
// assert.ElementsMatch(t, []string{"111"}, got.pchans)
assert.ElementsMatch(t, []string{"111", "111"}, got.pchans)
})
}
func TestDeleteVChansByVID(t *testing.T) {
mgr := singleTypeChannelsMgr{
id2vchansMtx: sync.RWMutex{},
id2vchans: map[int][]vChan{
10: {"v1"},
func Test_getDmlChannelsFunc(t *testing.T) {
t.Run("failed to describe collection", func(t *testing.T) {
ctx := context.Background()
rc := newMockRootCoord()
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
return nil, errors.New("mock")
}
f := getDmlChannelsFunc(ctx, rc)
_, err := f(100)
assert.Error(t, err)
})
t.Run("error code not success", func(t *testing.T) {
ctx := context.Background()
rc := newMockRootCoord()
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil
}
f := getDmlChannelsFunc(ctx, rc)
_, err := f(100)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
ctx := context.Background()
rc := newMockRootCoord()
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{
VirtualChannelNames: []string{"111", "222"},
PhysicalChannelNames: []string{"111", "111"},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil
}
f := getDmlChannelsFunc(ctx, rc)
got, err := f(100)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"111", "222"}, got.vchans)
// assert.ElementsMatch(t, []string{"111"}, got.pchans)
assert.ElementsMatch(t, []string{"111", "111"}, got.pchans)
})
}
func Test_singleTypeChannelsMgr_getAllChannels(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {channelInfos: channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}},
},
}
got, err := m.getAllChannels(100)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"111", "222"}, got.vchans)
assert.ElementsMatch(t, []string{"111"}, got.pchans)
})
t.Run("not found", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{},
}
_, err := m.getAllChannels(100)
assert.Error(t, err)
})
}
func Test_singleTypeChannelsMgr_getPChans(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
}
got, err := m.getPChans(100)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"111"}, got)
})
t.Run("error case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getPChans(100)
assert.Error(t, err)
})
}
func Test_singleTypeChannelsMgr_getVChans(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
}
got, err := m.getVChans(100)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"111", "222"}, got)
})
t.Run("error case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getVChans(100)
assert.Error(t, err)
})
}
func Test_singleTypeChannelsMgr_getChannels(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {channelInfos: channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}},
},
}
got, err := m.getChannels(100)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"111"}, got)
})
t.Run("error case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getChannels(100)
assert.Error(t, err)
})
}
func Test_singleTypeChannelsMgr_getVChannels(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {channelInfos: channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}},
},
}
got, err := m.getVChannels(100)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"111", "222"}, got)
})
t.Run("error case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getVChannels(100)
assert.Error(t, err)
})
}
func Test_singleTypeChannelsMgr_streamExist(t *testing.T) {
t.Run("exist", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newSimpleMockMsgStream()},
},
}
exist := m.streamExist(100)
assert.True(t, exist)
})
t.Run("not exist", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: nil},
},
}
exist := m.streamExist(100)
assert.False(t, exist)
m.infos = make(map[UniqueID]streamInfos)
exist = m.streamExist(100)
assert.False(t, exist)
})
}
func Test_createStream(t *testing.T) {
t.Run("failed to create msgstream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(factory, dmlStreamType, nil, nil)
assert.Error(t, err)
})
t.Run("failed to create query msgstream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(factory, dqlStreamType, nil, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
_, err := createStream(factory, dmlStreamType, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return nil, nil
})
assert.NoError(t, err)
})
}
func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
t.Run("re-create", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
err := m.createMsgStream(100)
assert.NoError(t, err)
})
t.Run("failed to get channels", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
err := m.createMsgStream(100)
assert.Error(t, err)
})
t.Run("failed to create message stream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
singleStreamType: dmlStreamType,
repackFunc: nil,
}
err := m.createMsgStream(100)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
singleStreamType: dmlStreamType,
repackFunc: nil,
}
err := m.createMsgStream(100)
assert.NoError(t, err)
stream, err := m.getStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_lockGetStream(t *testing.T) {
t.Run("collection not found", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
}
_, err := m.lockGetStream(100)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
stream, err := m.lockGetStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
t.Run("exist", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
stream, err := m.getStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
t.Run("failed to create", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{},
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getStream(100)
assert.Error(t, err)
})
t.Run("get after create", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
singleStreamType: dmlStreamType,
repackFunc: nil,
}
stream, err := m.getStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_removeStream(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {
stream: newMockMsgStream(),
},
},
}
mgr.deleteVChansByVID(10)
err := m.removeStream(100)
assert.NoError(t, err)
_, err = m.lockGetStream(100)
assert.Error(t, err)
}
func Test_singleTypeChannelsMgr_removeAllStream(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {
stream: newMockMsgStream(),
},
},
}
err := m.removeAllStream()
assert.NoError(t, err)
_, err = m.lockGetStream(100)
assert.Error(t, err)
}

View File

@ -3609,26 +3609,31 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi
zap.Error(err))
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
resp.Status.Reason = err.Error()
return resp, err
return resp, nil
}
chNames, err := node.chMgr.getVChannels(collID)
if err != nil {
err = node.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
chNames, err = node.chMgr.getVChannels(collID)
if err != nil {
return nil, err
}
log.Error("failed to get virtual channels",
zap.Error(err),
zap.String("collection", req.GetCollectionName()),
zap.Int64("collection_id", collID))
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
resp.Status.Reason = err.Error()
return resp, nil
}
req.ChannelNames = chNames
if req.GetPartitionName() == "" {
req.PartitionName = Params.CommonCfg.DefaultPartitionName
}
// Call rootCoord to finish import.
resp, err = node.rootCoord.Import(ctx, req)
return resp, err
respFromRC, err := node.rootCoord.Import(ctx, req)
if err != nil {
log.Error("failed to execute bulk load request", zap.Error(err))
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
resp.Status.Reason = err.Error()
return resp, nil
}
return respFromRC, nil
}
// GetImportState checks import task state from datanode

View File

@ -30,6 +30,9 @@ func (m *mockCache) GetCollectionSchema(ctx context.Context, collectionName stri
return nil, nil
}
func (m *mockCache) RemoveCollection(ctx context.Context, collectionName string) {
}
func (m *mockCache) setGetIDFunc(f getCollectionIDFunc) {
m.getIDFunc = f
}

View File

@ -0,0 +1,28 @@
package proxy
type getVChannelsFuncType = func(collectionID UniqueID) ([]vChan, error)
type removeDMLStreamFuncType = func(collectionID UniqueID) error
type mockChannelsMgr struct {
channelsMgr
getVChannelsFuncType
removeDMLStreamFuncType
}
func (m *mockChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) {
if m.getVChannelsFuncType != nil {
return m.getVChannelsFuncType(collectionID)
}
return nil, nil
}
func (m *mockChannelsMgr) removeDMLStream(collectionID UniqueID) error {
if m.removeDMLStreamFuncType != nil {
return m.removeDMLStreamFuncType(collectionID)
}
return nil
}
func newMockChannelsMgr() *mockChannelsMgr {
return &mockChannelsMgr{}
}

View File

@ -0,0 +1,69 @@
package proxy
import (
"context"
"errors"
"github.com/milvus-io/milvus/internal/mq/msgstream"
)
type mockMsgStream struct {
msgstream.MsgStream
asProducer func([]string)
setRepack func(repackFunc msgstream.RepackFunc)
close func()
}
func (m *mockMsgStream) AsProducer(producers []string) {
if m.asProducer != nil {
m.asProducer(producers)
}
}
func (m *mockMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {
if m.setRepack != nil {
m.setRepack(repackFunc)
}
}
func (m *mockMsgStream) Close() {
if m.close != nil {
m.close()
}
}
func newMockMsgStream() *mockMsgStream {
return &mockMsgStream{}
}
type mockMsgStreamFactory struct {
msgstream.Factory
f func(ctx context.Context) (msgstream.MsgStream, error)
fQStream func(ctx context.Context) (msgstream.MsgStream, error)
fTtStream func(ctx context.Context) (msgstream.MsgStream, error)
}
func (m *mockMsgStreamFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
if m.f != nil {
return m.f(ctx)
}
return nil, errors.New("mock")
}
func (m *mockMsgStreamFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
if m.fTtStream != nil {
return m.fTtStream(ctx)
}
return nil, errors.New("mock")
}
func (m *mockMsgStreamFactory) NewQueryMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
if m.fQStream != nil {
return m.fQStream(ctx)
}
return nil, errors.New("mock")
}
func newMockMsgStreamFactory() *mockMsgStreamFactory {
return &mockMsgStreamFactory{}
}

View File

@ -86,37 +86,6 @@ func newMockIDAllocatorInterface() idAllocatorInterface {
return &mockIDAllocatorInterface{}
}
type mockGetChannelsService struct {
collectionID2Channels map[UniqueID]map[vChan]pChan
f getChannelsFuncType
}
func newMockGetChannelsService() *mockGetChannelsService {
return &mockGetChannelsService{
collectionID2Channels: make(map[UniqueID]map[vChan]pChan),
}
}
func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) {
if m.f != nil {
return m.f(collectionID)
}
channels, ok := m.collectionID2Channels[collectionID]
if ok {
return channels, nil
}
channels = make(map[vChan]pChan)
l := rand.Uint64()%10 + 1
for i := 0; uint64(i) < l; i++ {
channels[funcutil.GenRandomStr()] = funcutil.GenRandomStr()
}
m.collectionID2Channels[collectionID] = channels
return channels, nil
}
type mockTask struct {
*TaskCondition
id UniqueID

View File

@ -198,7 +198,7 @@ func (node *Proxy) Init() error {
log.Debug("create channels manager", zap.String("role", typeutil.ProxyRole))
dmlChannelsFunc := getDmlChannelsFunc(node.ctx, node.rootCoord)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, nil, node.factory)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, node.factory)
node.chMgr = chMgr
log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole))

View File

@ -18,6 +18,7 @@ package proxy
import (
"context"
"errors"
"net"
"os"
"strconv"
@ -1591,8 +1592,8 @@ func TestProxy(t *testing.T) {
}
proxy.stateCode.Store(internalpb.StateCode_Healthy)
resp, err := proxy.Import(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode)
assert.Error(t, err)
})
wg.Add(1)
@ -1604,8 +1605,8 @@ func TestProxy(t *testing.T) {
}
proxy.stateCode.Store(internalpb.StateCode_Healthy)
resp, err := proxy.Import(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode)
assert.Error(t, err)
})
wg.Add(1)
@ -3028,63 +3029,104 @@ func TestProxy_GetComponentStates_state_code(t *testing.T) {
}
func TestProxy_Import(t *testing.T) {
rc := NewRootCoordMock()
master := newMockGetChannelsService()
msgStreamFactory := newSimpleMockMsgStreamFactory()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
shardMgr := newShardClientMgr()
err := InitMetaCache(rc, qc, shardMgr)
assert.NoError(t, err)
rc.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
MsgID: 100,
Timestamp: 100,
},
CollectionName: "import_collection",
})
localMsg := true
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
factory := dependency.NewDefaultFactory(localMsg)
proxy, err := NewProxy(ctx, factory)
proxy.rootCoord = rc
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
t.Run("test import get vChannel failed (the first one)", func(t *testing.T) {
defer wg.Done()
proxy.stateCode.Store(internalpb.StateCode_Healthy)
proxy.chMgr = newChannelsMgrImpl(master.GetChannels, nil, nil, msgStreamFactory)
resp, err := proxy.Import(context.TODO(),
&milvuspb.ImportRequest{
CollectionName: "import_collection",
})
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.NoError(t, err)
})
wg.Add(1)
t.Run("test import with unhealthy", func(t *testing.T) {
defer wg.Done()
req := &milvuspb.ImportRequest{
CollectionName: "dummy",
}
proxy.stateCode.Store(internalpb.StateCode_Abnormal)
proxy := &Proxy{}
proxy.UpdateStateCode(internalpb.StateCode_Abnormal)
resp, err := proxy.Import(context.TODO(), req)
assert.EqualValues(t, unhealthyStatus(), resp.Status)
assert.NoError(t, err)
assert.EqualValues(t, unhealthyStatus(), resp.GetStatus())
})
resp, err := rc.DropCollection(context.TODO(), &milvuspb.DropCollectionRequest{
CollectionName: "import_collection",
wg.Add(1)
t.Run("collection not found", func(t *testing.T) {
defer wg.Done()
proxy := &Proxy{}
proxy.UpdateStateCode(internalpb.StateCode_Healthy)
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 0, errors.New("mock")
})
globalMetaCache = cache
req := &milvuspb.ImportRequest{
CollectionName: "dummy",
}
resp, err := proxy.Import(context.TODO(), req)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
wg.Add(1)
t.Run("failed to get virtual channels", func(t *testing.T) {
defer wg.Done()
proxy := &Proxy{}
proxy.UpdateStateCode(internalpb.StateCode_Healthy)
cache := newMockCache()
globalMetaCache = cache
chMgr := newMockChannelsMgr()
chMgr.getVChannelsFuncType = func(collectionID UniqueID) ([]vChan, error) {
return nil, errors.New("mock")
}
proxy.chMgr = chMgr
req := &milvuspb.ImportRequest{
CollectionName: "dummy",
}
resp, err := proxy.Import(context.TODO(), req)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
wg.Add(1)
t.Run("rootcoord fail", func(t *testing.T) {
defer wg.Done()
proxy := &Proxy{}
proxy.UpdateStateCode(internalpb.StateCode_Healthy)
cache := newMockCache()
globalMetaCache = cache
chMgr := newMockChannelsMgr()
proxy.chMgr = chMgr
rc := newMockRootCoord()
rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
return nil, errors.New("mock")
}
proxy.rootCoord = rc
req := &milvuspb.ImportRequest{
CollectionName: "dummy",
}
resp, err := proxy.Import(context.TODO(), req)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
wg.Add(1)
t.Run("normal case", func(t *testing.T) {
defer wg.Done()
proxy := &Proxy{}
proxy.UpdateStateCode(internalpb.StateCode_Healthy)
cache := newMockCache()
globalMetaCache = cache
chMgr := newMockChannelsMgr()
proxy.chMgr = chMgr
rc := newMockRootCoord()
rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
return &milvuspb.ImportResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil
}
proxy.rootCoord = rc
req := &milvuspb.ImportRequest{
CollectionName: "dummy",
}
resp, err := proxy.Import(context.TODO(), req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
wg.Wait()
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
rc.Stop()
}
func TestProxy_GetImportState(t *testing.T) {

View File

@ -1066,6 +1066,8 @@ type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitio
type DescribeIndexFunc func(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error)
type ShowSegmentsFunc func(ctx context.Context, request *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error)
type DescribeSegmentsFunc func(ctx context.Context, request *rootcoordpb.DescribeSegmentsRequest) (*rootcoordpb.DescribeSegmentsResponse, error)
type ImportFunc func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error)
type DropCollectionFunc func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error)
type mockRootCoord struct {
types.RootCoord
@ -1074,6 +1076,8 @@ type mockRootCoord struct {
DescribeIndexFunc
ShowSegmentsFunc
DescribeSegmentsFunc
ImportFunc
DropCollectionFunc
}
func (m *mockRootCoord) DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
@ -1111,6 +1115,20 @@ func (m *mockRootCoord) DescribeSegments(ctx context.Context, request *rootcoord
return nil, errors.New("mock")
}
func (m *mockRootCoord) Import(ctx context.Context, request *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
if m.ImportFunc != nil {
return m.ImportFunc(ctx, request)
}
return nil, errors.New("mock")
}
func (m *mockRootCoord) DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) {
if m.DropCollectionFunc != nil {
return m.DropCollectionFunc(ctx, request)
}
return nil, errors.New("mock")
}
func newMockRootCoord() *mockRootCoord {
return &mockRootCoord{}
}

View File

@ -111,7 +111,7 @@ type task interface {
type dmlTask interface {
task
getChannels() ([]vChan, error)
getChannels() ([]pChan, error)
getPChanStats() (map[pChan]pChanStatistics, error)
}
@ -192,16 +192,7 @@ func (it *insertTask) getChannels() ([]pChan, error) {
if err != nil {
return nil, err
}
var channels []pChan
channels, err = it.chMgr.getChannels(collID)
if err != nil {
err = it.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
channels, err = it.chMgr.getChannels(collID)
}
return channels, err
return it.chMgr.getChannels(collID)
}
func (it *insertTask) OnEnqueue() error {
@ -508,18 +499,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
stream, err := it.chMgr.getDMLStream(collID)
if err != nil {
err = it.chMgr.createDMLMsgStream(collID)
if err != nil {
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
return err
}
stream, err = it.chMgr.getDMLStream(collID)
if err != nil {
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
return err
}
return err
}
tr.Record("get used message stream")
@ -531,6 +511,14 @@ func (it *insertTask) Execute(ctx context.Context) error {
return err
}
log.Info("send insert request to virtual channels",
zap.String("collection", it.GetCollectionName()),
zap.String("partition", it.GetPartitionName()),
zap.Int64("collection_id", collID),
zap.Int64("partition_id", partitionID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", it.ID()))
// assign segmentID for insert data and repack data by segmentID
msgPack, err := it.assignSegmentID(channelNames)
if err != nil {
@ -3141,16 +3129,7 @@ func (dt *deleteTask) getChannels() ([]pChan, error) {
if err != nil {
return nil, err
}
var channels []pChan
channels, err = dt.chMgr.getChannels(collID)
if err != nil {
err = dt.chMgr.createDMLMsgStream(collID)
if err != nil {
return nil, err
}
channels, err = dt.chMgr.getChannels(collID)
}
return channels, err
return dt.chMgr.getChannels(collID)
}
func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) {
@ -3283,18 +3262,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
collID := dt.DeleteRequest.CollectionID
stream, err := dt.chMgr.getDMLStream(collID)
if err != nil {
err = dt.chMgr.createDMLMsgStream(collID)
if err != nil {
dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
dt.result.Status.Reason = err.Error()
return err
}
stream, err = dt.chMgr.getDMLStream(collID)
if err != nil {
dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
dt.result.Status.Reason = err.Error()
return err
}
return err
}
// hash primary keys to channels
@ -3307,6 +3275,12 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
}
dt.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames)
log.Info("send delete request to virtual channels",
zap.String("collection", dt.GetCollectionName()),
zap.Int64("collection_id", collID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", dt.ID()))
tr.Record("get vchannels")
// repack delete msg by dmChannel
result := make(map[uint32]msgstream.TsMsg)

View File

@ -27,15 +27,15 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
@ -1087,47 +1087,12 @@ func TestCreateCollectionTask(t *testing.T) {
func TestDropCollectionTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(rc, qc, mgr)
master := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
channelMgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
defer channelMgr.removeAllDMLStream()
prefix := "TestDropCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
ctx := context.Background()
shardsNum := int32(2)
int64Field := "int64"
floatVecField := "fvec"
dim := 128
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
}
//CreateCollection
task := &dropCollectionTask{
Condition: NewTaskCondition(ctx),
DropCollectionRequest: &milvuspb.DropCollectionRequest{
@ -1139,38 +1104,58 @@ func TestDropCollectionTask(t *testing.T) {
DbName: dbName,
CollectionName: collectionName,
},
ctx: ctx,
chMgr: channelMgr,
rootCoord: rc,
result: nil,
ctx: ctx,
result: nil,
}
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_DropCollection, task.Type())
task.SetID(100)
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, DropCollectionTaskName, task.Name())
assert.Equal(t, commonpb.MsgType_DropCollection, task.Type())
task.SetTs(100)
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
err := task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, Params.ProxyCfg.GetNodeID(), task.GetBase().GetSourceID())
// missing collectionID in globalMetaCache
err = task.Execute(ctx)
assert.NotNil(t, err)
// createCollection in RootCood and fill GlobalMetaCache
rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, collectionName)
// success to drop collection
err = task.Execute(ctx)
assert.Nil(t, err)
// illegal name
task.CollectionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.NotNil(t, err)
assert.Error(t, err)
task.CollectionName = collectionName
err = task.PreExecute(ctx)
assert.Nil(t, err)
cache := newMockCache()
chMgr := newMockChannelsMgr()
rc := newMockRootCoord()
globalMetaCache = cache
task.chMgr = chMgr
task.rootCoord = rc
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 0, errors.New("mock")
})
err = task.Execute(ctx)
assert.Error(t, err)
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 0, nil
})
rc.DropCollectionFunc = func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) {
return nil, errors.New("mock")
}
err = task.Execute(ctx)
assert.Error(t, err)
rc.DropCollectionFunc = func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) {
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}
// normal case
err = task.Execute(ctx)
assert.NoError(t, err)
err = task.PostExecute(ctx)
assert.NoError(t, err)
}
func TestHasCollectionTask(t *testing.T) {
@ -1728,7 +1713,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, nil, factory)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
err = chMgr.createDMLMsgStream(collectionID)
@ -1983,7 +1968,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, nil, factory)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
err = chMgr.createDMLMsgStream(collectionID)