mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-04 04:49:08 +08:00
87a1e0b662
Signed-off-by: Xiangyu Wang <xiangyu.wang@zilliz.com>
194 lines
4.2 KiB
Go
194 lines
4.2 KiB
Go
package msgstream
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
|
)
|
|
|
|
type MemMsgStream struct {
|
|
ctx context.Context
|
|
streamCancel func()
|
|
|
|
repackFunc RepackFunc
|
|
|
|
consumers []*MemConsumer
|
|
producers []string
|
|
|
|
receiveBuf chan *MsgPack
|
|
|
|
wait sync.WaitGroup
|
|
}
|
|
|
|
func NewMemMsgStream(ctx context.Context, receiveBufSize int64) (*MemMsgStream, error) {
|
|
streamCtx, streamCancel := context.WithCancel(ctx)
|
|
receiveBuf := make(chan *MsgPack, receiveBufSize)
|
|
channels := make([]string, 0)
|
|
consumers := make([]*MemConsumer, 0)
|
|
|
|
stream := &MemMsgStream{
|
|
ctx: streamCtx,
|
|
streamCancel: streamCancel,
|
|
receiveBuf: receiveBuf,
|
|
consumers: consumers,
|
|
producers: channels,
|
|
}
|
|
|
|
return stream, nil
|
|
}
|
|
|
|
func (mms *MemMsgStream) Start() {
|
|
}
|
|
|
|
func (mms *MemMsgStream) Close() {
|
|
for _, consumer := range mms.consumers {
|
|
Mmq.DestroyConsumerGroup(consumer.GroupName, consumer.ChannelName)
|
|
}
|
|
|
|
mms.streamCancel()
|
|
mms.wait.Wait()
|
|
}
|
|
|
|
func (mms *MemMsgStream) SetRepackFunc(repackFunc RepackFunc) {
|
|
mms.repackFunc = repackFunc
|
|
}
|
|
|
|
func (mms *MemMsgStream) AsProducer(channels []string) {
|
|
for _, channel := range channels {
|
|
err := Mmq.CreateChannel(channel)
|
|
if err == nil {
|
|
mms.producers = append(mms.producers, channel)
|
|
} else {
|
|
errMsg := "Failed to create producer " + channel + ", error = " + err.Error()
|
|
panic(errMsg)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mms *MemMsgStream) AsConsumer(channels []string, groupName string) {
|
|
for _, channelName := range channels {
|
|
consumer, err := Mmq.CreateConsumerGroup(groupName, channelName)
|
|
if err == nil {
|
|
mms.consumers = append(mms.consumers, consumer)
|
|
|
|
mms.wait.Add(1)
|
|
go mms.receiveMsg(*consumer)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mms *MemMsgStream) Produce(pack *MsgPack) error {
|
|
tsMsgs := pack.Msgs
|
|
if len(tsMsgs) <= 0 {
|
|
log.Printf("Warning: Receive empty msgPack")
|
|
return nil
|
|
}
|
|
if len(mms.producers) <= 0 {
|
|
return errors.New("nil producer in msg stream")
|
|
}
|
|
reBucketValues := make([][]int32, len(tsMsgs))
|
|
for channelID, tsMsg := range tsMsgs {
|
|
hashValues := tsMsg.HashKeys()
|
|
bucketValues := make([]int32, len(hashValues))
|
|
for index, hashValue := range hashValues {
|
|
if tsMsg.Type() == commonpb.MsgType_SearchResult {
|
|
searchResult := tsMsg.(*SearchResultMsg)
|
|
channelID := searchResult.ResultChannelID
|
|
channelIDInt, _ := strconv.ParseInt(channelID, 10, 64)
|
|
if channelIDInt >= int64(len(mms.producers)) {
|
|
return errors.New("Failed to produce rmq msg to unKnow channel")
|
|
}
|
|
bucketValues[index] = int32(channelIDInt)
|
|
continue
|
|
}
|
|
bucketValues[index] = int32(hashValue % uint32(len(mms.producers)))
|
|
}
|
|
reBucketValues[channelID] = bucketValues
|
|
}
|
|
|
|
var result map[int32]*MsgPack
|
|
var err error
|
|
if mms.repackFunc != nil {
|
|
result, err = mms.repackFunc(tsMsgs, reBucketValues)
|
|
} else {
|
|
msgType := (tsMsgs[0]).Type()
|
|
switch msgType {
|
|
case commonpb.MsgType_Insert:
|
|
result, err = InsertRepackFunc(tsMsgs, reBucketValues)
|
|
case commonpb.MsgType_Delete:
|
|
result, err = DeleteRepackFunc(tsMsgs, reBucketValues)
|
|
default:
|
|
result, err = DefaultRepackFunc(tsMsgs, reBucketValues)
|
|
}
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for k, v := range result {
|
|
err := Mmq.Produce(mms.producers[k], v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (mms *MemMsgStream) Broadcast(msgPack *MsgPack) error {
|
|
for _, channelName := range mms.producers {
|
|
err := Mmq.Produce(channelName, msgPack)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (mms *MemMsgStream) Consume() *MsgPack {
|
|
for {
|
|
select {
|
|
case cm, ok := <-mms.receiveBuf:
|
|
if !ok {
|
|
log.Println("buf chan closed")
|
|
return nil
|
|
}
|
|
return cm
|
|
case <-mms.ctx.Done():
|
|
log.Printf("context closed")
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
receiveMsg func is used to solve search timeout problem
|
|
which is caused by selectcase
|
|
*/
|
|
func (mms *MemMsgStream) receiveMsg(consumer MemConsumer) {
|
|
defer mms.wait.Done()
|
|
for {
|
|
select {
|
|
case <-mms.ctx.Done():
|
|
return
|
|
case msg := <-consumer.MsgChan:
|
|
if msg == nil {
|
|
return
|
|
}
|
|
|
|
mms.receiveBuf <- msg
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mms *MemMsgStream) Chan() <-chan *MsgPack {
|
|
return mms.receiveBuf
|
|
}
|
|
|
|
func (mms *MemMsgStream) Seek(offset *MsgPosition) error {
|
|
return errors.New("MemMsgStream seek not implemented")
|
|
}
|