diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index c45a1ea2b3..80569240d0 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -546,16 +546,14 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re idSet := make(map[interface{}]struct{}) cursors := make([]int64, len(validRetrieveResults)) - realLimit := typeutil.Unlimited if queryParams != nil && queryParams.limit != typeutil.Unlimited { - realLimit = queryParams.limit if !queryParams.reduceStopForBest { loopEnd = int(queryParams.limit) } if queryParams.offset > 0 { for i := int64(0); i < queryParams.offset; i++ { - sel := typeutil.SelectMinPK(validRetrieveResults, cursors, queryParams.reduceStopForBest, realLimit) - if sel == -1 { + sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) + if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) { return ret, nil } cursors[sel]++ @@ -570,8 +568,8 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; j++ { - sel := typeutil.SelectMinPK(validRetrieveResults, cursors, reduceStopForBest, realLimit) - if sel == -1 { + sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) + if sel == -1 || (reduceStopForBest && drainOneResult) { break } diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index f3e56080c1..266e2830d7 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -600,7 +600,7 @@ func TestTaskQuery_functions(t *testing.T) { &queryParams{limit: typeutil.Unlimited, reduceStopForBest: true}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) - assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.Equal(t, []int64{11, 11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) len := len(result.GetFieldsData()[0].GetScalars().GetLongData().Data) assert.InDeltaSlice(t, resultFloat[0:(len)*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) }) diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 020e6c91b9..390f104b13 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -282,8 +282,8 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; { - sel := typeutil.SelectMinPK(validRetrieveResults, cursors, param.mergeStopForBest, param.limit) - if sel == -1 { + sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) + if sel == -1 || (param.mergeStopForBest && drainOneResult) { break } @@ -386,8 +386,8 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; j++ { - sel := typeutil.SelectMinPK(validRetrieveResults, cursors, param.mergeStopForBest, param.limit) - if sel == -1 { + sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) + if sel == -1 || (param.mergeStopForBest && drainOneResult) { break } diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index 80e7a38e6f..4965bbd463 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -524,10 +524,10 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) - suite.Equal([]int64{0, 1, 2, 3, 4, 6}, result.GetIds().GetIntId().GetData()) + suite.Equal([]int64{0, 1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) // here, we can only get best result from 0 to 4 without 6, because we can never know whether there is // one potential 5 in following result1 - suite.Equal([]int64{11, 22, 11, 22, 33, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + suite.Equal([]int64{11, 22, 11, 22, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 11, 22, 33, 44}, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) }) diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 4dd7a80ce3..56dd6600f3 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -1022,29 +1022,18 @@ type ResultWithID interface { } // SelectMinPK select the index of the minPK in results T of the cursors. -func SelectMinPK[T ResultWithID](results []T, cursors []int64, stopForBest bool, realLimit int64) int { +func SelectMinPK[T ResultWithID](results []T, cursors []int64) (int, bool) { var ( - sel = -1 - minIntPK int64 = math.MaxInt64 + sel = -1 + drainResult = false + minIntPK int64 = math.MaxInt64 firstStr = true minStrPK string ) - for i, cursor := range cursors { if int(cursor) >= GetSizeOfIDs(results[i].GetIds()) { - if realLimit == Unlimited { - // if there is no limit set and all possible results of one query unit(shard or segment) - // has drained all possible results without any leftover, so it's safe to continue the selection - // under this case - continue - } - if stopForBest && GetSizeOfIDs(results[i].GetIds()) >= int(realLimit) { - // if one query unit(shard or segment) has more than realLimit results, and it has run out of - // all results in this round, then we have to stop select since there may be further the latest result - // in the following result of current query unit - return -1 - } + drainResult = true continue } @@ -1066,5 +1055,5 @@ func SelectMinPK[T ResultWithID](results []T, cursors []int64, stopForBest bool, } } - return sel + return sel, drainResult }