diff --git a/internal/mq/msgstream/mq_msgstream.go b/internal/mq/msgstream/mq_msgstream.go index 1f9f4f8274..429dec7bf8 100644 --- a/internal/mq/msgstream/mq_msgstream.go +++ b/internal/mq/msgstream/mq_msgstream.go @@ -669,6 +669,10 @@ func (ms *MqTtMsgStream) Close() { ms.mqMsgStream.Close() } +func isDMLMsg(msg TsMsg) bool { + return msg.Type() == commonpb.MsgType_Insert || msg.Type() == commonpb.MsgType_Delete +} + func (ms *MqTtMsgStream) bufMsgPackToChannel() { ms.closeRWMutex.RLock() defer ms.closeRWMutex.RUnlock() @@ -760,7 +764,7 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { idset := make(typeutil.UniqueSet) uniqueMsgs := make([]TsMsg, 0, len(timeTickBuf)) for _, msg := range timeTickBuf { - if idset.Contain(msg.ID()) { + if isDMLMsg(msg) && idset.Contain(msg.ID()) { log.Warn("mqTtMsgStream, found duplicated msg", zap.Int64("msgID", msg.ID())) continue } diff --git a/internal/mq/msgstream/mq_msgstream_test.go b/internal/mq/msgstream/mq_msgstream_test.go index 76dee14816..f5883b1b12 100644 --- a/internal/mq/msgstream/mq_msgstream_test.go +++ b/internal/mq/msgstream/mq_msgstream_test.go @@ -1570,8 +1570,13 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) + // would not dedup for non-dml messages msgPack2 := MsgPack{} - msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(15)) + msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_Search, 2)) + msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_Search, 2)) + + msgPack3 := MsgPack{} + msgPack3.Msgs = append(msgPack3.Msgs, getTimeTickMsg(15)) ctx := context.Background() inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) @@ -1580,11 +1585,13 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { assert.Nil(t, err) err = inputStream.Produce(&msgPack1) assert.Nil(t, err) - err = inputStream.Broadcast(&msgPack2) + err = inputStream.Produce(&msgPack2) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack3) assert.Nil(t, err) receivedMsg := consumer(ctx, outputStream) - assert.Equal(t, len(receivedMsg.Msgs), 1) + assert.Equal(t, len(receivedMsg.Msgs), 3) assert.Equal(t, receivedMsg.BeginTs, uint64(0)) assert.Equal(t, receivedMsg.EndTs, uint64(15)) @@ -1600,13 +1607,12 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { outputStream.Seek(receivedMsg.StartPositions) outputStream.Start() seekMsg := consumer(ctx, outputStream) - assert.Equal(t, len(seekMsg.Msgs), 1) - for _, msg := range seekMsg.Msgs { - assert.EqualValues(t, msg.BeginTs(), 1) - } + assert.Equal(t, len(seekMsg.Msgs), 1+2) + assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1) + assert.Equal(t, commonpb.MsgType_Search, seekMsg.Msgs[1].Type()) + assert.Equal(t, commonpb.MsgType_Search, seekMsg.Msgs[2].Type()) Close(rocksdbName, inputStream, outputStream, etcdKV) - } func TestStream_RmqTtMsgStream_Seek(t *testing.T) {