Load growing segment in query node (#11664)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2021-11-12 18:27:10 +08:00 committed by GitHub
parent 93ab5b0a8f
commit 93149c5ad9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 824 additions and 130 deletions

View File

@ -783,6 +783,12 @@ func (s *Server) loadCollectionFromRootCoord(ctx context.Context, collectionID i
// GetVChanPositions get vchannel latest postitions with provided dml channel names
func (s *Server) GetVChanPositions(channel string, collectionID UniqueID, seekFromStartPosition bool) *datapb.VchannelInfo {
segments := s.meta.GetSegmentsByChannel(channel)
log.Debug("GetSegmentsByChannel",
zap.Any("collectionID", collectionID),
zap.Any("channel", channel),
zap.Any("seekFromStartPosition", seekFromStartPosition),
zap.Any("numOfSegments", len(segments)),
)
flushed := make([]*datapb.SegmentInfo, 0)
unflushed := make([]*datapb.SegmentInfo, 0)
var seekPosition *internalpb.MsgPosition
@ -799,7 +805,7 @@ func (s *Server) GetVChanPositions(channel string, collectionID UniqueID, seekFr
continue
}
unflushed = append(unflushed, trimSegmentInfo(s.SegmentInfo))
unflushed = append(unflushed, s.SegmentInfo)
segmentPosition := s.DmlPosition
if seekFromStartPosition {

View File

@ -1281,7 +1281,7 @@ func TestGetRecoveryInfo(t *testing.T) {
assert.EqualValues(t, 0, len(resp.GetBinlogs()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.EqualValues(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
})
t.Run("test get binlogs", func(t *testing.T) {

View File

@ -520,8 +520,12 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
channels := dresp.GetVirtualChannelNames()
channelInfos := make([]*datapb.VchannelInfo, 0, len(channels))
for _, c := range channels {
channelInfo := s.GetVChanPositions(c, collectionID, true)
channelInfo := s.GetVChanPositions(c, collectionID, false)
channelInfos = append(channelInfos, channelInfo)
log.Debug("datacoord append channelInfo in GetRecoveryInfo",
zap.Any("collectionID", collectionID),
zap.Any("channelInfo", channelInfo),
)
}
resp.Binlogs = binlogs

View File

@ -110,7 +110,7 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
iData.insertIDs[task.SegmentID] = append(iData.insertIDs[task.SegmentID], task.RowIDs...)
iData.insertTimestamps[task.SegmentID] = append(iData.insertTimestamps[task.SegmentID], task.Timestamps...)
iData.insertRecords[task.SegmentID] = append(iData.insertRecords[task.SegmentID], task.RowData...)
iData.insertPKs[task.SegmentID] = iNode.getPrimaryKeys(task)
iData.insertPKs[task.SegmentID] = getPrimaryKeys(task, iNode.streamingReplica)
}
// 2. do preInsert
@ -305,14 +305,16 @@ func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID, wg *
log.Debug("Do delete done", zap.Int("len", len(deleteData.deleteIDs[segmentID])), zap.Int64("segmentID", segmentID))
}
func (iNode *insertNode) getPrimaryKeys(msg *msgstream.InsertMsg) []int64 {
// TODO: remove this function to proper file
// TODO: why not return error?
func getPrimaryKeys(msg *msgstream.InsertMsg, streamingReplica ReplicaInterface) []int64 {
if len(msg.RowIDs) != len(msg.Timestamps) || len(msg.RowIDs) != len(msg.RowData) {
log.Warn("misaligned messages detected")
return nil
}
collectionID := msg.GetCollectionID()
collection, err := iNode.streamingReplica.getCollectionByID(collectionID)
collection, err := streamingReplica.getCollectionByID(collectionID)
if err != nil {
log.Warn(err.Error())
return nil

View File

@ -25,11 +25,9 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
)
const (
@ -41,8 +39,6 @@ type historical struct {
ctx context.Context
replica ReplicaInterface
loader *segmentLoader
statsService *statsService
tSafeReplica TSafeReplicaInterface
mu sync.Mutex // guards globalSealedSegments
@ -54,19 +50,12 @@ type historical struct {
// newHistorical returns a new historical
func newHistorical(ctx context.Context,
replica ReplicaInterface,
rootCoord types.RootCoord,
indexCoord types.IndexCoord,
factory msgstream.Factory,
etcdKV *etcdkv.EtcdKV,
tSafeReplica TSafeReplicaInterface) *historical {
loader := newSegmentLoader(ctx, rootCoord, indexCoord, replica, etcdKV)
ss := newStatsService(ctx, replica, loader.indexLoader.fieldStatsChan, factory)
return &historical{
ctx: ctx,
replica: replica,
loader: loader,
statsService: ss,
globalSealedSegments: make(map[UniqueID]*querypb.SegmentInfo),
etcdKV: etcdKV,
tSafeReplica: tSafeReplica,
@ -74,13 +63,10 @@ func newHistorical(ctx context.Context,
}
func (h *historical) start() {
go h.statsService.start()
go h.watchGlobalSegmentMeta()
}
func (h *historical) close() {
h.statsService.close()
// free collectionReplica
h.replica.freeAll()
}

View File

@ -25,29 +25,31 @@ func TestIndexLoader_setIndexInfo(t *testing.T) {
defer cancel()
t.Run("test setIndexInfo", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
historical.loader.indexLoader.rootCoord = newMockRootCoord()
historical.loader.indexLoader.indexCoord = newMockIndexCoord()
loader.indexLoader.rootCoord = newMockRootCoord()
loader.indexLoader.indexCoord = newMockIndexCoord()
err = historical.loader.indexLoader.setIndexInfo(defaultCollectionID, segment, rowIDFieldID)
err = loader.indexLoader.setIndexInfo(defaultCollectionID, segment, rowIDFieldID)
assert.NoError(t, err)
})
t.Run("test nil root and index", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
err = historical.loader.indexLoader.setIndexInfo(defaultCollectionID, segment, rowIDFieldID)
err = loader.indexLoader.setIndexInfo(defaultCollectionID, segment, rowIDFieldID)
assert.NoError(t, err)
})
}
@ -57,23 +59,25 @@ func TestIndexLoader_getIndexBinlog(t *testing.T) {
defer cancel()
t.Run("test getIndexBinlog", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
paths, err := generateIndex(defaultSegmentID)
assert.NoError(t, err)
_, _, _, err = historical.loader.indexLoader.getIndexBinlog(paths)
_, _, _, err = loader.indexLoader.getIndexBinlog(paths)
assert.NoError(t, err)
})
t.Run("test invalid path", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
_, _, _, err = historical.loader.indexLoader.getIndexBinlog([]string{""})
_, _, _, err = loader.indexLoader.getIndexBinlog([]string{""})
assert.Error(t, err)
})
}
@ -82,9 +86,10 @@ func TestIndexLoader_printIndexParams(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
indexKV := []*commonpb.KeyValuePair{
{
@ -92,7 +97,7 @@ func TestIndexLoader_printIndexParams(t *testing.T) {
Value: "test-value-0",
},
}
historical.loader.indexLoader.printIndexParams(indexKV)
loader.indexLoader.printIndexParams(indexKV)
}
func TestIndexLoader_loadIndex(t *testing.T) {
@ -100,38 +105,40 @@ func TestIndexLoader_loadIndex(t *testing.T) {
defer cancel()
t.Run("test loadIndex", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
historical.loader.indexLoader.rootCoord = newMockRootCoord()
historical.loader.indexLoader.indexCoord = newMockIndexCoord()
loader.indexLoader.rootCoord = newMockRootCoord()
loader.indexLoader.indexCoord = newMockIndexCoord()
err = historical.loader.indexLoader.setIndexInfo(defaultCollectionID, segment, simpleVecField.id)
err = loader.indexLoader.setIndexInfo(defaultCollectionID, segment, simpleVecField.id)
assert.NoError(t, err)
err = historical.loader.indexLoader.loadIndex(segment, simpleVecField.id)
err = loader.indexLoader.loadIndex(segment, simpleVecField.id)
assert.NoError(t, err)
})
t.Run("test set indexinfo with empty indexFilePath", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
historical.loader.indexLoader.rootCoord = newMockRootCoord()
loader.indexLoader.rootCoord = newMockRootCoord()
ic := newMockIndexCoord()
ic.idxFileInfo.IndexFilePaths = []string{}
historical.loader.indexLoader.indexCoord = ic
loader.indexLoader.indexCoord = ic
err = historical.loader.indexLoader.setIndexInfo(defaultCollectionID, segment, simpleVecField.id)
err = loader.indexLoader.setIndexInfo(defaultCollectionID, segment, simpleVecField.id)
assert.Error(t, err)
})
@ -151,22 +158,23 @@ func TestIndexLoader_loadIndex(t *testing.T) {
//})
t.Run("test checkIndexReady failed", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
historical.loader.indexLoader.rootCoord = newMockRootCoord()
historical.loader.indexLoader.indexCoord = newMockIndexCoord()
loader.indexLoader.rootCoord = newMockRootCoord()
loader.indexLoader.indexCoord = newMockIndexCoord()
err = historical.loader.indexLoader.setIndexInfo(defaultCollectionID, segment, rowIDFieldID)
err = loader.indexLoader.setIndexInfo(defaultCollectionID, segment, rowIDFieldID)
assert.NoError(t, err)
segment.indexInfos[rowIDFieldID].setReadyLoad(false)
err = historical.loader.indexLoader.loadIndex(segment, rowIDFieldID)
err = loader.indexLoader.loadIndex(segment, rowIDFieldID)
assert.Error(t, err)
})
}

View File

@ -882,11 +882,15 @@ func genSimpleReplica() (ReplicaInterface, error) {
return r, err
}
func genSimpleHistorical(ctx context.Context, tSafeReplica TSafeReplicaInterface) (*historical, error) {
fac, err := genFactory()
func genSimpleSegmentLoader(ctx context.Context, historicalReplica ReplicaInterface, streamingReplica ReplicaInterface) (*segmentLoader, error) {
kv, err := genEtcdKV()
if err != nil {
return nil, err
}
return newSegmentLoader(ctx, newMockRootCoord(), newMockIndexCoord(), historicalReplica, streamingReplica, kv), nil
}
func genSimpleHistorical(ctx context.Context, tSafeReplica TSafeReplicaInterface) (*historical, error) {
kv, err := genEtcdKV()
if err != nil {
return nil, err
@ -895,7 +899,7 @@ func genSimpleHistorical(ctx context.Context, tSafeReplica TSafeReplicaInterface
if err != nil {
return nil, err
}
h := newHistorical(ctx, replica, newMockRootCoord(), newMockIndexCoord(), fac, kv, tSafeReplica)
h := newHistorical(ctx, replica, kv, tSafeReplica)
r, err := genSimpleReplica()
if err != nil {
return nil, err
@ -909,7 +913,6 @@ func genSimpleHistorical(ctx context.Context, tSafeReplica TSafeReplicaInterface
return nil, err
}
h.replica = r
h.loader.historicalReplica = r
col, err := h.replica.getCollectionByID(defaultCollectionID)
if err != nil {
return nil, err
@ -1326,6 +1329,12 @@ func genSimpleQueryNode(ctx context.Context) (*QueryNode, error) {
node.streaming = streaming
node.historical = historical
loader, err := genSimpleSegmentLoader(node.queryNodeLoopCtx, historical.replica, streaming.replica)
if err != nil {
return nil, err
}
node.loader = loader
// start task scheduler
go node.scheduler.Start()

View File

@ -136,7 +136,7 @@ func TestQueryCollection_withoutVChannel(t *testing.T) {
historicalReplica := newCollectionReplica(etcdKV)
tsReplica := newTSafeReplica()
streamingReplica := newCollectionReplica(etcdKV)
historical := newHistorical(context.Background(), historicalReplica, nil, nil, factory, etcdKV, tsReplica)
historical := newHistorical(context.Background(), historicalReplica, etcdKV, tsReplica)
//add a segment to historical data
err = historical.replica.addCollection(0, schema)

View File

@ -87,6 +87,10 @@ type QueryNode struct {
// internal services
queryService *queryService
statsService *statsService
// segment loader
loader *segmentLoader
// clients
rootCoord types.RootCoord
@ -192,9 +196,6 @@ func (node *QueryNode) Init() error {
node.historical = newHistorical(node.queryNodeLoopCtx,
historicalReplica,
node.rootCoord,
node.indexCoord,
node.msFactory,
node.etcdKV,
node.tSafeReplica,
)
@ -205,17 +206,33 @@ func (node *QueryNode) Init() error {
node.tSafeReplica,
)
node.loader = newSegmentLoader(node.queryNodeLoopCtx,
node.rootCoord,
node.indexCoord,
node.historical.replica,
node.streaming.replica,
node.etcdKV)
node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, node.loader.indexLoader.fieldStatsChan, node.msFactory)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, streamingReplica, historicalReplica, node.tSafeReplica, node.msFactory)
node.InitSegcore()
if node.rootCoord == nil {
log.Error("null root coordinator detected")
initError = errors.New("null root coordinator detected when queryNode init")
return
}
if node.indexCoord == nil {
log.Error("null index coordinator detected")
initError = errors.New("null index coordinator detected when queryNode init")
return
}
log.Debug("query node init successfully",
zap.Any("queryNodeID", Params.QueryNodeID),
zap.Any("IP", Params.QueryNodeIP),
zap.Any("Port", Params.QueryNodePort),
)
})
return initError
@ -246,11 +263,17 @@ func (node *QueryNode) Start() error {
// start services
go node.historical.start()
go node.watchChangeInfo()
go node.statsService.start()
Params.CreatedTime = time.Now()
Params.UpdatedTime = time.Now()
node.UpdateStateCode(internalpb.StateCode_Healthy)
log.Debug("query node start successfully",
zap.Any("queryNodeID", Params.QueryNodeID),
zap.Any("IP", Params.QueryNodeIP),
zap.Any("Port", Params.QueryNodePort),
)
return nil
}
@ -272,6 +295,9 @@ func (node *QueryNode) Stop() error {
if node.queryService != nil {
node.queryService.close()
}
if node.statsService != nil {
node.statsService.close()
}
return nil
}

View File

@ -195,9 +195,11 @@ func newQueryNodeMock() *QueryNode {
tsReplica := newTSafeReplica()
streamingReplica := newCollectionReplica(etcdKV)
historicalReplica := newCollectionReplica(etcdKV)
svr.historical = newHistorical(svr.queryNodeLoopCtx, historicalReplica, nil, nil, svr.msFactory, etcdKV, tsReplica)
svr.historical = newHistorical(svr.queryNodeLoopCtx, historicalReplica, etcdKV, tsReplica)
svr.streaming = newStreaming(ctx, streamingReplica, msFactory, etcdKV, tsReplica)
svr.dataSyncService = newDataSyncService(ctx, svr.streaming.replica, svr.historical.replica, tsReplica, msFactory)
svr.statsService = newStatsService(ctx, svr.historical.replica, nil, msFactory)
svr.loader = newSegmentLoader(ctx, nil, nil, svr.historical.replica, svr.streaming.replica, etcdKV)
svr.etcdKV = etcdKV
return svr
@ -275,7 +277,7 @@ func TestQueryNode_init(t *testing.T) {
assert.NoError(t, err)
err = node.Init()
assert.NoError(t, err)
assert.Error(t, err)
}
func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, error) {

View File

@ -25,7 +25,10 @@ import (
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
minioKV "github.com/milvus-io/milvus/internal/kv/minio"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
@ -39,6 +42,7 @@ const (
// segmentLoader is only responsible for loading the field data from binlog
type segmentLoader struct {
historicalReplica ReplicaInterface
streamingReplica ReplicaInterface
dataCoord types.DataCoord
@ -48,12 +52,18 @@ type segmentLoader struct {
indexLoader *indexLoader
}
func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest) error {
func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segmentType segmentType) error {
// no segment needs to load, return
if len(req.Infos) == 0 {
return nil
}
log.Debug("segmentLoader start loading...",
zap.Any("collectionID", req.CollectionID),
zap.Any("numOfSegments", len(req.Infos)),
zap.Any("loadType", segmentType),
)
newSegments := make(map[UniqueID]*Segment)
segmentGC := func() {
for _, s := range newSegments {
@ -76,7 +86,7 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest) error
segmentGC()
return err
}
segment := newSegment(collection, segmentID, partitionID, collectionID, "", segmentTypeSealed, true)
segment := newSegment(collection, segmentID, partitionID, collectionID, "", segmentType, true)
newSegments[segmentID] = segment
fieldBinlog, indexedFieldID, err := loader.getFieldAndIndexInfo(segment, info)
if err != nil {
@ -110,7 +120,8 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest) error
err = loader.loadSegmentInternal(newSegments[segmentID],
segmentFieldBinLogs[segmentID],
segmentIndexedFieldIDs[segmentID],
info)
info,
segmentType)
if err != nil {
segmentGC()
return err
@ -118,6 +129,16 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest) error
}
// set segments
switch segmentType {
case segmentTypeGrowing:
for _, s := range newSegments {
err := loader.streamingReplica.setSegment(s)
if err != nil {
segmentGC()
return err
}
}
case segmentTypeSealed:
for _, s := range newSegments {
err := loader.historicalReplica.setSegment(s)
if err != nil {
@ -125,15 +146,27 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest) error
return err
}
}
default:
err := errors.New(fmt.Sprintln("illegal segment type when load segment, collectionID = ", req.CollectionID))
segmentGC()
return err
}
return nil
}
func (loader *segmentLoader) loadSegmentInternal(segment *Segment,
fieldBinLogs []*datapb.FieldBinlog,
indexFieldIDs []FieldID,
segmentLoadInfo *querypb.SegmentLoadInfo) error {
log.Debug("loading insert...")
err := loader.loadSegmentFieldsData(segment, fieldBinLogs)
segmentLoadInfo *querypb.SegmentLoadInfo,
segmentType segmentType) error {
log.Debug("loading insert...",
zap.Any("collectionID", segment.collectionID),
zap.Any("segmentID", segment.ID()),
zap.Any("segmentType", segmentType),
zap.Any("fieldBinLogs", fieldBinLogs),
zap.Any("indexFieldIDs", indexFieldIDs),
)
err := loader.loadSegmentFieldsData(segment, fieldBinLogs, segmentType)
if err != nil {
return err
}
@ -190,7 +223,7 @@ func (loader *segmentLoader) filterFieldBinlogs(fieldBinlogs []*datapb.FieldBinl
return result
}
func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, fieldBinlogs []*datapb.FieldBinlog) error {
func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, fieldBinlogs []*datapb.FieldBinlog, segmentType segmentType) error {
iCodec := storage.InsertCodec{}
defer func() {
err := iCodec.Close()
@ -226,6 +259,79 @@ func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, fieldBinlog
return err
}
for i := range insertData.Infos {
log.Debug("segmentLoader deserialize fields",
zap.Any("collectionID", segment.collectionID),
zap.Any("segmentID", segment.ID()),
zap.Any("numRows", insertData.Infos[i].Length),
)
}
switch segmentType {
case segmentTypeGrowing:
timestamps, ids, rowData, err := storage.TransferColumnBasedInsertDataToRowBased(insertData)
if err != nil {
return err
}
return loader.loadGrowingSegments(segment, ids, timestamps, rowData)
case segmentTypeSealed:
return loader.loadSealedSegments(segment, insertData)
default:
err := errors.New(fmt.Sprintln("illegal segment type when load segment, collectionID = ", segment.collectionID))
return err
}
}
func (loader *segmentLoader) loadGrowingSegments(segment *Segment,
ids []UniqueID,
timestamps []Timestamp,
records []*commonpb.Blob) error {
if len(ids) != len(timestamps) || len(timestamps) != len(records) {
return errors.New(fmt.Sprintln("illegal insert data when load segment, collectionID = ", segment.collectionID))
}
log.Debug("start load growing segments...",
zap.Any("collectionID", segment.collectionID),
zap.Any("segmentID", segment.ID()),
zap.Any("numRows", len(ids)),
)
// 1. do preInsert
var numOfRecords = len(ids)
offset, err := segment.segmentPreInsert(numOfRecords)
if err != nil {
return err
}
log.Debug("insertNode operator", zap.Int("insert size", numOfRecords), zap.Int64("insert offset", offset), zap.Int64("segment id", segment.ID()))
// 2. update bloom filter
tmpInsertMsg := &msgstream.InsertMsg{
InsertRequest: internalpb.InsertRequest{
CollectionID: segment.collectionID,
Timestamps: timestamps,
RowIDs: ids,
RowData: records,
},
}
pks := getPrimaryKeys(tmpInsertMsg, loader.streamingReplica)
segment.updateBloomFilter(pks)
// 3. do insert
err = segment.segmentInsert(offset, &ids, &timestamps, &records)
if err != nil {
return err
}
log.Debug("Do insert done in segment loader", zap.Int("len", numOfRecords), zap.Int64("segmentID", segment.ID()))
return nil
}
func (loader *segmentLoader) loadSealedSegments(segment *Segment, insertData *storage.InsertData) error {
log.Debug("start load sealed segments...",
zap.Any("collectionID", segment.collectionID),
zap.Any("segmentID", segment.ID()),
zap.Any("numFields", len(insertData.Data)),
)
for fieldID, value := range insertData.Data {
var numRows []int64
var data interface{}
@ -270,13 +376,12 @@ func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, fieldBinlog
for _, numRow := range numRows {
totalNumRows += numRow
}
err = segment.segmentLoadFieldData(fieldID, int(totalNumRows), data)
err := segment.segmentLoadFieldData(fieldID, int(totalNumRows), data)
if err != nil {
// TODO: return or continue?
return err
}
}
return nil
}
@ -460,7 +565,12 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentSize
return nil
}
func newSegmentLoader(ctx context.Context, rootCoord types.RootCoord, indexCoord types.IndexCoord, replica ReplicaInterface, etcdKV *etcdkv.EtcdKV) *segmentLoader {
func newSegmentLoader(ctx context.Context,
rootCoord types.RootCoord,
indexCoord types.IndexCoord,
historicalReplica ReplicaInterface,
streamingReplica ReplicaInterface,
etcdKV *etcdkv.EtcdKV) *segmentLoader {
option := &minioKV.Option{
Address: Params.MinioEndPoint,
AccessKeyID: Params.MinioAccessKeyID,
@ -475,9 +585,10 @@ func newSegmentLoader(ctx context.Context, rootCoord types.RootCoord, indexCoord
panic(err)
}
iLoader := newIndexLoader(ctx, rootCoord, indexCoord, replica)
iLoader := newIndexLoader(ctx, rootCoord, indexCoord, historicalReplica)
return &segmentLoader{
historicalReplica: replica,
historicalReplica: historicalReplica,
streamingReplica: streamingReplica,
minioKV: client,
etcdKV: etcdKV,

View File

@ -28,22 +28,19 @@ func TestSegmentLoader_loadSegment(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
kv, err := genEtcdKV()
assert.NoError(t, err)
schema := genSimpleInsertDataSchema()
fieldBinlog, err := saveSimpleBinLog(ctx)
assert.NoError(t, err)
t.Run("test load segment", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
err = historical.replica.removeSegment(defaultSegmentID)
err = node.historical.replica.removeSegment(defaultSegmentID)
assert.NoError(t, err)
loader := newSegmentLoader(ctx, nil, nil, historical.replica, kv)
loader := node.loader
assert.NotNil(t, loader)
req := &querypb.LoadSegmentsRequest{
@ -64,18 +61,18 @@ func TestSegmentLoader_loadSegment(t *testing.T) {
},
}
err = loader.loadSegment(req)
err = loader.loadSegment(req, segmentTypeSealed)
assert.NoError(t, err)
})
t.Run("test set segment error", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
t.Run("test set segment error due to without partition", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
err = historical.replica.removePartition(defaultPartitionID)
err = node.historical.replica.removePartition(defaultPartitionID)
assert.NoError(t, err)
loader := newSegmentLoader(ctx, nil, nil, historical.replica, kv)
loader := node.loader
assert.NotNil(t, loader)
req := &querypb.LoadSegmentsRequest{
@ -96,7 +93,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) {
},
}
err = loader.loadSegment(req)
err = loader.loadSegment(req, segmentTypeSealed)
assert.Error(t, err)
})
}
@ -106,9 +103,10 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) {
defer cancel()
runLoadSegmentFieldData := func(dataType schemapb.DataType) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
fieldUID := genConstantField(uidField)
fieldTimestamp := genConstantField(timestampField)
@ -149,7 +147,7 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) {
schema.Fields = append(schema.Fields, constField)
err = historical.replica.removeSegment(defaultSegmentID)
err = loader.historicalReplica.removeSegment(defaultSegmentID)
assert.NoError(t, err)
col := newCollection(defaultCollectionID, schema)
@ -168,7 +166,7 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) {
binlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema)
assert.NoError(t, err)
err = historical.loader.loadSegmentFieldsData(segment, binlog)
err = loader.loadSegmentFieldsData(segment, binlog, segmentTypeSealed)
assert.NoError(t, err)
}
@ -188,11 +186,12 @@ func TestSegmentLoader_invalid(t *testing.T) {
defer cancel()
t.Run("test no collection", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
err = historical.replica.removeCollection(defaultCollectionID)
err = node.historical.replica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
req := &querypb.LoadSegmentsRequest{
@ -211,7 +210,7 @@ func TestSegmentLoader_invalid(t *testing.T) {
},
}
err = historical.loader.loadSegment(req)
err = loader.loadSegment(req, segmentTypeSealed)
assert.Error(t, err)
})
@ -251,11 +250,12 @@ func TestSegmentLoader_invalid(t *testing.T) {
//})
t.Run("test no vec field 2", func(t *testing.T) {
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
err = historical.replica.removeCollection(defaultCollectionID)
err = node.historical.replica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
schema := &schemapb.CollectionSchema{
@ -268,7 +268,7 @@ func TestSegmentLoader_invalid(t *testing.T) {
}),
},
}
err = historical.loader.historicalReplica.addCollection(defaultCollectionID, schema)
err = loader.historicalReplica.addCollection(defaultCollectionID, schema)
assert.NoError(t, err)
req := &querypb.LoadSegmentsRequest{
@ -287,7 +287,7 @@ func TestSegmentLoader_invalid(t *testing.T) {
},
},
}
err = historical.loader.loadSegment(req)
err = loader.loadSegment(req, segmentTypeSealed)
assert.Error(t, err)
})
}
@ -296,11 +296,12 @@ func TestSegmentLoader_checkSegmentSize(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
err = historical.loader.checkSegmentSize(defaultSegmentID, map[UniqueID]int64{defaultSegmentID: 1024})
err = loader.checkSegmentSize(defaultSegmentID, map[UniqueID]int64{defaultSegmentID: 1024})
assert.NoError(t, err)
//totalMem, err := getTotalMemory()
@ -313,11 +314,12 @@ func TestSegmentLoader_estimateSegmentSize(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tSafe := newTSafeReplica()
historical, err := genSimpleHistorical(ctx, tSafe)
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
seg, err := historical.replica.getSegmentByID(defaultSegmentID)
seg, err := node.historical.replica.getSegmentByID(defaultSegmentID)
assert.NoError(t, err)
binlog := []*datapb.FieldBinlog{
@ -327,13 +329,13 @@ func TestSegmentLoader_estimateSegmentSize(t *testing.T) {
},
}
_, err = historical.loader.estimateSegmentSize(seg, binlog, nil)
_, err = loader.estimateSegmentSize(seg, binlog, nil)
assert.Error(t, err)
binlog, err = saveSimpleBinLog(ctx)
assert.NoError(t, err)
_, err = historical.loader.estimateSegmentSize(seg, binlog, nil)
_, err = loader.estimateSegmentSize(seg, binlog, nil)
assert.NoError(t, err)
indexPath, err := generateIndex(defaultSegmentID)
@ -345,12 +347,133 @@ func TestSegmentLoader_estimateSegmentSize(t *testing.T) {
err = seg.setIndexPaths(simpleVecField.id, indexPath)
assert.NoError(t, err)
_, err = historical.loader.estimateSegmentSize(seg, binlog, []FieldID{simpleVecField.id})
_, err = loader.estimateSegmentSize(seg, binlog, []FieldID{simpleVecField.id})
assert.NoError(t, err)
err = seg.setIndexPaths(simpleVecField.id, []string{"&*^*(^*(&*%^&*^(&"})
assert.NoError(t, err)
_, err = historical.loader.estimateSegmentSize(seg, binlog, []FieldID{simpleVecField.id})
_, err = loader.estimateSegmentSize(seg, binlog, []FieldID{simpleVecField.id})
assert.Error(t, err)
}
func TestSegmentLoader_testLoadGrowing(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("test load growing segments", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
collection, err := node.historical.replica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
segment := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultVChannel, segmentTypeGrowing, true)
insertMsg, err := genSimpleInsertMsg()
assert.NoError(t, err)
err = loader.loadGrowingSegments(segment, insertMsg.RowIDs, insertMsg.Timestamps, insertMsg.RowData)
assert.NoError(t, err)
})
t.Run("test invalid insert data", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
collection, err := node.historical.replica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
segment := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultVChannel, segmentTypeGrowing, true)
insertMsg, err := genSimpleInsertMsg()
assert.NoError(t, err)
insertMsg.RowData = nil
err = loader.loadGrowingSegments(segment, insertMsg.RowIDs, insertMsg.Timestamps, insertMsg.RowData)
assert.Error(t, err)
})
}
func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := genSimpleInsertDataSchema()
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: UniqueID(102),
Name: "pk",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
})
fieldBinlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema)
assert.NoError(t, err)
t.Run("test load growing and sealed segments", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
loader := node.loader
assert.NotNil(t, loader)
loader.indexLoader.indexCoord = nil
loader.indexLoader.rootCoord = nil
segmentID1 := UniqueID(100)
req1 := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
DstNodeID: 0,
Schema: schema,
LoadCondition: querypb.TriggerCondition_grpcRequest,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: segmentID1,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
BinlogPaths: fieldBinlog,
},
},
}
err = loader.loadSegment(req1, segmentTypeSealed)
assert.NoError(t, err)
segmentID2 := UniqueID(101)
req2 := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
DstNodeID: 0,
Schema: schema,
LoadCondition: querypb.TriggerCondition_grpcRequest,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: segmentID2,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
BinlogPaths: fieldBinlog,
},
},
}
err = loader.loadSegment(req2, segmentTypeGrowing)
assert.NoError(t, err)
segment1, err := loader.historicalReplica.getSegmentByID(segmentID1)
assert.NoError(t, err)
segment2, err := loader.streamingReplica.getSegmentByID(segmentID2)
assert.NoError(t, err)
assert.Equal(t, segment1.getRowCount(), segment2.getRowCount())
})
}

View File

@ -12,6 +12,7 @@
package querynode
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/msgstream"
@ -29,15 +30,18 @@ func TestStatsService_start(t *testing.T) {
"ReceiveBufSize": 1024,
"PulsarBufSize": 1024}
msFactory.SetParams(m)
node.historical.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, nil, msFactory)
node.historical.statsService.start()
node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, node.loader.indexLoader.fieldStatsChan, msFactory)
node.statsService.start()
node.Stop()
}
//NOTE: start pulsar before test
func TestSegmentManagement_sendSegmentStatistic(t *testing.T) {
node := newQueryNodeMock()
initTestMeta(t, node, 0, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
const receiveBufSize = 1024
// start pulsar
@ -48,7 +52,7 @@ func TestSegmentManagement_sendSegmentStatistic(t *testing.T) {
"receiveBufSize": receiveBufSize,
"pulsarAddress": Params.PulsarAddress,
"pulsarBufSize": 1024}
err := msFactory.SetParams(m)
err = msFactory.SetParams(m)
assert.Nil(t, err)
statsStream, err := msFactory.NewMsgStream(node.queryNodeLoopCtx)
@ -57,11 +61,11 @@ func TestSegmentManagement_sendSegmentStatistic(t *testing.T) {
var statsMsgStream msgstream.MsgStream = statsStream
node.historical.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, nil, msFactory)
node.historical.statsService.statsStream = statsMsgStream
node.historical.statsService.statsStream.Start()
node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, node.loader.indexLoader.fieldStatsChan, msFactory)
node.statsService.statsStream = statsMsgStream
node.statsService.statsStream.Start()
// send stats
node.historical.statsService.publicStatistic(nil)
node.statsService.publicStatistic(nil)
node.Stop()
}

View File

@ -301,6 +301,41 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) error {
zap.Any("collectionID", collectionID),
zap.Any("toSeekChannels", toSeekChannels))
// load growing segments
unFlushedSegments := make([]*queryPb.SegmentLoadInfo, 0)
unFlushedSegmentIDs := make([]UniqueID, 0)
for _, info := range w.req.Infos {
for _, ufInfo := range info.UnflushedSegments {
unFlushedSegments = append(unFlushedSegments, &queryPb.SegmentLoadInfo{
SegmentID: ufInfo.ID,
PartitionID: ufInfo.PartitionID,
CollectionID: ufInfo.CollectionID,
BinlogPaths: ufInfo.Binlogs,
NumOfRows: ufInfo.NumOfRows,
Statslogs: ufInfo.Statslogs,
Deltalogs: ufInfo.Deltalogs,
})
unFlushedSegmentIDs = append(unFlushedSegmentIDs, ufInfo.ID)
}
}
req := &queryPb.LoadSegmentsRequest{
Infos: unFlushedSegments,
CollectionID: collectionID,
Schema: w.req.Schema,
}
log.Debug("loading growing segments in WatchDmChannels...",
zap.Any("collectionID", collectionID),
zap.Any("unFlushedSegmentIDs", unFlushedSegmentIDs),
)
err = w.node.loader.loadSegment(req, segmentTypeGrowing)
if err != nil {
return err
}
log.Debug("load growing segments done in WatchDmChannels",
zap.Any("collectionID", collectionID),
zap.Any("unFlushedSegmentIDs", unFlushedSegmentIDs),
)
// start flow graphs
if loadPartition {
err = w.node.dataSyncService.startPartitionFlowGraph(partitionID, vChannels)
@ -544,7 +579,7 @@ func (l *loadSegmentsTask) Execute(ctx context.Context) error {
}
}
err = l.node.historical.loader.loadSegment(l.req)
err = l.node.loader.loadSegment(l.req, segmentTypeSealed)
if err != nil {
log.Warn(err.Error())
return err

View File

@ -95,6 +95,9 @@ func (b Blob) GetValue() []byte {
type FieldData interface {
Length() int
Get(i int) interface{}
GetMemorySize() int
RowNum() int
GetRow(i int) interface{}
}
type BoolFieldData struct {
@ -162,6 +165,32 @@ func (data *StringFieldData) Get(i int) interface{} { return data.Data[i]
func (data *BinaryVectorFieldData) Get(i int) interface{} { return data.Data[i] }
func (data *FloatVectorFieldData) Get(i int) interface{} { return data.Data[i] }
func (data *BoolFieldData) RowNum() int { return len(data.Data) }
func (data *Int8FieldData) RowNum() int { return len(data.Data) }
func (data *Int16FieldData) RowNum() int { return len(data.Data) }
func (data *Int32FieldData) RowNum() int { return len(data.Data) }
func (data *Int64FieldData) RowNum() int { return len(data.Data) }
func (data *FloatFieldData) RowNum() int { return len(data.Data) }
func (data *DoubleFieldData) RowNum() int { return len(data.Data) }
func (data *StringFieldData) RowNum() int { return len(data.Data) }
func (data *BinaryVectorFieldData) RowNum() int { return len(data.Data) * 8 / data.Dim }
func (data *FloatVectorFieldData) RowNum() int { return len(data.Data) / data.Dim }
func (data *BoolFieldData) GetRow(i int) interface{} { return data.Data[i] }
func (data *Int8FieldData) GetRow(i int) interface{} { return data.Data[i] }
func (data *Int16FieldData) GetRow(i int) interface{} { return data.Data[i] }
func (data *Int32FieldData) GetRow(i int) interface{} { return data.Data[i] }
func (data *Int64FieldData) GetRow(i int) interface{} { return data.Data[i] }
func (data *FloatFieldData) GetRow(i int) interface{} { return data.Data[i] }
func (data *DoubleFieldData) GetRow(i int) interface{} { return data.Data[i] }
func (data *StringFieldData) GetRow(i int) interface{} { return data.Data[i] }
func (data *BinaryVectorFieldData) GetRow(i int) interface{} {
return data.Data[i*data.Dim/8 : (i+1)*data.Dim/8]
}
func (data *FloatVectorFieldData) GetRow(i int) interface{} {
return data.Data[i*data.Dim : (i+1)*data.Dim]
}
// why not binary.Size(data) directly? binary.Size(data) return -1
// binary.Size returns how many bytes Write would generate to encode the value v, which
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.

View File

@ -14,9 +14,15 @@ package storage
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"sort"
"strconv"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/kv"
)
@ -92,3 +98,139 @@ func EstimateMemorySize(kv kv.DataKV, key string) (int64, error) {
return total, nil
}
//////////////////////////////////////////////////////////////////////////////////////////////////
func checkTsField(data *InsertData) bool {
tsData, ok := data.Data[common.TimeStampField]
if !ok {
return false
}
_, ok = tsData.(*Int64FieldData)
return ok
}
func checkRowIDField(data *InsertData) bool {
rowIDData, ok := data.Data[common.RowIDField]
if !ok {
return false
}
_, ok = rowIDData.(*Int64FieldData)
return ok
}
func checkNumRows(fieldDatas ...FieldData) bool {
if len(fieldDatas) <= 0 {
return true
}
numRows := fieldDatas[0].RowNum()
for i := 1; i < len(fieldDatas); i++ {
if numRows != fieldDatas[i].RowNum() {
return false
}
}
return true
}
type fieldDataList struct {
IDs []FieldID
datas []FieldData
}
func (ls fieldDataList) Len() int {
return len(ls.IDs)
}
func (ls fieldDataList) Less(i, j int) bool {
return ls.IDs[i] < ls.IDs[j]
}
func (ls fieldDataList) Swap(i, j int) {
ls.IDs[i], ls.IDs[j] = ls.IDs[j], ls.IDs[i]
ls.datas[i], ls.datas[j] = ls.datas[j], ls.datas[i]
}
func sortFieldDataList(ls fieldDataList) {
sort.Sort(ls)
}
// TransferColumnBasedInsertDataToRowBased transfer column-based insert data to row-based rows.
// Note:
// - ts column must exist in insert data;
// - row id column must exist in insert data;
// - the row num of all column must be equal;
// - num_rows = len(RowData), a row will be assembled into the value of blob with field id order;
func TransferColumnBasedInsertDataToRowBased(data *InsertData) (
Timestamps []uint64,
RowIDs []int64,
RowData []*commonpb.Blob,
err error,
) {
if !checkTsField(data) {
return nil, nil, nil,
errors.New("cannot get timestamps from insert data")
}
if !checkRowIDField(data) {
return nil, nil, nil,
errors.New("cannot get row ids from insert data")
}
tss := data.Data[common.TimeStampField].(*Int64FieldData)
rowIds := data.Data[common.RowIDField].(*Int64FieldData)
ls := fieldDataList{
IDs: make([]FieldID, 0),
datas: make([]FieldData, 0),
}
for fieldID := range data.Data {
if fieldID == common.TimeStampField || fieldID == common.RowIDField {
continue
}
ls.IDs = append(ls.IDs, fieldID)
ls.datas = append(ls.datas, data.Data[fieldID])
}
// checkNumRows(tss, rowIds, ls.datas...) // don't work
all := []FieldData{tss, rowIds}
all = append(all, ls.datas...)
if !checkNumRows(all...) {
return nil, nil, nil,
errors.New("columns of insert data have different length")
}
sortFieldDataList(ls)
numRows := tss.RowNum()
rows := make([]*commonpb.Blob, numRows)
for i := 0; i < numRows; i++ {
blob := &commonpb.Blob{
Value: make([]byte, 0),
}
var buffer bytes.Buffer
for j := 0; j < ls.Len(); j++ {
d := ls.datas[j].GetRow(i)
err := binary.Write(&buffer, common.Endian, d)
if err != nil {
return nil, nil, nil,
fmt.Errorf("failed to get binary row, err: %v", err)
}
}
blob.Value = buffer.Bytes()
rows[i] = blob
}
utss := make([]uint64, tss.RowNum())
for i := 0; i < tss.RowNum(); i++ {
utss[i] = uint64(tss.Data[i])
}
return utss, rowIds.Data, rows, nil
}

View File

@ -377,3 +377,210 @@ func TestEstimateMemorySize_cannot_convert_original_size_to_int(t *testing.T) {
_, err := EstimateMemorySize(mockKV, key)
assert.Error(t, err)
}
//////////////////////////////////////////////////////////////////////////////////////////////////
func TestCheckTsField(t *testing.T) {
data := &InsertData{
Data: make(map[FieldID]FieldData),
}
assert.False(t, checkTsField(data))
data.Data[common.TimeStampField] = &BoolFieldData{}
assert.False(t, checkTsField(data))
data.Data[common.TimeStampField] = &Int64FieldData{}
assert.True(t, checkTsField(data))
}
func TestCheckRowIDField(t *testing.T) {
data := &InsertData{
Data: make(map[FieldID]FieldData),
}
assert.False(t, checkRowIDField(data))
data.Data[common.RowIDField] = &BoolFieldData{}
assert.False(t, checkRowIDField(data))
data.Data[common.RowIDField] = &Int64FieldData{}
assert.True(t, checkRowIDField(data))
}
func TestCheckNumRows(t *testing.T) {
assert.True(t, checkNumRows())
f1 := &Int64FieldData{
NumRows: nil,
Data: []int64{1, 2, 3},
}
f2 := &Int64FieldData{
NumRows: nil,
Data: []int64{1, 2, 3},
}
f3 := &Int64FieldData{
NumRows: nil,
Data: []int64{1, 2, 3, 4},
}
assert.True(t, checkNumRows(f1, f2))
assert.False(t, checkNumRows(f1, f3))
assert.False(t, checkNumRows(f2, f3))
assert.False(t, checkNumRows(f1, f2, f3))
}
func TestSortFieldDataList(t *testing.T) {
f1 := &Int16FieldData{
NumRows: nil,
Data: []int16{1, 2, 3},
}
f2 := &Int32FieldData{
NumRows: nil,
Data: []int32{4, 5, 6},
}
f3 := &Int64FieldData{
NumRows: nil,
Data: []int64{7, 8, 9},
}
ls := fieldDataList{
IDs: []FieldID{1, 3, 2},
datas: []FieldData{f1, f3, f2},
}
assert.Equal(t, 3, ls.Len())
sortFieldDataList(ls)
assert.ElementsMatch(t, []FieldID{1, 2, 3}, ls.IDs)
assert.ElementsMatch(t, []FieldData{f1, f2, f3}, ls.datas)
}
func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) {
var err error
data := &InsertData{
Data: make(map[FieldID]FieldData),
}
// no ts
_, _, _, err = TransferColumnBasedInsertDataToRowBased(data)
assert.Error(t, err)
tss := &Int64FieldData{
Data: []int64{1, 2, 3},
}
data.Data[common.TimeStampField] = tss
// no row ids
_, _, _, err = TransferColumnBasedInsertDataToRowBased(data)
assert.Error(t, err)
rowIdsF := &Int64FieldData{
Data: []int64{1, 2, 3, 4},
}
data.Data[common.RowIDField] = rowIdsF
// row num mismatch
_, _, _, err = TransferColumnBasedInsertDataToRowBased(data)
assert.Error(t, err)
data.Data[common.RowIDField] = &Int64FieldData{
Data: []int64{1, 2, 3},
}
f1 := &BoolFieldData{
Data: []bool{true, false, true},
}
f2 := &Int8FieldData{
Data: []int8{0, 0xf, 0x1f},
}
f3 := &Int16FieldData{
Data: []int16{0, 0xff, 0x1fff},
}
f4 := &Int32FieldData{
Data: []int32{0, 0xffff, 0x1fffffff},
}
f5 := &Int64FieldData{
Data: []int64{0, 0xffffffff, 0x1fffffffffffffff},
}
f6 := &FloatFieldData{
Data: []float32{0, 0, 0},
}
f7 := &DoubleFieldData{
Data: []float64{0, 0, 0},
}
// maybe we cannot support string now, no matter what the length of string is fixed or not.
// f8 := &StringFieldData{
// Data: []string{"1", "2", "3"},
// }
f9 := &BinaryVectorFieldData{
Dim: 8,
Data: []byte{1, 2, 3},
}
f10 := &FloatVectorFieldData{
Dim: 1,
Data: []float32{0, 0, 0},
}
data.Data[101] = f1
data.Data[102] = f2
data.Data[103] = f3
data.Data[104] = f4
data.Data[105] = f5
data.Data[106] = f6
data.Data[107] = f7
// data.Data[108] = f8
data.Data[109] = f9
data.Data[110] = f10
utss, rowIds, rows, err := TransferColumnBasedInsertDataToRowBased(data)
assert.NoError(t, err)
assert.ElementsMatch(t, []uint64{1, 2, 3}, utss)
assert.ElementsMatch(t, []int64{1, 2, 3}, rowIds)
assert.Equal(t, 3, len(rows))
// b := []byte("1")[0]
if common.Endian == binary.LittleEndian {
// low byte in high address
assert.ElementsMatch(t,
[]byte{
1, // true
0, // 0
0, 0, // 0
0, 0, 0, 0, // 0
0, 0, 0, 0, 0, 0, 0, 0, // 0
0, 0, 0, 0, // 0
0, 0, 0, 0, 0, 0, 0, 0, // 0
// b + 1, // "1"
1, // 1
0, 0, 0, 0, // 0
},
rows[0].Value)
assert.ElementsMatch(t,
[]byte{
0, // false
0xf, // 0xf
0, 0xff, // 0xff
0, 0, 0xff, 0xff, // 0xffff
0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, // 0xffffffff
0, 0, 0, 0, // 0
0, 0, 0, 0, 0, 0, 0, 0, // 0
// b + 2, // "2"
2, // 2
0, 0, 0, 0, // 0
},
rows[1].Value)
assert.ElementsMatch(t,
[]byte{
1, // false
0x1f, // 0x1f
0xff, 0x1f, // 0x1fff
0xff, 0xff, 0xff, 0x1f, // 0x1fffffff
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x1f, // 0x1fffffffffffffff
0, 0, 0, 0, // 0
0, 0, 0, 0, 0, 0, 0, 0, // 0
// b + 3, // "3"
3, // 3
0, 0, 0, 0, // 0
},
rows[2].Value)
}
}