mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
fix: Fix improper use of offset in HybridSearch (#36253)
pr : https://github.com/milvus-io/milvus/pull/36244 issue : https://github.com/milvus-io/milvus/issues/36243 Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
This commit is contained in:
parent
faf5be2e72
commit
34e5f99bd6
@ -26,32 +26,73 @@ type rankParams struct {
|
||||
roundDecimal int64
|
||||
}
|
||||
|
||||
// parseSearchInfo returns QueryInfo and offset
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, ignoreOffset bool) (*planpb.QueryInfo, int64, error) {
|
||||
// 0. parse iterator field
|
||||
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||
func (r *rankParams) GetLimit() int64 {
|
||||
if r != nil {
|
||||
return r.limit
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// 1. parse offset and real topk
|
||||
func (r *rankParams) GetOffset() int64 {
|
||||
if r != nil {
|
||||
return r.offset
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *rankParams) GetRoundDecimal() int64 {
|
||||
if r != nil {
|
||||
return r.roundDecimal
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *rankParams) String() string {
|
||||
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
|
||||
}
|
||||
|
||||
// parseSearchInfo returns QueryInfo and offset
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) {
|
||||
var topK int64
|
||||
isAdvanced := rankParams != nil
|
||||
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
|
||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
||||
if err != nil {
|
||||
return nil, 0, errors.New(TopKKey + " not found in search_params")
|
||||
}
|
||||
topK, err := strconv.ParseInt(topKStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
if externalLimit <= 0 {
|
||||
return nil, 0, fmt.Errorf("%s is required", TopKKey)
|
||||
}
|
||||
topK = externalLimit
|
||||
} else {
|
||||
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
|
||||
if err != nil {
|
||||
if externalLimit <= 0 {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
}
|
||||
topK = externalLimit
|
||||
} else {
|
||||
if topKInParam < externalLimit {
|
||||
topK = externalLimit
|
||||
} else {
|
||||
topK = topKInParam
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||
|
||||
if err := validateLimit(topK); err != nil {
|
||||
if isIterator == "True" {
|
||||
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
|
||||
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
|
||||
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
|
||||
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
|
||||
} else {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
|
||||
}
|
||||
}
|
||||
|
||||
var offset int64
|
||||
if !ignoreOffset {
|
||||
// ignore offset if isAdvanced
|
||||
if !isAdvanced {
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
|
@ -170,8 +170,11 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
t.rankParams, err = parseRankParams(t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
log.Info("parseRankParams failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.rankParams = nil
|
||||
}
|
||||
// Manually update nq if not set.
|
||||
nq, err := t.checkNq(ctx)
|
||||
@ -343,11 +346,12 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
|
||||
// fetch search_growing from search param
|
||||
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
|
||||
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
|
||||
for index, subReq := range t.request.GetSubReqs() {
|
||||
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), true)
|
||||
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -423,7 +427,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
// fetch search_growing from search param
|
||||
|
||||
plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl(), false)
|
||||
plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -469,7 +473,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, ignoreOffset bool) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
|
||||
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
|
||||
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
|
||||
if err != nil || len(annsFieldName) == 0 {
|
||||
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
|
||||
@ -482,7 +486,7 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
|
||||
}
|
||||
annsFieldName = vecFields[0].Name
|
||||
}
|
||||
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, ignoreOffset)
|
||||
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
|
||||
if parseErr != nil {
|
||||
return nil, nil, 0, parseErr
|
||||
}
|
||||
|
@ -1935,7 +1935,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
assert.NoError(t, task.Execute(ctx))
|
||||
}
|
||||
|
||||
func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
||||
t.Run("parseSearchInfo no error", func(t *testing.T) {
|
||||
var targetOffset int64 = 200
|
||||
|
||||
@ -1971,7 +1971,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, offset, err := parseSearchInfo(test.validParams, nil, false)
|
||||
info, offset, err := parseSearchInfo(test.validParams, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
if test.description == "offsetParam" {
|
||||
@ -1981,6 +1981,24 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parseSearchInfo externalLimit", func(t *testing.T) {
|
||||
var externalLimit int64 = 200
|
||||
offsetParam := getValidSearchParams()
|
||||
offsetParam = append(offsetParam, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Value: strconv.FormatInt(10, 10),
|
||||
})
|
||||
rank := &rankParams{
|
||||
limit: externalLimit,
|
||||
}
|
||||
|
||||
info, offset, err := parseSearchInfo(offsetParam, nil, rank)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
assert.Equal(t, externalLimit, info.GetTopk())
|
||||
assert.Equal(t, int64(0), offset)
|
||||
})
|
||||
|
||||
t.Run("parseSearchInfo error", func(t *testing.T) {
|
||||
spNoTopk := []*commonpb.KeyValuePair{{
|
||||
Key: AnnsFieldKey,
|
||||
@ -2060,7 +2078,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, offset, err := parseSearchInfo(test.invalidParams, nil, false)
|
||||
info, offset, err := parseSearchInfo(test.invalidParams, nil, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
assert.Zero(t, offset)
|
||||
@ -2087,7 +2105,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema, false)
|
||||
info, _, err := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
})
|
||||
@ -2106,7 +2124,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema, false)
|
||||
info, _, err := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
})
|
||||
@ -2125,7 +2143,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema, false)
|
||||
info, _, err := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.NotNil(t, info)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk)
|
||||
|
Loading…
Reference in New Issue
Block a user