mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
feat: Support multiple vector search (#29433)
issue #25639 Signed-off-by: xige-16 <xi.ge@zilliz.com> Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
parent
7e6f73a12d
commit
9702cef2b5
@ -281,6 +281,16 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
||||
void
|
||||
check_search(const query::Plan* plan) const override {
|
||||
Assert(plan);
|
||||
auto& metric_str = plan->plan_node_->search_info_.metric_type_;
|
||||
auto searched_field_id = plan->plan_node_->search_info_.field_id_;
|
||||
auto index_meta =
|
||||
index_meta_->GetFieldIndexMeta(FieldId(searched_field_id));
|
||||
if (metric_str.empty()) {
|
||||
metric_str = index_meta.GeMetricType();
|
||||
} else {
|
||||
AssertInfo(metric_str == index_meta.GeMetricType(),
|
||||
"metric type not match");
|
||||
}
|
||||
}
|
||||
|
||||
const ConcurrentVector<Timestamp>&
|
||||
|
@ -917,6 +917,17 @@ SegmentSealedImpl::check_search(const query::Plan* plan) const {
|
||||
AssertInfo(plan->extra_info_opt_.has_value(),
|
||||
"Extra info of search plan doesn't have value");
|
||||
|
||||
auto& metric_str = plan->plan_node_->search_info_.metric_type_;
|
||||
auto searched_field_id = plan->plan_node_->search_info_.field_id_;
|
||||
auto index_meta =
|
||||
col_index_meta_->GetFieldIndexMeta(FieldId(searched_field_id));
|
||||
if (metric_str.empty()) {
|
||||
metric_str = index_meta.GeMetricType();
|
||||
} else {
|
||||
AssertInfo(metric_str == index_meta.GeMetricType(),
|
||||
"metric type not match");
|
||||
}
|
||||
|
||||
if (!is_system_field_ready()) {
|
||||
PanicInfo(
|
||||
FieldNotLoaded,
|
||||
|
@ -219,7 +219,7 @@ message LoadMetaInfo {
|
||||
LoadType load_type = 1;
|
||||
int64 collectionID = 2;
|
||||
repeated int64 partitionIDs = 3;
|
||||
string metric_type = 4;
|
||||
string metric_type = 4 [deprecated=true];
|
||||
}
|
||||
|
||||
message WatchDmChannelsRequest {
|
||||
|
@ -2726,9 +2726,135 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
||||
}
|
||||
|
||||
func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
|
||||
receiveSize := proto.Size(request)
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.HybridSearchLabel,
|
||||
request.GetCollectionName(),
|
||||
).Add(float64(receiveSize))
|
||||
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(merr.WrapErrServiceInternal("unimplemented")),
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
method := "HybridSearch"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch")
|
||||
defer sp.End()
|
||||
|
||||
qt := &hybridSearchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
request: request,
|
||||
tr: timerecord.NewTimeRecorder(method),
|
||||
qc: node.queryCoord,
|
||||
node: node,
|
||||
lb: node.lbPolicy,
|
||||
}
|
||||
|
||||
guaranteeTs := request.GuaranteeTimestamp
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("partitions", request.PartitionNames),
|
||||
zap.Any("OutputFields", request.OutputFields),
|
||||
zap.Uint64("guarantee_timestamp", guaranteeTs),
|
||||
)
|
||||
|
||||
defer func() {
|
||||
span := tr.ElapseSpan()
|
||||
if span >= SlowReadSpan {
|
||||
log.Info(rpcSlow(method), zap.Duration("duration", span))
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug(rpcReceived(method))
|
||||
|
||||
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
|
||||
log.Warn(
|
||||
rpcFailedToEnqueue(method),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.AbandonLabel,
|
||||
).Inc()
|
||||
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
tr.CtxRecord(ctx, "hybrid search request enqueue")
|
||||
|
||||
log.Debug(
|
||||
rpcEnqueued(method),
|
||||
zap.Uint64("timestamp", qt.request.Base.Timestamp),
|
||||
)
|
||||
|
||||
if err := qt.WaitToFinish(); err != nil {
|
||||
log.Warn(
|
||||
rpcFailedToWaitToFinish(method),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.FailLabel,
|
||||
).Inc()
|
||||
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
span := tr.CtxRecord(ctx, "wait hybrid search result")
|
||||
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.HybridSearchLabel,
|
||||
).Observe(float64(span.Milliseconds()))
|
||||
|
||||
tr.CtxRecord(ctx, "wait hybrid search result")
|
||||
log.Debug(rpcDone(method))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
|
||||
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(len(qt.request.GetRequests())))
|
||||
|
||||
searchDur := tr.ElapseSpan().Milliseconds()
|
||||
metrics.ProxySQLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.HybridSearchLabel,
|
||||
).Observe(float64(searchDur))
|
||||
|
||||
metrics.ProxyCollectionSQLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.HybridSearchLabel,
|
||||
request.CollectionName,
|
||||
).Observe(float64(searchDur))
|
||||
|
||||
if qt.result != nil {
|
||||
sentSize := proto.Size(qt.result)
|
||||
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
|
||||
}
|
||||
return qt.result, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) getVectorPlaceholderGroupForSearchByPks(ctx context.Context, request *milvuspb.SearchRequest) ([]byte, error) {
|
||||
|
@ -1477,7 +1477,7 @@ func TestProxy(t *testing.T) {
|
||||
topk := 10
|
||||
roundDecimal := 6
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
constructVectorsPlaceholderGroup := func() *commonpb.PlaceholderGroup {
|
||||
constructVectorsPlaceholderGroup := func(nq int) *commonpb.PlaceholderGroup {
|
||||
values := make([][]byte, 0, nq)
|
||||
for i := 0; i < nq; i++ {
|
||||
bs := make([]byte, 0, dim*4)
|
||||
@ -1502,8 +1502,8 @@ func TestProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
constructSearchRequest := func() *milvuspb.SearchRequest {
|
||||
plg := constructVectorsPlaceholderGroup()
|
||||
constructSearchRequest := func(nq int) *milvuspb.SearchRequest {
|
||||
plg := constructVectorsPlaceholderGroup(nq)
|
||||
plgBs, err := proto.Marshal(plg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -1538,13 +1538,51 @@ func TestProxy(t *testing.T) {
|
||||
wg.Add(1)
|
||||
t.Run("search", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req := constructSearchRequest()
|
||||
req := constructSearchRequest(nq)
|
||||
|
||||
resp, err := proxy.Search(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
constructHybridSearchRequest := func(reqs []*milvuspb.SearchRequest) *milvuspb.HybridSearchRequest {
|
||||
params := make(map[string]float64)
|
||||
params[RRFParamsKey] = 60
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
{Key: LimitKey, Value: strconv.Itoa(topk)},
|
||||
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
|
||||
}
|
||||
|
||||
return &milvuspb.HybridSearchRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Requests: reqs,
|
||||
PartitionNames: nil,
|
||||
OutputFields: nil,
|
||||
RankParams: rankParams,
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
nq = 1
|
||||
t.Run("hybrid search", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req1 := constructSearchRequest(nq)
|
||||
req2 := constructSearchRequest(nq)
|
||||
|
||||
resp, err := proxy.HybridSearch(ctx, constructHybridSearchRequest([]*milvuspb.SearchRequest{req1, req2}))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
nq = 10
|
||||
|
||||
constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup {
|
||||
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0])
|
||||
exprBytes := []byte(expr)
|
||||
|
157
internal/proxy/reScorer.go
Normal file
157
internal/proxy/reScorer.go
Normal file
@ -0,0 +1,157 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type rankType int
|
||||
|
||||
const (
|
||||
invalidRankType rankType = iota // invalidRankType = 0
|
||||
rrfRankType // rrfRankType = 1
|
||||
weightedRankType // weightedRankType = 2
|
||||
udfExprRankType // udfExprRankType = 3
|
||||
)
|
||||
|
||||
var rankTypeMap = map[string]rankType{
|
||||
"invalid": invalidRankType,
|
||||
"rrf": rrfRankType,
|
||||
"weighted": weightedRankType,
|
||||
"expr": udfExprRankType,
|
||||
}
|
||||
|
||||
type reScorer interface {
|
||||
name() string
|
||||
scorerType() rankType
|
||||
reScore(input *milvuspb.SearchResults)
|
||||
}
|
||||
|
||||
type baseScorer struct {
|
||||
scorerName string
|
||||
}
|
||||
|
||||
func (bs *baseScorer) name() string {
|
||||
return bs.scorerName
|
||||
}
|
||||
|
||||
type rrfScorer struct {
|
||||
baseScorer
|
||||
k float32
|
||||
}
|
||||
|
||||
func (rs *rrfScorer) reScore(input *milvuspb.SearchResults) {
|
||||
for i := range input.Results.GetScores() {
|
||||
input.Results.Scores[i] = 1 / (rs.k + float32(i+1))
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *rrfScorer) scorerType() rankType {
|
||||
return rrfRankType
|
||||
}
|
||||
|
||||
type weightedScorer struct {
|
||||
baseScorer
|
||||
weight float32
|
||||
}
|
||||
|
||||
func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) {
|
||||
for i, score := range input.Results.GetScores() {
|
||||
input.Results.Scores[i] = ws.weight * score
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *weightedScorer) scorerType() rankType {
|
||||
return weightedRankType
|
||||
}
|
||||
|
||||
func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
|
||||
res := make([]reScorer, len(reqs))
|
||||
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams)
|
||||
if err != nil {
|
||||
log.Info("rank strategy not specified, use rrf instead")
|
||||
// if not set rank strategy, use rrf rank as default
|
||||
for i := range reqs {
|
||||
res[i] = &rrfScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "rrf",
|
||||
},
|
||||
k: float32(defaultRRFParamsValue),
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
if _, ok := rankTypeMap[rankTypeStr]; !ok {
|
||||
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
|
||||
}
|
||||
|
||||
paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankParamsKey, rankParams)
|
||||
if err != nil {
|
||||
return nil, errors.New(RankParamsKey + " not found in rank_params")
|
||||
}
|
||||
|
||||
var params map[string]interface{}
|
||||
err = json.Unmarshal([]byte(paramStr), ¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch rankTypeMap[rankTypeStr] {
|
||||
case rrfRankType:
|
||||
k, ok := params[RRFParamsKey].(float64)
|
||||
if !ok {
|
||||
return nil, errors.New(RRFParamsKey + " not found in rank_params")
|
||||
}
|
||||
log.Debug("rrf params", zap.Float64("k", k))
|
||||
for i := range reqs {
|
||||
res[i] = &rrfScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "rrf",
|
||||
},
|
||||
k: float32(k),
|
||||
}
|
||||
}
|
||||
case weightedRankType:
|
||||
if _, ok := params[WeightsParamsKey]; !ok {
|
||||
return nil, errors.New(WeightsParamsKey + " not found in rank_params")
|
||||
}
|
||||
weights := make([]float32, 0)
|
||||
switch reflect.TypeOf(params[WeightsParamsKey]).Kind() {
|
||||
case reflect.Slice:
|
||||
rs := reflect.ValueOf(params[WeightsParamsKey])
|
||||
for i := 0; i < rs.Len(); i++ {
|
||||
weights = append(weights, float32(rs.Index(i).Interface().(float64)))
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("The weights param should be an array")
|
||||
}
|
||||
|
||||
log.Debug("weights params", zap.Any("weights", weights))
|
||||
if len(reqs) != len(weights) {
|
||||
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(reqs)), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
|
||||
}
|
||||
for i := range reqs {
|
||||
res[i] = &weightedScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "weighted",
|
||||
},
|
||||
weight: weights[i],
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
55
internal/proxy/reScorer_test.go
Normal file
55
internal/proxy/reScorer_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
)
|
||||
|
||||
func TestRescorer(t *testing.T) {
|
||||
t.Run("default scorer", func(t *testing.T) {
|
||||
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
||||
})
|
||||
|
||||
t.Run("rrf", func(t *testing.T) {
|
||||
params := make(map[string]float64)
|
||||
params[RRFParamsKey] = 61
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
||||
assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k)
|
||||
})
|
||||
|
||||
t.Run("weights", func(t *testing.T) {
|
||||
weights := []float64{0.5, 0.2}
|
||||
params := make(map[string][]float64)
|
||||
params[WeightsParamsKey] = weights
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "weighted"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
|
||||
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
|
||||
})
|
||||
}
|
@ -88,6 +88,11 @@ const (
|
||||
|
||||
// minFloat32 minimum float.
|
||||
minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
RankTypeKey = "strategy"
|
||||
RankParamsKey = "params"
|
||||
RRFParamsKey = "k"
|
||||
WeightsParamsKey = "weights"
|
||||
)
|
||||
|
||||
type task interface {
|
||||
|
461
internal/proxy/task_hybrid_search.go
Normal file
461
internal/proxy/task_hybrid_search.go
Normal file
@ -0,0 +1,461 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
HybridSearchTaskName = "HybridSearchTask"
|
||||
)
|
||||
|
||||
type hybridSearchTask struct {
|
||||
Condition
|
||||
ctx context.Context
|
||||
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.HybridSearchRequest
|
||||
|
||||
tr *timerecord.TimeRecorder
|
||||
schema *schemaInfo
|
||||
requery bool
|
||||
|
||||
userOutputFields []string
|
||||
|
||||
qc types.QueryCoordClient
|
||||
node types.ProxyComponent
|
||||
lb LBPolicy
|
||||
|
||||
collectionID UniqueID
|
||||
|
||||
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
|
||||
reScorers []reScorer
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PreExecute")
|
||||
defer sp.End()
|
||||
|
||||
if len(t.request.Requests) <= 0 {
|
||||
return errors.New("minimum of ann search requests is 1")
|
||||
}
|
||||
|
||||
if len(t.request.Requests) > defaultMaxSearchRequest {
|
||||
return errors.New("maximum of ann search requests is 1024")
|
||||
}
|
||||
for _, req := range t.request.GetRequests() {
|
||||
nq, err := getNq(req)
|
||||
if err != nil {
|
||||
log.Debug("failed to get nq", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if nq != 1 {
|
||||
err = merr.WrapErrParameterInvalid("1", fmt.Sprint(nq), "nq should be equal to 1")
|
||||
log.Debug(err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
collectionName := t.request.CollectionName
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.collectionID = collID
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
||||
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("get collection schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("is partition key mode failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
|
||||
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
||||
}
|
||||
|
||||
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
|
||||
if err != nil {
|
||||
log.Warn("translate output fields failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("translate output fields",
|
||||
zap.Strings("output fields", t.request.GetOutputFields()))
|
||||
|
||||
if len(t.request.OutputFields) > 0 {
|
||||
t.requery = true
|
||||
}
|
||||
|
||||
log.Debug("hybrid search preExecute done.",
|
||||
zap.Uint64("guarantee_ts", t.request.GetGuaranteeTimestamp()),
|
||||
zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()),
|
||||
zap.Any("consistency level", t.request.GetConsistencyLevel()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Execute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID()))
|
||||
defer tr.CtxElapse(ctx, "done")
|
||||
|
||||
futures := make([]*conc.Future[*milvuspb.SearchResults], len(t.request.Requests))
|
||||
for index := range t.request.Requests {
|
||||
searchReq := t.request.Requests[index]
|
||||
future := conc.Go(func() (*milvuspb.SearchResults, error) {
|
||||
searchReq.TravelTimestamp = t.request.GetTravelTimestamp()
|
||||
searchReq.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
|
||||
searchReq.NotReturnAllMeta = t.request.GetNotReturnAllMeta()
|
||||
searchReq.ConsistencyLevel = t.request.GetConsistencyLevel()
|
||||
searchReq.UseDefaultConsistency = t.request.GetUseDefaultConsistency()
|
||||
searchReq.OutputFields = nil
|
||||
|
||||
return t.node.Search(ctx, searchReq)
|
||||
})
|
||||
futures[index] = future
|
||||
}
|
||||
|
||||
err := conc.AwaitAll(futures...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
|
||||
for i, future := range futures {
|
||||
err = future.Err()
|
||||
if err != nil {
|
||||
log.Debug("QueryNode search result error", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
result := futures[i].Value()
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Debug("QueryNode search result error",
|
||||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
return merr.Error(result.GetStatus())
|
||||
}
|
||||
|
||||
t.reScorers[i].reScore(result)
|
||||
t.multipleRecallResults.Insert(result)
|
||||
}
|
||||
|
||||
log.Debug("hybrid search execute done.")
|
||||
return nil
|
||||
}
|
||||
|
||||
type rankParams struct {
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
}
|
||||
|
||||
// parseRankParams get limit and offset from rankParams, both are optional.
|
||||
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
|
||||
var (
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
err error
|
||||
)
|
||||
|
||||
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
|
||||
if err != nil {
|
||||
return nil, errors.New(LimitKey + " not found in search_params")
|
||||
}
|
||||
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
|
||||
}
|
||||
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
}
|
||||
}
|
||||
|
||||
// validate max result window.
|
||||
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
|
||||
return nil, fmt.Errorf("invalid max query result window, %w", err)
|
||||
}
|
||||
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
|
||||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
}
|
||||
|
||||
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
return &rankParams{
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
roundDecimal: roundDecimal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID()))
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
rankParams, err := parseRankParams(t.request.GetRankParams())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.result, err = rankSearchResultData(ctx, 1,
|
||||
rankParams,
|
||||
primaryFieldSchema.GetDataType(),
|
||||
t.multipleRecallResults.Collect())
|
||||
if err != nil {
|
||||
log.Warn("rank search result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.result.CollectionName = t.request.GetCollectionName()
|
||||
t.fillInFieldInfo()
|
||||
|
||||
if t.requery {
|
||||
err := t.Requery()
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
|
||||
log.Debug("hybrid search post execute done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Requery() error {
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
},
|
||||
DbName: t.request.GetDbName(),
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
Expr: "",
|
||||
OutputFields: t.request.GetOutputFields(),
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
|
||||
TravelTimestamp: t.request.GetTravelTimestamp(),
|
||||
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
|
||||
ConsistencyLevel: t.request.GetConsistencyLevel(),
|
||||
UseDefaultConsistency: t.request.GetUseDefaultConsistency(),
|
||||
}
|
||||
|
||||
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result)
|
||||
}
|
||||
|
||||
func rankSearchResultData(ctx context.Context,
|
||||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultData")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
offset := params.offset
|
||||
limit := params.limit
|
||||
topk := limit + offset
|
||||
roundDecimal := params.roundDecimal
|
||||
log.Ctx(ctx).Debug("rankSearchResultData",
|
||||
zap.Int("len(searchResults)", len(searchResults)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: limit,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
}
|
||||
|
||||
// []map[id]score
|
||||
accumulatedScores := make([]map[interface{}]float32, nq)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
accumulatedScores[i] = make(map[interface{}]float32)
|
||||
}
|
||||
|
||||
for _, result := range searchResults {
|
||||
scores := result.GetResults().GetScores()
|
||||
start := int64(0)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
realTopk := result.GetResults().Topks[i]
|
||||
for j := start; j < start+realTopk; j++ {
|
||||
id := typeutil.GetPK(result.GetResults().GetIds(), j)
|
||||
accumulatedScores[i][id] += scores[j]
|
||||
}
|
||||
start += realTopk
|
||||
}
|
||||
}
|
||||
|
||||
for i := int64(0); i < nq; i++ {
|
||||
idSet := accumulatedScores[i]
|
||||
keys := make([]interface{}, 0)
|
||||
for key := range idSet {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
if int64(len(keys)) <= offset {
|
||||
ret.Results.Topks = append(ret.Results.Topks, 0)
|
||||
continue
|
||||
}
|
||||
|
||||
// sort id by score
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return idSet[keys[i]] >= idSet[keys[j]]
|
||||
})
|
||||
|
||||
if int64(len(keys)) > topk {
|
||||
keys = keys[:topk]
|
||||
}
|
||||
|
||||
// set real topk
|
||||
ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset)
|
||||
// append id and score
|
||||
for index := offset; index < int64(len(keys)); index++ {
|
||||
typeutil.AppendPKs(ret.Results.Ids, keys[index])
|
||||
score := idSet[keys[index]]
|
||||
if roundDecimal != -1 {
|
||||
multiplier := math.Pow(10.0, float64(roundDecimal))
|
||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||
}
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) fillInFieldInfo() {
|
||||
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
|
||||
for i, name := range t.request.OutputFields {
|
||||
for _, field := range t.schema.Fields {
|
||||
if t.result.Results.FieldsData[i] != nil && field.Name == name {
|
||||
t.result.Results.FieldsData[i].FieldName = field.Name
|
||||
t.result.Results.FieldsData[i].FieldId = field.FieldID
|
||||
t.result.Results.FieldsData[i].Type = field.DataType
|
||||
t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) TraceCtx() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) ID() UniqueID {
|
||||
return t.request.Base.MsgID
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) SetID(uid UniqueID) {
|
||||
t.request.Base.MsgID = uid
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Name() string {
|
||||
return HybridSearchTaskName
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Type() commonpb.MsgType {
|
||||
return t.request.Base.MsgType
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) BeginTs() Timestamp {
|
||||
return t.request.Base.Timestamp
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) EndTs() Timestamp {
|
||||
return t.request.Base.Timestamp
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) SetTs(ts Timestamp) {
|
||||
t.request.Base.Timestamp = ts
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) OnEnqueue() error {
|
||||
t.request.Base = commonpbutil.NewMsgBase()
|
||||
t.request.Base.MsgType = commonpb.MsgType_Search
|
||||
t.request.Base.SourceID = paramtable.GetNodeID()
|
||||
return nil
|
||||
}
|
330
internal/proxy/task_hybrid_search_test.go
Normal file
330
internal/proxy/task_hybrid_search_test.go
Normal file
@ -0,0 +1,330 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func createCollWithMultiVecField(t *testing.T, name string, rc types.RootCoordClient) {
|
||||
schema := genCollectionSchema(name)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
require.NoError(t, err)
|
||||
ctx := context.TODO()
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(context.TODO()),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: name,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: common.DefaultShardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
}
|
||||
|
||||
require.NoError(t, createColT.OnEnqueue())
|
||||
require.NoError(t, createColT.PreExecute(ctx))
|
||||
require.NoError(t, createColT.Execute(ctx))
|
||||
require.NoError(t, createColT.PostExecute(ctx))
|
||||
}
|
||||
|
||||
func TestHybridSearchTask_PreExecute(t *testing.T) {
|
||||
var err error
|
||||
|
||||
var (
|
||||
rc = NewRootCoordMock()
|
||||
qc = mocks.NewMockQueryCoordClient(t)
|
||||
ctx = context.TODO()
|
||||
)
|
||||
|
||||
defer rc.Close()
|
||||
require.NoError(t, err)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, rc, qc, mgr)
|
||||
require.NoError(t, err)
|
||||
|
||||
genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask {
|
||||
task := &hybridSearchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
CollectionName: collName,
|
||||
Requests: reqs,
|
||||
},
|
||||
qc: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-hybrid-search"),
|
||||
}
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
|
||||
t.Run("bad nq 0", func(t *testing.T) {
|
||||
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
// Nq must be 1.
|
||||
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 0}})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("bad req num 0", func(t *testing.T) {
|
||||
collName := "test_bad_req_num0_error" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
// num of reqs must be [1, 1024].
|
||||
task := genHybridSearchTaskWithNq(t, collName, nil)
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("bad req num 1025", func(t *testing.T) {
|
||||
collName := "test_bad_req_num1025_error" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
// num of reqs must be [1, 1024].
|
||||
reqs := make([]*milvuspb.SearchRequest, 0)
|
||||
for i := 0; i <= defaultMaxSearchRequest; i++ {
|
||||
reqs = append(reqs, &milvuspb.SearchRequest{
|
||||
CollectionName: collName,
|
||||
Nq: 1,
|
||||
})
|
||||
}
|
||||
task := genHybridSearchTaskWithNq(t, collName, reqs)
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("collection not exist", func(t *testing.T) {
|
||||
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
|
||||
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("hybrid search with timeout", func(t *testing.T) {
|
||||
collName := "hybrid_search_with_timeout" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
|
||||
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
task.ctx = ctxTimeout
|
||||
task.request.OutputFields = []string{testFloatVecField}
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestHybridSearchTask_ErrExecute(t *testing.T) {
|
||||
var (
|
||||
err error
|
||||
ctx = context.TODO()
|
||||
|
||||
rc = NewRootCoordMock()
|
||||
qc = getQueryCoordClient()
|
||||
qn = getQueryNodeClient()
|
||||
|
||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||
)
|
||||
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
lb := NewLBPolicyImpl(mgr)
|
||||
|
||||
factory := dependency.NewDefaultFactory(true)
|
||||
node, err := NewProxy(ctx, factory)
|
||||
assert.NoError(t, err)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
|
||||
assert.NoError(t, err)
|
||||
node.sched = scheduler
|
||||
err = node.sched.Start()
|
||||
assert.NoError(t, err)
|
||||
err = node.initRateCollector()
|
||||
assert.NoError(t, err)
|
||||
node.rootCoord = rc
|
||||
node.queryCoord = qc
|
||||
|
||||
defer qc.Close()
|
||||
|
||||
err = InitMetaCache(ctx, rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createCollWithMultiVecField(t, collectionName, rc)
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
|
||||
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
|
||||
Status: successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1},
|
||||
NodeAddrs: []string{"localhost:9000"},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
||||
Status: successStatus,
|
||||
CollectionIDs: []int64{collectionID},
|
||||
InMemoryPercentages: []int64{100},
|
||||
}, nil)
|
||||
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
vectorFields := typeutil.GetVectorFieldSchemas(schema.CollectionSchema)
|
||||
vectorFieldNames := make([]string, len(vectorFields))
|
||||
for i, field := range vectorFields {
|
||||
vectorFieldNames[i] = field.GetName()
|
||||
}
|
||||
|
||||
// test begins
|
||||
task := &hybridSearchTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ctx: ctx,
|
||||
result: &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
},
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
CollectionName: collectionName,
|
||||
Requests: []*milvuspb.SearchRequest{
|
||||
{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
Nq: 1,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: testFloatVecField},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
Nq: 1,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: testBinaryVecField},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
},
|
||||
},
|
||||
},
|
||||
OutputFields: vectorFieldNames,
|
||||
},
|
||||
qc: qc,
|
||||
lb: lb,
|
||||
node: node,
|
||||
}
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
task.ctx = ctx
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
|
||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
},
|
||||
}, nil)
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
}
|
||||
|
||||
func TestHybridSearchTask_PostExecute(t *testing.T) {
|
||||
var (
|
||||
rc = NewRootCoordMock()
|
||||
qc = getQueryCoordClient()
|
||||
qn = getQueryNodeClient()
|
||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||
)
|
||||
|
||||
defer rc.Close()
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
|
||||
t.Run("Test empty result", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err := InitMetaCache(ctx, rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
createCollWithMultiVecField(t, collectionName, rc)
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: LimitKey, Value: strconv.Itoa(3)},
|
||||
{Key: OffsetKey, Value: strconv.Itoa(2)},
|
||||
}
|
||||
qt := &hybridSearchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(context.TODO()),
|
||||
qc: nil,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
schema: schema,
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
RankParams: rankParams,
|
||||
},
|
||||
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
|
||||
}
|
||||
|
||||
err = qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
})
|
||||
}
|
@ -606,13 +606,6 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
|
||||
}
|
||||
|
||||
func (t *searchTask) Requery() error {
|
||||
pkField, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ids := t.result.GetResults().GetIds()
|
||||
plan := planparserv2.CreateRequeryPlan(pkField, ids)
|
||||
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
@ -625,69 +618,8 @@ func (t *searchTask) Requery() error {
|
||||
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
|
||||
QueryParams: t.request.GetSearchParams(),
|
||||
}
|
||||
qt := &queryTask{
|
||||
ctx: t.ctx,
|
||||
Condition: NewTaskCondition(t.ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: queryReq,
|
||||
plan: plan,
|
||||
qc: t.node.(*Proxy).queryCoord,
|
||||
lb: t.node.(*Proxy).lbPolicy,
|
||||
}
|
||||
queryResult, err := t.node.(*Proxy).query(t.ctx, qt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return merr.Error(queryResult.GetStatus())
|
||||
}
|
||||
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
|
||||
// We should reorganize query results to keep the order of original queried ids. For example:
|
||||
// ===========================================
|
||||
// 3 2 5 4 1 (query ids)
|
||||
// ||
|
||||
// || (query)
|
||||
// \/
|
||||
// 4 3 5 1 2 (result ids)
|
||||
// v4 v3 v5 v1 v2 (result vectors)
|
||||
// ||
|
||||
// || (reorganize)
|
||||
// \/
|
||||
// 3 2 5 4 1 (result ids)
|
||||
// v3 v2 v5 v4 v1 (result vectors)
|
||||
// ===========================================
|
||||
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offsets := make(map[any]int)
|
||||
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
|
||||
pk := typeutil.GetData(pkFieldData, i)
|
||||
offsets[pk] = i
|
||||
}
|
||||
|
||||
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
|
||||
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
|
||||
id := typeutil.GetPK(ids, int64(i))
|
||||
if _, ok := offsets[id]; !ok {
|
||||
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
|
||||
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID())
|
||||
}
|
||||
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
|
||||
}
|
||||
|
||||
// filter id field out if it is not specified as output
|
||||
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
return lo.Contains(t.request.GetOutputFields(), fieldData.GetFieldName())
|
||||
})
|
||||
|
||||
return nil
|
||||
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result)
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
||||
@ -734,6 +666,86 @@ func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.Se
|
||||
}
|
||||
}
|
||||
|
||||
func doRequery(ctx context.Context,
|
||||
collectionID int64,
|
||||
node types.ProxyComponent,
|
||||
schema *schemapb.CollectionSchema,
|
||||
request *milvuspb.QueryRequest,
|
||||
result *milvuspb.SearchResults,
|
||||
) error {
|
||||
outputFields := request.GetOutputFields()
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ids := result.GetResults().GetIds()
|
||||
plan := planparserv2.CreateRequeryPlan(pkField, ids)
|
||||
|
||||
qt := &queryTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
plan: plan,
|
||||
qc: node.(*Proxy).queryCoord,
|
||||
lb: node.(*Proxy).lbPolicy,
|
||||
}
|
||||
queryResult, err := node.(*Proxy).query(ctx, qt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return merr.Error(queryResult.GetStatus())
|
||||
}
|
||||
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
|
||||
// We should reorganize query results to keep the order of original queried ids. For example:
|
||||
// ===========================================
|
||||
// 3 2 5 4 1 (query ids)
|
||||
// ||
|
||||
// || (query)
|
||||
// \/
|
||||
// 4 3 5 1 2 (result ids)
|
||||
// v4 v3 v5 v1 v2 (result vectors)
|
||||
// ||
|
||||
// || (reorganize)
|
||||
// \/
|
||||
// 3 2 5 4 1 (result ids)
|
||||
// v3 v2 v5 v4 v1 (result vectors)
|
||||
// ===========================================
|
||||
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offsets := make(map[any]int)
|
||||
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
|
||||
pk := typeutil.GetData(pkFieldData, i)
|
||||
offsets[pk] = i
|
||||
}
|
||||
|
||||
result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
|
||||
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
|
||||
id := typeutil.GetPK(ids, int64(i))
|
||||
if _, ok := offsets[id]; !ok {
|
||||
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
|
||||
id, typeutil.GetSizeOfIDs(ids), len(offsets), collectionID)
|
||||
}
|
||||
typeutil.AppendFieldData(result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
|
||||
}
|
||||
|
||||
// filter id field out if it is not specified as output
|
||||
result.Results.FieldsData = lo.Filter(result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
return lo.Contains(outputFields, fieldData.GetFieldName())
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
||||
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
||||
results := make([]*schemapb.SearchResultData, 0)
|
||||
|
@ -71,6 +71,20 @@ const (
|
||||
testMaxVarCharLength = 100
|
||||
)
|
||||
|
||||
func genCollectionSchema(collectionName string) *schemapb.CollectionSchema {
|
||||
return constructCollectionSchemaWithAllType(
|
||||
testBoolField,
|
||||
testInt32Field,
|
||||
testInt64Field,
|
||||
testFloatField,
|
||||
testDoubleField,
|
||||
testFloatVecField,
|
||||
testBinaryVecField,
|
||||
testFloat16VecField,
|
||||
testVecDim,
|
||||
collectionName)
|
||||
}
|
||||
|
||||
func constructCollectionSchema(
|
||||
int64Field, floatVecField string,
|
||||
dim int,
|
||||
|
@ -62,6 +62,8 @@ const (
|
||||
|
||||
defaultMaxArrayCapacity = 4096
|
||||
|
||||
defaultMaxSearchRequest = 1024
|
||||
|
||||
// DefaultArithmeticIndexType name of default index type for scalar field
|
||||
DefaultArithmeticIndexType = "STL_SORT"
|
||||
|
||||
@ -69,6 +71,8 @@ const (
|
||||
DefaultStringIndexType = "Trie"
|
||||
|
||||
InvertedIndexType = "INVERTED"
|
||||
|
||||
defaultRRFParamsValue = 60
|
||||
)
|
||||
|
||||
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))
|
||||
|
@ -175,7 +175,6 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
|
||||
|
||||
loadMeta := packLoadMeta(
|
||||
ex.meta.GetLoadType(task.CollectionID()),
|
||||
"",
|
||||
task.CollectionID(),
|
||||
partitions...,
|
||||
)
|
||||
@ -370,14 +369,8 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
|
||||
log.Warn("fail to get index meta of collection")
|
||||
return err
|
||||
}
|
||||
metricType, err := getMetricType(indexInfo, collectionInfo.GetSchema())
|
||||
if err != nil {
|
||||
log.Warn("failed to get metric type", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
loadMeta := packLoadMeta(
|
||||
ex.meta.GetLoadType(task.CollectionID()),
|
||||
metricType,
|
||||
task.CollectionID(),
|
||||
partitions...,
|
||||
)
|
||||
|
@ -21,8 +21,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
@ -33,7 +31,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
@ -162,12 +159,11 @@ func packReleaseSegmentRequest(task *SegmentTask, action *SegmentAction) *queryp
|
||||
}
|
||||
}
|
||||
|
||||
func packLoadMeta(loadType querypb.LoadType, metricType string, collectionID int64, partitions ...int64) *querypb.LoadMetaInfo {
|
||||
func packLoadMeta(loadType querypb.LoadType, collectionID int64, partitions ...int64) *querypb.LoadMetaInfo {
|
||||
return &querypb.LoadMetaInfo{
|
||||
LoadType: loadType,
|
||||
CollectionID: collectionID,
|
||||
PartitionIDs: partitions,
|
||||
MetricType: metricType,
|
||||
}
|
||||
}
|
||||
|
||||
@ -241,22 +237,3 @@ func getShardLeader(replicaMgr *meta.ReplicaManager, distMgr *meta.DistributionM
|
||||
}
|
||||
return distMgr.GetShardLeader(replica, channel)
|
||||
}
|
||||
|
||||
func getMetricType(indexInfos []*indexpb.IndexInfo, schema *schemapb.CollectionSchema) (string, error) {
|
||||
vecField, err := typeutil.GetVectorFieldSchema(schema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
indexInfo, ok := lo.Find(indexInfos, func(info *indexpb.IndexInfo) bool {
|
||||
return info.GetFieldID() == vecField.GetFieldID()
|
||||
})
|
||||
if !ok || indexInfo == nil {
|
||||
err = fmt.Errorf("cannot find index info for %s field", vecField.GetName())
|
||||
return "", err
|
||||
}
|
||||
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, indexInfo.GetIndexParams())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return metricType, nil
|
||||
}
|
||||
|
@ -26,7 +26,6 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
@ -35,57 +34,6 @@ type UtilsSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *UtilsSuite) TestGetMetricType() {
|
||||
collection := int64(1)
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "TestGetMetricType",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector},
|
||||
},
|
||||
}
|
||||
indexInfo := &indexpb.IndexInfo{
|
||||
CollectionID: collection,
|
||||
FieldID: 100,
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: "L2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
indexInfo2 := &indexpb.IndexInfo{
|
||||
CollectionID: collection,
|
||||
FieldID: 100,
|
||||
}
|
||||
|
||||
s.Run("test normal", func() {
|
||||
metricType, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, schema)
|
||||
s.NoError(err)
|
||||
s.Equal("L2", metricType)
|
||||
})
|
||||
|
||||
s.Run("test get vec field failed", func() {
|
||||
_, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{
|
||||
Name: "TestGetMetricType",
|
||||
})
|
||||
s.Error(err)
|
||||
})
|
||||
s.Run("test field id mismatch", func() {
|
||||
_, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{
|
||||
Name: "TestGetMetricType",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: -1, Name: "vec", DataType: schemapb.DataType_FloatVector},
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
})
|
||||
s.Run("test no metric type", func() {
|
||||
_, err := getMetricType([]*indexpb.IndexInfo{indexInfo2}, schema)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UtilsSuite) TestPackLoadSegmentRequest() {
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -17,12 +17,10 @@
|
||||
package querynodev2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
@ -60,45 +58,32 @@ const (
|
||||
|
||||
// ---------- unittest util functions ----------
|
||||
// functions of messages and requests
|
||||
func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecimal int64) (string, error) {
|
||||
var vecFieldName string
|
||||
var metricType string
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
nProbStr := strconv.Itoa(defaultNProb)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
var fieldID int64
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
fieldID = f.FieldID
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
func genSearchPlan(dataType schemapb.DataType, fieldID int64, metricType string) *planpb.PlanNode {
|
||||
var vectorType planpb.VectorType
|
||||
switch dataType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
vectorType = planpb.VectorType_FloatVector
|
||||
case schemapb.DataType_Float16Vector:
|
||||
vectorType = planpb.VectorType_Float16Vector
|
||||
case schemapb.DataType_BinaryVector:
|
||||
vectorType = planpb.VectorType_BinaryVector
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if vecFieldName == "" || metricType == "" {
|
||||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
return `vector_anns: <
|
||||
field_id: ` + fmt.Sprintf("%d", fieldID) + `
|
||||
query_info: <
|
||||
topk: ` + topKStr + `
|
||||
round_decimal: ` + roundDecimalStr + `
|
||||
metric_type: "` + metricType + `"
|
||||
search_params: "{\"nprobe\": ` + nProbStr + `}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>`, nil
|
||||
}
|
||||
|
||||
func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) {
|
||||
if indexType == IndexFaissIDMap { // float vector
|
||||
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
|
||||
return &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
VectorType: vectorType,
|
||||
FieldId: fieldID,
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: defaultTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: "{\"nprobe\":" + strconv.Itoa(defaultNProb) + "}",
|
||||
RoundDecimal: defaultRoundDecimal,
|
||||
},
|
||||
PlaceholderTag: "$0",
|
||||
},
|
||||
},
|
||||
}
|
||||
return "", fmt.Errorf("Invalid indexType")
|
||||
}
|
||||
|
||||
func genPlaceHolderGroup(nq int64) ([]byte, error) {
|
||||
|
@ -83,7 +83,6 @@ func (m *collectionManager) PutOrRef(collectionID int64, schema *schemapb.Collec
|
||||
}
|
||||
|
||||
collection := NewCollection(collectionID, schema, meta, loadMeta.GetLoadType())
|
||||
collection.metricType.Store(loadMeta.GetMetricType())
|
||||
collection.AddPartition(loadMeta.GetPartitionIDs()...)
|
||||
collection.Ref(1)
|
||||
m.collections[collectionID] = collection
|
||||
@ -125,7 +124,7 @@ type Collection struct {
|
||||
id int64
|
||||
partitions *typeutil.ConcurrentSet[int64]
|
||||
loadType querypb.LoadType
|
||||
metricType atomic.String
|
||||
metricType atomic.String // deprecated
|
||||
schema atomic.Pointer[schemapb.CollectionSchema]
|
||||
isGpuIndex bool
|
||||
|
||||
@ -175,14 +174,6 @@ func (c *Collection) GetLoadType() querypb.LoadType {
|
||||
return c.loadType
|
||||
}
|
||||
|
||||
func (c *Collection) SetMetricType(metricType string) {
|
||||
c.metricType.Store(metricType)
|
||||
}
|
||||
|
||||
func (c *Collection) GetMetricType() string {
|
||||
return c.metricType.Load()
|
||||
}
|
||||
|
||||
func (c *Collection) Ref(count uint32) uint32 {
|
||||
refCount := c.refCount.Add(count)
|
||||
log.Debug("collection ref increment",
|
||||
|
@ -291,7 +291,56 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType) *s
|
||||
return &schema
|
||||
}
|
||||
|
||||
func GenTestIndexInfoList(collectionID int64, schema *schemapb.CollectionSchema) []*indexpb.IndexInfo {
|
||||
res := make([]*indexpb.IndexInfo, 0)
|
||||
vectorFieldSchemas := typeutil.GetVectorFieldSchemas(schema)
|
||||
for _, field := range vectorFieldSchemas {
|
||||
index := &indexpb.IndexInfo{
|
||||
CollectionID: collectionID,
|
||||
FieldID: field.GetFieldID(),
|
||||
// For now, a field can only have one index
|
||||
// using fieldID and fieldName as indexID and indexName, just make sure not repeated.
|
||||
IndexID: field.GetFieldID(),
|
||||
IndexName: field.GetName(),
|
||||
TypeParams: field.GetTypeParams(),
|
||||
}
|
||||
switch field.GetDataType() {
|
||||
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector:
|
||||
{
|
||||
index.IndexParams = []*commonpb.KeyValuePair{
|
||||
{Key: common.MetricTypeKey, Value: metric.L2},
|
||||
{Key: common.IndexTypeKey, Value: IndexFaissIVFFlat},
|
||||
{Key: "nlist", Value: "128"},
|
||||
}
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
{
|
||||
index.IndexParams = []*commonpb.KeyValuePair{
|
||||
{Key: common.MetricTypeKey, Value: metric.JACCARD},
|
||||
{Key: common.IndexTypeKey, Value: IndexFaissBinIVFFlat},
|
||||
{Key: "nlist", Value: "128"},
|
||||
}
|
||||
}
|
||||
}
|
||||
res = append(res, index)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func GenTestIndexMeta(collectionID int64, schema *schemapb.CollectionSchema) *segcorepb.CollectionIndexMeta {
|
||||
indexInfos := GenTestIndexInfoList(collectionID, schema)
|
||||
fieldIndexMetas := make([]*segcorepb.FieldIndexMeta, 0)
|
||||
for _, info := range indexInfos {
|
||||
fieldIndexMetas = append(fieldIndexMetas, &segcorepb.FieldIndexMeta{
|
||||
CollectionID: info.GetCollectionID(),
|
||||
FieldID: info.GetFieldID(),
|
||||
IndexName: info.GetIndexName(),
|
||||
TypeParams: info.GetTypeParams(),
|
||||
IndexParams: info.GetIndexParams(),
|
||||
IsAutoIndex: info.GetIsAutoIndex(),
|
||||
UserIndexParams: info.GetUserIndexParams(),
|
||||
})
|
||||
}
|
||||
sizePerRecord, err := typeutil.EstimateSizePerRecord(schema)
|
||||
maxIndexRecordPerSegment := int64(0)
|
||||
if err != nil || sizePerRecord == 0 {
|
||||
@ -302,37 +351,6 @@ func GenTestIndexMeta(collectionID int64, schema *schemapb.CollectionSchema) *se
|
||||
maxIndexRecordPerSegment = int64(threshold * proportion / float64(sizePerRecord))
|
||||
}
|
||||
|
||||
fieldIndexMetas := make([]*segcorepb.FieldIndexMeta, 0)
|
||||
fieldIndexMetas = append(fieldIndexMetas, &segcorepb.FieldIndexMeta{
|
||||
CollectionID: collectionID,
|
||||
FieldID: simpleFloatVecField.id,
|
||||
IndexName: "querynode-test",
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: dimKey,
|
||||
Value: strconv.Itoa(simpleFloatVecField.dim),
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: metricTypeKey,
|
||||
Value: simpleFloatVecField.metricType,
|
||||
},
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: IndexFaissIVFFlat,
|
||||
},
|
||||
{
|
||||
Key: "nlist",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
IsAutoIndex: false,
|
||||
UserIndexParams: []*commonpb.KeyValuePair{
|
||||
{},
|
||||
},
|
||||
})
|
||||
|
||||
indexMeta := segcorepb.CollectionIndexMeta{
|
||||
MaxIndexRowCount: maxIndexRecordPerSegment,
|
||||
IndexMetas: fieldIndexMetas,
|
||||
@ -889,6 +907,80 @@ func SaveDeltaLog(collectionID int64,
|
||||
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
|
||||
}
|
||||
|
||||
func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
|
||||
fieldSchema *schemapb.FieldSchema,
|
||||
indexInfo *indexpb.IndexInfo,
|
||||
cm storage.ChunkManager,
|
||||
msgLength int,
|
||||
) (*querypb.FieldIndexInfo, error) {
|
||||
typeParams := funcutil.KeyValuePair2Map(indexInfo.GetTypeParams())
|
||||
indexParams := funcutil.KeyValuePair2Map(indexInfo.GetIndexParams())
|
||||
|
||||
index, err := indexcgowrapper.NewCgoIndex(fieldSchema.GetDataType(), typeParams, indexParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer index.Delete()
|
||||
|
||||
var dataset *indexcgowrapper.Dataset
|
||||
switch fieldSchema.DataType {
|
||||
case schemapb.DataType_BinaryVector:
|
||||
dataset = indexcgowrapper.GenBinaryVecDataset(generateBinaryVectors(msgLength, defaultDim))
|
||||
case schemapb.DataType_FloatVector:
|
||||
dataset = indexcgowrapper.GenFloatVecDataset(generateFloatVectors(msgLength, defaultDim))
|
||||
}
|
||||
|
||||
err = index.Build(dataset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// save index to minio
|
||||
binarySet, err := index.Serialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// serialize index params
|
||||
indexCodec := storage.NewIndexFileBinlogCodec()
|
||||
serializedIndexBlobs, err := indexCodec.Serialize(
|
||||
buildID,
|
||||
0,
|
||||
collectionID,
|
||||
partitionID,
|
||||
segmentID,
|
||||
fieldSchema.GetFieldID(),
|
||||
indexParams,
|
||||
indexInfo.GetIndexName(),
|
||||
indexInfo.GetIndexID(),
|
||||
binarySet,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexPaths := make([]string, 0)
|
||||
for _, index := range serializedIndexBlobs {
|
||||
indexPath := filepath.Join(cm.RootPath(), "index_files",
|
||||
strconv.Itoa(int(segmentID)), index.Key)
|
||||
indexPaths = append(indexPaths, indexPath)
|
||||
err := cm.Write(context.Background(), indexPath, index.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
_, cCurrentIndexVersion := getIndexEngineVersion()
|
||||
|
||||
return &querypb.FieldIndexInfo{
|
||||
FieldID: fieldSchema.GetFieldID(),
|
||||
EnableIndex: true,
|
||||
IndexName: indexInfo.GetIndexName(),
|
||||
IndexParams: indexInfo.GetIndexParams(),
|
||||
IndexFilePaths: indexPaths,
|
||||
CurrentIndexVersion: cCurrentIndexVersion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLength int, indexType, metricType string, cm storage.ChunkManager) (*querypb.FieldIndexInfo, error) {
|
||||
typeParams, indexParams := genIndexParams(indexType, metricType)
|
||||
|
||||
|
@ -54,13 +54,7 @@ func createSearchPlanByExpr(ctx context.Context, col *Collection, expr []byte, m
|
||||
return nil, err1
|
||||
}
|
||||
|
||||
newPlan := &SearchPlan{cSearchPlan: cPlan}
|
||||
if len(metricType) != 0 {
|
||||
newPlan.setMetricType(metricType)
|
||||
} else {
|
||||
newPlan.setMetricType(col.GetMetricType())
|
||||
}
|
||||
return newPlan, nil
|
||||
return &SearchPlan{cSearchPlan: cPlan}, nil
|
||||
}
|
||||
|
||||
func (plan *SearchPlan) getTopK() int64 {
|
||||
|
@ -205,7 +205,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
||||
|
||||
log.Info("received watch channel request",
|
||||
zap.Int64("version", req.GetVersion()),
|
||||
zap.String("metricType", req.GetLoadMeta().GetMetricType()),
|
||||
)
|
||||
|
||||
// check node healthy
|
||||
@ -219,12 +218,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
// check metric type
|
||||
if req.GetLoadMeta().GetMetricType() == "" {
|
||||
err := fmt.Errorf("empty metric type, collection = %d", req.GetCollectionID())
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
// check index
|
||||
if len(req.GetIndexInfoList()) == 0 {
|
||||
err := merr.WrapErrIndexNotFoundForCollection(req.GetSchema().GetName())
|
||||
@ -253,8 +246,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
||||
|
||||
node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(),
|
||||
node.composeIndexMeta(req.GetIndexInfoList(), req.Schema), req.GetLoadMeta())
|
||||
collection := node.manager.Collection.Get(req.GetCollectionID())
|
||||
collection.SetMetricType(req.GetLoadMeta().GetMetricType())
|
||||
|
||||
delegator, err := delegator.NewShardDelegator(
|
||||
ctx,
|
||||
req.GetCollectionID(),
|
||||
@ -769,20 +761,6 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Check if the metric type specified in search params matches the metric type in the index info.
|
||||
if !req.GetFromShardLeader() && req.GetReq().GetMetricType() != "" {
|
||||
if req.GetReq().GetMetricType() != collection.GetMetricType() {
|
||||
resp.Status = merr.Status(merr.WrapErrParameterInvalid(collection.GetMetricType(), req.GetReq().GetMetricType(),
|
||||
fmt.Sprintf("collection:%d, metric type not match", collection.ID())))
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Define the metric type when it has not been explicitly assigned by the user.
|
||||
if !req.GetFromShardLeader() && req.GetReq().GetMetricType() == "" {
|
||||
req.Req.MetricType = collection.GetMetricType()
|
||||
}
|
||||
|
||||
toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels()))
|
||||
runningGp, runningCtx := errgroup.WithContext(ctx)
|
||||
for i, ch := range req.GetDmlChannels() {
|
||||
|
@ -39,7 +39,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
@ -52,7 +51,6 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
@ -257,6 +255,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
|
||||
ctx := context.Background()
|
||||
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
deltaLogs, err := segments.SaveDeltaLog(suite.collectionID,
|
||||
suite.partitionIDs[0],
|
||||
suite.flushedSegmentIDs[0],
|
||||
@ -292,16 +291,14 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
|
||||
Level: datapb.SegmentLevel_L0,
|
||||
},
|
||||
},
|
||||
Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64),
|
||||
Schema: schema,
|
||||
LoadMeta: &querypb.LoadMetaInfo{
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionIDs: suite.partitionIDs,
|
||||
MetricType: defaultMetricType,
|
||||
},
|
||||
IndexInfoList: []*indexpb.IndexInfo{
|
||||
{},
|
||||
},
|
||||
IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema),
|
||||
}
|
||||
|
||||
// mocks
|
||||
@ -326,6 +323,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
|
||||
ctx := context.Background()
|
||||
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar)
|
||||
req := &querypb.WatchDmChannelsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_WatchDmChannels,
|
||||
@ -344,16 +342,14 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
|
||||
DroppedSegmentIds: suite.droppedSegmentIDs,
|
||||
},
|
||||
},
|
||||
Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar),
|
||||
Schema: schema,
|
||||
LoadMeta: &querypb.LoadMetaInfo{
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionIDs: suite.partitionIDs,
|
||||
MetricType: defaultMetricType,
|
||||
},
|
||||
IndexInfoList: []*indexpb.IndexInfo{
|
||||
{},
|
||||
},
|
||||
IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema),
|
||||
}
|
||||
|
||||
// mocks
|
||||
@ -378,6 +374,7 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
|
||||
ctx := context.Background()
|
||||
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
req := &querypb.WatchDmChannelsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_WatchDmChannels,
|
||||
@ -396,13 +393,11 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
|
||||
DroppedSegmentIds: suite.droppedSegmentIDs,
|
||||
},
|
||||
},
|
||||
Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64),
|
||||
Schema: schema,
|
||||
LoadMeta: &querypb.LoadMetaInfo{
|
||||
MetricType: defaultMetricType,
|
||||
},
|
||||
IndexInfoList: []*indexpb.IndexInfo{
|
||||
{},
|
||||
},
|
||||
IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema),
|
||||
}
|
||||
|
||||
// test channel is unsubscribing
|
||||
@ -439,14 +434,6 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
|
||||
status, err = suite.node.WatchDmChannels(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode())
|
||||
|
||||
// empty metric type
|
||||
req.LoadMeta.MetricType = ""
|
||||
req.Base.TargetID = paramtable.GetNodeID()
|
||||
suite.node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
status, err = suite.node.WatchDmChannels(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestUnsubDmChannels_Normal() {
|
||||
@ -502,22 +489,9 @@ func (suite *ServiceSuite) TestUnsubDmChannels_Failed() {
|
||||
suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode())
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) genSegmentIndexInfos(loadInfo []*querypb.SegmentLoadInfo) []*indexpb.IndexInfo {
|
||||
indexInfoList := make([]*indexpb.IndexInfo, 0)
|
||||
seg0LoadInfo := loadInfo[0]
|
||||
fieldIndexInfos := seg0LoadInfo.IndexInfos
|
||||
for _, info := range fieldIndexInfos {
|
||||
indexInfoList = append(indexInfoList, &indexpb.IndexInfo{
|
||||
CollectionID: suite.collectionID,
|
||||
FieldID: info.GetFieldID(),
|
||||
IndexName: info.GetIndexName(),
|
||||
IndexParams: info.GetIndexParams(),
|
||||
})
|
||||
}
|
||||
return indexInfoList
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema) []*querypb.SegmentLoadInfo {
|
||||
func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema,
|
||||
indexInfos []*indexpb.IndexInfo,
|
||||
) []*querypb.SegmentLoadInfo {
|
||||
ctx := context.Background()
|
||||
|
||||
segNum := len(suite.validSegmentIDs)
|
||||
@ -534,18 +508,25 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema
|
||||
)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
vecFieldIDs := funcutil.GetVecFieldIDs(schema)
|
||||
indexes, err := segments.GenAndSaveIndex(
|
||||
vectorFieldSchemas := typeutil.GetVectorFieldSchemas(schema)
|
||||
indexes := make([]*querypb.FieldIndexInfo, 0)
|
||||
for offset, field := range vectorFieldSchemas {
|
||||
indexInfo := lo.FindOrElse(indexInfos, nil, func(info *indexpb.IndexInfo) bool { return info.FieldID == field.GetFieldID() })
|
||||
if indexInfo != nil {
|
||||
index, err := segments.GenAndSaveIndexV2(
|
||||
suite.collectionID,
|
||||
suite.partitionIDs[i%partNum],
|
||||
suite.validSegmentIDs[i],
|
||||
vecFieldIDs[0],
|
||||
1000,
|
||||
segments.IndexFaissIVFFlat,
|
||||
metric.L2,
|
||||
int64(offset),
|
||||
field,
|
||||
indexInfo,
|
||||
suite.node.chunkManager,
|
||||
1000,
|
||||
)
|
||||
suite.Require().NoError(err)
|
||||
indexes = append(indexes, index)
|
||||
}
|
||||
}
|
||||
|
||||
info := &querypb.SegmentLoadInfo{
|
||||
SegmentID: suite.validSegmentIDs[i],
|
||||
@ -555,7 +536,7 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema
|
||||
NumOfRows: 1000,
|
||||
BinlogPaths: binlogs,
|
||||
Statslogs: statslogs,
|
||||
IndexInfos: []*querypb.FieldIndexInfo{indexes},
|
||||
IndexInfos: indexes,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
}
|
||||
@ -569,7 +550,8 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() {
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
infos := suite.genSegmentLoadInfos(schema)
|
||||
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema)
|
||||
infos := suite.genSegmentLoadInfos(schema, indexInfos)
|
||||
for _, info := range infos {
|
||||
req := &querypb.LoadSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -582,9 +564,7 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() {
|
||||
Schema: schema,
|
||||
DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}},
|
||||
NeedTransfer: true,
|
||||
IndexInfoList: []*indexpb.IndexInfo{
|
||||
{},
|
||||
},
|
||||
IndexInfoList: indexInfos,
|
||||
}
|
||||
|
||||
// LoadSegment
|
||||
@ -607,7 +587,7 @@ func (suite *ServiceSuite) TestLoadSegments_VarChar() {
|
||||
suite.node.manager.Collection = segments.NewCollectionManager()
|
||||
suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, loadMeta)
|
||||
|
||||
infos := suite.genSegmentLoadInfos(schema)
|
||||
infos := suite.genSegmentLoadInfos(schema, nil)
|
||||
for _, info := range infos {
|
||||
req := &querypb.LoadSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -643,7 +623,7 @@ func (suite *ServiceSuite) TestLoadDeltaInt64() {
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
DstNodeID: suite.node.session.ServerID,
|
||||
Infos: suite.genSegmentLoadInfos(schema),
|
||||
Infos: suite.genSegmentLoadInfos(schema, nil),
|
||||
Schema: schema,
|
||||
NeedTransfer: true,
|
||||
LoadScope: querypb.LoadScope_Delta,
|
||||
@ -668,7 +648,7 @@ func (suite *ServiceSuite) TestLoadDeltaVarchar() {
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
DstNodeID: suite.node.session.ServerID,
|
||||
Infos: suite.genSegmentLoadInfos(schema),
|
||||
Infos: suite.genSegmentLoadInfos(schema, nil),
|
||||
Schema: schema,
|
||||
NeedTransfer: true,
|
||||
LoadScope: querypb.LoadScope_Delta,
|
||||
@ -687,7 +667,8 @@ func (suite *ServiceSuite) TestLoadIndex_Success() {
|
||||
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
|
||||
infos := suite.genSegmentLoadInfos(schema)
|
||||
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema)
|
||||
infos := suite.genSegmentLoadInfos(schema, indexInfos)
|
||||
infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo {
|
||||
info.SegmentID = info.SegmentID + 1000
|
||||
return info
|
||||
@ -697,8 +678,7 @@ func (suite *ServiceSuite) TestLoadIndex_Success() {
|
||||
info.IndexInfos = nil
|
||||
return info
|
||||
})
|
||||
// generate indexinfos for setting index meta.
|
||||
indexInfoList := suite.genSegmentIndexInfos(infos)
|
||||
|
||||
req := &querypb.LoadSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63(),
|
||||
@ -710,7 +690,7 @@ func (suite *ServiceSuite) TestLoadIndex_Success() {
|
||||
Schema: schema,
|
||||
NeedTransfer: false,
|
||||
LoadScope: querypb.LoadScope_Full,
|
||||
IndexInfoList: indexInfoList,
|
||||
IndexInfoList: indexInfos,
|
||||
}
|
||||
|
||||
// Load segment
|
||||
@ -759,7 +739,8 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
|
||||
suite.Run("load_non_exist_segment", func() {
|
||||
infos := suite.genSegmentLoadInfos(schema)
|
||||
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema)
|
||||
infos := suite.genSegmentLoadInfos(schema, indexInfos)
|
||||
infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo {
|
||||
info.SegmentID = info.SegmentID + 1000
|
||||
return info
|
||||
@ -780,7 +761,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
|
||||
Schema: schema,
|
||||
NeedTransfer: false,
|
||||
LoadScope: querypb.LoadScope_Index,
|
||||
IndexInfoList: []*indexpb.IndexInfo{{}},
|
||||
IndexInfoList: indexInfos,
|
||||
}
|
||||
|
||||
// Load segment
|
||||
@ -801,7 +782,8 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
|
||||
|
||||
mockLoader.EXPECT().LoadIndex(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mocked error"))
|
||||
|
||||
infos := suite.genSegmentLoadInfos(schema)
|
||||
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema)
|
||||
infos := suite.genSegmentLoadInfos(schema, indexInfos)
|
||||
req := &querypb.LoadSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63(),
|
||||
@ -813,7 +795,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
|
||||
Schema: schema,
|
||||
NeedTransfer: false,
|
||||
LoadScope: querypb.LoadScope_Index,
|
||||
IndexInfoList: []*indexpb.IndexInfo{{}},
|
||||
IndexInfoList: indexInfos,
|
||||
}
|
||||
|
||||
// Load segment
|
||||
@ -834,7 +816,7 @@ func (suite *ServiceSuite) TestLoadSegments_Failed() {
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
DstNodeID: suite.node.session.ServerID,
|
||||
Infos: suite.genSegmentLoadInfos(schema),
|
||||
Infos: suite.genSegmentLoadInfos(schema, nil),
|
||||
Schema: schema,
|
||||
NeedTransfer: true,
|
||||
IndexInfoList: []*indexpb.IndexInfo{
|
||||
@ -886,7 +868,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
DstNodeID: suite.node.session.ServerID,
|
||||
Infos: suite.genSegmentLoadInfos(schema),
|
||||
Infos: suite.genSegmentLoadInfos(schema, nil),
|
||||
Schema: schema,
|
||||
NeedTransfer: true,
|
||||
IndexInfoList: []*indexpb.IndexInfo{{}},
|
||||
@ -908,7 +890,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
DstNodeID: suite.node.session.ServerID,
|
||||
Infos: suite.genSegmentLoadInfos(schema),
|
||||
Infos: suite.genSegmentLoadInfos(schema, nil),
|
||||
Schema: schema,
|
||||
NeedTransfer: true,
|
||||
IndexInfoList: []*indexpb.IndexInfo{{}},
|
||||
@ -935,7 +917,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
DstNodeID: suite.node.session.ServerID,
|
||||
Infos: suite.genSegmentLoadInfos(schema),
|
||||
Infos: suite.genSegmentLoadInfos(schema, nil),
|
||||
Schema: schema,
|
||||
NeedTransfer: true,
|
||||
IndexInfoList: []*indexpb.IndexInfo{{}},
|
||||
@ -1139,18 +1121,14 @@ func (suite *ServiceSuite) TestGetSegmentInfo_Failed() {
|
||||
}
|
||||
|
||||
// Test Search
|
||||
func (suite *ServiceSuite) genCSearchRequest(nq int64, indexType string, schema *schemapb.CollectionSchema) (*internalpb.SearchRequest, error) {
|
||||
func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataType, fieldID int64, metricType string) (*internalpb.SearchRequest, error) {
|
||||
placeHolder, err := genPlaceHolderGroup(nq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
planStr, err := genDSLByIndexType(schema, indexType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var planpb planpb.PlanNode
|
||||
proto.UnmarshalText(planStr, &planpb)
|
||||
serializedPlan, err2 := proto.Marshal(&planpb)
|
||||
|
||||
plan := genSearchPlan(dataType, fieldID, metricType)
|
||||
serializedPlan, err2 := proto.Marshal(plan)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
@ -1175,9 +1153,7 @@ func (suite *ServiceSuite) TestSearch_Normal() {
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: false,
|
||||
@ -1197,14 +1173,11 @@ func (suite *ServiceSuite) TestSearch_Concurrent() {
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
|
||||
concurrency := 16
|
||||
futures := make([]*conc.Future[*internalpb.SearchResults], 0, concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||
creq, err := suite.genCSearchRequest(30, IndexFaissIDMap, schema)
|
||||
creq, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: false,
|
||||
@ -1230,7 +1203,7 @@ func (suite *ServiceSuite) TestSearch_Failed() {
|
||||
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType")
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: false,
|
||||
@ -1250,15 +1223,9 @@ func (suite *ServiceSuite) TestSearch_Failed() {
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionIDs: suite.partitionIDs,
|
||||
MetricType: "L2",
|
||||
}
|
||||
suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, LoadMeta)
|
||||
req.GetReq().MetricType = "IP"
|
||||
resp, err = suite.node.Search(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.ErrorIs(merr.Error(resp.GetStatus()), merr.ErrParameterInvalid)
|
||||
suite.Contains(resp.GetStatus().GetReason(), merr.ErrParameterInvalid.Error())
|
||||
req.GetReq().MetricType = "L2"
|
||||
indexMeta := suite.node.composeIndexMeta(segments.GenTestIndexInfoList(suite.collectionID, schema), schema)
|
||||
suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, LoadMeta)
|
||||
|
||||
// Delegator not found
|
||||
resp, err = suite.node.Search(ctx, req)
|
||||
@ -1268,6 +1235,34 @@ func (suite *ServiceSuite) TestSearch_Failed() {
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
|
||||
// sync segment data
|
||||
syncReq := &querypb.SyncDistributionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63(),
|
||||
TargetID: suite.node.session.ServerID,
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
Channel: suite.vchannel,
|
||||
}
|
||||
|
||||
syncVersionAction := &querypb.SyncAction{
|
||||
Type: querypb.SyncType_UpdateVersion,
|
||||
SealedInTarget: []int64{1, 2, 3, 4},
|
||||
TargetVersion: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
syncReq.Actions = []*querypb.SyncAction{syncVersionAction}
|
||||
status, err := suite.node.SyncDistribution(ctx, syncReq)
|
||||
suite.NoError(err)
|
||||
suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
// metric type not match
|
||||
req.GetReq().MetricType = "IP"
|
||||
resp, err = suite.node.Search(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Contains(resp.GetStatus().GetReason(), "metric type not match")
|
||||
req.GetReq().MetricType = "L2"
|
||||
|
||||
// target not match
|
||||
req.Req.Base.TargetID = -1
|
||||
resp, err = suite.node.Search(ctx, req)
|
||||
@ -1333,9 +1328,7 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() {
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: true,
|
||||
|
@ -31,6 +31,8 @@ const (
|
||||
FailLabel = "fail"
|
||||
TotalLabel = "total"
|
||||
|
||||
HybridSearchLabel = "hybrid_search"
|
||||
|
||||
InsertLabel = "insert"
|
||||
DeleteLabel = "delete"
|
||||
UpsertLabel = "upsert"
|
||||
|
@ -239,9 +239,9 @@ func TestSchema_GetVectorFieldSchema(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("GetVectorFieldSchema", func(t *testing.T) {
|
||||
fieldSchema, err := GetVectorFieldSchema(schemaNormal)
|
||||
assert.Equal(t, "field_float_vector", fieldSchema.Name)
|
||||
assert.NoError(t, err)
|
||||
fieldSchema := GetVectorFieldSchemas(schemaNormal)
|
||||
assert.Equal(t, 1, len(fieldSchema))
|
||||
assert.Equal(t, "field_float_vector", fieldSchema[0].Name)
|
||||
})
|
||||
|
||||
schemaInvalid := &schemapb.CollectionSchema{
|
||||
@ -260,8 +260,8 @@ func TestSchema_GetVectorFieldSchema(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("GetVectorFieldSchemaInvalid", func(t *testing.T) {
|
||||
_, err := GetVectorFieldSchema(schemaInvalid)
|
||||
assert.Error(t, err)
|
||||
res := GetVectorFieldSchemas(schemaInvalid)
|
||||
assert.Equal(t, 0, len(res))
|
||||
})
|
||||
}
|
||||
|
||||
|
225
tests/integration/hybridsearch/hybridsearch_test.go
Normal file
225
tests/integration/hybridsearch/hybridsearch_test.go
Normal file
@ -0,0 +1,225 @@
|
||||
package hybridsearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proxy"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
type HybridSearchSuite struct {
|
||||
integration.MiniClusterSuite
|
||||
}
|
||||
|
||||
func (s *HybridSearchSuite) TestHybridSearch() {
|
||||
c := s.Cluster
|
||||
ctx, cancel := context.WithCancel(c.GetContext())
|
||||
defer cancel()
|
||||
|
||||
prefix := "TestHybridSearch"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
dim := 128
|
||||
rowNum := 3000
|
||||
|
||||
schema := integration.ConstructSchema(collectionName, dim, true,
|
||||
&schemapb.FieldSchema{Name: integration.Int64Field, DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
|
||||
&schemapb.FieldSchema{Name: integration.FloatVecField, DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}},
|
||||
&schemapb.FieldSchema{Name: integration.BinVecField, DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}},
|
||||
)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
|
||||
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: common.DefaultShardsNum,
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
err = merr.Error(createCollectionStatus)
|
||||
if err != nil {
|
||||
log.Warn("createCollectionStatus fail reason", zap.Error(err))
|
||||
}
|
||||
|
||||
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
|
||||
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(showCollectionsResp.GetStatus()))
|
||||
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
|
||||
|
||||
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
|
||||
bVecColumn := integration.NewBinaryVectorFieldData(integration.BinVecField, rowNum, dim)
|
||||
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(insertResult.GetStatus()))
|
||||
|
||||
// flush
|
||||
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
|
||||
DbName: dbName,
|
||||
CollectionNames: []string{collectionName},
|
||||
})
|
||||
s.NoError(err)
|
||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||
ids := segmentIDs.GetData()
|
||||
s.Require().NotEmpty(segmentIDs)
|
||||
s.Require().True(has)
|
||||
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
|
||||
s.True(has)
|
||||
|
||||
segments, err := c.MetaWatcher.ShowSegments()
|
||||
s.NoError(err)
|
||||
s.NotEmpty(segments)
|
||||
for _, segment := range segments {
|
||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||
}
|
||||
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
|
||||
|
||||
// load without index on vector fields
|
||||
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Error(merr.Error(loadStatus))
|
||||
|
||||
// create index for float vector
|
||||
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default_float",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
if err != nil {
|
||||
log.Warn("createIndexStatus fail reason", zap.Error(err))
|
||||
}
|
||||
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
|
||||
|
||||
// load with index on partial vector fields
|
||||
loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Error(merr.Error(loadStatus))
|
||||
|
||||
// create index for binary vector
|
||||
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: integration.BinVecField,
|
||||
IndexName: "_default_binary",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissBinIvfFlat, metric.JACCARD),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
if err != nil {
|
||||
log.Warn("createIndexStatus fail reason", zap.Error(err))
|
||||
}
|
||||
s.WaitForIndexBuiltWithIndexName(ctx, collectionName, integration.BinVecField, "_default_binary")
|
||||
|
||||
// load with index on all vector fields
|
||||
loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(loadStatus)
|
||||
if err != nil {
|
||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||
}
|
||||
s.WaitForLoad(ctx, collectionName)
|
||||
|
||||
// search
|
||||
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
|
||||
nq := 1
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
|
||||
fParams := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
|
||||
bParams := integration.GetSearchParams(integration.IndexFaissBinIvfFlat, metric.L2)
|
||||
fSearchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, fParams, nq, dim, topk, roundDecimal)
|
||||
|
||||
bSearchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.BinVecField, schemapb.DataType_BinaryVector, nil, metric.JACCARD, bParams, nq, dim, topk, roundDecimal)
|
||||
|
||||
hSearchReq := &milvuspb.HybridSearchRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionNames: nil,
|
||||
Requests: []*milvuspb.SearchRequest{fSearchReq, bSearchReq},
|
||||
OutputFields: []string{integration.FloatVecField, integration.BinVecField},
|
||||
}
|
||||
|
||||
// rrf rank hybrid search
|
||||
rrfParams := make(map[string]float64)
|
||||
rrfParams[proxy.RRFParamsKey] = 60
|
||||
b, err := json.Marshal(rrfParams)
|
||||
s.NoError(err)
|
||||
hSearchReq.RankParams = []*commonpb.KeyValuePair{
|
||||
{Key: proxy.RankTypeKey, Value: "rrf"},
|
||||
{Key: proxy.RankParamsKey, Value: string(b)},
|
||||
{Key: proxy.LimitKey, Value: strconv.Itoa(topk)},
|
||||
{Key: proxy.RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
|
||||
}
|
||||
|
||||
searchResult, err := c.Proxy.HybridSearch(ctx, hSearchReq)
|
||||
|
||||
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
|
||||
|
||||
// weighted rank hybrid search
|
||||
weightsParams := make(map[string][]float64)
|
||||
weightsParams[proxy.WeightsParamsKey] = []float64{0.5, 0.2}
|
||||
b, err = json.Marshal(weightsParams)
|
||||
s.NoError(err)
|
||||
hSearchReq.RankParams = []*commonpb.KeyValuePair{
|
||||
{Key: proxy.RankTypeKey, Value: "weighted"},
|
||||
{Key: proxy.RankParamsKey, Value: string(b)},
|
||||
{Key: proxy.LimitKey, Value: strconv.Itoa(topk)},
|
||||
}
|
||||
|
||||
searchResult, err = c.Proxy.HybridSearch(ctx, hSearchReq)
|
||||
|
||||
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
|
||||
|
||||
log.Info("TestHybridSearch succeed")
|
||||
}
|
||||
|
||||
func TestHybridSearch(t *testing.T) {
|
||||
suite.Run(t, new(HybridSearchSuite))
|
||||
}
|
@ -44,19 +44,24 @@ const (
|
||||
)
|
||||
|
||||
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
|
||||
s.waitForIndexBuiltInternal(ctx, dbName, collection, field)
|
||||
s.waitForIndexBuiltInternal(ctx, dbName, collection, field, "")
|
||||
}
|
||||
|
||||
func (s *MiniClusterSuite) WaitForIndexBuilt(ctx context.Context, collection, field string) {
|
||||
s.waitForIndexBuiltInternal(ctx, "", collection, field)
|
||||
s.waitForIndexBuiltInternal(ctx, "", collection, field, "")
|
||||
}
|
||||
|
||||
func (s *MiniClusterSuite) waitForIndexBuiltInternal(ctx context.Context, dbName, collection, field string) {
|
||||
func (s *MiniClusterSuite) WaitForIndexBuiltWithIndexName(ctx context.Context, collection, field, indexName string) {
|
||||
s.waitForIndexBuiltInternal(ctx, "", collection, field, indexName)
|
||||
}
|
||||
|
||||
func (s *MiniClusterSuite) waitForIndexBuiltInternal(ctx context.Context, dbName, collection, field, indexName string) {
|
||||
getIndexBuilt := func() bool {
|
||||
resp, err := s.Cluster.Proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collection,
|
||||
FieldName: field,
|
||||
IndexName: indexName,
|
||||
})
|
||||
if err != nil {
|
||||
s.FailNow("failed to describe index")
|
||||
|
Loading…
Reference in New Issue
Block a user