feat: Support multiple vector search (#29433)

issue #25639 

Signed-off-by: xige-16 <xi.ge@zilliz.com>

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2024-01-08 15:34:48 +08:00 committed by GitHub
parent 7e6f73a12d
commit 9702cef2b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1780 additions and 374 deletions

View File

@ -281,6 +281,16 @@ class SegmentGrowingImpl : public SegmentGrowing {
void
check_search(const query::Plan* plan) const override {
Assert(plan);
auto& metric_str = plan->plan_node_->search_info_.metric_type_;
auto searched_field_id = plan->plan_node_->search_info_.field_id_;
auto index_meta =
index_meta_->GetFieldIndexMeta(FieldId(searched_field_id));
if (metric_str.empty()) {
metric_str = index_meta.GeMetricType();
} else {
AssertInfo(metric_str == index_meta.GeMetricType(),
"metric type not match");
}
}
const ConcurrentVector<Timestamp>&

View File

@ -917,6 +917,17 @@ SegmentSealedImpl::check_search(const query::Plan* plan) const {
AssertInfo(plan->extra_info_opt_.has_value(),
"Extra info of search plan doesn't have value");
auto& metric_str = plan->plan_node_->search_info_.metric_type_;
auto searched_field_id = plan->plan_node_->search_info_.field_id_;
auto index_meta =
col_index_meta_->GetFieldIndexMeta(FieldId(searched_field_id));
if (metric_str.empty()) {
metric_str = index_meta.GeMetricType();
} else {
AssertInfo(metric_str == index_meta.GeMetricType(),
"metric type not match");
}
if (!is_system_field_ready()) {
PanicInfo(
FieldNotLoaded,

View File

@ -219,7 +219,7 @@ message LoadMetaInfo {
LoadType load_type = 1;
int64 collectionID = 2;
repeated int64 partitionIDs = 3;
string metric_type = 4;
string metric_type = 4 [deprecated=true];
}
message WatchDmChannelsRequest {

View File

@ -2726,9 +2726,135 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
}
func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
return &milvuspb.SearchResults{
Status: merr.Status(merr.WrapErrServiceInternal("unimplemented")),
}, nil
receiveSize := proto.Size(request)
metrics.ProxyReceiveBytes.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
request.GetCollectionName(),
).Add(float64(receiveSize))
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}
method := "HybridSearch"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch")
defer sp.End()
qt := &hybridSearchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
request: request,
tr: timerecord.NewTimeRecorder(method),
qc: node.queryCoord,
node: node,
lb: node.lbPolicy,
}
guaranteeTs := request.GuaranteeTimestamp
log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("partitions", request.PartitionNames),
zap.Any("OutputFields", request.OutputFields),
zap.Uint64("guarantee_timestamp", guaranteeTs),
)
defer func() {
span := tr.ElapseSpan()
if span >= SlowReadSpan {
log.Info(rpcSlow(method), zap.Duration("duration", span))
}
}()
log.Debug(rpcReceived(method))
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
log.Warn(
rpcFailedToEnqueue(method),
zap.Error(err),
)
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.AbandonLabel,
).Inc()
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}
tr.CtxRecord(ctx, "hybrid search request enqueue")
log.Debug(
rpcEnqueued(method),
zap.Uint64("timestamp", qt.request.Base.Timestamp),
)
if err := qt.WaitToFinish(); err != nil {
log.Warn(
rpcFailedToWaitToFinish(method),
zap.Error(err),
)
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.FailLabel,
).Inc()
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}
span := tr.CtxRecord(ctx, "wait hybrid search result")
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
).Observe(float64(span.Milliseconds()))
tr.CtxRecord(ctx, "wait hybrid search result")
log.Debug(rpcDone(method))
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(len(qt.request.GetRequests())))
searchDur := tr.ElapseSpan().Milliseconds()
metrics.ProxySQLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
).Observe(float64(searchDur))
metrics.ProxyCollectionSQLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
request.CollectionName,
).Observe(float64(searchDur))
if qt.result != nil {
sentSize := proto.Size(qt.result)
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
}
return qt.result, nil
}
func (node *Proxy) getVectorPlaceholderGroupForSearchByPks(ctx context.Context, request *milvuspb.SearchRequest) ([]byte, error) {

View File

@ -1477,7 +1477,7 @@ func TestProxy(t *testing.T) {
topk := 10
roundDecimal := 6
expr := fmt.Sprintf("%s > 0", int64Field)
constructVectorsPlaceholderGroup := func() *commonpb.PlaceholderGroup {
constructVectorsPlaceholderGroup := func(nq int) *commonpb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
@ -1502,8 +1502,8 @@ func TestProxy(t *testing.T) {
}
}
constructSearchRequest := func() *milvuspb.SearchRequest {
plg := constructVectorsPlaceholderGroup()
constructSearchRequest := func(nq int) *milvuspb.SearchRequest {
plg := constructVectorsPlaceholderGroup(nq)
plgBs, err := proto.Marshal(plg)
assert.NoError(t, err)
@ -1538,13 +1538,51 @@ func TestProxy(t *testing.T) {
wg.Add(1)
t.Run("search", func(t *testing.T) {
defer wg.Done()
req := constructSearchRequest()
req := constructSearchRequest(nq)
resp, err := proxy.Search(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
constructHybridSearchRequest := func(reqs []*milvuspb.SearchRequest) *milvuspb.HybridSearchRequest {
params := make(map[string]float64)
params[RRFParamsKey] = 60
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
{Key: LimitKey, Value: strconv.Itoa(topk)},
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
}
return &milvuspb.HybridSearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Requests: reqs,
PartitionNames: nil,
OutputFields: nil,
RankParams: rankParams,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
wg.Add(1)
nq = 1
t.Run("hybrid search", func(t *testing.T) {
defer wg.Done()
req1 := constructSearchRequest(nq)
req2 := constructSearchRequest(nq)
resp, err := proxy.HybridSearch(ctx, constructHybridSearchRequest([]*milvuspb.SearchRequest{req1, req2}))
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
nq = 10
constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup {
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0])
exprBytes := []byte(expr)

157
internal/proxy/reScorer.go Normal file
View File

@ -0,0 +1,157 @@
package proxy
import (
"encoding/json"
"fmt"
"reflect"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type rankType int
const (
invalidRankType rankType = iota // invalidRankType = 0
rrfRankType // rrfRankType = 1
weightedRankType // weightedRankType = 2
udfExprRankType // udfExprRankType = 3
)
var rankTypeMap = map[string]rankType{
"invalid": invalidRankType,
"rrf": rrfRankType,
"weighted": weightedRankType,
"expr": udfExprRankType,
}
type reScorer interface {
name() string
scorerType() rankType
reScore(input *milvuspb.SearchResults)
}
type baseScorer struct {
scorerName string
}
func (bs *baseScorer) name() string {
return bs.scorerName
}
type rrfScorer struct {
baseScorer
k float32
}
func (rs *rrfScorer) reScore(input *milvuspb.SearchResults) {
for i := range input.Results.GetScores() {
input.Results.Scores[i] = 1 / (rs.k + float32(i+1))
}
}
func (rs *rrfScorer) scorerType() rankType {
return rrfRankType
}
type weightedScorer struct {
baseScorer
weight float32
}
func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) {
for i, score := range input.Results.GetScores() {
input.Results.Scores[i] = ws.weight * score
}
}
func (ws *weightedScorer) scorerType() rankType {
return weightedRankType
}
func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
res := make([]reScorer, len(reqs))
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams)
if err != nil {
log.Info("rank strategy not specified, use rrf instead")
// if not set rank strategy, use rrf rank as default
for i := range reqs {
res[i] = &rrfScorer{
baseScorer: baseScorer{
scorerName: "rrf",
},
k: float32(defaultRRFParamsValue),
}
}
return res, nil
}
if _, ok := rankTypeMap[rankTypeStr]; !ok {
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
}
paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankParamsKey, rankParams)
if err != nil {
return nil, errors.New(RankParamsKey + " not found in rank_params")
}
var params map[string]interface{}
err = json.Unmarshal([]byte(paramStr), &params)
if err != nil {
return nil, err
}
switch rankTypeMap[rankTypeStr] {
case rrfRankType:
k, ok := params[RRFParamsKey].(float64)
if !ok {
return nil, errors.New(RRFParamsKey + " not found in rank_params")
}
log.Debug("rrf params", zap.Float64("k", k))
for i := range reqs {
res[i] = &rrfScorer{
baseScorer: baseScorer{
scorerName: "rrf",
},
k: float32(k),
}
}
case weightedRankType:
if _, ok := params[WeightsParamsKey]; !ok {
return nil, errors.New(WeightsParamsKey + " not found in rank_params")
}
weights := make([]float32, 0)
switch reflect.TypeOf(params[WeightsParamsKey]).Kind() {
case reflect.Slice:
rs := reflect.ValueOf(params[WeightsParamsKey])
for i := 0; i < rs.Len(); i++ {
weights = append(weights, float32(rs.Index(i).Interface().(float64)))
}
default:
return nil, errors.New("The weights param should be an array")
}
log.Debug("weights params", zap.Any("weights", weights))
if len(reqs) != len(weights) {
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(reqs)), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
}
for i := range reqs {
res[i] = &weightedScorer{
baseScorer: baseScorer{
scorerName: "weighted",
},
weight: weights[i],
}
}
default:
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
}
return res, nil
}

View File

@ -0,0 +1,55 @@
package proxy
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
func TestRescorer(t *testing.T) {
t.Run("default scorer", func(t *testing.T) {
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, nil)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
})
t.Run("rrf", func(t *testing.T) {
params := make(map[string]float64)
params[RRFParamsKey] = 61
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
}
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k)
})
t.Run("weights", func(t *testing.T) {
weights := []float64{0.5, 0.2}
params := make(map[string][]float64)
params[WeightsParamsKey] = weights
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "weighted"},
{Key: RankParamsKey, Value: string(b)},
}
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
})
}

View File

@ -88,6 +88,11 @@ const (
// minFloat32 minimum float.
minFloat32 = -1 * float32(math.MaxFloat32)
RankTypeKey = "strategy"
RankParamsKey = "params"
RRFParamsKey = "k"
WeightsParamsKey = "weights"
)
type task interface {

View File

@ -0,0 +1,461 @@
package proxy
import (
"context"
"fmt"
"math"
"sort"
"strconv"
"github.com/cockroachdb/errors"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
const (
HybridSearchTaskName = "HybridSearchTask"
)
type hybridSearchTask struct {
Condition
ctx context.Context
result *milvuspb.SearchResults
request *milvuspb.HybridSearchRequest
tr *timerecord.TimeRecorder
schema *schemaInfo
requery bool
userOutputFields []string
qc types.QueryCoordClient
node types.ProxyComponent
lb LBPolicy
collectionID UniqueID
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
reScorers []reScorer
}
func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PreExecute")
defer sp.End()
if len(t.request.Requests) <= 0 {
return errors.New("minimum of ann search requests is 1")
}
if len(t.request.Requests) > defaultMaxSearchRequest {
return errors.New("maximum of ann search requests is 1024")
}
for _, req := range t.request.GetRequests() {
nq, err := getNq(req)
if err != nil {
log.Debug("failed to get nq", zap.Error(err))
return err
}
if nq != 1 {
err = merr.WrapErrParameterInvalid("1", fmt.Sprint(nq), "nq should be equal to 1")
log.Debug(err.Error())
return err
}
}
collectionName := t.request.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
if err != nil {
return err
}
t.collectionID = collID
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
if err != nil {
log.Warn("get collection schema failed", zap.Error(err))
return err
}
partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil {
log.Warn("is partition key mode failed", zap.Error(err))
return err
}
if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
return errors.New("not support manually specifying the partition names if partition key mode is used")
}
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
if err != nil {
log.Warn("translate output fields failed", zap.Error(err))
return err
}
log.Debug("translate output fields",
zap.Strings("output fields", t.request.GetOutputFields()))
if len(t.request.OutputFields) > 0 {
t.requery = true
}
log.Debug("hybrid search preExecute done.",
zap.Uint64("guarantee_ts", t.request.GetGuaranteeTimestamp()),
zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()),
zap.Any("consistency level", t.request.GetConsistencyLevel()))
return nil
}
func (t *hybridSearchTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID()))
defer tr.CtxElapse(ctx, "done")
futures := make([]*conc.Future[*milvuspb.SearchResults], len(t.request.Requests))
for index := range t.request.Requests {
searchReq := t.request.Requests[index]
future := conc.Go(func() (*milvuspb.SearchResults, error) {
searchReq.TravelTimestamp = t.request.GetTravelTimestamp()
searchReq.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
searchReq.NotReturnAllMeta = t.request.GetNotReturnAllMeta()
searchReq.ConsistencyLevel = t.request.GetConsistencyLevel()
searchReq.UseDefaultConsistency = t.request.GetUseDefaultConsistency()
searchReq.OutputFields = nil
return t.node.Search(ctx, searchReq)
})
futures[index] = future
}
err := conc.AwaitAll(futures...)
if err != nil {
return err
}
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
if err != nil {
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
return err
}
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
for i, future := range futures {
err = future.Err()
if err != nil {
log.Debug("QueryNode search result error", zap.Error(err))
return err
}
result := futures[i].Value()
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Debug("QueryNode search result error",
zap.String("reason", result.GetStatus().GetReason()))
return merr.Error(result.GetStatus())
}
t.reScorers[i].reScore(result)
t.multipleRecallResults.Insert(result)
}
log.Debug("hybrid search execute done.")
return nil
}
type rankParams struct {
limit int64
offset int64
roundDecimal int64
}
// parseRankParams get limit and offset from rankParams, both are optional.
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
var (
limit int64
offset int64
roundDecimal int64
err error
)
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
if err != nil {
return nil, errors.New(LimitKey + " not found in search_params")
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
}
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}
}
// validate max result window.
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
return nil, fmt.Errorf("invalid max query result window, %w", err)
}
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
return &rankParams{
limit: limit,
offset: offset,
roundDecimal: roundDecimal,
}, nil
}
func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID()))
defer func() {
tr.CtxElapse(ctx, "done")
}()
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return err
}
rankParams, err := parseRankParams(t.request.GetRankParams())
if err != nil {
return err
}
t.result, err = rankSearchResultData(ctx, 1,
rankParams,
primaryFieldSchema.GetDataType(),
t.multipleRecallResults.Collect())
if err != nil {
log.Warn("rank search result failed", zap.Error(err))
return err
}
t.result.CollectionName = t.request.GetCollectionName()
t.fillInFieldInfo()
if t.requery {
err := t.Requery()
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
}
t.result.Results.OutputFields = t.userOutputFields
log.Debug("hybrid search post execute done")
return nil
}
func (t *hybridSearchTask) Requery() error {
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
DbName: t.request.GetDbName(),
CollectionName: t.request.GetCollectionName(),
Expr: "",
OutputFields: t.request.GetOutputFields(),
PartitionNames: t.request.GetPartitionNames(),
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
TravelTimestamp: t.request.GetTravelTimestamp(),
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
ConsistencyLevel: t.request.GetConsistencyLevel(),
UseDefaultConsistency: t.request.GetUseDefaultConsistency(),
}
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result)
}
func rankSearchResultData(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
searchResults []*milvuspb.SearchResults,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultData")
defer func() {
tr.CtxElapse(ctx, "done")
}()
offset := params.offset
limit := params.limit
topk := limit + offset
roundDecimal := params.roundDecimal
log.Ctx(ctx).Debug("rankSearchResultData",
zap.Int("len(searchResults)", len(searchResults)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0),
},
}
default:
return nil, errors.New("unsupported pk type")
}
// []map[id]score
accumulatedScores := make([]map[interface{}]float32, nq)
for i := int64(0); i < nq; i++ {
accumulatedScores[i] = make(map[interface{}]float32)
}
for _, result := range searchResults {
scores := result.GetResults().GetScores()
start := int64(0)
for i := int64(0); i < nq; i++ {
realTopk := result.GetResults().Topks[i]
for j := start; j < start+realTopk; j++ {
id := typeutil.GetPK(result.GetResults().GetIds(), j)
accumulatedScores[i][id] += scores[j]
}
start += realTopk
}
}
for i := int64(0); i < nq; i++ {
idSet := accumulatedScores[i]
keys := make([]interface{}, 0)
for key := range idSet {
keys = append(keys, key)
}
if int64(len(keys)) <= offset {
ret.Results.Topks = append(ret.Results.Topks, 0)
continue
}
// sort id by score
sort.Slice(keys, func(i, j int) bool {
return idSet[keys[i]] >= idSet[keys[j]]
})
if int64(len(keys)) > topk {
keys = keys[:topk]
}
// set real topk
ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset)
// append id and score
for index := offset; index < int64(len(keys)); index++ {
typeutil.AppendPKs(ret.Results.Ids, keys[index])
score := idSet[keys[index]]
if roundDecimal != -1 {
multiplier := math.Pow(10.0, float64(roundDecimal))
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
}
ret.Results.Scores = append(ret.Results.Scores, score)
}
}
return ret, nil
}
func (t *hybridSearchTask) fillInFieldInfo() {
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
for i, name := range t.request.OutputFields {
for _, field := range t.schema.Fields {
if t.result.Results.FieldsData[i] != nil && field.Name == name {
t.result.Results.FieldsData[i].FieldName = field.Name
t.result.Results.FieldsData[i].FieldId = field.FieldID
t.result.Results.FieldsData[i].Type = field.DataType
t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic
}
}
}
}
}
func (t *hybridSearchTask) TraceCtx() context.Context {
return t.ctx
}
func (t *hybridSearchTask) ID() UniqueID {
return t.request.Base.MsgID
}
func (t *hybridSearchTask) SetID(uid UniqueID) {
t.request.Base.MsgID = uid
}
func (t *hybridSearchTask) Name() string {
return HybridSearchTaskName
}
func (t *hybridSearchTask) Type() commonpb.MsgType {
return t.request.Base.MsgType
}
func (t *hybridSearchTask) BeginTs() Timestamp {
return t.request.Base.Timestamp
}
func (t *hybridSearchTask) EndTs() Timestamp {
return t.request.Base.Timestamp
}
func (t *hybridSearchTask) SetTs(ts Timestamp) {
t.request.Base.Timestamp = ts
}
func (t *hybridSearchTask) OnEnqueue() error {
t.request.Base = commonpbutil.NewMsgBase()
t.request.Base.MsgType = commonpb.MsgType_Search
t.request.Base.SourceID = paramtable.GetNodeID()
return nil
}

View File

@ -0,0 +1,330 @@
package proxy
import (
"context"
"strconv"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func createCollWithMultiVecField(t *testing.T, name string, rc types.RootCoordClient) {
schema := genCollectionSchema(name)
marshaledSchema, err := proto.Marshal(schema)
require.NoError(t, err)
ctx := context.TODO()
createColT := &createCollectionTask{
Condition: NewTaskCondition(context.TODO()),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: name,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
},
ctx: ctx,
rootCoord: rc,
}
require.NoError(t, createColT.OnEnqueue())
require.NoError(t, createColT.PreExecute(ctx))
require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx))
}
func TestHybridSearchTask_PreExecute(t *testing.T) {
var err error
var (
rc = NewRootCoordMock()
qc = mocks.NewMockQueryCoordClient(t)
ctx = context.TODO()
)
defer rc.Close()
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(ctx, rc, qc, mgr)
require.NoError(t, err)
genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask {
task := &hybridSearchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
request: &milvuspb.HybridSearchRequest{
CollectionName: collName,
Requests: reqs,
},
qc: qc,
tr: timerecord.NewTimeRecorder("test-hybrid-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("bad nq 0", func(t *testing.T) {
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
// Nq must be 1.
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 0}})
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("bad req num 0", func(t *testing.T) {
collName := "test_bad_req_num0_error" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
// num of reqs must be [1, 1024].
task := genHybridSearchTaskWithNq(t, collName, nil)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("bad req num 1025", func(t *testing.T) {
collName := "test_bad_req_num1025_error" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
// num of reqs must be [1, 1024].
reqs := make([]*milvuspb.SearchRequest, 0)
for i := 0; i <= defaultMaxSearchRequest; i++ {
reqs = append(reqs, &milvuspb.SearchRequest{
CollectionName: collName,
Nq: 1,
})
}
task := genHybridSearchTaskWithNq(t, collName, reqs)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("collection not exist", func(t *testing.T) {
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("hybrid search with timeout", func(t *testing.T) {
collName := "hybrid_search_with_timeout" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
task.ctx = ctxTimeout
task.request.OutputFields = []string{testFloatVecField}
assert.NoError(t, task.PreExecute(ctx))
})
}
func TestHybridSearchTask_ErrExecute(t *testing.T) {
var (
err error
ctx = context.TODO()
rc = NewRootCoordMock()
qc = getQueryCoordClient()
qn = getQueryNodeClient()
collectionName = t.Name() + funcutil.GenRandomStr()
)
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
mgr := NewMockShardClientManager(t)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
lb := NewLBPolicyImpl(mgr)
factory := dependency.NewDefaultFactory(true)
node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
assert.NoError(t, err)
node.sched = scheduler
err = node.sched.Start()
assert.NoError(t, err)
err = node.initRateCollector()
assert.NoError(t, err)
node.rootCoord = rc
node.queryCoord = qc
defer qc.Close()
err = InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
createCollWithMultiVecField(t, collectionName, rc)
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1},
NodeAddrs: []string{"localhost:9000"},
},
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: successStatus,
CollectionIDs: []int64{collectionID},
InMemoryPercentages: []int64{100},
}, nil)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
SourceID: paramtable.GetNodeID(),
},
CollectionID: collectionID,
})
require.NoError(t, err)
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
vectorFields := typeutil.GetVectorFieldSchemas(schema.CollectionSchema)
vectorFieldNames := make([]string, len(vectorFields))
for i, field := range vectorFields {
vectorFieldNames[i] = field.GetName()
}
// test begins
task := &hybridSearchTask{
Condition: NewTaskCondition(ctx),
ctx: ctx,
result: &milvuspb.SearchResults{
Status: merr.Success(),
},
request: &milvuspb.HybridSearchRequest{
CollectionName: collectionName,
Requests: []*milvuspb.SearchRequest{
{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
Nq: 1,
DslType: commonpb.DslType_BoolExprV1,
SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: testFloatVecField},
{Key: TopKKey, Value: "10"},
},
},
{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
Nq: 1,
DslType: commonpb.DslType_BoolExprV1,
SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: testBinaryVecField},
{Key: TopKKey, Value: "10"},
},
},
},
OutputFields: vectorFieldNames,
},
qc: qc,
lb: lb,
node: node,
}
assert.NoError(t, task.OnEnqueue())
task.ctx = ctx
assert.NoError(t, task.PreExecute(ctx))
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}, nil)
assert.Error(t, task.Execute(ctx))
}
func TestHybridSearchTask_PostExecute(t *testing.T) {
var (
rc = NewRootCoordMock()
qc = getQueryCoordClient()
qn = getQueryNodeClient()
collectionName = t.Name() + funcutil.GenRandomStr()
)
defer rc.Close()
mgr := NewMockShardClientManager(t)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
t.Run("Test empty result", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
createCollWithMultiVecField(t, collectionName, rc)
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: LimitKey, Value: strconv.Itoa(3)},
{Key: OffsetKey, Value: strconv.Itoa(2)},
}
qt := &hybridSearchTask{
ctx: ctx,
Condition: NewTaskCondition(context.TODO()),
qc: nil,
tr: timerecord.NewTimeRecorder("search"),
schema: schema,
request: &milvuspb.HybridSearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
},
CollectionName: collectionName,
RankParams: rankParams,
},
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
}
err = qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
})
}

View File

@ -606,13 +606,6 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
}
func (t *searchTask) Requery() error {
pkField, err := t.schema.GetPkField()
if err != nil {
return err
}
ids := t.result.GetResults().GetIds()
plan := planparserv2.CreateRequeryPlan(pkField, ids)
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
@ -625,69 +618,8 @@ func (t *searchTask) Requery() error {
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
QueryParams: t.request.GetSearchParams(),
}
qt := &queryTask{
ctx: t.ctx,
Condition: NewTaskCondition(t.ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: queryReq,
plan: plan,
qc: t.node.(*Proxy).queryCoord,
lb: t.node.(*Proxy).lbPolicy,
}
queryResult, err := t.node.(*Proxy).query(t.ctx, qt)
if err != nil {
return err
}
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return merr.Error(queryResult.GetStatus())
}
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
// We should reorganize query results to keep the order of original queried ids. For example:
// ===========================================
// 3 2 5 4 1 (query ids)
// ||
// || (query)
// \/
// 4 3 5 1 2 (result ids)
// v4 v3 v5 v1 v2 (result vectors)
// ||
// || (reorganize)
// \/
// 3 2 5 4 1 (result ids)
// v3 v2 v5 v4 v1 (result vectors)
// ===========================================
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
if err != nil {
return err
}
offsets := make(map[any]int)
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
pk := typeutil.GetData(pkFieldData, i)
offsets[pk] = i
}
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
id := typeutil.GetPK(ids, int64(i))
if _, ok := offsets[id]; !ok {
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID())
}
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
}
// filter id field out if it is not specified as output
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(t.request.GetOutputFields(), fieldData.GetFieldName())
})
return nil
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result)
}
func (t *searchTask) fillInEmptyResult(numQueries int64) {
@ -734,6 +666,86 @@ func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.Se
}
}
func doRequery(ctx context.Context,
collectionID int64,
node types.ProxyComponent,
schema *schemapb.CollectionSchema,
request *milvuspb.QueryRequest,
result *milvuspb.SearchResults,
) error {
outputFields := request.GetOutputFields()
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
return err
}
ids := result.GetResults().GetIds()
plan := planparserv2.CreateRequeryPlan(pkField, ids)
qt := &queryTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: request,
plan: plan,
qc: node.(*Proxy).queryCoord,
lb: node.(*Proxy).lbPolicy,
}
queryResult, err := node.(*Proxy).query(ctx, qt)
if err != nil {
return err
}
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return merr.Error(queryResult.GetStatus())
}
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
// We should reorganize query results to keep the order of original queried ids. For example:
// ===========================================
// 3 2 5 4 1 (query ids)
// ||
// || (query)
// \/
// 4 3 5 1 2 (result ids)
// v4 v3 v5 v1 v2 (result vectors)
// ||
// || (reorganize)
// \/
// 3 2 5 4 1 (result ids)
// v3 v2 v5 v4 v1 (result vectors)
// ===========================================
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
if err != nil {
return err
}
offsets := make(map[any]int)
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
pk := typeutil.GetData(pkFieldData, i)
offsets[pk] = i
}
result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
id := typeutil.GetPK(ids, int64(i))
if _, ok := offsets[id]; !ok {
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
id, typeutil.GetSizeOfIDs(ids), len(offsets), collectionID)
}
typeutil.AppendFieldData(result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
}
// filter id field out if it is not specified as output
result.Results.FieldsData = lo.Filter(result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(outputFields, fieldData.GetFieldName())
})
return nil
}
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
tr := timerecord.NewTimeRecorder("decodeSearchResults")
results := make([]*schemapb.SearchResultData, 0)

View File

@ -71,6 +71,20 @@ const (
testMaxVarCharLength = 100
)
func genCollectionSchema(collectionName string) *schemapb.CollectionSchema {
return constructCollectionSchemaWithAllType(
testBoolField,
testInt32Field,
testInt64Field,
testFloatField,
testDoubleField,
testFloatVecField,
testBinaryVecField,
testFloat16VecField,
testVecDim,
collectionName)
}
func constructCollectionSchema(
int64Field, floatVecField string,
dim int,

View File

@ -62,6 +62,8 @@ const (
defaultMaxArrayCapacity = 4096
defaultMaxSearchRequest = 1024
// DefaultArithmeticIndexType name of default index type for scalar field
DefaultArithmeticIndexType = "STL_SORT"
@ -69,6 +71,8 @@ const (
DefaultStringIndexType = "Trie"
InvertedIndexType = "INVERTED"
defaultRRFParamsValue = 60
)
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))

View File

@ -175,7 +175,6 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
loadMeta := packLoadMeta(
ex.meta.GetLoadType(task.CollectionID()),
"",
task.CollectionID(),
partitions...,
)
@ -370,14 +369,8 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
log.Warn("fail to get index meta of collection")
return err
}
metricType, err := getMetricType(indexInfo, collectionInfo.GetSchema())
if err != nil {
log.Warn("failed to get metric type", zap.Error(err))
return err
}
loadMeta := packLoadMeta(
ex.meta.GetLoadType(task.CollectionID()),
metricType,
task.CollectionID(),
partitions...,
)

View File

@ -21,8 +21,6 @@ import (
"fmt"
"time"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -33,7 +31,6 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -162,12 +159,11 @@ func packReleaseSegmentRequest(task *SegmentTask, action *SegmentAction) *queryp
}
}
func packLoadMeta(loadType querypb.LoadType, metricType string, collectionID int64, partitions ...int64) *querypb.LoadMetaInfo {
func packLoadMeta(loadType querypb.LoadType, collectionID int64, partitions ...int64) *querypb.LoadMetaInfo {
return &querypb.LoadMetaInfo{
LoadType: loadType,
CollectionID: collectionID,
PartitionIDs: partitions,
MetricType: metricType,
}
}
@ -241,22 +237,3 @@ func getShardLeader(replicaMgr *meta.ReplicaManager, distMgr *meta.DistributionM
}
return distMgr.GetShardLeader(replica, channel)
}
func getMetricType(indexInfos []*indexpb.IndexInfo, schema *schemapb.CollectionSchema) (string, error) {
vecField, err := typeutil.GetVectorFieldSchema(schema)
if err != nil {
return "", err
}
indexInfo, ok := lo.Find(indexInfos, func(info *indexpb.IndexInfo) bool {
return info.GetFieldID() == vecField.GetFieldID()
})
if !ok || indexInfo == nil {
err = fmt.Errorf("cannot find index info for %s field", vecField.GetName())
return "", err
}
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, indexInfo.GetIndexParams())
if err != nil {
return "", err
}
return metricType, nil
}

View File

@ -26,7 +26,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/common"
)
@ -35,57 +34,6 @@ type UtilsSuite struct {
suite.Suite
}
func (s *UtilsSuite) TestGetMetricType() {
collection := int64(1)
schema := &schemapb.CollectionSchema{
Name: "TestGetMetricType",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector},
},
}
indexInfo := &indexpb.IndexInfo{
CollectionID: collection,
FieldID: 100,
IndexParams: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: "L2",
},
},
}
indexInfo2 := &indexpb.IndexInfo{
CollectionID: collection,
FieldID: 100,
}
s.Run("test normal", func() {
metricType, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, schema)
s.NoError(err)
s.Equal("L2", metricType)
})
s.Run("test get vec field failed", func() {
_, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{
Name: "TestGetMetricType",
})
s.Error(err)
})
s.Run("test field id mismatch", func() {
_, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{
Name: "TestGetMetricType",
Fields: []*schemapb.FieldSchema{
{FieldID: -1, Name: "vec", DataType: schemapb.DataType_FloatVector},
},
})
s.Error(err)
})
s.Run("test no metric type", func() {
_, err := getMetricType([]*indexpb.IndexInfo{indexInfo2}, schema)
s.Error(err)
})
}
func (s *UtilsSuite) TestPackLoadSegmentRequest() {
ctx := context.Background()

View File

@ -17,12 +17,10 @@
package querynodev2
import (
"fmt"
"math"
"math/rand"
"strconv"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -60,45 +58,32 @@ const (
// ---------- unittest util functions ----------
// functions of messages and requests
func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
topKStr := strconv.FormatInt(topK, 10)
nProbStr := strconv.Itoa(defaultNProb)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
var fieldID int64
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
fieldID = f.FieldID
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
func genSearchPlan(dataType schemapb.DataType, fieldID int64, metricType string) *planpb.PlanNode {
var vectorType planpb.VectorType
switch dataType {
case schemapb.DataType_FloatVector:
vectorType = planpb.VectorType_FloatVector
case schemapb.DataType_Float16Vector:
vectorType = planpb.VectorType_Float16Vector
case schemapb.DataType_BinaryVector:
vectorType = planpb.VectorType_BinaryVector
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return `vector_anns: <
field_id: ` + fmt.Sprintf("%d", fieldID) + `
query_info: <
topk: ` + topKStr + `
round_decimal: ` + roundDecimalStr + `
metric_type: "` + metricType + `"
search_params: "{\"nprobe\": ` + nProbStr + `}"
>
placeholder_tag: "$0"
>`, nil
}
func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) {
if indexType == IndexFaissIDMap { // float vector
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
return &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
VectorType: vectorType,
FieldId: fieldID,
QueryInfo: &planpb.QueryInfo{
Topk: defaultTopK,
MetricType: metricType,
SearchParams: "{\"nprobe\":" + strconv.Itoa(defaultNProb) + "}",
RoundDecimal: defaultRoundDecimal,
},
PlaceholderTag: "$0",
},
},
}
return "", fmt.Errorf("Invalid indexType")
}
func genPlaceHolderGroup(nq int64) ([]byte, error) {

View File

@ -83,7 +83,6 @@ func (m *collectionManager) PutOrRef(collectionID int64, schema *schemapb.Collec
}
collection := NewCollection(collectionID, schema, meta, loadMeta.GetLoadType())
collection.metricType.Store(loadMeta.GetMetricType())
collection.AddPartition(loadMeta.GetPartitionIDs()...)
collection.Ref(1)
m.collections[collectionID] = collection
@ -125,7 +124,7 @@ type Collection struct {
id int64
partitions *typeutil.ConcurrentSet[int64]
loadType querypb.LoadType
metricType atomic.String
metricType atomic.String // deprecated
schema atomic.Pointer[schemapb.CollectionSchema]
isGpuIndex bool
@ -175,14 +174,6 @@ func (c *Collection) GetLoadType() querypb.LoadType {
return c.loadType
}
func (c *Collection) SetMetricType(metricType string) {
c.metricType.Store(metricType)
}
func (c *Collection) GetMetricType() string {
return c.metricType.Load()
}
func (c *Collection) Ref(count uint32) uint32 {
refCount := c.refCount.Add(count)
log.Debug("collection ref increment",

View File

@ -291,7 +291,56 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType) *s
return &schema
}
func GenTestIndexInfoList(collectionID int64, schema *schemapb.CollectionSchema) []*indexpb.IndexInfo {
res := make([]*indexpb.IndexInfo, 0)
vectorFieldSchemas := typeutil.GetVectorFieldSchemas(schema)
for _, field := range vectorFieldSchemas {
index := &indexpb.IndexInfo{
CollectionID: collectionID,
FieldID: field.GetFieldID(),
// For now, a field can only have one index
// using fieldID and fieldName as indexID and indexName, just make sure not repeated.
IndexID: field.GetFieldID(),
IndexName: field.GetName(),
TypeParams: field.GetTypeParams(),
}
switch field.GetDataType() {
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector:
{
index.IndexParams = []*commonpb.KeyValuePair{
{Key: common.MetricTypeKey, Value: metric.L2},
{Key: common.IndexTypeKey, Value: IndexFaissIVFFlat},
{Key: "nlist", Value: "128"},
}
}
case schemapb.DataType_BinaryVector:
{
index.IndexParams = []*commonpb.KeyValuePair{
{Key: common.MetricTypeKey, Value: metric.JACCARD},
{Key: common.IndexTypeKey, Value: IndexFaissBinIVFFlat},
{Key: "nlist", Value: "128"},
}
}
}
res = append(res, index)
}
return res
}
func GenTestIndexMeta(collectionID int64, schema *schemapb.CollectionSchema) *segcorepb.CollectionIndexMeta {
indexInfos := GenTestIndexInfoList(collectionID, schema)
fieldIndexMetas := make([]*segcorepb.FieldIndexMeta, 0)
for _, info := range indexInfos {
fieldIndexMetas = append(fieldIndexMetas, &segcorepb.FieldIndexMeta{
CollectionID: info.GetCollectionID(),
FieldID: info.GetFieldID(),
IndexName: info.GetIndexName(),
TypeParams: info.GetTypeParams(),
IndexParams: info.GetIndexParams(),
IsAutoIndex: info.GetIsAutoIndex(),
UserIndexParams: info.GetUserIndexParams(),
})
}
sizePerRecord, err := typeutil.EstimateSizePerRecord(schema)
maxIndexRecordPerSegment := int64(0)
if err != nil || sizePerRecord == 0 {
@ -302,37 +351,6 @@ func GenTestIndexMeta(collectionID int64, schema *schemapb.CollectionSchema) *se
maxIndexRecordPerSegment = int64(threshold * proportion / float64(sizePerRecord))
}
fieldIndexMetas := make([]*segcorepb.FieldIndexMeta, 0)
fieldIndexMetas = append(fieldIndexMetas, &segcorepb.FieldIndexMeta{
CollectionID: collectionID,
FieldID: simpleFloatVecField.id,
IndexName: "querynode-test",
TypeParams: []*commonpb.KeyValuePair{
{
Key: dimKey,
Value: strconv.Itoa(simpleFloatVecField.dim),
},
},
IndexParams: []*commonpb.KeyValuePair{
{
Key: metricTypeKey,
Value: simpleFloatVecField.metricType,
},
{
Key: common.IndexTypeKey,
Value: IndexFaissIVFFlat,
},
{
Key: "nlist",
Value: "128",
},
},
IsAutoIndex: false,
UserIndexParams: []*commonpb.KeyValuePair{
{},
},
})
indexMeta := segcorepb.CollectionIndexMeta{
MaxIndexRowCount: maxIndexRecordPerSegment,
IndexMetas: fieldIndexMetas,
@ -889,6 +907,80 @@ func SaveDeltaLog(collectionID int64,
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
}
func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
fieldSchema *schemapb.FieldSchema,
indexInfo *indexpb.IndexInfo,
cm storage.ChunkManager,
msgLength int,
) (*querypb.FieldIndexInfo, error) {
typeParams := funcutil.KeyValuePair2Map(indexInfo.GetTypeParams())
indexParams := funcutil.KeyValuePair2Map(indexInfo.GetIndexParams())
index, err := indexcgowrapper.NewCgoIndex(fieldSchema.GetDataType(), typeParams, indexParams)
if err != nil {
return nil, err
}
defer index.Delete()
var dataset *indexcgowrapper.Dataset
switch fieldSchema.DataType {
case schemapb.DataType_BinaryVector:
dataset = indexcgowrapper.GenBinaryVecDataset(generateBinaryVectors(msgLength, defaultDim))
case schemapb.DataType_FloatVector:
dataset = indexcgowrapper.GenFloatVecDataset(generateFloatVectors(msgLength, defaultDim))
}
err = index.Build(dataset)
if err != nil {
return nil, err
}
// save index to minio
binarySet, err := index.Serialize()
if err != nil {
return nil, err
}
// serialize index params
indexCodec := storage.NewIndexFileBinlogCodec()
serializedIndexBlobs, err := indexCodec.Serialize(
buildID,
0,
collectionID,
partitionID,
segmentID,
fieldSchema.GetFieldID(),
indexParams,
indexInfo.GetIndexName(),
indexInfo.GetIndexID(),
binarySet,
)
if err != nil {
return nil, err
}
indexPaths := make([]string, 0)
for _, index := range serializedIndexBlobs {
indexPath := filepath.Join(cm.RootPath(), "index_files",
strconv.Itoa(int(segmentID)), index.Key)
indexPaths = append(indexPaths, indexPath)
err := cm.Write(context.Background(), indexPath, index.Value)
if err != nil {
return nil, err
}
}
_, cCurrentIndexVersion := getIndexEngineVersion()
return &querypb.FieldIndexInfo{
FieldID: fieldSchema.GetFieldID(),
EnableIndex: true,
IndexName: indexInfo.GetIndexName(),
IndexParams: indexInfo.GetIndexParams(),
IndexFilePaths: indexPaths,
CurrentIndexVersion: cCurrentIndexVersion,
}, nil
}
func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLength int, indexType, metricType string, cm storage.ChunkManager) (*querypb.FieldIndexInfo, error) {
typeParams, indexParams := genIndexParams(indexType, metricType)

View File

@ -54,13 +54,7 @@ func createSearchPlanByExpr(ctx context.Context, col *Collection, expr []byte, m
return nil, err1
}
newPlan := &SearchPlan{cSearchPlan: cPlan}
if len(metricType) != 0 {
newPlan.setMetricType(metricType)
} else {
newPlan.setMetricType(col.GetMetricType())
}
return newPlan, nil
return &SearchPlan{cSearchPlan: cPlan}, nil
}
func (plan *SearchPlan) getTopK() int64 {

View File

@ -205,7 +205,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
log.Info("received watch channel request",
zap.Int64("version", req.GetVersion()),
zap.String("metricType", req.GetLoadMeta().GetMetricType()),
)
// check node healthy
@ -219,12 +218,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
return merr.Status(err), nil
}
// check metric type
if req.GetLoadMeta().GetMetricType() == "" {
err := fmt.Errorf("empty metric type, collection = %d", req.GetCollectionID())
return merr.Status(err), nil
}
// check index
if len(req.GetIndexInfoList()) == 0 {
err := merr.WrapErrIndexNotFoundForCollection(req.GetSchema().GetName())
@ -253,8 +246,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(),
node.composeIndexMeta(req.GetIndexInfoList(), req.Schema), req.GetLoadMeta())
collection := node.manager.Collection.Get(req.GetCollectionID())
collection.SetMetricType(req.GetLoadMeta().GetMetricType())
delegator, err := delegator.NewShardDelegator(
ctx,
req.GetCollectionID(),
@ -769,20 +761,6 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
return resp, nil
}
// Check if the metric type specified in search params matches the metric type in the index info.
if !req.GetFromShardLeader() && req.GetReq().GetMetricType() != "" {
if req.GetReq().GetMetricType() != collection.GetMetricType() {
resp.Status = merr.Status(merr.WrapErrParameterInvalid(collection.GetMetricType(), req.GetReq().GetMetricType(),
fmt.Sprintf("collection:%d, metric type not match", collection.ID())))
return resp, nil
}
}
// Define the metric type when it has not been explicitly assigned by the user.
if !req.GetFromShardLeader() && req.GetReq().GetMetricType() == "" {
req.Req.MetricType = collection.GetMetricType()
}
toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels()))
runningGp, runningCtx := errgroup.WithContext(ctx)
for i, ch := range req.GetDmlChannels() {

View File

@ -39,7 +39,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
@ -52,7 +51,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -257,6 +255,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
ctx := context.Background()
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
deltaLogs, err := segments.SaveDeltaLog(suite.collectionID,
suite.partitionIDs[0],
suite.flushedSegmentIDs[0],
@ -292,16 +291,14 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
Level: datapb.SegmentLevel_L0,
},
},
Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64),
Schema: schema,
LoadMeta: &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID,
PartitionIDs: suite.partitionIDs,
MetricType: defaultMetricType,
},
IndexInfoList: []*indexpb.IndexInfo{
{},
},
IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema),
}
// mocks
@ -326,6 +323,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
ctx := context.Background()
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar)
req := &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
@ -344,16 +342,14 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
DroppedSegmentIds: suite.droppedSegmentIDs,
},
},
Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar),
Schema: schema,
LoadMeta: &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID,
PartitionIDs: suite.partitionIDs,
MetricType: defaultMetricType,
},
IndexInfoList: []*indexpb.IndexInfo{
{},
},
IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema),
}
// mocks
@ -378,6 +374,7 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
ctx := context.Background()
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
req := &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
@ -396,13 +393,11 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
DroppedSegmentIds: suite.droppedSegmentIDs,
},
},
Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64),
Schema: schema,
LoadMeta: &querypb.LoadMetaInfo{
MetricType: defaultMetricType,
},
IndexInfoList: []*indexpb.IndexInfo{
{},
},
IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema),
}
// test channel is unsubscribing
@ -439,14 +434,6 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
status, err = suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode())
// empty metric type
req.LoadMeta.MetricType = ""
req.Base.TargetID = paramtable.GetNodeID()
suite.node.UpdateStateCode(commonpb.StateCode_Healthy)
status, err = suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
}
func (suite *ServiceSuite) TestUnsubDmChannels_Normal() {
@ -502,22 +489,9 @@ func (suite *ServiceSuite) TestUnsubDmChannels_Failed() {
suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode())
}
func (suite *ServiceSuite) genSegmentIndexInfos(loadInfo []*querypb.SegmentLoadInfo) []*indexpb.IndexInfo {
indexInfoList := make([]*indexpb.IndexInfo, 0)
seg0LoadInfo := loadInfo[0]
fieldIndexInfos := seg0LoadInfo.IndexInfos
for _, info := range fieldIndexInfos {
indexInfoList = append(indexInfoList, &indexpb.IndexInfo{
CollectionID: suite.collectionID,
FieldID: info.GetFieldID(),
IndexName: info.GetIndexName(),
IndexParams: info.GetIndexParams(),
})
}
return indexInfoList
}
func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema) []*querypb.SegmentLoadInfo {
func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema,
indexInfos []*indexpb.IndexInfo,
) []*querypb.SegmentLoadInfo {
ctx := context.Background()
segNum := len(suite.validSegmentIDs)
@ -534,18 +508,25 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema
)
suite.Require().NoError(err)
vecFieldIDs := funcutil.GetVecFieldIDs(schema)
indexes, err := segments.GenAndSaveIndex(
suite.collectionID,
suite.partitionIDs[i%partNum],
suite.validSegmentIDs[i],
vecFieldIDs[0],
1000,
segments.IndexFaissIVFFlat,
metric.L2,
suite.node.chunkManager,
)
suite.Require().NoError(err)
vectorFieldSchemas := typeutil.GetVectorFieldSchemas(schema)
indexes := make([]*querypb.FieldIndexInfo, 0)
for offset, field := range vectorFieldSchemas {
indexInfo := lo.FindOrElse(indexInfos, nil, func(info *indexpb.IndexInfo) bool { return info.FieldID == field.GetFieldID() })
if indexInfo != nil {
index, err := segments.GenAndSaveIndexV2(
suite.collectionID,
suite.partitionIDs[i%partNum],
suite.validSegmentIDs[i],
int64(offset),
field,
indexInfo,
suite.node.chunkManager,
1000,
)
suite.Require().NoError(err)
indexes = append(indexes, index)
}
}
info := &querypb.SegmentLoadInfo{
SegmentID: suite.validSegmentIDs[i],
@ -555,7 +536,7 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema
NumOfRows: 1000,
BinlogPaths: binlogs,
Statslogs: statslogs,
IndexInfos: []*querypb.FieldIndexInfo{indexes},
IndexInfos: indexes,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000},
}
@ -569,7 +550,8 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() {
suite.TestWatchDmChannelsInt64()
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
infos := suite.genSegmentLoadInfos(schema)
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos)
for _, info := range infos {
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
@ -582,9 +564,7 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() {
Schema: schema,
DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}},
NeedTransfer: true,
IndexInfoList: []*indexpb.IndexInfo{
{},
},
IndexInfoList: indexInfos,
}
// LoadSegment
@ -607,7 +587,7 @@ func (suite *ServiceSuite) TestLoadSegments_VarChar() {
suite.node.manager.Collection = segments.NewCollectionManager()
suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, loadMeta)
infos := suite.genSegmentLoadInfos(schema)
infos := suite.genSegmentLoadInfos(schema, nil)
for _, info := range infos {
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
@ -643,7 +623,7 @@ func (suite *ServiceSuite) TestLoadDeltaInt64() {
},
CollectionID: suite.collectionID,
DstNodeID: suite.node.session.ServerID,
Infos: suite.genSegmentLoadInfos(schema),
Infos: suite.genSegmentLoadInfos(schema, nil),
Schema: schema,
NeedTransfer: true,
LoadScope: querypb.LoadScope_Delta,
@ -668,7 +648,7 @@ func (suite *ServiceSuite) TestLoadDeltaVarchar() {
},
CollectionID: suite.collectionID,
DstNodeID: suite.node.session.ServerID,
Infos: suite.genSegmentLoadInfos(schema),
Infos: suite.genSegmentLoadInfos(schema, nil),
Schema: schema,
NeedTransfer: true,
LoadScope: querypb.LoadScope_Delta,
@ -687,7 +667,8 @@ func (suite *ServiceSuite) TestLoadIndex_Success() {
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
infos := suite.genSegmentLoadInfos(schema)
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos)
infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo {
info.SegmentID = info.SegmentID + 1000
return info
@ -697,8 +678,7 @@ func (suite *ServiceSuite) TestLoadIndex_Success() {
info.IndexInfos = nil
return info
})
// generate indexinfos for setting index meta.
indexInfoList := suite.genSegmentIndexInfos(infos)
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgID: rand.Int63(),
@ -710,7 +690,7 @@ func (suite *ServiceSuite) TestLoadIndex_Success() {
Schema: schema,
NeedTransfer: false,
LoadScope: querypb.LoadScope_Full,
IndexInfoList: indexInfoList,
IndexInfoList: indexInfos,
}
// Load segment
@ -759,7 +739,8 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
suite.Run("load_non_exist_segment", func() {
infos := suite.genSegmentLoadInfos(schema)
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos)
infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo {
info.SegmentID = info.SegmentID + 1000
return info
@ -780,7 +761,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
Schema: schema,
NeedTransfer: false,
LoadScope: querypb.LoadScope_Index,
IndexInfoList: []*indexpb.IndexInfo{{}},
IndexInfoList: indexInfos,
}
// Load segment
@ -801,7 +782,8 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
mockLoader.EXPECT().LoadIndex(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mocked error"))
infos := suite.genSegmentLoadInfos(schema)
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos)
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgID: rand.Int63(),
@ -813,7 +795,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
Schema: schema,
NeedTransfer: false,
LoadScope: querypb.LoadScope_Index,
IndexInfoList: []*indexpb.IndexInfo{{}},
IndexInfoList: indexInfos,
}
// Load segment
@ -834,7 +816,7 @@ func (suite *ServiceSuite) TestLoadSegments_Failed() {
},
CollectionID: suite.collectionID,
DstNodeID: suite.node.session.ServerID,
Infos: suite.genSegmentLoadInfos(schema),
Infos: suite.genSegmentLoadInfos(schema, nil),
Schema: schema,
NeedTransfer: true,
IndexInfoList: []*indexpb.IndexInfo{
@ -886,7 +868,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
},
CollectionID: suite.collectionID,
DstNodeID: suite.node.session.ServerID,
Infos: suite.genSegmentLoadInfos(schema),
Infos: suite.genSegmentLoadInfos(schema, nil),
Schema: schema,
NeedTransfer: true,
IndexInfoList: []*indexpb.IndexInfo{{}},
@ -908,7 +890,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
},
CollectionID: suite.collectionID,
DstNodeID: suite.node.session.ServerID,
Infos: suite.genSegmentLoadInfos(schema),
Infos: suite.genSegmentLoadInfos(schema, nil),
Schema: schema,
NeedTransfer: true,
IndexInfoList: []*indexpb.IndexInfo{{}},
@ -935,7 +917,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
},
CollectionID: suite.collectionID,
DstNodeID: suite.node.session.ServerID,
Infos: suite.genSegmentLoadInfos(schema),
Infos: suite.genSegmentLoadInfos(schema, nil),
Schema: schema,
NeedTransfer: true,
IndexInfoList: []*indexpb.IndexInfo{{}},
@ -1139,18 +1121,14 @@ func (suite *ServiceSuite) TestGetSegmentInfo_Failed() {
}
// Test Search
func (suite *ServiceSuite) genCSearchRequest(nq int64, indexType string, schema *schemapb.CollectionSchema) (*internalpb.SearchRequest, error) {
func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataType, fieldID int64, metricType string) (*internalpb.SearchRequest, error) {
placeHolder, err := genPlaceHolderGroup(nq)
if err != nil {
return nil, err
}
planStr, err := genDSLByIndexType(schema, indexType)
if err != nil {
return nil, err
}
var planpb planpb.PlanNode
proto.UnmarshalText(planStr, &planpb)
serializedPlan, err2 := proto.Marshal(&planpb)
plan := genSearchPlan(dataType, fieldID, metricType)
serializedPlan, err2 := proto.Marshal(plan)
if err2 != nil {
return nil, err2
}
@ -1175,9 +1153,7 @@ func (suite *ServiceSuite) TestSearch_Normal() {
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema)
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType)
req := &querypb.SearchRequest{
Req: creq,
FromShardLeader: false,
@ -1197,14 +1173,11 @@ func (suite *ServiceSuite) TestSearch_Concurrent() {
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
concurrency := 16
futures := make([]*conc.Future[*internalpb.SearchResults], 0, concurrency)
for i := 0; i < concurrency; i++ {
future := conc.Go(func() (*internalpb.SearchResults, error) {
creq, err := suite.genCSearchRequest(30, IndexFaissIDMap, schema)
creq, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
req := &querypb.SearchRequest{
Req: creq,
FromShardLeader: false,
@ -1230,7 +1203,7 @@ func (suite *ServiceSuite) TestSearch_Failed() {
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema)
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType")
req := &querypb.SearchRequest{
Req: creq,
FromShardLeader: false,
@ -1250,15 +1223,9 @@ func (suite *ServiceSuite) TestSearch_Failed() {
LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID,
PartitionIDs: suite.partitionIDs,
MetricType: "L2",
}
suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, LoadMeta)
req.GetReq().MetricType = "IP"
resp, err = suite.node.Search(ctx, req)
suite.NoError(err)
suite.ErrorIs(merr.Error(resp.GetStatus()), merr.ErrParameterInvalid)
suite.Contains(resp.GetStatus().GetReason(), merr.ErrParameterInvalid.Error())
req.GetReq().MetricType = "L2"
indexMeta := suite.node.composeIndexMeta(segments.GenTestIndexInfoList(suite.collectionID, schema), schema)
suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, LoadMeta)
// Delegator not found
resp, err = suite.node.Search(ctx, req)
@ -1268,6 +1235,34 @@ func (suite *ServiceSuite) TestSearch_Failed() {
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
// sync segment data
syncReq := &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{
MsgID: rand.Int63(),
TargetID: suite.node.session.ServerID,
},
CollectionID: suite.collectionID,
Channel: suite.vchannel,
}
syncVersionAction := &querypb.SyncAction{
Type: querypb.SyncType_UpdateVersion,
SealedInTarget: []int64{1, 2, 3, 4},
TargetVersion: time.Now().UnixMilli(),
}
syncReq.Actions = []*querypb.SyncAction{syncVersionAction}
status, err := suite.node.SyncDistribution(ctx, syncReq)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
// metric type not match
req.GetReq().MetricType = "IP"
resp, err = suite.node.Search(ctx, req)
suite.NoError(err)
suite.Contains(resp.GetStatus().GetReason(), "metric type not match")
req.GetReq().MetricType = "L2"
// target not match
req.Req.Base.TargetID = -1
resp, err = suite.node.Search(ctx, req)
@ -1333,9 +1328,7 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() {
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema)
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType)
req := &querypb.SearchRequest{
Req: creq,
FromShardLeader: true,

View File

@ -31,6 +31,8 @@ const (
FailLabel = "fail"
TotalLabel = "total"
HybridSearchLabel = "hybrid_search"
InsertLabel = "insert"
DeleteLabel = "delete"
UpsertLabel = "upsert"

View File

@ -239,9 +239,9 @@ func TestSchema_GetVectorFieldSchema(t *testing.T) {
}
t.Run("GetVectorFieldSchema", func(t *testing.T) {
fieldSchema, err := GetVectorFieldSchema(schemaNormal)
assert.Equal(t, "field_float_vector", fieldSchema.Name)
assert.NoError(t, err)
fieldSchema := GetVectorFieldSchemas(schemaNormal)
assert.Equal(t, 1, len(fieldSchema))
assert.Equal(t, "field_float_vector", fieldSchema[0].Name)
})
schemaInvalid := &schemapb.CollectionSchema{
@ -260,8 +260,8 @@ func TestSchema_GetVectorFieldSchema(t *testing.T) {
}
t.Run("GetVectorFieldSchemaInvalid", func(t *testing.T) {
_, err := GetVectorFieldSchema(schemaInvalid)
assert.Error(t, err)
res := GetVectorFieldSchemas(schemaInvalid)
assert.Equal(t, 0, len(res))
})
}

View File

@ -0,0 +1,225 @@
package hybridsearch
import (
"context"
"encoding/json"
"fmt"
"strconv"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
type HybridSearchSuite struct {
integration.MiniClusterSuite
}
func (s *HybridSearchSuite) TestHybridSearch() {
c := s.Cluster
ctx, cancel := context.WithCancel(c.GetContext())
defer cancel()
prefix := "TestHybridSearch"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
dim := 128
rowNum := 3000
schema := integration.ConstructSchema(collectionName, dim, true,
&schemapb.FieldSchema{Name: integration.Int64Field, DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
&schemapb.FieldSchema{Name: integration.FloatVecField, DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}},
&schemapb.FieldSchema{Name: integration.BinVecField, DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}},
)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
err = merr.Error(createCollectionStatus)
if err != nil {
log.Warn("createCollectionStatus fail reason", zap.Error(err))
}
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.True(merr.Ok(showCollectionsResp.GetStatus()))
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
bVecColumn := integration.NewBinaryVectorFieldData(integration.BinVecField, rowNum, dim)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.True(merr.Ok(insertResult.GetStatus()))
// flush
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
// load without index on vector fields
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
s.Error(merr.Error(loadStatus))
// create index for float vector
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default_float",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
if err != nil {
log.Warn("createIndexStatus fail reason", zap.Error(err))
}
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
// load with index on partial vector fields
loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
s.Error(merr.Error(loadStatus))
// create index for binary vector
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.BinVecField,
IndexName: "_default_binary",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissBinIvfFlat, metric.JACCARD),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
if err != nil {
log.Warn("createIndexStatus fail reason", zap.Error(err))
}
s.WaitForIndexBuiltWithIndexName(ctx, collectionName, integration.BinVecField, "_default_binary")
// load with index on all vector fields
loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
err = merr.Error(loadStatus)
if err != nil {
log.Warn("LoadCollection fail reason", zap.Error(err))
}
s.WaitForLoad(ctx, collectionName)
// search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 1
topk := 10
roundDecimal := -1
fParams := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
bParams := integration.GetSearchParams(integration.IndexFaissBinIvfFlat, metric.L2)
fSearchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, fParams, nq, dim, topk, roundDecimal)
bSearchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.BinVecField, schemapb.DataType_BinaryVector, nil, metric.JACCARD, bParams, nq, dim, topk, roundDecimal)
hSearchReq := &milvuspb.HybridSearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Requests: []*milvuspb.SearchRequest{fSearchReq, bSearchReq},
OutputFields: []string{integration.FloatVecField, integration.BinVecField},
}
// rrf rank hybrid search
rrfParams := make(map[string]float64)
rrfParams[proxy.RRFParamsKey] = 60
b, err := json.Marshal(rrfParams)
s.NoError(err)
hSearchReq.RankParams = []*commonpb.KeyValuePair{
{Key: proxy.RankTypeKey, Value: "rrf"},
{Key: proxy.RankParamsKey, Value: string(b)},
{Key: proxy.LimitKey, Value: strconv.Itoa(topk)},
{Key: proxy.RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
}
searchResult, err := c.Proxy.HybridSearch(ctx, hSearchReq)
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
// weighted rank hybrid search
weightsParams := make(map[string][]float64)
weightsParams[proxy.WeightsParamsKey] = []float64{0.5, 0.2}
b, err = json.Marshal(weightsParams)
s.NoError(err)
hSearchReq.RankParams = []*commonpb.KeyValuePair{
{Key: proxy.RankTypeKey, Value: "weighted"},
{Key: proxy.RankParamsKey, Value: string(b)},
{Key: proxy.LimitKey, Value: strconv.Itoa(topk)},
}
searchResult, err = c.Proxy.HybridSearch(ctx, hSearchReq)
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
log.Info("TestHybridSearch succeed")
}
func TestHybridSearch(t *testing.T) {
suite.Run(t, new(HybridSearchSuite))
}

View File

@ -44,19 +44,24 @@ const (
)
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
s.waitForIndexBuiltInternal(ctx, dbName, collection, field)
s.waitForIndexBuiltInternal(ctx, dbName, collection, field, "")
}
func (s *MiniClusterSuite) WaitForIndexBuilt(ctx context.Context, collection, field string) {
s.waitForIndexBuiltInternal(ctx, "", collection, field)
s.waitForIndexBuiltInternal(ctx, "", collection, field, "")
}
func (s *MiniClusterSuite) waitForIndexBuiltInternal(ctx context.Context, dbName, collection, field string) {
func (s *MiniClusterSuite) WaitForIndexBuiltWithIndexName(ctx context.Context, collection, field, indexName string) {
s.waitForIndexBuiltInternal(ctx, "", collection, field, indexName)
}
func (s *MiniClusterSuite) waitForIndexBuiltInternal(ctx context.Context, dbName, collection, field, indexName string) {
getIndexBuilt := func() bool {
resp, err := s.Cluster.Proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{
DbName: dbName,
CollectionName: collection,
FieldName: field,
IndexName: indexName,
})
if err != nil {
s.FailNow("failed to describe index")