milvus/internal/querynodev2/optimizers/query_hook.go
Gao 55d894bd5e
enhance: support disable search optimization (#32141)
Signed-off-by: chasingegg <chao.gao@zilliz.com>
2024-04-16 10:51:20 +08:00

87 lines
3.0 KiB
Go

package optimizers
import (
"context"
"fmt"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
// QueryHook is the interface for search/query parameter optimizer.
type QueryHook interface {
Run(map[string]any) error
Init(string) error
InitTuningConfig(map[string]string) error
DeleteTuningConfig(string) error
}
func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, queryHook QueryHook, numSegments int) (*querypb.SearchRequest, error) {
// no hook applied or disabled, just return
if queryHook == nil || !paramtable.Get().AutoIndexConfig.Enable.GetAsBool() {
return req, nil
}
log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID()))
serializedPlan := req.GetReq().GetSerializedExprPlan()
// plan not found
if serializedPlan == nil {
log.Warn("serialized plan not found")
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
}
channelNum := req.GetTotalChannelNum()
// not set, change to conservative channel num 1
if channelNum <= 0 {
channelNum = 1
}
plan := planpb.PlanNode{}
err := proto.Unmarshal(serializedPlan, &plan)
if err != nil {
log.Warn("failed to unmarshal plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
}
switch plan.GetNode().(type) {
case *planpb.PlanNode_VectorAnns:
// use shardNum * segments num in shard to estimate total segment number
estSegmentNum := numSegments * int(channelNum)
withFilter := (plan.GetVectorAnns().GetPredicates() != nil)
queryInfo := plan.GetVectorAnns().GetQueryInfo()
params := map[string]any{
common.TopKKey: queryInfo.GetTopk(),
common.SearchParamKey: queryInfo.GetSearchParams(),
common.SegmentNumKey: estSegmentNum,
common.WithFilterKey: withFilter,
common.WithOptimizeKey: paramtable.Get().AutoIndexConfig.EnableOptimize.GetAsBool(),
common.CollectionKey: req.GetReq().GetCollectionID(),
}
err := queryHook.Run(params)
if err != nil {
log.Warn("failed to execute queryHook", zap.Error(err))
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
}
queryInfo.Topk = params[common.TopKKey].(int64)
queryInfo.SearchParams = params[common.SearchParamKey].(string)
serializedExprPlan, err := proto.Marshal(&plan)
if err != nil {
log.Warn("failed to marshal optimized plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
}
req.Req.SerializedExprPlan = serializedExprPlan
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
default:
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
}
return req, nil
}