mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 10:59:32 +08:00
Refine channels management in Proxy. (#17334)
Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
parent
5fdbe23779
commit
adf3b14027
@ -18,9 +18,9 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
@ -30,7 +30,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@ -45,60 +44,70 @@ type channelsMgr interface {
|
||||
removeAllDMLStream() error
|
||||
}
|
||||
|
||||
type channelInfos struct {
|
||||
// It seems that there is no need to maintain relationships between vchans & pchans.
|
||||
vchans []vChan
|
||||
pchans []pChan
|
||||
}
|
||||
|
||||
type streamInfos struct {
|
||||
channelInfos channelInfos
|
||||
stream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func removeDuplicate(ss []string) []string {
|
||||
m := make(map[string]struct{})
|
||||
filtered := make([]string, 0, len(ss))
|
||||
for _, s := range ss {
|
||||
if _, ok := m[s]; !ok {
|
||||
filtered = append(filtered, s)
|
||||
m[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func newChannels(vchans []vChan, pchans []pChan) (channelInfos, error) {
|
||||
if len(vchans) != len(pchans) {
|
||||
err := fmt.Errorf("physical channels mismatch virtual channels, len(VirtualChannelNames): %v, len(PhysicalChannelNames): %v", len(vchans), len(pchans))
|
||||
log.Error(err.Error())
|
||||
return channelInfos{}, err
|
||||
}
|
||||
/*
|
||||
// remove duplicate physical channels.
|
||||
return channelInfos{vchans: vchans, pchans: removeDuplicate(pchans)}, nil
|
||||
*/
|
||||
return channelInfos{vchans: vchans, pchans: pchans}, nil
|
||||
}
|
||||
|
||||
// getChannelsFuncType returns the channel information according to the collection id.
|
||||
type getChannelsFuncType = func(collectionID UniqueID) (map[vChan]pChan, error)
|
||||
type getChannelsFuncType = func(collectionID UniqueID) (channelInfos, error)
|
||||
|
||||
// repackFuncType repacks message into message pack.
|
||||
type repackFuncType = func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error)
|
||||
|
||||
// getDmlChannelsFunc returns a function about how to get dml channels of a collection.
|
||||
func getDmlChannelsFunc(ctx context.Context, rc types.RootCoord) getChannelsFuncType {
|
||||
return func(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
return func(collectionID UniqueID) (channelInfos, error) {
|
||||
req := &milvuspb.DescribeCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_DescribeCollection,
|
||||
MsgID: 0, // todo
|
||||
Timestamp: 0, // todo
|
||||
SourceID: 0, // todo
|
||||
},
|
||||
DbName: "", // todo
|
||||
CollectionName: "", // todo
|
||||
CollectionID: collectionID,
|
||||
TimeStamp: 0, // todo
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DescribeCollection},
|
||||
CollectionID: collectionID,
|
||||
}
|
||||
|
||||
resp, err := rc.DescribeCollection(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to describe collection", zap.Error(err), zap.Int64("collection", collectionID))
|
||||
return nil, err
|
||||
}
|
||||
if resp.Status.ErrorCode != 0 {
|
||||
log.Warn("DescribeCollection",
|
||||
zap.Any("ErrorCode", resp.Status.ErrorCode),
|
||||
zap.Any("Reason", resp.Status.Reason))
|
||||
return nil, err
|
||||
}
|
||||
if len(resp.VirtualChannelNames) != len(resp.PhysicalChannelNames) {
|
||||
err := fmt.Errorf(
|
||||
"len(VirtualChannelNames): %v, len(PhysicalChannelNames): %v",
|
||||
len(resp.VirtualChannelNames),
|
||||
len(resp.PhysicalChannelNames))
|
||||
log.Warn("GetDmlChannels", zap.Error(err))
|
||||
return nil, err
|
||||
log.Error("failed to describe collection", zap.Error(err), zap.Int64("collection", collectionID))
|
||||
return channelInfos{}, err
|
||||
}
|
||||
|
||||
ret := make(map[vChan]pChan)
|
||||
for idx, name := range resp.VirtualChannelNames {
|
||||
if _, ok := ret[name]; ok {
|
||||
err := fmt.Errorf(
|
||||
"duplicated virtual channel found, vchan: %v, pchan: %v",
|
||||
name,
|
||||
resp.PhysicalChannelNames[idx])
|
||||
return nil, err
|
||||
}
|
||||
ret[name] = resp.PhysicalChannelNames[idx]
|
||||
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Error("failed to describe collection",
|
||||
zap.String("error_code", resp.GetStatus().GetErrorCode().String()),
|
||||
zap.String("reason", resp.GetStatus().GetReason()))
|
||||
return channelInfos{}, errors.New(resp.GetStatus().GetReason())
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
return newChannels(resp.GetVirtualChannelNames(), resp.GetPhysicalChannelNames())
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,323 +120,195 @@ const (
|
||||
)
|
||||
|
||||
type singleTypeChannelsMgr struct {
|
||||
collectionID2VIDs map[UniqueID][]int // id are sorted
|
||||
collMtx sync.RWMutex
|
||||
|
||||
id2vchans map[int][]vChan
|
||||
id2vchansMtx sync.RWMutex
|
||||
|
||||
id2Stream map[int]msgstream.MsgStream
|
||||
id2UsageHistogramOfStream map[int]int
|
||||
streamMtx sync.RWMutex
|
||||
|
||||
vchans2pchans map[vChan]pChan
|
||||
vchans2pchansMtx sync.RWMutex
|
||||
|
||||
getChannelsFunc getChannelsFuncType
|
||||
|
||||
repackFunc repackFuncType
|
||||
infos map[UniqueID]streamInfos // collection id -> stream infos
|
||||
mu sync.RWMutex
|
||||
|
||||
getChannelsFunc getChannelsFuncType
|
||||
repackFunc repackFuncType
|
||||
singleStreamType streamType
|
||||
|
||||
msgStreamFactory msgstream.Factory
|
||||
}
|
||||
|
||||
func getAllKeys(m map[vChan]pChan) []vChan {
|
||||
keys := make([]vChan, 0, len(m))
|
||||
for key := range m {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
func (mgr *singleTypeChannelsMgr) getAllChannels(collectionID UniqueID) (channelInfos, error) {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
|
||||
func getAllValues(m map[vChan]pChan) []pChan {
|
||||
values := make([]pChan, 0, len(m))
|
||||
for _, value := range m {
|
||||
values = append(values, value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) getLatestVID(collectionID UniqueID) (int, error) {
|
||||
mgr.collMtx.RLock()
|
||||
defer mgr.collMtx.RUnlock()
|
||||
|
||||
ids, ok := mgr.collectionID2VIDs[collectionID]
|
||||
if !ok || len(ids) <= 0 {
|
||||
return 0, fmt.Errorf("v-channel ID is not found for collection %d", collectionID)
|
||||
infos, ok := mgr.infos[collectionID]
|
||||
if ok {
|
||||
return infos.channelInfos, nil
|
||||
}
|
||||
|
||||
return ids[len(ids)-1], nil
|
||||
return channelInfos{}, fmt.Errorf("collection not found in channels manager: %d", collectionID)
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) getAllVIDs(collectionID UniqueID) ([]int, error) {
|
||||
mgr.collMtx.RLock()
|
||||
defer mgr.collMtx.RUnlock()
|
||||
|
||||
ids, exist := mgr.collectionID2VIDs[collectionID]
|
||||
if !exist {
|
||||
return nil, fmt.Errorf("collection %d not found", collectionID)
|
||||
func (mgr *singleTypeChannelsMgr) getPChans(collectionID UniqueID) ([]pChan, error) {
|
||||
channelInfos, err := mgr.getChannelsFunc(collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
return channelInfos.pchans, nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) getVChansByVID(vid int) ([]vChan, error) {
|
||||
mgr.id2vchansMtx.RLock()
|
||||
defer mgr.id2vchansMtx.RUnlock()
|
||||
|
||||
vchans, ok := mgr.id2vchans[vid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("vid %d not found", vid)
|
||||
}
|
||||
|
||||
return vchans, nil
|
||||
}
|
||||
|
||||
// getPChansByVChans converts virtual channel names to physical channel names
|
||||
func (mgr *singleTypeChannelsMgr) getPChansByVChans(vchans []vChan) ([]pChan, error) {
|
||||
mgr.vchans2pchansMtx.RLock()
|
||||
defer mgr.vchans2pchansMtx.RUnlock()
|
||||
|
||||
pchans := make([]pChan, 0, len(vchans))
|
||||
for _, vchan := range vchans {
|
||||
pchan, ok := mgr.vchans2pchans[vchan]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("vchan %v not found", vchan)
|
||||
}
|
||||
pchans = append(pchans, pchan)
|
||||
}
|
||||
|
||||
return pchans, nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) updateVChans(vid int, vchans []vChan) {
|
||||
mgr.id2vchansMtx.Lock()
|
||||
defer mgr.id2vchansMtx.Unlock()
|
||||
|
||||
mgr.id2vchans[vid] = vchans
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) deleteVChansByVID(vid int) {
|
||||
mgr.id2vchansMtx.Lock()
|
||||
defer mgr.id2vchansMtx.Unlock()
|
||||
|
||||
delete(mgr.id2vchans, vid)
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) deleteVChansByVIDs(vids []int) {
|
||||
mgr.id2vchansMtx.Lock()
|
||||
defer mgr.id2vchansMtx.Unlock()
|
||||
|
||||
for _, vid := range vids {
|
||||
delete(mgr.id2vchans, vid)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) deleteStreamByVID(vid int) {
|
||||
mgr.streamMtx.Lock()
|
||||
defer mgr.streamMtx.Unlock()
|
||||
|
||||
delete(mgr.id2Stream, vid)
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) deleteStreamByVIDs(vids []int) {
|
||||
mgr.streamMtx.Lock()
|
||||
defer mgr.streamMtx.Unlock()
|
||||
|
||||
for _, vid := range vids {
|
||||
delete(mgr.id2Stream, vid)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) updateChannels(channels map[vChan]pChan) {
|
||||
mgr.vchans2pchansMtx.Lock()
|
||||
defer mgr.vchans2pchansMtx.Unlock()
|
||||
|
||||
for vchan, pchan := range channels {
|
||||
mgr.vchans2pchans[vchan] = pchan
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) deleteAllChannels() {
|
||||
mgr.vchans2pchansMtx.Lock()
|
||||
defer mgr.vchans2pchansMtx.Unlock()
|
||||
|
||||
mgr.vchans2pchans = nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) deleteAllStream() {
|
||||
mgr.id2vchansMtx.Lock()
|
||||
defer mgr.id2vchansMtx.Unlock()
|
||||
|
||||
mgr.id2UsageHistogramOfStream = nil
|
||||
mgr.id2Stream = nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) deleteAllVChans() {
|
||||
mgr.id2vchansMtx.Lock()
|
||||
defer mgr.id2vchansMtx.Unlock()
|
||||
|
||||
mgr.id2vchans = nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) deleteAllCollection() {
|
||||
mgr.collMtx.Lock()
|
||||
defer mgr.collMtx.Unlock()
|
||||
|
||||
mgr.collectionID2VIDs = nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) addStream(vid int, stream msgstream.MsgStream) {
|
||||
mgr.streamMtx.Lock()
|
||||
defer mgr.streamMtx.Unlock()
|
||||
|
||||
mgr.id2Stream[vid] = stream
|
||||
mgr.id2UsageHistogramOfStream[vid] = 0
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) updateCollection(collectionID UniqueID, id int) {
|
||||
mgr.collMtx.Lock()
|
||||
defer mgr.collMtx.Unlock()
|
||||
|
||||
vids, ok := mgr.collectionID2VIDs[collectionID]
|
||||
if !ok {
|
||||
mgr.collectionID2VIDs[collectionID] = make([]int, 1)
|
||||
mgr.collectionID2VIDs[collectionID][0] = id
|
||||
} else {
|
||||
vids = append(vids, id)
|
||||
sort.Slice(vids, func(i, j int) bool {
|
||||
return vids[i] < vids[j]
|
||||
})
|
||||
mgr.collectionID2VIDs[collectionID] = vids
|
||||
func (mgr *singleTypeChannelsMgr) getVChans(collectionID UniqueID) ([]vChan, error) {
|
||||
channelInfos, err := mgr.getChannelsFunc(collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return channelInfos.vchans, nil
|
||||
}
|
||||
|
||||
// getChannels returns the physical channels.
|
||||
func (mgr *singleTypeChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) {
|
||||
id, err := mgr.getLatestVID(collectionID)
|
||||
if err == nil {
|
||||
vchans, err := mgr.getVChansByVID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mgr.getPChansByVChans(vchans)
|
||||
var channelInfos channelInfos
|
||||
channelInfos, err := mgr.getAllChannels(collectionID)
|
||||
if err != nil {
|
||||
return mgr.getPChans(collectionID)
|
||||
}
|
||||
|
||||
// TODO(dragondriver): return error or update channel information from master?
|
||||
return nil, err
|
||||
return channelInfos.pchans, nil
|
||||
}
|
||||
|
||||
// getVChannels returns the virtual channels.
|
||||
func (mgr *singleTypeChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) {
|
||||
id, err := mgr.getLatestVID(collectionID)
|
||||
if err == nil {
|
||||
return mgr.getVChansByVID(id)
|
||||
var channelInfos channelInfos
|
||||
channelInfos, err := mgr.getAllChannels(collectionID)
|
||||
if err != nil {
|
||||
return mgr.getVChans(collectionID)
|
||||
}
|
||||
return channelInfos.vchans, nil
|
||||
}
|
||||
|
||||
// TODO(dragondriver): return error or update channel information from master?
|
||||
return nil, err
|
||||
func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool {
|
||||
streamInfos, ok := mgr.infos[collectionID]
|
||||
return ok && streamInfos.stream != nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) streamExist(collectionID UniqueID) bool {
|
||||
stream, err := mgr.getStream(collectionID)
|
||||
return err == nil && stream != nil
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
return mgr.streamExistPrivate(collectionID)
|
||||
}
|
||||
|
||||
func createStream(factory msgstream.Factory, streamType streamType, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
|
||||
var stream msgstream.MsgStream
|
||||
var err error
|
||||
|
||||
if streamType == dqlStreamType {
|
||||
stream, err = factory.NewQueryMsgStream(context.Background())
|
||||
} else {
|
||||
stream, err = factory.NewMsgStream(context.Background())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream.AsProducer(pchans)
|
||||
if repack != nil {
|
||||
stream.SetRepackFunc(repack)
|
||||
}
|
||||
runtime.SetFinalizer(stream, func(stream msgstream.MsgStream) {
|
||||
stream.Close()
|
||||
})
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) updateCollection(collectionID UniqueID, channelInfos channelInfos, stream msgstream.MsgStream) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
if !mgr.streamExistPrivate(collectionID) {
|
||||
mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream}
|
||||
}
|
||||
}
|
||||
|
||||
func incPChansMetrics(pchans []pChan) {
|
||||
for _, pc := range pchans {
|
||||
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
func decPChanMetrics(pchans []pChan) {
|
||||
for _, pc := range pchans {
|
||||
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Dec()
|
||||
}
|
||||
}
|
||||
|
||||
// createMsgStream create message stream for specified collection. Idempotent.
|
||||
// If stream already exists, directly return nil.
|
||||
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) error {
|
||||
if mgr.streamExist(collectionID) {
|
||||
log.Info("stream already exist, no need to re-create", zap.Int64("collection_id", collectionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
channels, err := mgr.getChannelsFunc(collectionID)
|
||||
channelInfos, err := mgr.getChannelsFunc(collectionID)
|
||||
if err != nil {
|
||||
log.Warn("failed to create message stream",
|
||||
zap.Int64("collection_id", collectionID),
|
||||
zap.Error(err))
|
||||
log.Error("failed to get channels", zap.Error(err), zap.Int64("collection", collectionID))
|
||||
return err
|
||||
}
|
||||
log.Debug("singleTypeChannelsMgr",
|
||||
|
||||
stream, err := createStream(mgr.msgStreamFactory, mgr.singleStreamType, channelInfos.pchans, mgr.repackFunc)
|
||||
if err != nil {
|
||||
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
|
||||
return err
|
||||
}
|
||||
|
||||
mgr.updateCollection(collectionID, channelInfos, stream)
|
||||
|
||||
log.Info("create message stream",
|
||||
zap.Int64("collection_id", collectionID),
|
||||
zap.Any("createMsgStream.getChannels", channels))
|
||||
zap.Strings("virtual_channels", channelInfos.vchans),
|
||||
zap.Strings("physical_channels", channelInfos.pchans))
|
||||
|
||||
mgr.updateChannels(channels)
|
||||
incPChansMetrics(channelInfos.pchans)
|
||||
|
||||
id := uniquegenerator.GetUniqueIntGeneratorIns().GetInt()
|
||||
|
||||
vchans, pchans := make([]string, 0, len(channels)), make([]string, 0, len(channels))
|
||||
for k, v := range channels {
|
||||
vchans = append(vchans, k)
|
||||
pchans = append(pchans, v)
|
||||
}
|
||||
mgr.updateVChans(id, vchans)
|
||||
|
||||
var stream msgstream.MsgStream
|
||||
if mgr.singleStreamType == dqlStreamType {
|
||||
stream, err = mgr.msgStreamFactory.NewQueryMsgStream(context.Background())
|
||||
} else {
|
||||
stream, err = mgr.msgStreamFactory.NewMsgStream(context.Background())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream.AsProducer(pchans)
|
||||
if mgr.repackFunc != nil {
|
||||
stream.SetRepackFunc(mgr.repackFunc)
|
||||
}
|
||||
runtime.SetFinalizer(stream, func(stream msgstream.MsgStream) {
|
||||
stream.Close()
|
||||
})
|
||||
mgr.addStream(id, stream)
|
||||
|
||||
mgr.updateCollection(collectionID, id)
|
||||
for _, pc := range pchans {
|
||||
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Inc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) getStream(collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
mgr.streamMtx.RLock()
|
||||
defer mgr.streamMtx.RUnlock()
|
||||
func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
streamInfos, ok := mgr.infos[collectionID]
|
||||
if ok {
|
||||
return streamInfos.stream, nil
|
||||
}
|
||||
return nil, fmt.Errorf("collection not found: %d", collectionID)
|
||||
}
|
||||
|
||||
vid, err := mgr.getLatestVID(collectionID)
|
||||
if err != nil {
|
||||
// getStream get message stream of specified collection.
|
||||
// If stream don't exists, call createMsgStream to create for it.
|
||||
func (mgr *singleTypeChannelsMgr) getStream(collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
if stream, err := mgr.lockGetStream(collectionID); err == nil {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
if err := mgr.createMsgStream(collectionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream, ok := mgr.id2Stream[vid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no dml stream for collection %v", collectionID)
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
return mgr.lockGetStream(collectionID)
|
||||
}
|
||||
|
||||
// removeStream remove the corresponding stream of the specified collection. Idempotent.
|
||||
// If stream already exists, remove it, otherwise do nothing.
|
||||
func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) error {
|
||||
channels, err := mgr.getChannels(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ids, err2 := mgr.getAllVIDs(collectionID)
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
|
||||
mgr.deleteVChansByVIDs(ids)
|
||||
mgr.deleteStreamByVIDs(ids)
|
||||
for _, pc := range channels {
|
||||
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Dec()
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
if info, ok := mgr.infos[collectionID]; ok {
|
||||
decPChanMetrics(info.channelInfos.pchans)
|
||||
delete(mgr.infos, collectionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeAllStream remove all message stream.
|
||||
func (mgr *singleTypeChannelsMgr) removeAllStream() error {
|
||||
mgr.deleteAllChannels()
|
||||
mgr.deleteAllStream()
|
||||
mgr.deleteAllVChans()
|
||||
mgr.deleteAllCollection()
|
||||
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
for _, info := range mgr.infos {
|
||||
decPChanMetrics(info.channelInfos.pchans)
|
||||
}
|
||||
mgr.infos = make(map[UniqueID]streamInfos)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -438,15 +319,11 @@ func newSingleTypeChannelsMgr(
|
||||
singleStreamType streamType,
|
||||
) *singleTypeChannelsMgr {
|
||||
return &singleTypeChannelsMgr{
|
||||
collectionID2VIDs: make(map[UniqueID][]int),
|
||||
id2vchans: make(map[int][]vChan),
|
||||
id2Stream: make(map[int]msgstream.MsgStream),
|
||||
id2UsageHistogramOfStream: make(map[int]int),
|
||||
vchans2pchans: make(map[vChan]pChan),
|
||||
getChannelsFunc: getChannelsFunc,
|
||||
repackFunc: repackFunc,
|
||||
singleStreamType: singleStreamType,
|
||||
msgStreamFactory: msgStreamFactory,
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
getChannelsFunc: getChannelsFunc,
|
||||
repackFunc: repackFunc,
|
||||
singleStreamType: singleStreamType,
|
||||
msgStreamFactory: msgStreamFactory,
|
||||
}
|
||||
}
|
||||
|
||||
@ -486,7 +363,6 @@ func (mgr *channelsMgrImpl) removeAllDMLStream() error {
|
||||
func newChannelsMgrImpl(
|
||||
getDmlChannelsFunc getChannelsFuncType,
|
||||
dmlRepackFunc repackFuncType,
|
||||
dqlRepackFunc repackFuncType,
|
||||
msgStreamFactory msgstream.Factory,
|
||||
) *channelsMgrImpl {
|
||||
return &channelsMgrImpl{
|
||||
|
@ -17,153 +17,398 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChannelsMgrImpl_getChannels(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
_, err := mgr.getChannels(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = mgr.createDMLMsgStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
_, err = mgr.getChannels(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
func Test_removeDuplicate(t *testing.T) {
|
||||
s1 := []string{"11", "11"}
|
||||
filtered1 := removeDuplicate(s1)
|
||||
assert.ElementsMatch(t, filtered1, []string{"11"})
|
||||
}
|
||||
|
||||
func TestChannelsMgrImpl_getVChannels(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
_, err := mgr.getVChannels(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = mgr.createDMLMsgStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
_, err = mgr.getVChannels(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func TestChannelsMgrImpl_createDMLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
_, err := mgr.getChannels(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
_, err = mgr.getVChannels(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = mgr.createDMLMsgStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// re-create message stream.
|
||||
err = mgr.createDMLMsgStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
_, err = mgr.getChannels(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
_, err = mgr.getVChannels(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func TestChannelsMgrImpl_getDMLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
_, err := mgr.getDMLStream(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = mgr.createDMLMsgStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
_, err = mgr.getDMLStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func TestChannelsMgrImpl_removeDMLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
_, err := mgr.getDMLStream(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = mgr.removeDMLStream(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
err = mgr.createDMLMsgStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
_, err = mgr.getDMLStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
err = mgr.removeDMLStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
_, err = mgr.getDMLStream(collID)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestChannelsMgrImpl_removeAllDMLMsgStream(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
mgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
|
||||
defer mgr.removeAllDMLStream()
|
||||
|
||||
num := 10
|
||||
for i := 0; i < num; i++ {
|
||||
collID := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
err := mgr.createDMLMsgStream(collID)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllKeysAndGetAllValues(t *testing.T) {
|
||||
chanMapping := make(map[vChan]pChan)
|
||||
chanMapping["v1"] = "p1"
|
||||
chanMapping["v2"] = "p2"
|
||||
|
||||
t.Run("getAllKeys", func(t *testing.T) {
|
||||
vChans := getAllKeys(chanMapping)
|
||||
assert.Equal(t, 2, len(vChans))
|
||||
func Test_newChannels(t *testing.T) {
|
||||
t.Run("length mismatch", func(t *testing.T) {
|
||||
_, err := newChannels([]string{"111", "222"}, []string{"111"})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("getAllValues", func(t *testing.T) {
|
||||
pChans := getAllValues(chanMapping)
|
||||
assert.Equal(t, 2, len(pChans))
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
got, err := newChannels([]string{"111", "222"}, []string{"111", "111"})
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"111", "222"}, got.vchans)
|
||||
// assert.ElementsMatch(t, []string{"111"}, got.pchans)
|
||||
assert.ElementsMatch(t, []string{"111", "111"}, got.pchans)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteVChansByVID(t *testing.T) {
|
||||
mgr := singleTypeChannelsMgr{
|
||||
id2vchansMtx: sync.RWMutex{},
|
||||
id2vchans: map[int][]vChan{
|
||||
10: {"v1"},
|
||||
func Test_getDmlChannelsFunc(t *testing.T) {
|
||||
t.Run("failed to describe collection", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rc := newMockRootCoord()
|
||||
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
f := getDmlChannelsFunc(ctx, rc)
|
||||
_, err := f(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("error code not success", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rc := newMockRootCoord()
|
||||
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
return &milvuspb.DescribeCollectionResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil
|
||||
}
|
||||
f := getDmlChannelsFunc(ctx, rc)
|
||||
_, err := f(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rc := newMockRootCoord()
|
||||
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
return &milvuspb.DescribeCollectionResponse{
|
||||
VirtualChannelNames: []string{"111", "222"},
|
||||
PhysicalChannelNames: []string{"111", "111"},
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil
|
||||
}
|
||||
f := getDmlChannelsFunc(ctx, rc)
|
||||
got, err := f(100)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"111", "222"}, got.vchans)
|
||||
// assert.ElementsMatch(t, []string{"111"}, got.pchans)
|
||||
assert.ElementsMatch(t, []string{"111", "111"}, got.pchans)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_getAllChannels(t *testing.T) {
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {channelInfos: channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}},
|
||||
},
|
||||
}
|
||||
got, err := m.getAllChannels(100)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"111", "222"}, got.vchans)
|
||||
assert.ElementsMatch(t, []string{"111"}, got.pchans)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{},
|
||||
}
|
||||
_, err := m.getAllChannels(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_getPChans(t *testing.T) {
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
}
|
||||
got, err := m.getPChans(100)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"111"}, got)
|
||||
})
|
||||
|
||||
t.Run("error case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.getPChans(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_getVChans(t *testing.T) {
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
}
|
||||
got, err := m.getVChans(100)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"111", "222"}, got)
|
||||
})
|
||||
|
||||
t.Run("error case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.getVChans(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_getChannels(t *testing.T) {
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {channelInfos: channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}},
|
||||
},
|
||||
}
|
||||
got, err := m.getChannels(100)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"111"}, got)
|
||||
})
|
||||
|
||||
t.Run("error case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.getChannels(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_getVChannels(t *testing.T) {
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {channelInfos: channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}},
|
||||
},
|
||||
}
|
||||
got, err := m.getVChannels(100)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"111", "222"}, got)
|
||||
})
|
||||
|
||||
t.Run("error case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.getVChannels(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_streamExist(t *testing.T) {
|
||||
t.Run("exist", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {stream: newSimpleMockMsgStream()},
|
||||
},
|
||||
}
|
||||
exist := m.streamExist(100)
|
||||
assert.True(t, exist)
|
||||
})
|
||||
|
||||
t.Run("not exist", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {stream: nil},
|
||||
},
|
||||
}
|
||||
exist := m.streamExist(100)
|
||||
assert.False(t, exist)
|
||||
m.infos = make(map[UniqueID]streamInfos)
|
||||
exist = m.streamExist(100)
|
||||
assert.False(t, exist)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_createStream(t *testing.T) {
|
||||
t.Run("failed to create msgstream", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
_, err := createStream(factory, dmlStreamType, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed to create query msgstream", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
_, err := createStream(factory, dqlStreamType, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return newMockMsgStream(), nil
|
||||
}
|
||||
_, err := createStream(factory, dmlStreamType, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
|
||||
t.Run("re-create", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {stream: newMockMsgStream()},
|
||||
},
|
||||
}
|
||||
err := m.createMsgStream(100)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed to get channels", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
err := m.createMsgStream(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed to create message stream", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
msgStreamFactory: factory,
|
||||
singleStreamType: dmlStreamType,
|
||||
repackFunc: nil,
|
||||
}
|
||||
err := m.createMsgStream(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return newMockMsgStream(), nil
|
||||
}
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
msgStreamFactory: factory,
|
||||
singleStreamType: dmlStreamType,
|
||||
repackFunc: nil,
|
||||
}
|
||||
err := m.createMsgStream(100)
|
||||
assert.NoError(t, err)
|
||||
stream, err := m.getStream(100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_lockGetStream(t *testing.T) {
|
||||
t.Run("collection not found", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
}
|
||||
_, err := m.lockGetStream(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {stream: newMockMsgStream()},
|
||||
},
|
||||
}
|
||||
stream, err := m.lockGetStream(100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
|
||||
t.Run("exist", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {stream: newMockMsgStream()},
|
||||
},
|
||||
}
|
||||
stream, err := m.getStream(100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
|
||||
t.Run("failed to create", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{},
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.getStream(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("get after create", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return newMockMsgStream(), nil
|
||||
}
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
msgStreamFactory: factory,
|
||||
singleStreamType: dmlStreamType,
|
||||
repackFunc: nil,
|
||||
}
|
||||
stream, err := m.getStream(100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_removeStream(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {
|
||||
stream: newMockMsgStream(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.deleteVChansByVID(10)
|
||||
err := m.removeStream(100)
|
||||
assert.NoError(t, err)
|
||||
_, err = m.lockGetStream(100)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_removeAllStream(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {
|
||||
stream: newMockMsgStream(),
|
||||
},
|
||||
},
|
||||
}
|
||||
err := m.removeAllStream()
|
||||
assert.NoError(t, err)
|
||||
_, err = m.lockGetStream(100)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
@ -3609,26 +3609,31 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi
|
||||
zap.Error(err))
|
||||
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
resp.Status.Reason = err.Error()
|
||||
return resp, err
|
||||
return resp, nil
|
||||
}
|
||||
chNames, err := node.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
err = node.chMgr.createDMLMsgStream(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chNames, err = node.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Error("failed to get virtual channels",
|
||||
zap.Error(err),
|
||||
zap.String("collection", req.GetCollectionName()),
|
||||
zap.Int64("collection_id", collID))
|
||||
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
resp.Status.Reason = err.Error()
|
||||
return resp, nil
|
||||
}
|
||||
req.ChannelNames = chNames
|
||||
if req.GetPartitionName() == "" {
|
||||
req.PartitionName = Params.CommonCfg.DefaultPartitionName
|
||||
}
|
||||
// Call rootCoord to finish import.
|
||||
resp, err = node.rootCoord.Import(ctx, req)
|
||||
return resp, err
|
||||
respFromRC, err := node.rootCoord.Import(ctx, req)
|
||||
if err != nil {
|
||||
log.Error("failed to execute bulk load request", zap.Error(err))
|
||||
resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
resp.Status.Reason = err.Error()
|
||||
return resp, nil
|
||||
}
|
||||
return respFromRC, nil
|
||||
}
|
||||
|
||||
// GetImportState checks import task state from datanode
|
||||
|
@ -30,6 +30,9 @@ func (m *mockCache) GetCollectionSchema(ctx context.Context, collectionName stri
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCache) RemoveCollection(ctx context.Context, collectionName string) {
|
||||
}
|
||||
|
||||
func (m *mockCache) setGetIDFunc(f getCollectionIDFunc) {
|
||||
m.getIDFunc = f
|
||||
}
|
||||
|
28
internal/proxy/mock_channels_mgr_test.go
Normal file
28
internal/proxy/mock_channels_mgr_test.go
Normal file
@ -0,0 +1,28 @@
|
||||
package proxy
|
||||
|
||||
type getVChannelsFuncType = func(collectionID UniqueID) ([]vChan, error)
|
||||
type removeDMLStreamFuncType = func(collectionID UniqueID) error
|
||||
|
||||
type mockChannelsMgr struct {
|
||||
channelsMgr
|
||||
getVChannelsFuncType
|
||||
removeDMLStreamFuncType
|
||||
}
|
||||
|
||||
func (m *mockChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) {
|
||||
if m.getVChannelsFuncType != nil {
|
||||
return m.getVChannelsFuncType(collectionID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockChannelsMgr) removeDMLStream(collectionID UniqueID) error {
|
||||
if m.removeDMLStreamFuncType != nil {
|
||||
return m.removeDMLStreamFuncType(collectionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMockChannelsMgr() *mockChannelsMgr {
|
||||
return &mockChannelsMgr{}
|
||||
}
|
69
internal/proxy/mock_msgstream_test.go
Normal file
69
internal/proxy/mock_msgstream_test.go
Normal file
@ -0,0 +1,69 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
)
|
||||
|
||||
type mockMsgStream struct {
|
||||
msgstream.MsgStream
|
||||
asProducer func([]string)
|
||||
setRepack func(repackFunc msgstream.RepackFunc)
|
||||
close func()
|
||||
}
|
||||
|
||||
func (m *mockMsgStream) AsProducer(producers []string) {
|
||||
if m.asProducer != nil {
|
||||
m.asProducer(producers)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {
|
||||
if m.setRepack != nil {
|
||||
m.setRepack(repackFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockMsgStream) Close() {
|
||||
if m.close != nil {
|
||||
m.close()
|
||||
}
|
||||
}
|
||||
|
||||
func newMockMsgStream() *mockMsgStream {
|
||||
return &mockMsgStream{}
|
||||
}
|
||||
|
||||
type mockMsgStreamFactory struct {
|
||||
msgstream.Factory
|
||||
f func(ctx context.Context) (msgstream.MsgStream, error)
|
||||
fQStream func(ctx context.Context) (msgstream.MsgStream, error)
|
||||
fTtStream func(ctx context.Context) (msgstream.MsgStream, error)
|
||||
}
|
||||
|
||||
func (m *mockMsgStreamFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
if m.f != nil {
|
||||
return m.f(ctx)
|
||||
}
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
|
||||
func (m *mockMsgStreamFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
if m.fTtStream != nil {
|
||||
return m.fTtStream(ctx)
|
||||
}
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
|
||||
func (m *mockMsgStreamFactory) NewQueryMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
if m.fQStream != nil {
|
||||
return m.fQStream(ctx)
|
||||
}
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
|
||||
func newMockMsgStreamFactory() *mockMsgStreamFactory {
|
||||
return &mockMsgStreamFactory{}
|
||||
}
|
@ -86,37 +86,6 @@ func newMockIDAllocatorInterface() idAllocatorInterface {
|
||||
return &mockIDAllocatorInterface{}
|
||||
}
|
||||
|
||||
type mockGetChannelsService struct {
|
||||
collectionID2Channels map[UniqueID]map[vChan]pChan
|
||||
f getChannelsFuncType
|
||||
}
|
||||
|
||||
func newMockGetChannelsService() *mockGetChannelsService {
|
||||
return &mockGetChannelsService{
|
||||
collectionID2Channels: make(map[UniqueID]map[vChan]pChan),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
if m.f != nil {
|
||||
return m.f(collectionID)
|
||||
}
|
||||
|
||||
channels, ok := m.collectionID2Channels[collectionID]
|
||||
if ok {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
channels = make(map[vChan]pChan)
|
||||
l := rand.Uint64()%10 + 1
|
||||
for i := 0; uint64(i) < l; i++ {
|
||||
channels[funcutil.GenRandomStr()] = funcutil.GenRandomStr()
|
||||
}
|
||||
|
||||
m.collectionID2Channels[collectionID] = channels
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
type mockTask struct {
|
||||
*TaskCondition
|
||||
id UniqueID
|
||||
|
@ -198,7 +198,7 @@ func (node *Proxy) Init() error {
|
||||
|
||||
log.Debug("create channels manager", zap.String("role", typeutil.ProxyRole))
|
||||
dmlChannelsFunc := getDmlChannelsFunc(node.ctx, node.rootCoord)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, nil, node.factory)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, node.factory)
|
||||
node.chMgr = chMgr
|
||||
log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
|
@ -18,6 +18,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
@ -1591,8 +1592,8 @@ func TestProxy(t *testing.T) {
|
||||
}
|
||||
proxy.stateCode.Store(internalpb.StateCode_Healthy)
|
||||
resp, err := proxy.Import(context.TODO(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
@ -1604,8 +1605,8 @@ func TestProxy(t *testing.T) {
|
||||
}
|
||||
proxy.stateCode.Store(internalpb.StateCode_Healthy)
|
||||
resp, err := proxy.Import(context.TODO(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
@ -3028,63 +3029,104 @@ func TestProxy_GetComponentStates_state_code(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxy_Import(t *testing.T) {
|
||||
rc := NewRootCoordMock()
|
||||
master := newMockGetChannelsService()
|
||||
msgStreamFactory := newSimpleMockMsgStreamFactory()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
qc := NewQueryCoordMock()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
shardMgr := newShardClientMgr()
|
||||
err := InitMetaCache(rc, qc, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
rc.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_DropCollection,
|
||||
MsgID: 100,
|
||||
Timestamp: 100,
|
||||
},
|
||||
CollectionName: "import_collection",
|
||||
})
|
||||
localMsg := true
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
factory := dependency.NewDefaultFactory(localMsg)
|
||||
proxy, err := NewProxy(ctx, factory)
|
||||
proxy.rootCoord = rc
|
||||
assert.NoError(t, err)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
t.Run("test import get vChannel failed (the first one)", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
proxy.stateCode.Store(internalpb.StateCode_Healthy)
|
||||
proxy.chMgr = newChannelsMgrImpl(master.GetChannels, nil, nil, msgStreamFactory)
|
||||
resp, err := proxy.Import(context.TODO(),
|
||||
&milvuspb.ImportRequest{
|
||||
CollectionName: "import_collection",
|
||||
})
|
||||
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("test import with unhealthy", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req := &milvuspb.ImportRequest{
|
||||
CollectionName: "dummy",
|
||||
}
|
||||
proxy.stateCode.Store(internalpb.StateCode_Abnormal)
|
||||
proxy := &Proxy{}
|
||||
proxy.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
resp, err := proxy.Import(context.TODO(), req)
|
||||
assert.EqualValues(t, unhealthyStatus(), resp.Status)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, unhealthyStatus(), resp.GetStatus())
|
||||
})
|
||||
resp, err := rc.DropCollection(context.TODO(), &milvuspb.DropCollectionRequest{
|
||||
CollectionName: "import_collection",
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("collection not found", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
proxy := &Proxy{}
|
||||
proxy.UpdateStateCode(internalpb.StateCode_Healthy)
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return 0, errors.New("mock")
|
||||
})
|
||||
globalMetaCache = cache
|
||||
req := &milvuspb.ImportRequest{
|
||||
CollectionName: "dummy",
|
||||
}
|
||||
resp, err := proxy.Import(context.TODO(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("failed to get virtual channels", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
proxy := &Proxy{}
|
||||
proxy.UpdateStateCode(internalpb.StateCode_Healthy)
|
||||
cache := newMockCache()
|
||||
globalMetaCache = cache
|
||||
chMgr := newMockChannelsMgr()
|
||||
chMgr.getVChannelsFuncType = func(collectionID UniqueID) ([]vChan, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
proxy.chMgr = chMgr
|
||||
req := &milvuspb.ImportRequest{
|
||||
CollectionName: "dummy",
|
||||
}
|
||||
resp, err := proxy.Import(context.TODO(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("rootcoord fail", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
proxy := &Proxy{}
|
||||
proxy.UpdateStateCode(internalpb.StateCode_Healthy)
|
||||
cache := newMockCache()
|
||||
globalMetaCache = cache
|
||||
chMgr := newMockChannelsMgr()
|
||||
proxy.chMgr = chMgr
|
||||
rc := newMockRootCoord()
|
||||
rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
proxy.rootCoord = rc
|
||||
req := &milvuspb.ImportRequest{
|
||||
CollectionName: "dummy",
|
||||
}
|
||||
resp, err := proxy.Import(context.TODO(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
proxy := &Proxy{}
|
||||
proxy.UpdateStateCode(internalpb.StateCode_Healthy)
|
||||
cache := newMockCache()
|
||||
globalMetaCache = cache
|
||||
chMgr := newMockChannelsMgr()
|
||||
proxy.chMgr = chMgr
|
||||
rc := newMockRootCoord()
|
||||
rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
|
||||
return &milvuspb.ImportResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil
|
||||
}
|
||||
proxy.rootCoord = rc
|
||||
req := &milvuspb.ImportRequest{
|
||||
CollectionName: "dummy",
|
||||
}
|
||||
resp, err := proxy.Import(context.TODO(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
rc.Stop()
|
||||
}
|
||||
|
||||
func TestProxy_GetImportState(t *testing.T) {
|
||||
|
@ -1066,6 +1066,8 @@ type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitio
|
||||
type DescribeIndexFunc func(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error)
|
||||
type ShowSegmentsFunc func(ctx context.Context, request *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error)
|
||||
type DescribeSegmentsFunc func(ctx context.Context, request *rootcoordpb.DescribeSegmentsRequest) (*rootcoordpb.DescribeSegmentsResponse, error)
|
||||
type ImportFunc func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error)
|
||||
type DropCollectionFunc func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error)
|
||||
|
||||
type mockRootCoord struct {
|
||||
types.RootCoord
|
||||
@ -1074,6 +1076,8 @@ type mockRootCoord struct {
|
||||
DescribeIndexFunc
|
||||
ShowSegmentsFunc
|
||||
DescribeSegmentsFunc
|
||||
ImportFunc
|
||||
DropCollectionFunc
|
||||
}
|
||||
|
||||
func (m *mockRootCoord) DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
@ -1111,6 +1115,20 @@ func (m *mockRootCoord) DescribeSegments(ctx context.Context, request *rootcoord
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
|
||||
func (m *mockRootCoord) Import(ctx context.Context, request *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) {
|
||||
if m.ImportFunc != nil {
|
||||
return m.ImportFunc(ctx, request)
|
||||
}
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
|
||||
func (m *mockRootCoord) DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) {
|
||||
if m.DropCollectionFunc != nil {
|
||||
return m.DropCollectionFunc(ctx, request)
|
||||
}
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
|
||||
func newMockRootCoord() *mockRootCoord {
|
||||
return &mockRootCoord{}
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ type task interface {
|
||||
|
||||
type dmlTask interface {
|
||||
task
|
||||
getChannels() ([]vChan, error)
|
||||
getChannels() ([]pChan, error)
|
||||
getPChanStats() (map[pChan]pChanStatistics, error)
|
||||
}
|
||||
|
||||
@ -192,16 +192,7 @@ func (it *insertTask) getChannels() ([]pChan, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var channels []pChan
|
||||
channels, err = it.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
err = it.chMgr.createDMLMsgStream(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channels, err = it.chMgr.getChannels(collID)
|
||||
}
|
||||
return channels, err
|
||||
return it.chMgr.getChannels(collID)
|
||||
}
|
||||
|
||||
func (it *insertTask) OnEnqueue() error {
|
||||
@ -508,18 +499,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
||||
|
||||
stream, err := it.chMgr.getDMLStream(collID)
|
||||
if err != nil {
|
||||
err = it.chMgr.createDMLMsgStream(collID)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
stream, err = it.chMgr.getDMLStream(collID)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
tr.Record("get used message stream")
|
||||
|
||||
@ -531,6 +511,14 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("send insert request to virtual channels",
|
||||
zap.String("collection", it.GetCollectionName()),
|
||||
zap.String("partition", it.GetPartitionName()),
|
||||
zap.Int64("collection_id", collID),
|
||||
zap.Int64("partition_id", partitionID),
|
||||
zap.Strings("virtual_channels", channelNames),
|
||||
zap.Int64("task_id", it.ID()))
|
||||
|
||||
// assign segmentID for insert data and repack data by segmentID
|
||||
msgPack, err := it.assignSegmentID(channelNames)
|
||||
if err != nil {
|
||||
@ -3141,16 +3129,7 @@ func (dt *deleteTask) getChannels() ([]pChan, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var channels []pChan
|
||||
channels, err = dt.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
err = dt.chMgr.createDMLMsgStream(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channels, err = dt.chMgr.getChannels(collID)
|
||||
}
|
||||
return channels, err
|
||||
return dt.chMgr.getChannels(collID)
|
||||
}
|
||||
|
||||
func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) {
|
||||
@ -3283,18 +3262,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
||||
collID := dt.DeleteRequest.CollectionID
|
||||
stream, err := dt.chMgr.getDMLStream(collID)
|
||||
if err != nil {
|
||||
err = dt.chMgr.createDMLMsgStream(collID)
|
||||
if err != nil {
|
||||
dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
dt.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
stream, err = dt.chMgr.getDMLStream(collID)
|
||||
if err != nil {
|
||||
dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
dt.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// hash primary keys to channels
|
||||
@ -3307,6 +3275,12 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
||||
}
|
||||
dt.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames)
|
||||
|
||||
log.Info("send delete request to virtual channels",
|
||||
zap.String("collection", dt.GetCollectionName()),
|
||||
zap.Int64("collection_id", collID),
|
||||
zap.Strings("virtual_channels", channelNames),
|
||||
zap.Int64("task_id", dt.ID()))
|
||||
|
||||
tr.Record("get vchannels")
|
||||
// repack delete msg by dmChannel
|
||||
result := make(map[uint32]msgstream.TsMsg)
|
||||
|
@ -27,15 +27,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
@ -1087,47 +1087,12 @@ func TestCreateCollectionTask(t *testing.T) {
|
||||
|
||||
func TestDropCollectionTask(t *testing.T) {
|
||||
Params.Init()
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
qc := NewQueryCoordMock()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(rc, qc, mgr)
|
||||
|
||||
master := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
channelMgr := newChannelsMgrImpl(master.GetChannels, nil, nil, factory)
|
||||
defer channelMgr.removeAllDMLStream()
|
||||
|
||||
prefix := "TestDropCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
ctx := context.Background()
|
||||
|
||||
shardsNum := int32(2)
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
|
||||
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createColReq := &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_DropCollection,
|
||||
MsgID: 100,
|
||||
Timestamp: 100,
|
||||
},
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
}
|
||||
|
||||
//CreateCollection
|
||||
task := &dropCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DropCollectionRequest: &milvuspb.DropCollectionRequest{
|
||||
@ -1139,38 +1104,58 @@ func TestDropCollectionTask(t *testing.T) {
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
ctx: ctx,
|
||||
chMgr: channelMgr,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
ctx: ctx,
|
||||
result: nil,
|
||||
}
|
||||
task.PreExecute(ctx)
|
||||
|
||||
assert.Equal(t, commonpb.MsgType_DropCollection, task.Type())
|
||||
task.SetID(100)
|
||||
assert.Equal(t, UniqueID(100), task.ID())
|
||||
assert.Equal(t, DropCollectionTaskName, task.Name())
|
||||
assert.Equal(t, commonpb.MsgType_DropCollection, task.Type())
|
||||
task.SetTs(100)
|
||||
assert.Equal(t, Timestamp(100), task.BeginTs())
|
||||
assert.Equal(t, Timestamp(100), task.EndTs())
|
||||
|
||||
err := task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, Params.ProxyCfg.GetNodeID(), task.GetBase().GetSourceID())
|
||||
// missing collectionID in globalMetaCache
|
||||
err = task.Execute(ctx)
|
||||
assert.NotNil(t, err)
|
||||
// createCollection in RootCood and fill GlobalMetaCache
|
||||
rc.CreateCollection(ctx, createColReq)
|
||||
globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
|
||||
// success to drop collection
|
||||
err = task.Execute(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// illegal name
|
||||
task.CollectionName = "#0xc0de"
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
assert.Error(t, err)
|
||||
task.CollectionName = collectionName
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
cache := newMockCache()
|
||||
chMgr := newMockChannelsMgr()
|
||||
rc := newMockRootCoord()
|
||||
|
||||
globalMetaCache = cache
|
||||
task.chMgr = chMgr
|
||||
task.rootCoord = rc
|
||||
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return 0, errors.New("mock")
|
||||
})
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return 0, nil
|
||||
})
|
||||
|
||||
rc.DropCollectionFunc = func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
rc.DropCollectionFunc = func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
|
||||
}
|
||||
|
||||
// normal case
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = task.PostExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHasCollectionTask(t *testing.T) {
|
||||
@ -1728,7 +1713,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, nil, factory)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
err = chMgr.createDMLMsgStream(collectionID)
|
||||
@ -1983,7 +1968,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, nil, factory)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
err = chMgr.createDMLMsgStream(collectionID)
|
||||
|
Loading…
Reference in New Issue
Block a user