mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 20:09:57 +08:00
55d894bd5e
Signed-off-by: chasingegg <chao.gao@zilliz.com>
87 lines
3.0 KiB
Go
87 lines
3.0 KiB
Go
package optimizers
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/pkg/common"
|
|
"github.com/milvus-io/milvus/pkg/log"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
|
)
|
|
|
|
// QueryHook is the interface for search/query parameter optimizer.
|
|
type QueryHook interface {
|
|
Run(map[string]any) error
|
|
Init(string) error
|
|
InitTuningConfig(map[string]string) error
|
|
DeleteTuningConfig(string) error
|
|
}
|
|
|
|
func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, queryHook QueryHook, numSegments int) (*querypb.SearchRequest, error) {
|
|
// no hook applied or disabled, just return
|
|
if queryHook == nil || !paramtable.Get().AutoIndexConfig.Enable.GetAsBool() {
|
|
return req, nil
|
|
}
|
|
|
|
log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID()))
|
|
|
|
serializedPlan := req.GetReq().GetSerializedExprPlan()
|
|
// plan not found
|
|
if serializedPlan == nil {
|
|
log.Warn("serialized plan not found")
|
|
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
|
|
}
|
|
|
|
channelNum := req.GetTotalChannelNum()
|
|
// not set, change to conservative channel num 1
|
|
if channelNum <= 0 {
|
|
channelNum = 1
|
|
}
|
|
|
|
plan := planpb.PlanNode{}
|
|
err := proto.Unmarshal(serializedPlan, &plan)
|
|
if err != nil {
|
|
log.Warn("failed to unmarshal plan", zap.Error(err))
|
|
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
|
|
}
|
|
|
|
switch plan.GetNode().(type) {
|
|
case *planpb.PlanNode_VectorAnns:
|
|
// use shardNum * segments num in shard to estimate total segment number
|
|
estSegmentNum := numSegments * int(channelNum)
|
|
withFilter := (plan.GetVectorAnns().GetPredicates() != nil)
|
|
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
|
params := map[string]any{
|
|
common.TopKKey: queryInfo.GetTopk(),
|
|
common.SearchParamKey: queryInfo.GetSearchParams(),
|
|
common.SegmentNumKey: estSegmentNum,
|
|
common.WithFilterKey: withFilter,
|
|
common.WithOptimizeKey: paramtable.Get().AutoIndexConfig.EnableOptimize.GetAsBool(),
|
|
common.CollectionKey: req.GetReq().GetCollectionID(),
|
|
}
|
|
err := queryHook.Run(params)
|
|
if err != nil {
|
|
log.Warn("failed to execute queryHook", zap.Error(err))
|
|
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
|
|
}
|
|
queryInfo.Topk = params[common.TopKKey].(int64)
|
|
queryInfo.SearchParams = params[common.SearchParamKey].(string)
|
|
serializedExprPlan, err := proto.Marshal(&plan)
|
|
if err != nil {
|
|
log.Warn("failed to marshal optimized plan", zap.Error(err))
|
|
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
|
|
}
|
|
req.Req.SerializedExprPlan = serializedExprPlan
|
|
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
|
|
default:
|
|
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
|
|
}
|
|
return req, nil
|
|
}
|