mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
c7a49c97cc
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
98 lines
3.2 KiB
Go
98 lines
3.2 KiB
Go
package proxy
|
|
|
|
import (
|
|
"github.com/zilliztech/milvus-distributed/internal/allocator"
|
|
"github.com/zilliztech/milvus-distributed/internal/errors"
|
|
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
|
)
|
|
|
|
func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
|
hashKeys [][]int32,
|
|
segIDAssigner *allocator.SegIDAssigner,
|
|
together bool) (map[int32]*msgstream.MsgPack, error) {
|
|
|
|
result := make(map[int32]*msgstream.MsgPack)
|
|
|
|
for i, request := range tsMsgs {
|
|
if request.Type() != internalpb.MsgType_kInsert {
|
|
return nil, errors.New(string("msg's must be Insert"))
|
|
}
|
|
insertRequest, ok := request.(*msgstream.InsertMsg)
|
|
if !ok {
|
|
return nil, errors.New(string("msg's must be Insert"))
|
|
}
|
|
keys := hashKeys[i]
|
|
|
|
timestampLen := len(insertRequest.Timestamps)
|
|
rowIDLen := len(insertRequest.RowIDs)
|
|
rowDataLen := len(insertRequest.RowData)
|
|
keysLen := len(keys)
|
|
|
|
if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen {
|
|
return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal"))
|
|
}
|
|
|
|
reqID := insertRequest.ReqID
|
|
collectionName := insertRequest.CollectionName
|
|
partitionTag := insertRequest.PartitionTag
|
|
channelID := insertRequest.ChannelID
|
|
proxyID := insertRequest.ProxyID
|
|
for index, key := range keys {
|
|
ts := insertRequest.Timestamps[index]
|
|
rowID := insertRequest.RowIDs[index]
|
|
row := insertRequest.RowData[index]
|
|
_, ok := result[key]
|
|
if !ok {
|
|
msgPack := msgstream.MsgPack{}
|
|
result[key] = &msgPack
|
|
}
|
|
sliceRequest := internalpb.InsertRequest{
|
|
MsgType: internalpb.MsgType_kInsert,
|
|
ReqID: reqID,
|
|
CollectionName: collectionName,
|
|
PartitionTag: partitionTag,
|
|
SegmentID: 0, // will be assigned later if together
|
|
ChannelID: channelID,
|
|
ProxyID: proxyID,
|
|
Timestamps: []uint64{ts},
|
|
RowIDs: []int64{rowID},
|
|
RowData: []*commonpb.Blob{row},
|
|
}
|
|
insertMsg := &msgstream.InsertMsg{
|
|
InsertRequest: sliceRequest,
|
|
}
|
|
if together { // all rows with same hash value are accumulated to only one message
|
|
if len(result[key].Msgs) <= 0 {
|
|
result[key].Msgs = append(result[key].Msgs, insertMsg)
|
|
} else {
|
|
accMsgs, _ := result[key].Msgs[0].(*msgstream.InsertMsg)
|
|
accMsgs.Timestamps = append(accMsgs.Timestamps, ts)
|
|
accMsgs.RowIDs = append(accMsgs.RowIDs, rowID)
|
|
accMsgs.RowData = append(accMsgs.RowData, row)
|
|
}
|
|
} else { // every row is a message
|
|
segID, _ := segIDAssigner.GetSegmentID(collectionName, partitionTag, int32(channelID), 1)
|
|
insertMsg.SegmentID = segID
|
|
result[key].Msgs = append(result[key].Msgs, insertMsg)
|
|
}
|
|
}
|
|
}
|
|
|
|
if together {
|
|
for key := range result {
|
|
insertMsg, _ := result[key].Msgs[0].(*msgstream.InsertMsg)
|
|
rowNums := len(insertMsg.RowIDs)
|
|
collectionName := insertMsg.CollectionName
|
|
partitionTag := insertMsg.PartitionTag
|
|
channelID := insertMsg.ChannelID
|
|
segID, _ := segIDAssigner.GetSegmentID(collectionName, partitionTag, int32(channelID), uint32(rowNums))
|
|
insertMsg.SegmentID = segID
|
|
result[key].Msgs[0] = insertMsg
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|