milvus/internal/msgstream/mem_msgstream.go
Xiangyu Wang 87a1e0b662 Reorganize msgstream
Signed-off-by: Xiangyu Wang <xiangyu.wang@zilliz.com>
2021-04-02 13:48:25 +08:00

194 lines
4.2 KiB
Go

package msgstream
import (
"context"
"errors"
"log"
"strconv"
"sync"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
)
type MemMsgStream struct {
ctx context.Context
streamCancel func()
repackFunc RepackFunc
consumers []*MemConsumer
producers []string
receiveBuf chan *MsgPack
wait sync.WaitGroup
}
func NewMemMsgStream(ctx context.Context, receiveBufSize int64) (*MemMsgStream, error) {
streamCtx, streamCancel := context.WithCancel(ctx)
receiveBuf := make(chan *MsgPack, receiveBufSize)
channels := make([]string, 0)
consumers := make([]*MemConsumer, 0)
stream := &MemMsgStream{
ctx: streamCtx,
streamCancel: streamCancel,
receiveBuf: receiveBuf,
consumers: consumers,
producers: channels,
}
return stream, nil
}
func (mms *MemMsgStream) Start() {
}
func (mms *MemMsgStream) Close() {
for _, consumer := range mms.consumers {
Mmq.DestroyConsumerGroup(consumer.GroupName, consumer.ChannelName)
}
mms.streamCancel()
mms.wait.Wait()
}
func (mms *MemMsgStream) SetRepackFunc(repackFunc RepackFunc) {
mms.repackFunc = repackFunc
}
func (mms *MemMsgStream) AsProducer(channels []string) {
for _, channel := range channels {
err := Mmq.CreateChannel(channel)
if err == nil {
mms.producers = append(mms.producers, channel)
} else {
errMsg := "Failed to create producer " + channel + ", error = " + err.Error()
panic(errMsg)
}
}
}
func (mms *MemMsgStream) AsConsumer(channels []string, groupName string) {
for _, channelName := range channels {
consumer, err := Mmq.CreateConsumerGroup(groupName, channelName)
if err == nil {
mms.consumers = append(mms.consumers, consumer)
mms.wait.Add(1)
go mms.receiveMsg(*consumer)
}
}
}
func (mms *MemMsgStream) Produce(pack *MsgPack) error {
tsMsgs := pack.Msgs
if len(tsMsgs) <= 0 {
log.Printf("Warning: Receive empty msgPack")
return nil
}
if len(mms.producers) <= 0 {
return errors.New("nil producer in msg stream")
}
reBucketValues := make([][]int32, len(tsMsgs))
for channelID, tsMsg := range tsMsgs {
hashValues := tsMsg.HashKeys()
bucketValues := make([]int32, len(hashValues))
for index, hashValue := range hashValues {
if tsMsg.Type() == commonpb.MsgType_SearchResult {
searchResult := tsMsg.(*SearchResultMsg)
channelID := searchResult.ResultChannelID
channelIDInt, _ := strconv.ParseInt(channelID, 10, 64)
if channelIDInt >= int64(len(mms.producers)) {
return errors.New("Failed to produce rmq msg to unKnow channel")
}
bucketValues[index] = int32(channelIDInt)
continue
}
bucketValues[index] = int32(hashValue % uint32(len(mms.producers)))
}
reBucketValues[channelID] = bucketValues
}
var result map[int32]*MsgPack
var err error
if mms.repackFunc != nil {
result, err = mms.repackFunc(tsMsgs, reBucketValues)
} else {
msgType := (tsMsgs[0]).Type()
switch msgType {
case commonpb.MsgType_Insert:
result, err = InsertRepackFunc(tsMsgs, reBucketValues)
case commonpb.MsgType_Delete:
result, err = DeleteRepackFunc(tsMsgs, reBucketValues)
default:
result, err = DefaultRepackFunc(tsMsgs, reBucketValues)
}
}
if err != nil {
return err
}
for k, v := range result {
err := Mmq.Produce(mms.producers[k], v)
if err != nil {
return err
}
}
return nil
}
func (mms *MemMsgStream) Broadcast(msgPack *MsgPack) error {
for _, channelName := range mms.producers {
err := Mmq.Produce(channelName, msgPack)
if err != nil {
return err
}
}
return nil
}
func (mms *MemMsgStream) Consume() *MsgPack {
for {
select {
case cm, ok := <-mms.receiveBuf:
if !ok {
log.Println("buf chan closed")
return nil
}
return cm
case <-mms.ctx.Done():
log.Printf("context closed")
return nil
}
}
}
/**
receiveMsg func is used to solve search timeout problem
which is caused by selectcase
*/
func (mms *MemMsgStream) receiveMsg(consumer MemConsumer) {
defer mms.wait.Done()
for {
select {
case <-mms.ctx.Done():
return
case msg := <-consumer.MsgChan:
if msg == nil {
return
}
mms.receiveBuf <- msg
}
}
}
func (mms *MemMsgStream) Chan() <-chan *MsgPack {
return mms.receiveBuf
}
func (mms *MemMsgStream) Seek(offset *MsgPosition) error {
return errors.New("MemMsgStream seek not implemented")
}