mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 10:59:32 +08:00
089c805e0a
issue: https://github.com/milvus-io/milvus/issues/25639 https://github.com/milvus-io/milvus/issues/31368 Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
221 lines
5.4 KiB
Go
221 lines
5.4 KiB
Go
package proxy
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus/pkg/log"
|
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
|
)
|
|
|
|
type rankType int
|
|
|
|
const (
|
|
invalidRankType rankType = iota // invalidRankType = 0
|
|
rrfRankType // rrfRankType = 1
|
|
weightedRankType // weightedRankType = 2
|
|
udfExprRankType // udfExprRankType = 3
|
|
)
|
|
|
|
var rankTypeMap = map[string]rankType{
|
|
"invalid": invalidRankType,
|
|
"rrf": rrfRankType,
|
|
"weighted": weightedRankType,
|
|
"expr": udfExprRankType,
|
|
}
|
|
|
|
type reScorer interface {
|
|
name() string
|
|
scorerType() rankType
|
|
reScore(input *milvuspb.SearchResults)
|
|
setMetricType(metricType string)
|
|
getMetricType() string
|
|
}
|
|
|
|
type baseScorer struct {
|
|
scorerName string
|
|
metricType string
|
|
}
|
|
|
|
func (bs *baseScorer) name() string {
|
|
return bs.scorerName
|
|
}
|
|
|
|
func (bs *baseScorer) setMetricType(metricType string) {
|
|
bs.metricType = metricType
|
|
}
|
|
|
|
func (bs *baseScorer) getMetricType() string {
|
|
return bs.metricType
|
|
}
|
|
|
|
type rrfScorer struct {
|
|
baseScorer
|
|
k float32
|
|
}
|
|
|
|
func (rs *rrfScorer) reScore(input *milvuspb.SearchResults) {
|
|
for i := range input.Results.GetScores() {
|
|
input.Results.Scores[i] = 1 / (rs.k + float32(i+1))
|
|
}
|
|
}
|
|
|
|
func (rs *rrfScorer) scorerType() rankType {
|
|
return rrfRankType
|
|
}
|
|
|
|
type weightedScorer struct {
|
|
baseScorer
|
|
weight float32
|
|
}
|
|
|
|
type activateFunc func(float32) float32
|
|
|
|
func (ws *weightedScorer) getActivateFunc() activateFunc {
|
|
mUpper := strings.ToUpper(ws.getMetricType())
|
|
isCosine := mUpper == strings.ToUpper(metric.COSINE)
|
|
isIP := mUpper == strings.ToUpper(metric.IP)
|
|
if isCosine {
|
|
f := func(distance float32) float32 {
|
|
return (1 + distance) * 0.5
|
|
}
|
|
return f
|
|
}
|
|
|
|
if isIP {
|
|
f := func(distance float32) float32 {
|
|
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
|
|
}
|
|
return f
|
|
}
|
|
|
|
f := func(distance float32) float32 {
|
|
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
|
|
}
|
|
return f
|
|
}
|
|
|
|
func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) {
|
|
activateF := ws.getActivateFunc()
|
|
for i, distance := range input.Results.GetScores() {
|
|
input.Results.Scores[i] = ws.weight * activateF(distance)
|
|
}
|
|
}
|
|
|
|
func (ws *weightedScorer) scorerType() rankType {
|
|
return weightedRankType
|
|
}
|
|
|
|
func NewReScorers(reqCnt int, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
|
|
if reqCnt == 0 {
|
|
return []reScorer{}, nil
|
|
}
|
|
|
|
res := make([]reScorer, reqCnt)
|
|
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams)
|
|
if err != nil {
|
|
log.Info("rank strategy not specified, use rrf instead")
|
|
// if not set rank strategy, use rrf rank as default
|
|
for i := 0; i < reqCnt; i++ {
|
|
res[i] = &rrfScorer{
|
|
baseScorer: baseScorer{
|
|
scorerName: "rrf",
|
|
},
|
|
k: float32(defaultRRFParamsValue),
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
if _, ok := rankTypeMap[rankTypeStr]; !ok {
|
|
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
|
|
}
|
|
|
|
paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankParamsKey, rankParams)
|
|
if err != nil {
|
|
return nil, errors.New(RankParamsKey + " not found in rank_params")
|
|
}
|
|
|
|
var params map[string]interface{}
|
|
err = json.Unmarshal([]byte(paramStr), ¶ms)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch rankTypeMap[rankTypeStr] {
|
|
case rrfRankType:
|
|
_, ok := params[RRFParamsKey]
|
|
if !ok {
|
|
return nil, errors.New(RRFParamsKey + " not found in rank_params")
|
|
}
|
|
var k float64
|
|
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
|
|
k = reflect.ValueOf(params[RRFParamsKey]).Float()
|
|
} else {
|
|
return nil, errors.New("The type of rank param k should be float")
|
|
}
|
|
if k <= 0 || k >= maxRRFParamsValue {
|
|
return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue))
|
|
}
|
|
log.Debug("rrf params", zap.Float64("k", k))
|
|
for i := 0; i < reqCnt; i++ {
|
|
res[i] = &rrfScorer{
|
|
baseScorer: baseScorer{
|
|
scorerName: "rrf",
|
|
},
|
|
k: float32(k),
|
|
}
|
|
}
|
|
case weightedRankType:
|
|
if _, ok := params[WeightsParamsKey]; !ok {
|
|
return nil, errors.New(WeightsParamsKey + " not found in rank_params")
|
|
}
|
|
weights := make([]float32, 0)
|
|
switch reflect.TypeOf(params[WeightsParamsKey]).Kind() {
|
|
case reflect.Slice:
|
|
rs := reflect.ValueOf(params[WeightsParamsKey])
|
|
for i := 0; i < rs.Len(); i++ {
|
|
v := rs.Index(i).Elem()
|
|
if v.CanFloat() {
|
|
weight := v.Float()
|
|
if weight < 0 || weight > 1 {
|
|
return nil, errors.New("rank param weight should be in range [0, 1]")
|
|
}
|
|
weights = append(weights, float32(weight))
|
|
} else {
|
|
return nil, errors.New("The type of rank param weight should be float")
|
|
}
|
|
}
|
|
default:
|
|
return nil, errors.New("The weights param should be an array")
|
|
}
|
|
|
|
log.Debug("weights params", zap.Any("weights", weights))
|
|
if reqCnt != len(weights) {
|
|
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(reqCnt), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
|
|
}
|
|
for i := 0; i < reqCnt; i++ {
|
|
res[i] = &weightedScorer{
|
|
baseScorer: baseScorer{
|
|
scorerName: "weighted",
|
|
},
|
|
weight: weights[i],
|
|
}
|
|
}
|
|
default:
|
|
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
|
|
}
|
|
|
|
return res, nil
|
|
}
|