mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
365e50b63e
no need to check params in empty segment. #30365 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>
307 lines
9.3 KiB
Go
307 lines
9.3 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/golang/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
|
"github.com/milvus-io/milvus/pkg/common"
|
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
|
)
|
|
|
|
type rankParams struct {
|
|
limit int64
|
|
offset int64
|
|
roundDecimal int64
|
|
}
|
|
|
|
// parseSearchInfo returns QueryInfo and offset
|
|
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, ignoreOffset bool) (*planpb.QueryInfo, int64, error) {
|
|
// 1. parse offset and real topk
|
|
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
|
if err != nil {
|
|
return nil, 0, errors.New(TopKKey + " not found in search_params")
|
|
}
|
|
topK, err := strconv.ParseInt(topKStr, 0, 64)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
|
}
|
|
if err := validateTopKLimit(topK); err != nil {
|
|
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
|
|
}
|
|
|
|
var offset int64
|
|
if !ignoreOffset {
|
|
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
|
|
if err == nil {
|
|
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
|
}
|
|
|
|
if offset != 0 {
|
|
if err := validateTopKLimit(offset); err != nil {
|
|
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
queryTopK := topK + offset
|
|
if err := validateTopKLimit(queryTopK); err != nil {
|
|
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
|
|
}
|
|
|
|
// 2. parse metrics type
|
|
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
|
|
if err != nil {
|
|
metricType = ""
|
|
}
|
|
|
|
// 3. parse round decimal
|
|
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
|
|
if err != nil {
|
|
roundDecimalStr = "-1"
|
|
}
|
|
|
|
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
|
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
// 4. parse search param str
|
|
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
|
|
if err != nil {
|
|
searchParamStr = ""
|
|
}
|
|
|
|
// 5. parse group by field
|
|
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
|
if err != nil {
|
|
groupByFieldName = ""
|
|
}
|
|
var groupByFieldId int64 = -1
|
|
if groupByFieldName != "" {
|
|
fields := schema.GetFields()
|
|
for _, field := range fields {
|
|
if field.Name == groupByFieldName {
|
|
groupByFieldId = field.FieldID
|
|
break
|
|
}
|
|
}
|
|
if groupByFieldId == -1 {
|
|
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
|
}
|
|
}
|
|
|
|
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
|
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
|
if isIterator == "True" && groupByFieldId > 0 {
|
|
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
|
"Not allowed to do groupBy when doing iteration")
|
|
}
|
|
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
|
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
|
"Not allowed to do range-search when doing search-group-by")
|
|
}
|
|
|
|
return &planpb.QueryInfo{
|
|
Topk: queryTopK,
|
|
MetricType: metricType,
|
|
SearchParams: searchParamStr,
|
|
RoundDecimal: roundDecimal,
|
|
GroupByFieldId: groupByFieldId,
|
|
}, offset, nil
|
|
}
|
|
|
|
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
|
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
|
|
for _, name := range outputFields {
|
|
id, ok := schema.MapFieldID(name)
|
|
if !ok {
|
|
return nil, fmt.Errorf("Field %s not exist", name)
|
|
}
|
|
outputFieldIDs = append(outputFieldIDs, id)
|
|
}
|
|
return outputFieldIDs, nil
|
|
}
|
|
|
|
func getNqFromSubSearch(req *milvuspb.SubSearchRequest) (int64, error) {
|
|
if req.GetNq() == 0 {
|
|
// keep compatible with older client version.
|
|
x := &commonpb.PlaceholderGroup{}
|
|
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
total := int64(0)
|
|
for _, h := range x.GetPlaceholders() {
|
|
total += int64(len(h.Values))
|
|
}
|
|
return total, nil
|
|
}
|
|
return req.GetNq(), nil
|
|
}
|
|
|
|
func getNq(req *milvuspb.SearchRequest) (int64, error) {
|
|
if req.GetNq() == 0 {
|
|
// keep compatible with older client version.
|
|
x := &commonpb.PlaceholderGroup{}
|
|
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
total := int64(0)
|
|
for _, h := range x.GetPlaceholders() {
|
|
total += int64(len(h.Values))
|
|
}
|
|
return total, nil
|
|
}
|
|
return req.GetNq(), nil
|
|
}
|
|
|
|
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
|
for _, tag := range partitionNames {
|
|
if err := validatePartitionTag(tag, false); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
useRegexp := Params.ProxyCfg.PartitionNameRegexp.GetAsBool()
|
|
|
|
partitionsSet := typeutil.NewSet[int64]()
|
|
for _, partitionName := range partitionNames {
|
|
if useRegexp {
|
|
// Legacy feature, use partition name as regexp
|
|
pattern := fmt.Sprintf("^%s$", partitionName)
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid partition: %s", partitionName)
|
|
}
|
|
var found bool
|
|
for name, pID := range partitionsMap {
|
|
if re.MatchString(name) {
|
|
partitionsSet.Insert(pID)
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
|
}
|
|
} else {
|
|
partitionID, found := partitionsMap[partitionName]
|
|
if !found {
|
|
// TODO change after testcase updated: return nil, merr.WrapErrPartitionNotFound(partitionName)
|
|
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
|
}
|
|
if !partitionsSet.Contain(partitionID) {
|
|
partitionsSet.Insert(partitionID)
|
|
}
|
|
}
|
|
}
|
|
return partitionsSet.Collect(), nil
|
|
}
|
|
|
|
// parseRankParams get limit and offset from rankParams, both are optional.
|
|
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
|
|
var (
|
|
limit int64
|
|
offset int64
|
|
roundDecimal int64
|
|
err error
|
|
)
|
|
|
|
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
|
|
if err != nil {
|
|
return nil, errors.New(LimitKey + " not found in rank_params")
|
|
}
|
|
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
|
|
}
|
|
|
|
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
|
|
if err == nil {
|
|
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
|
}
|
|
}
|
|
|
|
// validate max result window.
|
|
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
|
|
return nil, fmt.Errorf("invalid max query result window, %w", err)
|
|
}
|
|
|
|
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
|
|
if err != nil {
|
|
roundDecimalStr = "-1"
|
|
}
|
|
|
|
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
return &rankParams{
|
|
limit: limit,
|
|
offset: offset,
|
|
roundDecimal: roundDecimal,
|
|
}, nil
|
|
}
|
|
|
|
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
|
ret := &milvuspb.SearchRequest{
|
|
Base: req.GetBase(),
|
|
DbName: req.GetDbName(),
|
|
CollectionName: req.GetCollectionName(),
|
|
PartitionNames: req.GetPartitionNames(),
|
|
OutputFields: req.GetOutputFields(),
|
|
SearchParams: req.GetRankParams(),
|
|
TravelTimestamp: req.GetTravelTimestamp(),
|
|
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
|
Nq: 0,
|
|
NotReturnAllMeta: req.GetNotReturnAllMeta(),
|
|
ConsistencyLevel: req.GetConsistencyLevel(),
|
|
UseDefaultConsistency: req.GetUseDefaultConsistency(),
|
|
SearchByPrimaryKeys: false,
|
|
SubReqs: nil,
|
|
}
|
|
|
|
for _, sub := range req.GetRequests() {
|
|
subReq := &milvuspb.SubSearchRequest{
|
|
Dsl: sub.GetDsl(),
|
|
PlaceholderGroup: sub.GetPlaceholderGroup(),
|
|
DslType: sub.GetDslType(),
|
|
SearchParams: sub.GetSearchParams(),
|
|
Nq: sub.GetNq(),
|
|
}
|
|
ret.SubReqs = append(ret.SubReqs, subReq)
|
|
}
|
|
return ret
|
|
}
|