milvus/internal/proxy/repack_func.go
dragondriver c7a49c97cc Add repack function of InsertTask in Proxy
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
2020-11-23 16:52:17 +08:00

98 lines
3.2 KiB
Go

package proxy
import (
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)
func insertRepackFunc(tsMsgs []msgstream.TsMsg,
hashKeys [][]int32,
segIDAssigner *allocator.SegIDAssigner,
together bool) (map[int32]*msgstream.MsgPack, error) {
result := make(map[int32]*msgstream.MsgPack)
for i, request := range tsMsgs {
if request.Type() != internalpb.MsgType_kInsert {
return nil, errors.New(string("msg's must be Insert"))
}
insertRequest, ok := request.(*msgstream.InsertMsg)
if !ok {
return nil, errors.New(string("msg's must be Insert"))
}
keys := hashKeys[i]
timestampLen := len(insertRequest.Timestamps)
rowIDLen := len(insertRequest.RowIDs)
rowDataLen := len(insertRequest.RowData)
keysLen := len(keys)
if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen {
return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal"))
}
reqID := insertRequest.ReqID
collectionName := insertRequest.CollectionName
partitionTag := insertRequest.PartitionTag
channelID := insertRequest.ChannelID
proxyID := insertRequest.ProxyID
for index, key := range keys {
ts := insertRequest.Timestamps[index]
rowID := insertRequest.RowIDs[index]
row := insertRequest.RowData[index]
_, ok := result[key]
if !ok {
msgPack := msgstream.MsgPack{}
result[key] = &msgPack
}
sliceRequest := internalpb.InsertRequest{
MsgType: internalpb.MsgType_kInsert,
ReqID: reqID,
CollectionName: collectionName,
PartitionTag: partitionTag,
SegmentID: 0, // will be assigned later if together
ChannelID: channelID,
ProxyID: proxyID,
Timestamps: []uint64{ts},
RowIDs: []int64{rowID},
RowData: []*commonpb.Blob{row},
}
insertMsg := &msgstream.InsertMsg{
InsertRequest: sliceRequest,
}
if together { // all rows with same hash value are accumulated to only one message
if len(result[key].Msgs) <= 0 {
result[key].Msgs = append(result[key].Msgs, insertMsg)
} else {
accMsgs, _ := result[key].Msgs[0].(*msgstream.InsertMsg)
accMsgs.Timestamps = append(accMsgs.Timestamps, ts)
accMsgs.RowIDs = append(accMsgs.RowIDs, rowID)
accMsgs.RowData = append(accMsgs.RowData, row)
}
} else { // every row is a message
segID, _ := segIDAssigner.GetSegmentID(collectionName, partitionTag, int32(channelID), 1)
insertMsg.SegmentID = segID
result[key].Msgs = append(result[key].Msgs, insertMsg)
}
}
}
if together {
for key := range result {
insertMsg, _ := result[key].Msgs[0].(*msgstream.InsertMsg)
rowNums := len(insertMsg.RowIDs)
collectionName := insertMsg.CollectionName
partitionTag := insertMsg.PartitionTag
channelID := insertMsg.ChannelID
segID, _ := segIDAssigner.GetSegmentID(collectionName, partitionTag, int32(channelID), uint32(rowNums))
insertMsg.SegmentID = segID
result[key].Msgs[0] = insertMsg
}
}
return result, nil
}