diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index fb679333a9..85d0457b63 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/zilliztech/milvus-distributed/internal/types" - "github.com/stretchr/testify/assert" "github.com/zilliztech/milvus-distributed/internal/msgstream" @@ -18,6 +16,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" "github.com/zilliztech/milvus-distributed/internal/proto/querypb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" + "github.com/zilliztech/milvus-distributed/internal/types" ) const ctxTimeInMillisecond = 5000 @@ -204,6 +203,19 @@ func (q *queryServiceMock) RegisterNode(ctx context.Context, req *querypb.Regist }, nil } +func newMessageStreamFactory() (msgstream.Factory, error) { + const receiveBufSize = 1024 + + pulsarURL := Params.PulsarAddress + msFactory := msgstream.NewPmsFactory() + m := map[string]interface{}{ + "receiveBufSize": receiveBufSize, + "pulsarAddress": pulsarURL, + "pulsarBufSize": 1024} + err := msFactory.SetParams(m) + return msFactory, err +} + func TestMain(m *testing.M) { setup() refreshChannelNames() diff --git a/internal/querynode/search_service_test.go b/internal/querynode/search_service_test.go index 31e6c5738c..9df6aa9b49 100644 --- a/internal/querynode/search_service_test.go +++ b/internal/querynode/search_service_test.go @@ -3,8 +3,8 @@ package querynode import ( "context" "encoding/binary" - "log" "math" + "math/rand" "testing" "time" @@ -17,22 +17,51 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" ) -func TestSearch_Search(t *testing.T) { - collectionID := UniqueID(0) +func loadFields(segment *Segment, DIM int, N int) error { + // generate vector field + vectorFieldID := int64(100) + vectors := make([]float32, N*DIM) + for i := 0; i < N*DIM; i++ { + vectors[i] = rand.Float32() + } - node := newQueryNodeMock() - initTestMeta(t, node, 0, 0) + // generate int field + agesFieldID := int64(101) + ages := make([]int32, N) + for i := 0; i < N; i++ { + ages[i] = int32(N) + } - pulsarURL := Params.PulsarAddress + err := segment.segmentLoadFieldData(vectorFieldID, N, vectors) + if err != nil { + return err + } + err = segment.segmentLoadFieldData(agesFieldID, N, ages) + if err != nil { + return err + } + rowIDs := ages + err = segment.segmentLoadFieldData(rowIDFieldID, N, rowIDs) + return err +} - // test data generate - const msgLength = 10 - const receiveBufSize = 1024 - const DIM = 16 +func sendSearchRequest(ctx context.Context, DIM int) error { + // init message stream + msFactory, err := newMessageStreamFactory() + if err != nil { + return err + } searchProducerChannels := Params.SearchChannelNames - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - // start search service + searchStream, _ := msFactory.NewMsgStream(ctx) + searchStream.AsProducer(searchProducerChannels) + searchStream.Start() + + // generate search rawData + var vec = make([]float32, DIM) + for i := 0; i < DIM; i++ { + vec[i] = rand.Float32() + } dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" var searchRawData1 []byte var searchRawData2 []byte @@ -46,35 +75,35 @@ func TestSearch_Search(t *testing.T) { binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4))) searchRawData2 = append(searchRawData2, buf...) } + + // generate placeholder placeholderValue := milvuspb.PlaceholderValue{ Tag: "$0", Type: milvuspb.PlaceholderType_FloatVector, Values: [][]byte{searchRawData1, searchRawData2}, } - placeholderGroup := milvuspb.PlaceholderGroup{ Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, } - placeGroupByte, err := proto.Marshal(&placeholderGroup) if err != nil { - log.Print("marshal placeholderGroup failed") + return err } - query := milvuspb.SearchRequest{ + // generate searchRequest + searchReq := milvuspb.SearchRequest{ Dsl: dslString, PlaceholderGroup: placeGroupByte, } - - queryByte, err := proto.Marshal(&query) + searchReqBytes, err := proto.Marshal(&searchReq) if err != nil { - log.Print("marshal query failed") + return err } - blob := commonpb.Blob{ - Value: queryByte, + Value: searchReqBytes, } + // generate searchMsg searchMsg := &msgstream.SearchMsg{ BaseMsg: msgstream.BaseMsg{ HashValues: []uint32{0}, @@ -83,98 +112,40 @@ func TestSearch_Search(t *testing.T) { Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, MsgID: 1, - Timestamp: uint64(10 + 1000), + Timestamp: Timestamp(10), SourceID: 1, }, ResultChannelID: "0", Query: &blob, }, } - msgPackSearch := msgstream.MsgPack{} msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg) - msFactory := msgstream.NewPmsFactory() - m := map[string]interface{}{ - "receiveBufSize": receiveBufSize, - "pulsarAddress": pulsarURL, - "pulsarBufSize": 1024} - err = msFactory.SetParams(m) - assert.Nil(t, err) - - searchStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx) - searchStream.AsProducer(searchProducerChannels) - searchStream.Start() + // produce search message err = searchStream.Produce(&msgPackSearch) - assert.NoError(t, err) + return err +} - node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, msFactory) - go node.searchService.start() - node.searchService.startSearchCollection(collectionID) - - // start insert - timeRange := TimeRange{ - timestampMin: 0, - timestampMax: math.MaxUint64, - } - - insertMessages := make([]msgstream.TsMsg, 0) - for i := 0; i < msgLength; i++ { - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2))) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - - var msg msgstream.TsMsg = &msgstream.InsertMsg{ - BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{ - uint32(i), - }, - }, - InsertRequest: internalpb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: int64(i), - Timestamp: uint64(10 + 1000), - SourceID: 0, - }, - CollectionID: collectionID, - PartitionID: defaultPartitionID, - SegmentID: int64(0), - ChannelID: "0", - Timestamps: []uint64{uint64(i + 1000)}, - RowIDs: []int64{int64(i)}, - RowData: []*commonpb.Blob{ - {Value: rawData}, - }, - }, - } - insertMessages = append(insertMessages, msg) - } - - msgPack := msgstream.MsgPack{ - BeginTs: timeRange.timestampMin, - EndTs: timeRange.timestampMax, - Msgs: insertMessages, +func sendTimeTick(ctx context.Context) error { + // init message stream + msFactory, err := newMessageStreamFactory() + if err != nil { + return err } // generate timeTick timeTickMsgPack := msgstream.MsgPack{} baseMsg := msgstream.BaseMsg{ - BeginTimestamp: 0, - EndTimestamp: 0, + BeginTimestamp: Timestamp(20), + EndTimestamp: Timestamp(20), HashValues: []uint32{0}, } timeTickResult := internalpb.TimeTickMsg{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_TimeTick, MsgID: 0, - Timestamp: math.MaxUint64, + Timestamp: Timestamp(20), SourceID: 0, }, } @@ -184,239 +155,114 @@ func TestSearch_Search(t *testing.T) { } timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) - // pulsar produce + // produce timeTick message insertChannels := Params.InsertChannelNames - ddChannels := Params.DDChannelNames - - insertStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx) + insertStream, _ := msFactory.NewMsgStream(ctx) insertStream.AsProducer(insertChannels) + insertStream.Start() - ddStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx) - ddStream.AsProducer(ddChannels) + err = insertStream.Broadcast(&timeTickMsgPack) + return err +} - var insertMsgStream msgstream.MsgStream = insertStream - insertMsgStream.Start() +func TestSearch_Search(t *testing.T) { + const N = 10000 + const DIM = 16 - var ddMsgStream msgstream.MsgStream = ddStream - ddMsgStream.Start() + // init queryNode + collectionID := UniqueID(0) + segmentID := UniqueID(1) + node := newQueryNodeMock() + initTestMeta(t, node, collectionID, UniqueID(0)) - err = insertMsgStream.Produce(&msgPack) + msFactory, err := newMessageStreamFactory() assert.NoError(t, err) - err = insertMsgStream.Broadcast(&timeTickMsgPack) + // start dataSync + newDS := newDataSyncService(node.queryNodeLoopCtx, node.replica, msFactory, collectionID) + err = node.addDataSyncService(collectionID, newDS) assert.NoError(t, err) - err = ddMsgStream.Broadcast(&timeTickMsgPack) + ds, err := node.getDataSyncService(collectionID) + assert.NoError(t, err) + go ds.start() + + // start search service + node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, msFactory) + go node.searchService.start() + node.searchService.startSearchCollection(collectionID) + + tSafe := node.replica.getTSafe(collectionID) + assert.NotNil(t, tSafe) + tSafe.set(1000) + + // load segment + err = node.replica.addSegment(segmentID, defaultPartitionID, collectionID, segmentTypeSealed) + assert.NoError(t, err) + segment, err := node.replica.getSegmentByID(segmentID) + assert.NoError(t, err) + err = loadFields(segment, DIM, N) assert.NoError(t, err) - // dataSync - node.dataSyncServices[collectionID] = newDataSyncService(node.queryNodeLoopCtx, node.replica, msFactory, collectionID) - go node.dataSyncServices[collectionID].start() + err = sendSearchRequest(node.queryNodeLoopCtx, DIM) + assert.NoError(t, err) time.Sleep(1 * time.Second) - node.Stop() + err = node.Stop() + assert.NoError(t, err) } func TestSearch_SearchMultiSegments(t *testing.T) { - collectionID := UniqueID(0) - - pulsarURL := Params.PulsarAddress - const receiveBufSize = 1024 - - msFactory := msgstream.NewPmsFactory() - m := map[string]interface{}{ - "receiveBufSize": receiveBufSize, - "pulsarAddress": pulsarURL, - "pulsarBufSize": 1024} - err := msFactory.SetParams(m) - assert.Nil(t, err) - - node := NewQueryNode(context.Background(), 0, msFactory) - initTestMeta(t, node, 0, 0) - - // test data generate - const msgLength = 10 + const N = 10000 const DIM = 16 - searchProducerChannels := Params.SearchChannelNames - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - // start search service - dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" - var searchRawData1 []byte - var searchRawData2 []byte - for i, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2))) - searchRawData1 = append(searchRawData1, buf...) - } - for i, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4))) - searchRawData2 = append(searchRawData2, buf...) - } - placeholderValue := milvuspb.PlaceholderValue{ - Tag: "$0", - Type: milvuspb.PlaceholderType_FloatVector, - Values: [][]byte{searchRawData1, searchRawData2}, - } + // init queryNode + collectionID := UniqueID(0) + segmentID1 := UniqueID(1) + segmentID2 := UniqueID(2) + node := newQueryNodeMock() + initTestMeta(t, node, collectionID, UniqueID(0)) - placeholderGroup := milvuspb.PlaceholderGroup{ - Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, - } - - placeGroupByte, err := proto.Marshal(&placeholderGroup) - if err != nil { - log.Print("marshal placeholderGroup failed") - } - - query := milvuspb.SearchRequest{ - Dsl: dslString, - PlaceholderGroup: placeGroupByte, - } - - queryByte, err := proto.Marshal(&query) - if err != nil { - log.Print("marshal query failed") - } - - blob := commonpb.Blob{ - Value: queryByte, - } - - searchMsg := &msgstream.SearchMsg{ - BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{0}, - }, - SearchRequest: internalpb.SearchRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Search, - MsgID: 1, - Timestamp: uint64(10 + 1000), - SourceID: 1, - }, - ResultChannelID: "0", - Query: &blob, - }, - } - - msgPackSearch := msgstream.MsgPack{} - msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg) - - searchStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx) - searchStream.AsProducer(searchProducerChannels) - searchStream.Start() - err = searchStream.Produce(&msgPackSearch) + msFactory, err := newMessageStreamFactory() assert.NoError(t, err) + // start dataSync + newDS := newDataSyncService(node.queryNodeLoopCtx, node.replica, msFactory, collectionID) + err = node.addDataSyncService(collectionID, newDS) + assert.NoError(t, err) + ds, err := node.getDataSyncService(collectionID) + assert.NoError(t, err) + go ds.start() + + // start search service node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, msFactory) go node.searchService.start() node.searchService.startSearchCollection(collectionID) - // start insert - timeRange := TimeRange{ - timestampMin: 0, - timestampMax: math.MaxUint64, - } + tSafe := node.replica.getTSafe(collectionID) + assert.NotNil(t, tSafe) + tSafe.set(1000) - insertMessages := make([]msgstream.TsMsg, 0) - for i := 0; i < msgLength; i++ { - segmentID := 0 - if i >= msgLength/2 { - segmentID = 1 - } - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2))) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - - var msg msgstream.TsMsg = &msgstream.InsertMsg{ - BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{ - uint32(i), - }, - }, - InsertRequest: internalpb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: int64(i), - Timestamp: uint64(i + 1000), - SourceID: 0, - }, - CollectionID: collectionID, - PartitionID: defaultPartitionID, - SegmentID: int64(segmentID), - ChannelID: "0", - Timestamps: []uint64{uint64(i + 1000)}, - RowIDs: []int64{int64(i)}, - RowData: []*commonpb.Blob{ - {Value: rawData}, - }, - }, - } - insertMessages = append(insertMessages, msg) - } - - msgPack := msgstream.MsgPack{ - BeginTs: timeRange.timestampMin, - EndTs: timeRange.timestampMax, - Msgs: insertMessages, - } - - // generate timeTick - timeTickMsgPack := msgstream.MsgPack{} - baseMsg := msgstream.BaseMsg{ - BeginTimestamp: 0, - EndTimestamp: 0, - HashValues: []uint32{0}, - } - timeTickResult := internalpb.TimeTickMsg{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_TimeTick, - MsgID: 0, - Timestamp: math.MaxUint64, - SourceID: 0, - }, - } - timeTickMsg := &msgstream.TimeTickMsg{ - BaseMsg: baseMsg, - TimeTickMsg: timeTickResult, - } - timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) - - // pulsar produce - insertChannels := Params.InsertChannelNames - ddChannels := Params.DDChannelNames - - insertStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx) - insertStream.AsProducer(insertChannels) - - ddStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx) - ddStream.AsProducer(ddChannels) - - var insertMsgStream msgstream.MsgStream = insertStream - insertMsgStream.Start() - - var ddMsgStream msgstream.MsgStream = ddStream - ddMsgStream.Start() - - err = insertMsgStream.Produce(&msgPack) + // load segments + err = node.replica.addSegment(segmentID1, defaultPartitionID, collectionID, segmentTypeSealed) + assert.NoError(t, err) + segment1, err := node.replica.getSegmentByID(segmentID1) + assert.NoError(t, err) + err = loadFields(segment1, DIM, N) assert.NoError(t, err) - err = insertMsgStream.Broadcast(&timeTickMsgPack) + err = node.replica.addSegment(segmentID2, defaultPartitionID, collectionID, segmentTypeSealed) assert.NoError(t, err) - err = ddMsgStream.Broadcast(&timeTickMsgPack) + segment2, err := node.replica.getSegmentByID(segmentID2) + assert.NoError(t, err) + err = loadFields(segment2, DIM, N) assert.NoError(t, err) - // dataSync - node.dataSyncServices[collectionID] = newDataSyncService(node.queryNodeLoopCtx, node.replica, msFactory, collectionID) - go node.dataSyncServices[collectionID].start() + err = sendSearchRequest(node.queryNodeLoopCtx, DIM) + assert.NoError(t, err) time.Sleep(1 * time.Second) - node.Stop() + err = node.Stop() + assert.NoError(t, err) }