diff --git a/internal/msgstream/rmqms/rmq_msgstream.go b/internal/msgstream/rmqms/rmq_msgstream.go index f5db14bcbf..75455b68c0 100644 --- a/internal/msgstream/rmqms/rmq_msgstream.go +++ b/internal/msgstream/rmqms/rmq_msgstream.go @@ -84,7 +84,7 @@ func (ms *RmqMsgStream) Close() { } for _, consumer := range ms.consumers { _ = rocksmq.Rmq.DestroyConsumerGroup(consumer.GroupName, consumer.ChannelName) - close(consumer.MsgNum) + close(consumer.MsgMutex) } } @@ -112,12 +112,13 @@ func (ms *RmqMsgStream) AsConsumer(channels []string, groupName string) { for _, channelName := range channels { consumer, err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName) if err == nil { - consumer.MsgNum = make(chan int, ms.rmqBufSize) + consumer.MsgMutex = make(chan struct{}, ms.rmqBufSize) + //consumer.MsgMutex <- struct{}{} ms.consumers = append(ms.consumers, *consumer) ms.consumerChannels = append(ms.consumerChannels, channelName) ms.consumerReflects = append(ms.consumerReflects, reflect.SelectCase{ Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(consumer.MsgNum), + Chan: reflect.ValueOf(consumer.MsgMutex), }) ms.wait.Add(1) go ms.receiveMsg(*consumer) @@ -244,30 +245,35 @@ func (ms *RmqMsgStream) receiveMsg(consumer rocksmq.Consumer) { select { case <-ms.ctx.Done(): return - case msgNum, ok := <-consumer.MsgNum: + case _, ok := <-consumer.MsgMutex: if !ok { return } - rmqMsg, err := rocksmq.Rmq.Consume(consumer.GroupName, consumer.ChannelName, msgNum) - if err != nil { - log.Printf("Failed to consume message in rocksmq, error = %v", err) - continue - } tsMsgList := make([]msgstream.TsMsg, 0) - for j := 0; j < len(rmqMsg); j++ { - headerMsg := commonpb.MsgHeader{} - err := proto.Unmarshal(rmqMsg[j].Payload, &headerMsg) + for { + rmqMsgs, err := rocksmq.Rmq.Consume(consumer.GroupName, consumer.ChannelName, 1) if err != nil { - log.Printf("Failed to unmarshal message header, error = %v", err) + log.Printf("Failed to consume message in rocksmq, error = %v", err) continue } - tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg[j].Payload, headerMsg.Base.MsgType) + if len(rmqMsgs) == 0 { + break + } + rmqMsg := rmqMsgs[0] + headerMsg := commonpb.MsgHeader{} + err = proto.Unmarshal(rmqMsg.Payload, &headerMsg) + if err != nil { + log.Printf("Failed to unmar`shal message header, error = %v", err) + continue + } + tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg.Payload, headerMsg.Base.MsgType) if err != nil { log.Printf("Failed to unmarshal tsMsg, error = %v", err) continue } tsMsgList = append(tsMsgList, tsMsg) } + if len(tsMsgList) > 0 { msgPack := util.MsgPack{Msgs: tsMsgList} ms.receiveBuf <- &msgPack @@ -326,12 +332,13 @@ func (ms *RmqTtMsgStream) AsConsumer(channels []string, if err != nil { panic(err.Error()) } - consumer.MsgNum = make(chan int, ms.rmqBufSize) + consumer.MsgMutex = make(chan struct{}, ms.rmqBufSize) + //consumer.MsgMutex <- struct{}{} ms.consumers = append(ms.consumers, *consumer) ms.consumerChannels = append(ms.consumerChannels, consumer.ChannelName) ms.consumerReflects = append(ms.consumerReflects, reflect.SelectCase{ Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(consumer.MsgNum), + Chan: reflect.ValueOf(consumer.MsgMutex), }) } } @@ -432,25 +439,28 @@ func (ms *RmqTtMsgStream) findTimeTick(consumer rocksmq.Consumer, select { case <-ms.ctx.Done(): return - case num, ok := <-consumer.MsgNum: + case _, ok := <-consumer.MsgMutex: if !ok { log.Printf("consumer closed!") return } - rmqMsg, err := rocksmq.Rmq.Consume(consumer.GroupName, consumer.ChannelName, num) - if err != nil { - log.Printf("Failed to consume message in rocksmq, error = %v", err) - continue - } - - for j := 0; j < len(rmqMsg); j++ { + for { + rmqMsgs, err := rocksmq.Rmq.Consume(consumer.GroupName, consumer.ChannelName, 1) + if err != nil { + log.Printf("Failed to consume message in rocksmq, error = %v", err) + continue + } + if len(rmqMsgs) == 0 { + return + } + rmqMsg := rmqMsgs[0] headerMsg := commonpb.MsgHeader{} - err := proto.Unmarshal(rmqMsg[j].Payload, &headerMsg) + err = proto.Unmarshal(rmqMsg.Payload, &headerMsg) if err != nil { log.Printf("Failed to unmarshal message header, error = %v", err) continue } - tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg[j].Payload, headerMsg.Base.MsgType) + tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg.Payload, headerMsg.Base.MsgType) if err != nil { log.Printf("Failed to unmarshal tsMsg, error = %v", err) continue @@ -458,7 +468,7 @@ func (ms *RmqTtMsgStream) findTimeTick(consumer rocksmq.Consumer, tsMsg.SetPosition(&msgstream.MsgPosition{ ChannelName: filepath.Base(consumer.ChannelName), - MsgID: strconv.Itoa(int(rmqMsg[j].MsgID)), + MsgID: strconv.Itoa(int(rmqMsg.MsgID)), }) ms.unsolvedMutex.Lock() @@ -469,7 +479,8 @@ func (ms *RmqTtMsgStream) findTimeTick(consumer rocksmq.Consumer, findMapMutex.Lock() eofMsgMap[consumer] = tsMsg.(*TimeTickMsg).Base.Timestamp findMapMutex.Unlock() - return + //consumer.MsgMutex <- struct{}{} + //return } } } @@ -504,8 +515,8 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { ms.unsolvedMutex.Lock() ms.unsolvedBuf[consumer] = make([]TsMsg, 0) - // When rmq seek is called, msgNum can't be used before current msgs all consumed, because - // new msgNum is not generated. So just try to consume msgs + // When rmq seek is called, msgMutex can't be used before current msgs all consumed, because + // new msgMutex is not generated. So just try to consume msgs for { rmqMsg, err := rocksmq.Rmq.Consume(consumer.GroupName, consumer.ChannelName, 1) if err != nil { diff --git a/internal/util/rocksmq/global_rmq.go b/internal/util/rocksmq/global_rmq.go index 519de496df..072d61f0e4 100644 --- a/internal/util/rocksmq/global_rmq.go +++ b/internal/util/rocksmq/global_rmq.go @@ -15,7 +15,7 @@ var once sync.Once type Consumer struct { GroupName string ChannelName string - MsgNum chan int + MsgMutex chan struct{} } func InitRmq(rocksdbName string, idAllocator allocator.GIDAllocator) error { diff --git a/internal/util/rocksmq/rocksmq.go b/internal/util/rocksmq/rocksmq.go index 1c96bcaf52..2e29edf0e1 100644 --- a/internal/util/rocksmq/rocksmq.go +++ b/internal/util/rocksmq/rocksmq.go @@ -247,8 +247,8 @@ func (rmq *RocksMQ) Produce(channelName string, messages []ProducerMessage) erro } for _, consumer := range rmq.notify[channelName] { - if consumer.MsgNum != nil { - consumer.MsgNum <- msgLen + if consumer.MsgMutex != nil { + consumer.MsgMutex <- struct{}{} } } return nil @@ -308,6 +308,7 @@ func (rmq *RocksMQ) Consume(groupName string, channelName string, n int) ([]Cons return nil, err } + // When already consume to last mes, an empty slice will be returned if len(consumerMessage) == 0 { return consumerMessage, nil }