mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 11:29:48 +08:00
Add ProduceMark interface to return messageID (#9556)
Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
parent
223c330ed5
commit
e99ecc8cab
@ -73,6 +73,9 @@ func (mtm *mockTtMsgStream) GetProduceChannels() []string {
|
|||||||
func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error {
|
func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (mtm *mockTtMsgStream) ProduceMark(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||||
|
return map[string][]msgstream.MessageID{}, nil
|
||||||
|
}
|
||||||
func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) error {
|
func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -266,6 +266,74 @@ func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProduceMark send msg pack to all producers and returns corresponding msg id
|
||||||
|
// the returned message id serves as marking
|
||||||
|
func (ms *mqMsgStream) ProduceMark(msgPack *MsgPack) (map[string][]MessageID, error) {
|
||||||
|
ids := make(map[string][]MessageID)
|
||||||
|
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
||||||
|
return ids, errors.New("empty msgs")
|
||||||
|
}
|
||||||
|
if len(ms.producers) <= 0 {
|
||||||
|
return ids, errors.New("nil producer in msg stream")
|
||||||
|
}
|
||||||
|
tsMsgs := msgPack.Msgs
|
||||||
|
reBucketValues := ms.ComputeProduceChannelIndexes(msgPack.Msgs)
|
||||||
|
var result map[int32]*MsgPack
|
||||||
|
var err error
|
||||||
|
if ms.repackFunc != nil {
|
||||||
|
result, err = ms.repackFunc(tsMsgs, reBucketValues)
|
||||||
|
} else {
|
||||||
|
msgType := (tsMsgs[0]).Type()
|
||||||
|
switch msgType {
|
||||||
|
case commonpb.MsgType_Insert:
|
||||||
|
result, err = InsertRepackFunc(tsMsgs, reBucketValues)
|
||||||
|
case commonpb.MsgType_Delete:
|
||||||
|
result, err = DeleteRepackFunc(tsMsgs, reBucketValues)
|
||||||
|
default:
|
||||||
|
result, err = DefaultRepackFunc(tsMsgs, reBucketValues)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return ids, err
|
||||||
|
}
|
||||||
|
for k, v := range result {
|
||||||
|
channel := ms.producerChannels[k]
|
||||||
|
for i, tsMsg := range v.Msgs {
|
||||||
|
sp, spanCtx := MsgSpanFromCtx(v.Msgs[i].TraceCtx(), tsMsg)
|
||||||
|
|
||||||
|
mb, err := tsMsg.Marshal(tsMsg)
|
||||||
|
if err != nil {
|
||||||
|
return ids, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := convertToByteArray(mb)
|
||||||
|
if err != nil {
|
||||||
|
return ids, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := &mqclient.ProducerMessage{Payload: m, Properties: map[string]string{}}
|
||||||
|
|
||||||
|
trace.InjectContextToPulsarMsgProperties(sp.Context(), msg.Properties)
|
||||||
|
|
||||||
|
ms.producerLock.Lock()
|
||||||
|
id, err := ms.producers[channel].Send(
|
||||||
|
spanCtx,
|
||||||
|
msg,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
ms.producerLock.Unlock()
|
||||||
|
trace.LogError(sp, err)
|
||||||
|
sp.Finish()
|
||||||
|
return ids, err
|
||||||
|
}
|
||||||
|
ids[channel] = append(ids[channel], id)
|
||||||
|
sp.Finish()
|
||||||
|
ms.producerLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Broadcast put msgPack to all producer in current msgstream
|
// Broadcast put msgPack to all producer in current msgstream
|
||||||
// which ignores repackFunc logic
|
// which ignores repackFunc logic
|
||||||
func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) error {
|
func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) error {
|
||||||
|
@ -1263,7 +1263,73 @@ func TestStream_BroadcastMark(t *testing.T) {
|
|||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
outputStream.Close()
|
outputStream.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStream_ProduceMark(t *testing.T) {
|
||||||
|
pulsarAddress, _ := Params.Load("_PulsarAddress")
|
||||||
|
c1 := funcutil.RandomString(8)
|
||||||
|
c2 := funcutil.RandomString(8)
|
||||||
|
producerChannels := []string{c1, c2}
|
||||||
|
|
||||||
|
factory := ProtoUDFactory{}
|
||||||
|
pulsarClient, err := mqclient.GetPulsarClientInstance(pulsar.ClientOptions{URL: pulsarAddress})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
outputStream, err := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
// add producer channels
|
||||||
|
outputStream.AsProducer(producerChannels)
|
||||||
|
outputStream.Start()
|
||||||
|
|
||||||
|
msgPack0 := MsgPack{}
|
||||||
|
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
|
||||||
|
|
||||||
|
ids, err := outputStream.ProduceMark(&msgPack0)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, ids)
|
||||||
|
assert.Equal(t, len(msgPack0.Msgs), len(ids))
|
||||||
|
for _, c := range producerChannels {
|
||||||
|
if id, ok := ids[c]; ok {
|
||||||
|
assert.Equal(t, len(msgPack0.Msgs), len(id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgPack1 := MsgPack{}
|
||||||
|
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
|
||||||
|
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 2))
|
||||||
|
|
||||||
|
ids, err = outputStream.ProduceMark(&msgPack1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, ids)
|
||||||
|
assert.Equal(t, len(producerChannels), len(ids))
|
||||||
|
for _, c := range producerChannels {
|
||||||
|
ids, ok := ids[c]
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, 1, len(ids))
|
||||||
|
}
|
||||||
|
|
||||||
|
// edge cases
|
||||||
|
_, err = outputStream.ProduceMark(nil)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
msgPack2 := MsgPack{}
|
||||||
|
msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{BaseMsg: BaseMsg{HashValues: []uint32{1}}})
|
||||||
|
_, err = outputStream.ProduceMark(&msgPack2)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
// mock send fail
|
||||||
|
for k, p := range outputStream.producers {
|
||||||
|
outputStream.producers[k] = &mockSendFailProducer{Producer: p}
|
||||||
|
}
|
||||||
|
_, err = outputStream.ProduceMark(&msgPack1)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
// mock producers is nil
|
||||||
|
outputStream.producers = nil
|
||||||
|
_, err = outputStream.ProduceMark(&msgPack1)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
outputStream.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TsMsg = (*MarshalFailTsMsg)(nil)
|
var _ TsMsg = (*MarshalFailTsMsg)(nil)
|
||||||
|
@ -57,6 +57,7 @@ type MsgStream interface {
|
|||||||
ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32
|
ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32
|
||||||
GetProduceChannels() []string
|
GetProduceChannels() []string
|
||||||
Produce(*MsgPack) error
|
Produce(*MsgPack) error
|
||||||
|
ProduceMark(*MsgPack) (map[string][]MessageID, error)
|
||||||
Broadcast(*MsgPack) error
|
Broadcast(*MsgPack) error
|
||||||
BroadcastMark(*MsgPack) (map[string][]MessageID, error)
|
BroadcastMark(*MsgPack) (map[string][]MessageID, error)
|
||||||
Consume() *MsgPack
|
Consume() *MsgPack
|
||||||
|
@ -325,6 +325,13 @@ func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ms *simpleMockMsgStream) ProduceMark(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||||
|
defer ms.increaseMsgCount(1)
|
||||||
|
ms.msgChan <- pack
|
||||||
|
|
||||||
|
return map[string][]msgstream.MessageID{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) error {
|
func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user