milvus/internal/proxy/search_util.go
Chun Han 2d29dcd30c
Some checks are pending
Code Checker / Code Checker AMD64 Ubuntu 22.04 (push) Waiting to run
Code Checker / Code Checker Amazonlinux 2023 (push) Waiting to run
Code Checker / Code Checker rockylinux8 (push) Waiting to run
Mac Code Checker / Code Checker MacOS 12 (push) Waiting to run
Build and test / Build and test AMD64 Ubuntu 22.04 (push) Waiting to run
Build and test / UT for Cpp (push) Blocked by required conditions
Build and test / UT for Go (push) Blocked by required conditions
Build and test / Integration Test (push) Blocked by required conditions
Build and test / Upload Code Coverage (push) Blocked by required conditions
Publish Test Images / PyTest (push) Waiting to run
enhance:refine group_strict_size parameter(#37482) (#37483)
related: #37482

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
2024-11-12 09:56:28 +08:00

497 lines
15 KiB
Go

package proxy
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type rankParams struct {
limit int64
offset int64
roundDecimal int64
groupByFieldId int64
groupSize int64
strictGroupSize bool
}
func (r *rankParams) GetLimit() int64 {
if r != nil {
return r.limit
}
return 0
}
func (r *rankParams) GetOffset() int64 {
if r != nil {
return r.offset
}
return 0
}
func (r *rankParams) GetRoundDecimal() int64 {
if r != nil {
return r.roundDecimal
}
return 0
}
func (r *rankParams) GetGroupByFieldId() int64 {
if r != nil {
return r.groupByFieldId
}
return -1
}
func (r *rankParams) GetGroupSize() int64 {
if r != nil {
return r.groupSize
}
return 1
}
func (r *rankParams) GetStrictGroupSize() bool {
if r != nil {
return r.strictGroupSize
}
return false
}
func (r *rankParams) String() string {
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
}
type SearchInfo struct {
planInfo *planpb.QueryInfo
offset int64
parseError error
isIterator bool
}
// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo {
var topK int64
isAdvanced := rankParams != nil
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
if externalLimit <= 0 {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s is required", TopKKey)}
}
topK = externalLimit
} else {
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
if externalLimit <= 0 {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)}
}
topK = externalLimit
} else {
topK = topKInParam
}
}
isIteratorStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
isIterator := (isIteratorStr == "True") || (isIteratorStr == "true")
if err := validateLimit(topK); err != nil {
if isIterator {
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
} else {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)}
}
}
var offset int64
// ignore offset if isAdvanced
if !isAdvanced {
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)}
}
if offset != 0 {
if err := validateLimit(offset); err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)}
}
}
}
}
queryTopK := topK + offset
if err := validateLimit(queryTopK); err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)}
}
// 2. parse metrics type
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
if err != nil {
metricType = ""
}
// 3. parse round decimal
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
}
// 4. parse search param str
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
if err != nil {
searchParamStr = ""
}
// 5. parse group by field and group by size
var groupByFieldId, groupSize int64
var strictGroupSize bool
if isAdvanced {
groupByFieldId, groupSize, strictGroupSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetStrictGroupSize()
} else {
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
if groupByInfo.err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
}
groupByFieldId, groupSize, strictGroupSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetStrictGroupSize()
}
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
if isIterator && groupByFieldId > 0 {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")}
}
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")}
}
return &SearchInfo{
planInfo: &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
StrictGroupSize: strictGroupSize,
},
offset: offset,
isIterator: isIterator,
parseError: nil,
}
}
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
for _, name := range outputFields {
id, ok := schema.MapFieldID(name)
if !ok {
return nil, fmt.Errorf("Field %s not exist", name)
}
outputFieldIDs = append(outputFieldIDs, id)
}
return outputFieldIDs, nil
}
func getNqFromSubSearch(req *milvuspb.SubSearchRequest) (int64, error) {
if req.GetNq() == 0 {
// keep compatible with older client version.
x := &commonpb.PlaceholderGroup{}
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
if err != nil {
return 0, err
}
total := int64(0)
for _, h := range x.GetPlaceholders() {
total += int64(len(h.Values))
}
return total, nil
}
return req.GetNq(), nil
}
func getNq(req *milvuspb.SearchRequest) (int64, error) {
if req.GetNq() == 0 {
// keep compatible with older client version.
x := &commonpb.PlaceholderGroup{}
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
if err != nil {
return 0, err
}
total := int64(0)
for _, h := range x.GetPlaceholders() {
total += int64(len(h.Values))
}
return total, nil
}
return req.GetNq(), nil
}
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
for _, tag := range partitionNames {
if err := validatePartitionTag(tag, false); err != nil {
return nil, err
}
}
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
if err != nil {
return nil, err
}
useRegexp := Params.ProxyCfg.PartitionNameRegexp.GetAsBool()
partitionsSet := typeutil.NewSet[int64]()
for _, partitionName := range partitionNames {
if useRegexp {
// Legacy feature, use partition name as regexp
pattern := fmt.Sprintf("^%s$", partitionName)
re, err := regexp.Compile(pattern)
if err != nil {
return nil, fmt.Errorf("invalid partition: %s", partitionName)
}
var found bool
for name, pID := range partitionsMap {
if re.MatchString(name) {
partitionsSet.Insert(pID)
found = true
}
}
if !found {
return nil, fmt.Errorf("partition name %s not found", partitionName)
}
} else {
partitionID, found := partitionsMap[partitionName]
if !found {
// TODO change after testcase updated: return nil, merr.WrapErrPartitionNotFound(partitionName)
return nil, fmt.Errorf("partition name %s not found", partitionName)
}
if !partitionsSet.Contain(partitionID) {
partitionsSet.Insert(partitionID)
}
}
}
return partitionsSet.Collect(), nil
}
type groupByInfo struct {
groupByFieldId int64
groupSize int64
strictGroupSize bool
err error
}
func (g *groupByInfo) GetGroupByFieldId() int64 {
if g != nil {
return g.groupByFieldId
}
return 0
}
func (g *groupByInfo) GetGroupSize() int64 {
if g != nil {
return g.groupSize
}
return 0
}
func (g *groupByInfo) GetStrictGroupSize() bool {
if g != nil {
return g.strictGroupSize
}
return false
}
func (g *groupByInfo) GetError() error {
if g != nil {
return g.err
}
return nil
}
func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) *groupByInfo {
ret := &groupByInfo{}
// 1. parse group_by_field
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
if err != nil {
groupByFieldName = ""
}
var groupByFieldId int64 = -1
if groupByFieldName != "" {
fields := schema.GetFields()
for _, field := range fields {
if field.Name == groupByFieldName {
if field.GetNullable() {
ret.err = merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
return ret
}
groupByFieldId = field.FieldID
break
}
}
if groupByFieldId == -1 {
ret.err = merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
return ret
}
}
ret.groupByFieldId = groupByFieldId
// 2. parse group size
var groupSize int64
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
if err != nil {
groupSize = 1
} else {
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
if err != nil {
ret.err = merr.WrapErrParameterInvalidMsg(
fmt.Sprintf("failed to parse input group size:%s", groupSizeStr))
return ret
}
if groupSize <= 0 {
ret.err = merr.WrapErrParameterInvalidMsg(
fmt.Sprintf("input group size:%d is negative, failed to do search_groupby", groupSize))
return ret
}
}
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
ret.err = merr.WrapErrParameterInvalidMsg(
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
return ret
}
ret.groupSize = groupSize
// 3. parse group strict size
var strictGroupSize bool
strictGroupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(StrictGroupSize, searchParamsPair)
if err != nil {
strictGroupSize = false
} else {
strictGroupSize, err = strconv.ParseBool(strictGroupSizeStr)
if err != nil {
strictGroupSize = false
}
}
ret.strictGroupSize = strictGroupSize
return ret
}
// parseRankParams get limit and offset from rankParams, both are optional.
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*rankParams, error) {
var (
limit int64
offset int64
roundDecimal int64
err error
)
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
if err != nil {
return nil, errors.New(LimitKey + " not found in rank_params")
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
}
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}
}
// validate max result window.
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
return nil, fmt.Errorf("invalid max query result window, %w", err)
}
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
// parse group_by parameters from main request body for hybrid search
groupByInfo := parseGroupByInfo(rankParamsPair, schema)
if groupByInfo.err != nil {
return nil, groupByInfo.err
}
return &rankParams{
limit: limit,
offset: offset,
roundDecimal: roundDecimal,
groupByFieldId: groupByInfo.GetGroupByFieldId(),
groupSize: groupByInfo.GetGroupSize(),
strictGroupSize: groupByInfo.GetStrictGroupSize(),
}, nil
}
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
ret := &milvuspb.SearchRequest{
Base: req.GetBase(),
DbName: req.GetDbName(),
CollectionName: req.GetCollectionName(),
PartitionNames: req.GetPartitionNames(),
OutputFields: req.GetOutputFields(),
SearchParams: req.GetRankParams(),
TravelTimestamp: req.GetTravelTimestamp(),
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
Nq: 0,
NotReturnAllMeta: req.GetNotReturnAllMeta(),
ConsistencyLevel: req.GetConsistencyLevel(),
UseDefaultConsistency: req.GetUseDefaultConsistency(),
SearchByPrimaryKeys: false,
SubReqs: nil,
}
for _, sub := range req.GetRequests() {
subReq := &milvuspb.SubSearchRequest{
Dsl: sub.GetDsl(),
PlaceholderGroup: sub.GetPlaceholderGroup(),
DslType: sub.GetDslType(),
SearchParams: sub.GetSearchParams(),
Nq: sub.GetNq(),
}
ret.SubReqs = append(ret.SubReqs, subReq)
}
return ret
}