fix: Fix improper use of offset in HybridSearch (#36253)

pr : https://github.com/milvus-io/milvus/pull/36244
issue : https://github.com/milvus-io/milvus/issues/36243

Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
This commit is contained in:
zhenshan.cao 2024-09-13 22:05:14 +08:00 committed by GitHub
parent faf5be2e72
commit 34e5f99bd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 22 deletions

View File

@ -26,32 +26,73 @@ type rankParams struct {
roundDecimal int64
}
// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, ignoreOffset bool) (*planpb.QueryInfo, int64, error) {
// 0. parse iterator field
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
func (r *rankParams) GetLimit() int64 {
if r != nil {
return r.limit
}
return 0
}
// 1. parse offset and real topk
func (r *rankParams) GetOffset() int64 {
if r != nil {
return r.offset
}
return 0
}
func (r *rankParams) GetRoundDecimal() int64 {
if r != nil {
return r.roundDecimal
}
return 0
}
func (r *rankParams) String() string {
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
}
// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) {
var topK int64
isAdvanced := rankParams != nil
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
return nil, 0, errors.New(TopKKey + " not found in search_params")
}
topK, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s is required", TopKKey)
}
topK = externalLimit
} else {
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
}
topK = externalLimit
} else {
if topKInParam < externalLimit {
topK = externalLimit
} else {
topK = topKInParam
}
}
}
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
if err := validateLimit(topK); err != nil {
if isIterator == "True" {
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
} else {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
}
}
var offset int64
if !ignoreOffset {
// ignore offset if isAdvanced
if !isAdvanced {
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)

View File

@ -170,8 +170,11 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if t.SearchRequest.GetIsAdvanced() {
t.rankParams, err = parseRankParams(t.request.GetSearchParams())
if err != nil {
log.Info("parseRankParams failed", zap.Error(err))
return err
}
} else {
t.rankParams = nil
}
// Manually update nq if not set.
nq, err := t.checkNq(ctx)
@ -343,11 +346,12 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
for index, subReq := range t.request.GetSubReqs() {
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), true)
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl())
if err != nil {
return err
}
@ -423,7 +427,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl(), false)
plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl())
if err != nil {
return err
}
@ -469,7 +473,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
return nil
}
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, ignoreOffset bool) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
if err != nil || len(annsFieldName) == 0 {
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
@ -482,7 +486,7 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
}
annsFieldName = vecFields[0].Name
}
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, ignoreOffset)
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
if parseErr != nil {
return nil, nil, 0, parseErr
}

View File

@ -1935,7 +1935,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
assert.NoError(t, task.Execute(ctx))
}
func TestTaskSearch_parseQueryInfo(t *testing.T) {
func TestTaskSearch_parseSearchInfo(t *testing.T) {
t.Run("parseSearchInfo no error", func(t *testing.T) {
var targetOffset int64 = 200
@ -1971,7 +1971,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.validParams, nil, false)
info, offset, err := parseSearchInfo(test.validParams, nil, nil)
assert.NoError(t, err)
assert.NotNil(t, info)
if test.description == "offsetParam" {
@ -1981,6 +1981,24 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
}
})
t.Run("parseSearchInfo externalLimit", func(t *testing.T) {
var externalLimit int64 = 200
offsetParam := getValidSearchParams()
offsetParam = append(offsetParam, &commonpb.KeyValuePair{
Key: OffsetKey,
Value: strconv.FormatInt(10, 10),
})
rank := &rankParams{
limit: externalLimit,
}
info, offset, err := parseSearchInfo(offsetParam, nil, rank)
assert.NoError(t, err)
assert.NotNil(t, info)
assert.Equal(t, externalLimit, info.GetTopk())
assert.Equal(t, int64(0), offset)
})
t.Run("parseSearchInfo error", func(t *testing.T) {
spNoTopk := []*commonpb.KeyValuePair{{
Key: AnnsFieldKey,
@ -2060,7 +2078,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.invalidParams, nil, false)
info, offset, err := parseSearchInfo(test.invalidParams, nil, nil)
assert.Error(t, err)
assert.Nil(t, info)
assert.Zero(t, offset)
@ -2087,7 +2105,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, false)
info, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
@ -2106,7 +2124,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, false)
info, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
@ -2125,7 +2143,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, false)
info, _, err := parseSearchInfo(normalParam, schema, nil)
assert.NotNil(t, info)
assert.NoError(t, err)
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk)