2024-01-08 15:34:48 +08:00
|
|
|
package proxy
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/json"
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestRescorer(t *testing.T) {
|
|
|
|
t.Run("default scorer", func(t *testing.T) {
|
|
|
|
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, nil)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, 2, len(rescorers))
|
|
|
|
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
|
|
|
})
|
|
|
|
|
2024-01-17 20:28:58 +08:00
|
|
|
t.Run("rrf without param", func(t *testing.T) {
|
|
|
|
params := make(map[string]float64)
|
|
|
|
b, err := json.Marshal(params)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
rankParams := []*commonpb.KeyValuePair{
|
|
|
|
{Key: RankTypeKey, Value: "rrf"},
|
|
|
|
{Key: RankParamsKey, Value: string(b)},
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
|
|
|
assert.Error(t, err)
|
|
|
|
assert.Contains(t, err.Error(), "k not found in rank_params")
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("rrf param out of range", func(t *testing.T) {
|
|
|
|
params := make(map[string]float64)
|
|
|
|
params[RRFParamsKey] = -1
|
|
|
|
b, err := json.Marshal(params)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
rankParams := []*commonpb.KeyValuePair{
|
|
|
|
{Key: RankTypeKey, Value: "rrf"},
|
|
|
|
{Key: RankParamsKey, Value: string(b)},
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
|
|
|
assert.Error(t, err)
|
|
|
|
|
|
|
|
params[RRFParamsKey] = maxRRFParamsValue + 1
|
|
|
|
b, err = json.Marshal(params)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
rankParams = []*commonpb.KeyValuePair{
|
|
|
|
{Key: RankTypeKey, Value: "rrf"},
|
|
|
|
{Key: RankParamsKey, Value: string(b)},
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
|
|
|
assert.Error(t, err)
|
|
|
|
})
|
|
|
|
|
2024-01-08 15:34:48 +08:00
|
|
|
t.Run("rrf", func(t *testing.T) {
|
|
|
|
params := make(map[string]float64)
|
|
|
|
params[RRFParamsKey] = 61
|
|
|
|
b, err := json.Marshal(params)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
rankParams := []*commonpb.KeyValuePair{
|
|
|
|
{Key: RankTypeKey, Value: "rrf"},
|
|
|
|
{Key: RankParamsKey, Value: string(b)},
|
|
|
|
}
|
|
|
|
|
|
|
|
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, 2, len(rescorers))
|
|
|
|
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
|
|
|
assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k)
|
|
|
|
})
|
|
|
|
|
2024-01-17 20:28:58 +08:00
|
|
|
t.Run("weights without param", func(t *testing.T) {
|
|
|
|
params := make(map[string][]float64)
|
|
|
|
b, err := json.Marshal(params)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
rankParams := []*commonpb.KeyValuePair{
|
|
|
|
{Key: RankTypeKey, Value: "weighted"},
|
|
|
|
{Key: RankParamsKey, Value: string(b)},
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
|
|
|
assert.Error(t, err)
|
|
|
|
assert.Contains(t, err.Error(), "not found in rank_params")
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("weights out of range", func(t *testing.T) {
|
|
|
|
weights := []float64{1.2, 2.3}
|
|
|
|
params := make(map[string][]float64)
|
|
|
|
params[WeightsParamsKey] = weights
|
|
|
|
b, err := json.Marshal(params)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
rankParams := []*commonpb.KeyValuePair{
|
|
|
|
{Key: RankTypeKey, Value: "weighted"},
|
|
|
|
{Key: RankParamsKey, Value: string(b)},
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
|
|
|
assert.Error(t, err)
|
|
|
|
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
|
|
|
|
})
|
|
|
|
|
2024-01-08 15:34:48 +08:00
|
|
|
t.Run("weights", func(t *testing.T) {
|
|
|
|
weights := []float64{0.5, 0.2}
|
|
|
|
params := make(map[string][]float64)
|
|
|
|
params[WeightsParamsKey] = weights
|
|
|
|
b, err := json.Marshal(params)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
rankParams := []*commonpb.KeyValuePair{
|
|
|
|
{Key: RankTypeKey, Value: "weighted"},
|
|
|
|
{Key: RankParamsKey, Value: string(b)},
|
|
|
|
}
|
|
|
|
|
|
|
|
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, 2, len(rescorers))
|
|
|
|
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
|
|
|
|
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
|
|
|
|
})
|
|
|
|
}
|