milvus/internal/msgstream/pulsarms/msg_test.go
neza2017 7b2d67242d Fix describe segment if index not exist
Signed-off-by: neza2017 <yefu.chen@zilliz.com>
2021-03-10 14:45:35 +08:00

183 lines
5.1 KiB
Go

package pulsarms
import (
"context"
"fmt"
"log"
"testing"
"errors"
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
)
type MarshalType = msgstream.MarshalType
type InsertTask struct {
Tag string
msgstream.InsertMsg
}
func (tt *InsertTask) Marshal(input msgstream.TsMsg) (MarshalType, error) {
testMsg := input.(*InsertTask)
insertRequest := &testMsg.InsertRequest
mb, err := proto.Marshal(insertRequest)
if err != nil {
return nil, err
}
return mb, nil
}
func (tt *InsertTask) Unmarshal(input MarshalType) (msgstream.TsMsg, error) {
insertRequest := internalpb2.InsertRequest{}
in, err := msgstream.ConvertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, &insertRequest)
testMsg := &InsertTask{InsertMsg: msgstream.InsertMsg{InsertRequest: insertRequest}}
testMsg.Tag = testMsg.InsertRequest.PartitionName
if err != nil {
return nil, err
}
return testMsg, nil
}
func newRepackFunc(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
result := make(map[int32]*msgstream.MsgPack)
for i, request := range tsMsgs {
if request.Type() != commonpb.MsgType_Insert {
return nil, errors.New("msg's must be Insert")
}
insertRequest := request.(*InsertTask).InsertRequest
keys := hashKeys[i]
timestampLen := len(insertRequest.Timestamps)
rowIDLen := len(insertRequest.RowIDs)
rowDataLen := len(insertRequest.RowData)
keysLen := len(keys)
if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen {
return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal")
}
for index, key := range keys {
_, ok := result[key]
if !ok {
msgPack := msgstream.MsgPack{}
result[key] = &msgPack
}
sliceRequest := internalpb2.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: insertRequest.Base.MsgID,
Timestamp: insertRequest.Timestamps[index],
SourceID: insertRequest.Base.SourceID,
},
CollectionName: insertRequest.CollectionName,
PartitionName: insertRequest.PartitionName,
SegmentID: insertRequest.SegmentID,
ChannelID: insertRequest.ChannelID,
Timestamps: []msgstream.Timestamp{insertRequest.Timestamps[index]},
RowIDs: []int64{insertRequest.RowIDs[index]},
RowData: []*commonpb.Blob{insertRequest.RowData[index]},
}
insertMsg := &InsertTask{
InsertMsg: msgstream.InsertMsg{InsertRequest: sliceRequest},
}
result[key].Msgs = append(result[key].Msgs, insertMsg)
}
}
return result, nil
}
func getInsertTask(reqID msgstream.UniqueID, hashValue uint32) msgstream.TsMsg {
baseMsg := msgstream.BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{hashValue},
}
insertRequest := internalpb2.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: reqID,
Timestamp: 1,
SourceID: 1,
},
CollectionName: "Collection",
PartitionName: "Partition",
SegmentID: 1,
ChannelID: "1",
Timestamps: []msgstream.Timestamp{1},
RowIDs: []int64{1},
RowData: []*commonpb.Blob{{}},
}
insertMsg := msgstream.InsertMsg{
BaseMsg: baseMsg,
InsertRequest: insertRequest,
}
testTask := &InsertTask{
InsertMsg: insertMsg,
}
return testTask
}
func TestStream_task_Insert(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
consumerSubName := "subInsert"
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getInsertTask(1, 1))
msgPack.Msgs = append(msgPack.Msgs, getInsertTask(3, 3))
factory := msgstream.ProtoUDFactory{}
inputStream, _ := newPulsarMsgStream(context.Background(), pulsarAddress, 100, 100, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.SetRepackFunc(newRepackFunc)
inputStream.Start()
dispatcher := factory.NewUnmarshalDispatcher()
outputStream, _ := newPulsarMsgStream(context.Background(), pulsarAddress, 100, 100, dispatcher)
testTask := InsertTask{}
dispatcher.AddMsgTemplate(commonpb.MsgType_Insert, testTask.Unmarshal)
outputStream.AsConsumer(consumerChannels, consumerSubName)
outputStream.Start()
err := inputStream.Produce(ctx, &msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveCount := 0
for {
result, _ := outputStream.Consume()
if len(result.Msgs) > 0 {
msgs := result.Msgs
for _, v := range msgs {
receiveCount++
// variable v could be type of '*msgstream.TimeTickMsg', here need to check
// if type conversation is successful
if task, ok := v.(*InsertTask); ok {
fmt.Println("msg type: ", v.Type(), ", msg value: ", v, "msg tag: ", task.Tag)
}
}
}
if receiveCount >= len(msgPack.Msgs) {
break
}
}
inputStream.Close()
outputStream.Close()
}