diff --git a/core/src/dog_segment/SegmentNaive.cpp b/core/src/dog_segment/SegmentNaive.cpp index 16713502a6..9c7c29b87b 100644 --- a/core/src/dog_segment/SegmentNaive.cpp +++ b/core/src/dog_segment/SegmentNaive.cpp @@ -174,11 +174,12 @@ Status SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& result) { // TODO: enable delete // TODO: enable index - auto& field = schema_->operator[](0); - assert(field.get_name() == "fakevec"); + auto& field = schema_->operator[](query->field_name); assert(field.get_data_type() == DataType::VECTOR_FLOAT); + auto dim = field.get_dim(); - assert(query == nullptr); + auto topK = query->topK; + int64_t barrier = [&] { auto& vec = record_.timestamps_; @@ -187,19 +188,20 @@ SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResu while (beg < end) { auto mid = (beg + end) / 2; if (vec[mid] < timestamp) { - end = mid + 1; + end = mid; } else { - beg = mid; + beg = mid + 1; } } return beg; }(); - // search until barriers + // TODO: optimize auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_[0]); for(int64_t i = 0; i < barrier; ++i) { -// auto element = + auto element = vec_ptr->get_element(i); + throw std::runtime_error("unimplemented"); } diff --git a/core/src/query/GeneralQuery.h b/core/src/query/GeneralQuery.h index 622aa15f55..b2b070896b 100644 --- a/core/src/query/GeneralQuery.h +++ b/core/src/query/GeneralQuery.h @@ -121,9 +121,14 @@ namespace query { // std::set index_fields; // std::unordered_map metric_types; // }; + struct Query{ - // TODO + int64_t num_queries; // + int topK; // topK of queries + std::string field_name; // must be fakevec, whose data_type must be VEC_FLOAT(DIM) + std::vector query_raw_data; // must be size of num_queries * DIM }; + using QueryPtr = std::shared_ptr; } // namespace query diff --git a/pkg/master/informer/pulsar.go b/pkg/master/informer/pulsar.go index d46485829b..2a3a5ca85a 100644 --- a/pkg/master/informer/pulsar.go +++ b/pkg/master/informer/pulsar.go @@ -39,7 +39,7 @@ func (pc PulsarClient) Listener(ssChan chan mock.SegmentStats) error { if err != nil { log.Fatal(err) } - for i := 0; i < 10; i++ { + for { msg, err := consumer.Receive(context.TODO()) if err != nil { log.Fatal(err) diff --git a/reader/collection.go b/reader/collection.go index 5d94b321cd..09991e9f9d 100644 --- a/reader/collection.go +++ b/reader/collection.go @@ -34,15 +34,3 @@ func (c *Collection) DeletePartition(partition *Partition) { // TODO: remove from c.Partitions } - -func (c *Collection) GetSegments() ([]*Segment, error) { - // TODO: add get segments - //segments, status := C.GetSegments(c.CollectionPtr) - // - //if status != 0 { - // return nil, errors.New("get segments failed") - //} - // - //return segments, nil - return nil, nil -} diff --git a/reader/message_client/message_client.go b/reader/message_client/message_client.go index 154886f6cb..3ab04497fe 100644 --- a/reader/message_client/message_client.go +++ b/reader/message_client/message_client.go @@ -2,9 +2,9 @@ package message_client import ( "context" + "github.com/apache/pulsar-client-go/pulsar" msgpb "github.com/czs007/suvlim/pkg/message" "github.com/golang/protobuf/proto" - "github.com/pulsar-client-go/pulsar" "log" ) @@ -32,7 +32,7 @@ type MessageClient struct { } func (mc *MessageClient) Send(ctx context.Context, msg msgpb.QueryResult) { - if err := mc.searchResultProducer.Send(ctx, pulsar.ProducerMessage{ + if _, err := mc.searchResultProducer.Send(ctx, &pulsar.ProducerMessage{ Payload: []byte(msg.String()), }); err != nil { log.Fatal(err) diff --git a/reader/quety_node_test.go b/reader/quety_node_test.go new file mode 100644 index 0000000000..3cf4dbc26b --- /dev/null +++ b/reader/quety_node_test.go @@ -0,0 +1,12 @@ +package reader + +import ( + "testing" +) + +// TODO: add query node test + +func TestQueryNode_RunInsertDelete(t *testing.T) { + +} + diff --git a/reader/reader_test.go b/reader/reader_test.go new file mode 100644 index 0000000000..e467f29753 --- /dev/null +++ b/reader/reader_test.go @@ -0,0 +1,9 @@ +package reader + +import ( + "testing" +) + +func TestReader_startQueryNode(t *testing.T) { + startQueryNode() +} diff --git a/reader/result_test.go b/reader/result_test.go new file mode 100644 index 0000000000..af854d29ad --- /dev/null +++ b/reader/result_test.go @@ -0,0 +1,31 @@ +package reader + +import ( + msgPb "github.com/czs007/suvlim/pkg/message" + "testing" +) + +func TestResult_PublishSearchResult(t *testing.T) { + // Construct node, collection, partition and segment + node := NewQueryNode(0, 0) + var collection = node.NewCollection("collection0", "fake schema") + var partition = collection.NewPartition("partition0") + var segment = partition.NewSegment(0) + node.SegmentsMap[0] = segment + + // TODO: start pulsar server + // TODO: fix result PublishSearchResult + const N = 10 + var entityIDs = msgPb.Entities { + Ids: make([]int64, N), + } + var results = msgPb.QueryResult { + Entities: &entityIDs, + Distances: make([]float32, N), + } + for i := 0; i < N; i++ { + results.Entities.Ids = append(results.Entities.Ids, int64(i)) + results.Distances = append(results.Distances, float32(i)) + } + node.PublishSearchResult(&results, 0) +} diff --git a/reader/segment_management_test.go b/reader/segment_management_test.go new file mode 100644 index 0000000000..f72b6f29f2 --- /dev/null +++ b/reader/segment_management_test.go @@ -0,0 +1,29 @@ +package reader + +import ( + "testing" +) + +func TestSegmentManagement_SegmentsManagement(t *testing.T) { + // Construct node, collection, partition and segment + node := NewQueryNode(0, 0) + var collection = node.NewCollection("collection0", "fake schema") + var partition = collection.NewPartition("partition0") + var segment = partition.NewSegment(0) + node.SegmentsMap[0] = segment + + // TODO: fix segment management + node.SegmentsManagement() +} + +func TestSegmentManagement_SegmentService(t *testing.T) { + // Construct node, collection, partition and segment + node := NewQueryNode(0, 0) + var collection = node.NewCollection("collection0", "fake schema") + var partition = collection.NewPartition("partition0") + var segment = partition.NewSegment(0) + node.SegmentsMap[0] = segment + + // TODO: fix segment service + node.SegmentService() +} diff --git a/reader/util_functions_test.go b/reader/util_functions_test.go index e2a6de0be3..b8ddefc6ed 100644 --- a/reader/util_functions_test.go +++ b/reader/util_functions_test.go @@ -25,10 +25,11 @@ func TestUtilFunctions_GetSegmentBySegmentID(t *testing.T) { node := NewQueryNode(0, 0) var collection = node.NewCollection("collection0", "fake schema") var partition = collection.NewPartition("partition0") - var _ = partition.NewSegment(0) + var segment = partition.NewSegment(0) + node.SegmentsMap[0] = segment // 2. Get segment by segment id var s0, err = node.GetSegmentBySegmentID(0) assert.NoError(t, err) - assert.Equal(t, s0.SegmentId, 0) + assert.Equal(t, s0.SegmentId, int64(0)) }