enhance: retrieve output fields after local reduce (#32346)

issue: #31822

---------

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2024-04-25 09:49:26 +08:00 committed by GitHub
parent 5119292411
commit c002745902
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 441 additions and 57 deletions

View File

@ -67,4 +67,24 @@ GetTraceIDAsVector(const TraceContext* ctx);
std::vector<uint8_t>
GetSpanIDAsVector(const TraceContext* ctx);
struct AutoSpan {
explicit AutoSpan(const std::string& name,
TraceContext* ctx = nullptr,
bool is_root_span = false) {
span_ = StartSpan(name, ctx);
if (is_root_span) {
SetRootSpan(span_);
}
}
~AutoSpan() {
if (span_ != nullptr) {
span_->End();
}
}
private:
std::shared_ptr<trace::Span> span_;
};
} // namespace milvus::tracer

View File

@ -83,6 +83,14 @@ std::unique_ptr<proto::segcore::RetrieveResults>
SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
Timestamp timestamp,
int64_t limit_size) const {
return Retrieve(plan, timestamp, limit_size, false);
}
std::unique_ptr<proto::segcore::RetrieveResults>
SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
Timestamp timestamp,
int64_t limit_size,
bool ignore_non_pk) const {
std::shared_lock lck(mutex_);
auto results = std::make_unique<proto::segcore::RetrieveResults>();
query::ExecPlanNodeVisitor visitor(*this, timestamp);
@ -110,21 +118,39 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
results->mutable_offset()->Add(retrieve_results.result_offsets_.begin(),
retrieve_results.result_offsets_.end());
FillTargetEntry(plan,
results,
retrieve_results.result_offsets_.data(),
retrieve_results.result_offsets_.size(),
ignore_non_pk,
true);
return results;
}
void
SegmentInternalInterface::FillTargetEntry(
const query::RetrievePlan* plan,
const std::unique_ptr<proto::segcore::RetrieveResults>& results,
const int64_t* offsets,
int64_t size,
bool ignore_non_pk,
bool fill_ids) const {
auto fields_data = results->mutable_fields_data();
auto ids = results->mutable_ids();
auto pk_field_id = plan->schema_.get_primary_field_id();
auto is_pk_field = [&, pk_field_id](const FieldId& field_id) -> bool {
return pk_field_id.has_value() && pk_field_id.value() == field_id;
};
for (auto field_id : plan->field_ids_) {
if (SystemProperty::Instance().IsSystem(field_id)) {
auto system_type =
SystemProperty::Instance().GetSystemFieldType(field_id);
auto size = retrieve_results.result_offsets_.size();
FixedVector<int64_t> output(size);
bulk_subscript(system_type,
retrieve_results.result_offsets_.data(),
size,
output.data());
bulk_subscript(system_type, offsets, size, output.data());
auto data_array = std::make_unique<DataArray>();
data_array->set_field_id(field_id.get());
@ -138,18 +164,21 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
continue;
}
if (ignore_non_pk && !is_pk_field(field_id)) {
continue;
}
auto& field_meta = plan->schema_[field_id];
auto col = bulk_subscript(field_id,
retrieve_results.result_offsets_.data(),
retrieve_results.result_offsets_.size());
auto col = bulk_subscript(field_id, offsets, size);
if (field_meta.get_data_type() == DataType::ARRAY) {
col->mutable_scalars()->mutable_array_data()->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
}
auto col_data = col.release();
fields_data->AddAllocated(col_data);
if (pk_field_id.has_value() && pk_field_id.value() == field_id) {
if (fill_ids && is_pk_field(field_id)) {
// fill_ids should be true when the first Retrieve was called. The reduce phase depends on the ids to do
// merge-sort.
auto col_data = col.get();
switch (field_meta.get_data_type()) {
case DataType::INT64: {
auto int_ids = ids->mutable_int_id();
@ -173,7 +202,25 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
}
}
}
if (!ignore_non_pk) {
// when ignore_non_pk is false, it indicates two situations:
// 1. No need to do the two-phase Retrieval, the target entries should be returned as the first Retrieval
// is done, below two cases are included:
// a. There is only one segment;
// b. No pagination is used;
// 2. The FillTargetEntry was called by the second Retrieval (by offsets).
fields_data->AddAllocated(col.release());
}
}
}
std::unique_ptr<proto::segcore::RetrieveResults>
SegmentInternalInterface::Retrieve(const query::RetrievePlan* Plan,
const int64_t* offsets,
int64_t size) const {
std::shared_lock lck(mutex_);
auto results = std::make_unique<proto::segcore::RetrieveResults>();
FillTargetEntry(Plan, results, offsets, size, false, false);
return results;
}

View File

@ -69,6 +69,17 @@ class SegmentInterface {
Timestamp timestamp,
int64_t limit_size) const = 0;
virtual std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* Plan,
Timestamp timestamp,
int64_t limit_size,
bool ignore_non_pk) const = 0;
virtual std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* Plan,
const int64_t* offsets,
int64_t size) const = 0;
virtual size_t
GetMemoryUsageInBytes() const = 0;
@ -159,6 +170,17 @@ class SegmentInternalInterface : public SegmentInterface {
Timestamp timestamp,
int64_t limit_size) const override;
std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* Plan,
Timestamp timestamp,
int64_t limit_size,
bool ignore_non_pk) const override;
std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* Plan,
const int64_t* offsets,
int64_t size) const override;
virtual bool
HasIndex(FieldId field_id) const = 0;
@ -279,6 +301,15 @@ class SegmentInternalInterface : public SegmentInterface {
const BitsetType& bitset,
bool false_filtered_out) const = 0;
void
FillTargetEntry(
const query::RetrievePlan* plan,
const std::unique_ptr<proto::segcore::RetrieveResults>& results,
const int64_t* offsets,
int64_t size,
bool ignore_non_pk,
bool fill_ids) const;
protected:
// internal API: return chunk_data in span
virtual SpanBase

View File

@ -170,3 +170,13 @@ DeleteRetrievePlan(CRetrievePlan c_plan) {
auto plan = static_cast<milvus::query::RetrievePlan*>(c_plan);
delete plan;
}
bool
ShouldIgnoreNonPk(CRetrievePlan c_plan) {
auto plan = static_cast<milvus::query::RetrievePlan*>(c_plan);
auto pk_field = plan->schema_.get_primary_field_id();
auto only_contain_pk = pk_field.has_value() &&
plan->field_ids_.size() == 1 &&
pk_field.value() == plan->field_ids_[0];
return !only_contain_pk;
}

View File

@ -68,6 +68,9 @@ CreateRetrievePlanByExpr(CCollection c_col,
void
DeleteRetrievePlan(CRetrievePlan plan);
bool
ShouldIgnoreNonPk(CRetrievePlan plan);
#ifdef __cplusplus
}
#endif

View File

@ -130,7 +130,8 @@ Retrieve(CTraceContext c_trace,
CRetrievePlan c_plan,
uint64_t timestamp,
CRetrieveResult* result,
int64_t limit_size) {
int64_t limit_size,
bool ignore_non_pk) {
try {
auto segment =
static_cast<milvus::segcore::SegmentInterface*>(c_segment);
@ -138,19 +139,50 @@ Retrieve(CTraceContext c_trace,
auto trace_ctx = milvus::tracer::TraceContext{
c_trace.traceID, c_trace.spanID, c_trace.traceFlags};
auto span = milvus::tracer::StartSpan("SegCoreRetrieve", &trace_ctx);
milvus::tracer::SetRootSpan(span);
milvus::tracer::AutoSpan span("SegCoreRetrieve", &trace_ctx, true);
auto retrieve_result = segment->Retrieve(plan, timestamp, limit_size);
auto retrieve_result =
segment->Retrieve(plan, timestamp, limit_size, ignore_non_pk);
auto size = retrieve_result->ByteSizeLong();
void* buffer = malloc(size);
retrieve_result->SerializePartialToArray(buffer, size);
std::unique_ptr<uint8_t[]> buffer(new uint8_t[size]);
retrieve_result->SerializePartialToArray(buffer.get(), size);
result->proto_blob = buffer;
result->proto_blob = buffer.release();
result->proto_size = size;
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(&e);
}
}
CStatus
RetrieveByOffsets(CTraceContext c_trace,
CSegmentInterface c_segment,
CRetrievePlan c_plan,
CRetrieveResult* result,
int64_t* offsets,
int64_t len) {
try {
auto segment =
static_cast<milvus::segcore::SegmentInterface*>(c_segment);
auto plan = static_cast<const milvus::query::RetrievePlan*>(c_plan);
auto trace_ctx = milvus::tracer::TraceContext{
c_trace.traceID, c_trace.spanID, c_trace.traceFlags};
milvus::tracer::AutoSpan span(
"SegCoreRetrieveByOffsets", &trace_ctx, true);
auto retrieve_result = segment->Retrieve(plan, offsets, len);
auto size = retrieve_result->ByteSizeLong();
std::unique_ptr<uint8_t[]> buffer(new uint8_t[size]);
retrieve_result->SerializePartialToArray(buffer.get(), size);
result->proto_blob = buffer.release();
result->proto_size = size;
span->End();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(&e);

View File

@ -60,7 +60,16 @@ Retrieve(CTraceContext c_trace,
CRetrievePlan c_plan,
uint64_t timestamp,
CRetrieveResult* result,
int64_t limit_size);
int64_t limit_size,
bool ignore_non_pk);
CStatus
RetrieveByOffsets(CTraceContext c_trace,
CSegmentInterface c_segment,
CRetrievePlan c_plan,
CRetrieveResult* result,
int64_t* offsets,
int64_t len);
int64_t
GetMemoryUsageInBytes(CSegmentInterface c_segment);

View File

@ -74,8 +74,13 @@ CRetrieve(CSegmentInterface c_segment,
CRetrievePlan c_plan,
uint64_t timestamp,
CRetrieveResult* result) {
return Retrieve(
{}, c_segment, c_plan, timestamp, result, DEFAULT_MAX_OUTPUT_SIZE);
return Retrieve({},
c_segment,
c_plan,
timestamp,
result,
DEFAULT_MAX_OUTPUT_SIZE,
false);
}
const char*

View File

@ -33,7 +33,7 @@ func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveR
type cntReducerSegCore struct{}
func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, _ []Segment, _ *RetrievePlan) (*segcorepb.RetrieveResults, error) {
cnt := int64(0)
allRetrieveCount := int64(0)
for _, res := range results {

View File

@ -76,7 +76,7 @@ func (suite *SegCoreCntReducerSuite) TestInvalid() {
},
}
_, err := suite.r.Reduce(context.TODO(), results)
_, err := suite.r.Reduce(context.TODO(), results, nil, nil)
suite.Error(err)
}
@ -88,7 +88,7 @@ func (suite *SegCoreCntReducerSuite) TestNormalCase() {
funcutil.WrapCntToSegCoreResult(4),
}
res, err := suite.r.Reduce(context.TODO(), results)
res, err := suite.r.Reduce(context.TODO(), results, nil, nil)
suite.NoError(err)
total, err := funcutil.CntOfSegCoreResult(res)

View File

@ -48,9 +48,9 @@ type defaultLimitReducerSegcore struct {
schema *schemapb.CollectionSchema
}
func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, segments []Segment, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, r.req.GetReq().GetReduceStopForBest())
return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, mergeParam)
return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, mergeParam, segments, plan)
}
func newDefaultLimitReducerSegcore(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) *defaultLimitReducerSegcore {

View File

@ -1163,6 +1163,62 @@ func (_c *MockSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *Ret
return _c
}
// RetrieveByOffsets provides a mock function with given fields: ctx, plan, offsets
func (_m *MockSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) {
ret := _m.Called(ctx, plan, offsets)
var r0 *segcorepb.RetrieveResults
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan, []int64) (*segcorepb.RetrieveResults, error)); ok {
return rf(ctx, plan, offsets)
}
if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan, []int64) *segcorepb.RetrieveResults); ok {
r0 = rf(ctx, plan, offsets)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcorepb.RetrieveResults)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *RetrievePlan, []int64) error); ok {
r1 = rf(ctx, plan, offsets)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockSegment_RetrieveByOffsets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByOffsets'
type MockSegment_RetrieveByOffsets_Call struct {
*mock.Call
}
// RetrieveByOffsets is a helper method to define mock.On call
// - ctx context.Context
// - plan *RetrievePlan
// - offsets []int64
func (_e *MockSegment_Expecter) RetrieveByOffsets(ctx interface{}, plan interface{}, offsets interface{}) *MockSegment_RetrieveByOffsets_Call {
return &MockSegment_RetrieveByOffsets_Call{Call: _e.mock.On("RetrieveByOffsets", ctx, plan, offsets)}
}
func (_c *MockSegment_RetrieveByOffsets_Call) Run(run func(ctx context.Context, plan *RetrievePlan, offsets []int64)) *MockSegment_RetrieveByOffsets_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*RetrievePlan), args[2].([]int64))
})
return _c
}
func (_c *MockSegment_RetrieveByOffsets_Call) Return(_a0 *segcorepb.RetrieveResults, _a1 error) *MockSegment_RetrieveByOffsets_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockSegment_RetrieveByOffsets_Call) RunAndReturn(run func(context.Context, *RetrievePlan, []int64) (*segcorepb.RetrieveResults, error)) *MockSegment_RetrieveByOffsets_Call {
_c.Call.Return(run)
return _c
}
// RowNum provides a mock function with given fields:
func (_m *MockSegment) RowNum() int64 {
ret := _m.Called()

View File

@ -172,6 +172,7 @@ type RetrievePlan struct {
cRetrievePlan C.CRetrievePlan
Timestamp Timestamp
msgID UniqueID // only used to debug.
ignoreNonPk bool
}
func NewRetrievePlan(ctx context.Context, col *Collection, expr []byte, timestamp Timestamp, msgID UniqueID) (*RetrievePlan, error) {
@ -198,6 +199,10 @@ func NewRetrievePlan(ctx context.Context, col *Collection, expr []byte, timestam
return newPlan, nil
}
func (plan *RetrievePlan) ShouldIgnoreNonPk() bool {
return bool(C.ShouldIgnoreNonPk(plan.cRetrievePlan))
}
func (plan *RetrievePlan) Delete() {
C.DeleteRetrievePlan(plan.cRetrievePlan)
}

View File

@ -21,7 +21,7 @@ func CreateInternalReducer(req *querypb.QueryRequest, schema *schemapb.Collectio
}
type segCoreReducer interface {
Reduce(context.Context, []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error)
Reduce(context.Context, []*segcorepb.RetrieveResults, []Segment, *RetrievePlan) (*segcorepb.RetrieveResults, error)
}
func CreateSegCoreReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) segCoreReducer {

View File

@ -23,6 +23,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -31,6 +32,7 @@ import (
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -475,7 +477,10 @@ func getTS(i *internalpb.RetrieveResults, idx int64) uint64 {
return 0
}
func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, param *mergeParam) (*segcorepb.RetrieveResults, error) {
func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, param *mergeParam, segments []Segment, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
ctx, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults")
defer span.End()
log.Ctx(ctx).Debug("mergeSegcoreRetrieveResults",
zap.Int64("limit", param.limit),
zap.Int("resultNum", len(retrieveResults)),
@ -490,7 +495,10 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
)
validRetrieveResults := []*segcorepb.RetrieveResults{}
for _, r := range retrieveResults {
validSegments := make([]Segment, 0, len(segments))
selectedOffsets := make([][]int64, 0, len(retrieveResults))
selectedIndexes := make([][]int64, 0, len(retrieveResults))
for i, r := range retrieveResults {
size := typeutil.GetSizeOfIDs(r.GetIds())
ret.AllRetrieveCount += r.GetAllRetrieveCount()
if r == nil || len(r.GetOffset()) == 0 || size == 0 {
@ -498,6 +506,11 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
continue
}
validRetrieveResults = append(validRetrieveResults, r)
if plan.ignoreNonPk {
validSegments = append(validSegments, segments[i])
}
selectedOffsets = append(selectedOffsets, make([]int64, 0, len(r.GetOffset())))
selectedIndexes = append(selectedIndexes, make([]int64, 0, len(r.GetOffset())))
loopEnd += size
}
@ -505,11 +518,12 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
return ret, nil
}
selected := make([]int, 0, ret.GetAllRetrieveCount())
if param.limit != typeutil.Unlimited && !param.mergeStopForBest {
loopEnd = int(param.limit)
}
ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData()))
idSet := make(map[interface{}]struct{})
cursors := make([]int64, len(validRetrieveResults))
@ -524,18 +538,15 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
if _, ok := idSet[pk]; !ok {
typeutil.AppendPKs(ret.Ids, pk)
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
selected = append(selected, sel)
selectedOffsets[sel] = append(selectedOffsets[sel], validRetrieveResults[sel].GetOffset()[cursors[sel]])
selectedIndexes[sel] = append(selectedIndexes[sel], cursors[sel])
idSet[pk] = struct{}{}
} else {
// primary keys duplicate
skipDupCnt++
}
// limit retrieve result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize)
}
cursors[sel]++
}
@ -543,6 +554,72 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
log.Debug("skip duplicated query result while reducing segcore.RetrieveResults", zap.Int64("dupCount", skipDupCnt))
}
if !plan.ignoreNonPk {
// target entry already retrieved, don't do this after AppendPKs for better performance. Save the cost everytime
// judge the `!plan.ignoreNonPk` condition.
_, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
defer span2.End()
ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData()))
cursors = make([]int64, len(validRetrieveResults))
for _, sel := range selected {
// cannot use `cursors[sel]` directly, since some of them may be skipped.
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), selectedIndexes[sel][cursors[sel]])
// limit retrieve result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize)
}
cursors[sel]++
}
} else {
// target entry not retrieved.
ctx, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-RetrieveByOffsets-AppendFieldData")
defer span2.End()
segmentResults := make([]*segcorepb.RetrieveResults, len(validRetrieveResults))
futures := make([]*conc.Future[any], 0, len(validRetrieveResults))
for i, offsets := range selectedOffsets {
if len(offsets) == 0 {
log.Ctx(ctx).Debug("skip empty retrieve results", zap.Int64("segment", validSegments[i].ID()))
continue
}
idx, theOffsets := i, offsets
future := GetSQPool().Submit(func() (any, error) {
r, err := validSegments[idx].RetrieveByOffsets(ctx, plan, theOffsets)
if err != nil {
return nil, err
}
segmentResults[idx] = r
return nil, nil
})
futures = append(futures, future)
}
if err := conc.AwaitAll(futures...); err != nil {
return nil, err
}
for _, r := range segmentResults {
if len(r.GetFieldsData()) != 0 {
ret.FieldsData = make([]*schemapb.FieldData, len(r.GetFieldsData()))
break
}
}
_, span3 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
defer span3.End()
cursors = make([]int64, len(segmentResults))
for _, sel := range selected {
retSize += typeutil.AppendFieldData(ret.FieldsData, segmentResults[sel].GetFieldsData(), cursors[sel])
// limit retrieve result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize)
}
cursors[sel]++
}
}
return ret, nil
}
@ -567,8 +644,10 @@ func mergeSegcoreRetrieveResultsAndFillIfEmpty(
ctx context.Context,
retrieveResults []*segcorepb.RetrieveResults,
param *mergeParam,
segments []Segment,
plan *RetrievePlan,
) (*segcorepb.RetrieveResults, error) {
mergedResult, err := MergeSegcoreRetrieveResults(ctx, retrieveResults, param)
mergedResult, err := MergeSegcoreRetrieveResults(ctx, retrieveResults, param, segments, plan)
if err != nil {
return nil, err
}

View File

@ -37,6 +37,11 @@ type ResultSuite struct {
suite.Suite
}
func MergeSegcoreRetrieveResultsV1(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, param *mergeParam) (*segcorepb.RetrieveResults, error) {
plan := &RetrievePlan{ignoreNonPk: false}
return MergeSegcoreRetrieveResults(ctx, retrieveResults, param, nil, plan)
}
func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
const (
Dim = 8
@ -80,7 +85,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
FieldsData: fieldDataArray2,
}
result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
suite.NoError(err)
suite.Equal(2, len(result.GetFieldsData()))
@ -90,7 +95,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
})
suite.Run("test nil results", func() {
ret, err := MergeSegcoreRetrieveResults(context.Background(), nil,
ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), nil,
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
suite.NoError(err)
suite.Empty(ret.GetIds())
@ -109,7 +114,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
FieldsData: fieldDataArray1,
}
ret, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r},
ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
suite.NoError(err)
suite.Empty(ret.GetIds())
@ -161,7 +166,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
resultField0 := []int64{11, 11, 22, 22}
for _, test := range tests {
suite.Run(test.description, func() {
result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
NewMergeParam(test.limit, make([]int64, 0), nil, false))
suite.Equal(2, len(result.GetFieldsData()))
suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData()))
@ -197,14 +202,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
FieldsData: []*schemapb.FieldData{fieldData},
}
_, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result},
_, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result},
NewMergeParam(reqLimit, make([]int64, 0), nil, false))
suite.Error(err)
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
})
suite.Run("test int ID", func() {
result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
suite.Equal(2, len(result.GetFieldsData()))
suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData())
@ -230,7 +235,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
},
}
result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
suite.NoError(err)
suite.Equal(2, len(result.GetFieldsData()))
@ -508,7 +513,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
FieldsData: fieldDataArray2,
}
suite.Run("merge stop finite limited", func() {
result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
NewMergeParam(3, make([]int64, 0), nil, true))
suite.NoError(err)
suite.Equal(2, len(result.GetFieldsData()))
@ -520,7 +525,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
})
suite.Run("merge stop unlimited", func() {
result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true))
suite.NoError(err)
suite.Equal(2, len(result.GetFieldsData()))

View File

@ -34,17 +34,24 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// retrieveOnSegments performs retrieve on listed segments
// all segment ids are validated before calling this function
func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, segType SegmentType, plan *RetrievePlan) ([]*segcorepb.RetrieveResults, error) {
func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, segType SegmentType, plan *RetrievePlan, req *querypb.QueryRequest) ([]*segcorepb.RetrieveResults, []Segment, error) {
type segmentResult struct {
result *segcorepb.RetrieveResults
segment Segment
}
var (
resultCh = make(chan *segcorepb.RetrieveResults, len(segments))
resultCh = make(chan segmentResult, len(segments))
errs = make([]error, len(segments))
wg sync.WaitGroup
)
plan.ignoreNonPk = len(segments) > 1 && req.GetReq().GetLimit() != typeutil.Unlimited && plan.ShouldIgnoreNonPk()
label := metrics.SealedSegmentLabel
if segType == commonpb.SegmentState_Growing {
label = metrics.GrowingSegmentLabel
@ -53,7 +60,10 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
retriever := func(s Segment) error {
tr := timerecord.NewTimeRecorder("retrieveOnSegments")
result, err := s.Retrieve(ctx, plan)
resultCh <- result
resultCh <- segmentResult{
result,
s,
}
if err != nil {
return err
}
@ -92,16 +102,18 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
for _, err := range errs {
if err != nil {
return nil, err
return nil, nil, err
}
}
var retrieveSegments []Segment
var retrieveResults []*segcorepb.RetrieveResults
for result := range resultCh {
retrieveResults = append(retrieveResults, result)
retrieveSegments = append(retrieveSegments, result.segment)
retrieveResults = append(retrieveResults, result.result)
}
return retrieveResults, nil
return retrieveResults, retrieveSegments, nil
}
func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segType SegmentType, plan *RetrievePlan, svr streamrpc.QueryStreamServer) error {
@ -172,8 +184,7 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu
return retrieveResults, retrieveSegments, err
}
retrieveResults, err = retrieveOnSegments(ctx, manager, retrieveSegments, SegType, plan)
return retrieveResults, retrieveSegments, err
return retrieveOnSegments(ctx, manager, retrieveSegments, SegType, plan, req)
}
// retrieveStreaming will retrieve all the validate target segments and return by stream

View File

@ -533,6 +533,7 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
zap.Int64("msgID", plan.msgID),
zap.String("segmentType", s.segmentType.String()),
)
log.Debug("begin to retrieve")
traceCtx := ParseCTraceContext(ctx)
@ -547,7 +548,8 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
plan.cRetrievePlan,
ts,
&retrieveResult.cRetrieveResult,
C.int64_t(maxLimitSize))
C.int64_t(maxLimitSize),
C.bool(plan.ignoreNonPk))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
@ -564,6 +566,9 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
return nil, err
}
_, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "partial-segcore-results-deserialization")
defer span.End()
result := new(segcorepb.RetrieveResults)
if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil {
return nil, err
@ -578,6 +583,67 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
return result, nil
}
func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
// TODO: check if the segment is readable but not released. too many related logic need to be refactor.
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
if len(offsets) == 0 {
return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets")
}
fields := []zap.Field{
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()),
zap.Int64("msgID", plan.msgID),
zap.String("segmentType", s.segmentType.String()),
zap.Int("resultNum", len(offsets)),
}
log := log.Ctx(ctx).With(fields...)
log.Debug("begin to retrieve by offsets")
traceCtx := ParseCTraceContext(ctx)
var retrieveResult RetrieveResult
var status C.CStatus
tr := timerecord.NewTimeRecorder("cgoRetrieveByOffsets")
status = C.RetrieveByOffsets(traceCtx,
s.ptr,
plan.cRetrievePlan,
&retrieveResult.cRetrieveResult,
(*C.int64_t)(unsafe.Pointer(&offsets[0])),
C.int64_t(len(offsets)))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("cgo retrieve by offsets done", zap.Duration("timeTaken", tr.ElapseSpan()))
if err := HandleCStatus(ctx, &status, "RetrieveByOffsets failed", fields...); err != nil {
return nil, err
}
_, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "reduced-segcore-results-deserialization")
defer span.End()
result := new(segcorepb.RetrieveResults)
if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil {
return nil, err
}
log.Debug("retrieve by segment offsets done")
return result, nil
}
func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) (dataPath string, offsetInBinlog int64) {
offsetInBinlog = offset
for _, binlog := range index.FieldBinlog.Binlogs {

View File

@ -96,6 +96,7 @@ type Segment interface {
// Read operations
Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error)
Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error)
RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error)
IsLazyLoad() bool
ResetIndexesLazyLoad(lazyState bool)
}

View File

@ -136,6 +136,10 @@ func (s *L0Segment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorep
return nil, nil
}
func (s *L0Segment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) {
return nil, nil
}
func (s *L0Segment) Insert(ctx context.Context, rowIDs []int64, timestamps []typeutil.Timestamp, record *segcorepb.InsertRecord) error {
return merr.WrapErrIoFailedReason("insert not supported for L0 segment")
}

View File

@ -120,7 +120,7 @@ func (t *QueryTask) Execute() error {
t.collection.Schema(),
)
beforeReduce := time.Now()
reducedResult, err := reducer.Reduce(t.ctx, results)
reducedResult, err := reducer.Reduce(t.ctx, results, querySegments, retrievePlan)
metrics.QueryNodeReduceLatency.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),