fix: Support mvcc with hybrid serach (#30114)

issue: https://github.com/milvus-io/milvus/issues/29656
/kind bug

Signed-off-by: xige-16 <xi.ge@zilliz.com>

---------

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2024-02-01 16:03:03 +08:00 committed by GitHub
parent 32914a3ddf
commit 060c8603a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1264 additions and 234 deletions

View File

@ -329,3 +329,10 @@ func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...gr
return client.Delete(ctx, req)
})
}
// HybridSearch performs replica hybrid search tasks in QueryNode.
func (c *Client) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest, _ ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.HybridSearchResult, error) {
return client.HybridSearch(ctx, req)
})
}

View File

@ -374,3 +374,8 @@ func (s *Server) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu
func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
return s.querynode.Delete(ctx, req)
}
// HybridSearch performs hybrid search of streaming/historical replica on QueryNode.
func (s *Server) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
return s.querynode.HybridSearch(ctx, req)
}

View File

@ -511,6 +511,61 @@ func (_c *MockQueryNode_GetTimeTickChannel_Call) RunAndReturn(run func(context.C
return _c
}
// HybridSearch provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNode) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
ret := _m.Called(_a0, _a1)
var r0 *querypb.HybridSearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.HybridSearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNode_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
type MockQueryNode_HybridSearch_Call struct {
*mock.Call
}
// HybridSearch is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.HybridSearchRequest
func (_e *MockQueryNode_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNode_HybridSearch_Call {
return &MockQueryNode_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
}
func (_c *MockQueryNode_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNode_HybridSearch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
})
return _c
}
func (_c *MockQueryNode_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNode_HybridSearch_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryNode_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNode_HybridSearch_Call {
_c.Call.Return(run)
return _c
}
// Init provides a mock function with given fields:
func (_m *MockQueryNode) Init() error {
ret := _m.Called()

View File

@ -632,6 +632,76 @@ func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) RunAndReturn(run func(con
return _c
}
// HybridSearch provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *querypb.HybridSearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) *querypb.HybridSearchResult); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.HybridSearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNodeClient_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
type MockQueryNodeClient_HybridSearch_Call struct {
*mock.Call
}
// HybridSearch is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.HybridSearchRequest
// - opts ...grpc.CallOption
func (_e *MockQueryNodeClient_Expecter) HybridSearch(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_HybridSearch_Call {
return &MockQueryNodeClient_HybridSearch_Call{Call: _e.mock.On("HybridSearch",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryNodeClient_HybridSearch_Call) Run(run func(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_HybridSearch_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryNodeClient_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeClient_HybridSearch_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryNodeClient_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)) *MockQueryNodeClient_HybridSearch_Call {
_c.Call.Return(run)
return _c
}
// LoadPartitions provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryNodeClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))

View File

@ -104,6 +104,18 @@ message SearchRequest {
string username = 18;
}
message HybridSearchRequest {
common.MsgBase base = 1;
int64 reqID = 2;
int64 dbID = 3;
int64 collectionID = 4;
repeated int64 partitionIDs = 5;
repeated SearchRequest reqs = 6;
uint64 mvcc_timestamp = 11;
uint64 guarantee_timestamp = 12;
uint64 timeout_timestamp = 13;
}
message SearchResults {
common.MsgBase base = 1;
common.Status status = 2;

View File

@ -71,6 +71,7 @@ service QueryNode {
rpc GetStatistics(GetStatisticsRequest) returns (internal.GetStatisticsResponse) {}
rpc Search(SearchRequest) returns (internal.SearchResults) {}
rpc HybridSearch(HybridSearchRequest) returns (HybridSearchResult) {}
rpc SearchSegments(SearchRequest) returns (internal.SearchResults) {}
rpc Query(QueryRequest) returns (internal.RetrieveResults) {}
rpc QueryStream(QueryRequest) returns (stream internal.RetrieveResults){}
@ -328,6 +329,20 @@ message SearchRequest {
int32 total_channel_num = 6;
}
message HybridSearchRequest {
internal.HybridSearchRequest req = 1;
repeated string dml_channels = 2;
int32 total_channel_num = 3;
}
message HybridSearchResult {
common.MsgBase base = 1;
common.Status status = 2;
repeated internal.SearchResults results = 3;
internal.CostAggregation costAggregation = 4;
map<string, uint64> channels_mvcc = 5;
}
message QueryRequest {
internal.RetrieveRequest req = 1;
repeated string dml_channels = 2;

View File

@ -2784,11 +2784,18 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
qt := &hybridSearchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
request: request,
tr: timerecord.NewTimeRecorder(method),
qc: node.queryCoord,
node: node,
lb: node.lbPolicy,
HybridSearchRequest: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: request,
tr: timerecord.NewTimeRecorder(method),
qc: node.queryCoord,
node: node,
lb: node.lbPolicy,
}
guaranteeTs := request.GuaranteeTimestamp
@ -2831,7 +2838,7 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
log.Debug(
rpcEnqueued(method),
zap.Uint64("timestamp", qt.request.Base.Timestamp),
zap.Uint64("timestamp", qt.Base.Timestamp),
)
if err := qt.WaitToFinish(); err != nil {

View File

@ -120,7 +120,7 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
return nil, errors.New("The type of rank param k should be float")
}
if k <= 0 || k >= maxRRFParamsValue {
return nil, errors.New("The rank params k should be in range (0, 16384)")
return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue))
}
log.Debug("rrf params", zap.Float64("k", k))
for i := range reqs {

View File

@ -0,0 +1,160 @@
package proxy
import (
"context"
"fmt"
"strconv"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func initSearchRequest(ctx context.Context, t *searchTask) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
var ignoreGrowing bool
var err error
for i, kv := range t.request.GetSearchParams() {
if kv.GetKey() == IgnoreGrowingKey {
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
if err != nil {
return errors.New("parse search growing failed")
}
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
break
}
}
t.SearchRequest.IgnoreGrowing = ignoreGrowing
// Manually update nq if not set.
nq, err := getNq(t.request)
if err != nil {
log.Warn("failed to get nq", zap.Error(err))
return err
}
// Check if nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateNQLimit(nq); err != nil {
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
}
t.SearchRequest.Nq = nq
log = log.With(zap.Int64("nq", nq))
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
if err != nil {
log.Warn("fail to get output field ids", zap.Error(err))
return err
}
t.SearchRequest.OutputFieldsId = outputFieldIDs
partitionNames := t.request.GetPartitionNames()
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
if err != nil || len(annsField) == 0 {
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
if len(vecFields) == 0 {
return errors.New(AnnsFieldKey + " not found in schema")
}
if enableMultipleVectorFields && len(vecFields) > 1 {
return errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
}
annsField = vecFields[0].Name
}
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
if err != nil {
return err
}
if queryInfo.GroupByFieldId != 0 {
t.SearchRequest.IgnoreGrowing = true
// for group by operation, currently, we ignore growing segments
}
t.offset = offset
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
if err != nil {
log.Warn("failed to create query plan", zap.Error(err),
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)
}
log.Debug("create query plan",
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
if t.partitionKeyMode {
expr, err := ParseExprFromPlan(plan)
if err != nil {
log.Warn("failed to parse expr", zap.Error(err))
return err
}
partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
if err != nil {
log.Warn("failed to assign partition keys", zap.Error(err))
return err
}
partitionNames = append(partitionNames, hashedPartitionNames...)
}
plan.OutputFieldIds = outputFieldIDs
t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
if err != nil {
log.Warn("failed to estimate result size", zap.Error(err))
return err
}
if estimateSize >= requeryThreshold {
t.requery = true
plan.OutputFieldIds = nil
}
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}
log.Debug("proxy init search request",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.
}
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, partitionNames)
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
}
if deadline, ok := t.TraceCtx().Deadline(); ok {
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
// Set username of this search request for feature like task scheduling.
if username, _ := GetCurUserFromContext(ctx); username != "" {
t.SearchRequest.Username = username
}
return nil
}

View File

@ -14,10 +14,11 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -32,9 +33,11 @@ const (
type hybridSearchTask struct {
Condition
ctx context.Context
*internalpb.HybridSearchRequest
result *milvuspb.SearchResults
request *milvuspb.HybridSearchRequest
result *milvuspb.SearchResults
request *milvuspb.HybridSearchRequest
searchTasks []*searchTask
tr *timerecord.TimeRecorder
schema *schemaInfo
@ -42,15 +45,14 @@ type hybridSearchTask struct {
userOutputFields []string
qc types.QueryCoordClient
node types.ProxyComponent
lb LBPolicy
queryChannelsTs map[string]Timestamp
collectionID UniqueID
qc types.QueryCoordClient
node types.ProxyComponent
lb LBPolicy
resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult]
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
reScorers []reScorer
queryChannelsTs map[string]Timestamp
rankParams *rankParams
}
@ -63,7 +65,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
}
if len(t.request.Requests) > defaultMaxSearchRequest {
return errors.New("maximum of ann search requests is 1024")
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
}
for _, req := range t.request.GetRequests() {
nq, err := getNq(req)
@ -78,12 +80,15 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
}
}
t.Base.MsgType = commonpb.MsgType_Search
t.Base.SourceID = paramtable.GetNodeID()
collectionName := t.request.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
if err != nil {
return err
}
t.collectionID = collID
t.CollectionID = collID
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
@ -113,6 +118,82 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
t.requery = true
}
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
if err2 != nil {
log.Warn("Proxy::hybridSearchTask::PreExecute failed to GetCollectionInfo from cache",
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2))
return err2
}
guaranteeTs := t.request.GetGuaranteeTimestamp()
var consistencyLevel commonpb.ConsistencyLevel
useDefaultConsistency := t.request.GetUseDefaultConsistency()
if useDefaultConsistency {
consistencyLevel = collectionInfo.consistencyLevel
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
} else {
consistencyLevel = t.request.GetConsistencyLevel()
// Compatibility logic, parse guarantee timestamp
if consistencyLevel == 0 && guaranteeTs > 0 {
guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs())
} else {
// parse from guarantee timestamp and user input consistency level
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
}
}
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
if err != nil {
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
return err
}
t.searchTasks = make([]*searchTask, len(t.request.GetRequests()))
for index := range t.request.Requests {
searchReq := t.request.Requests[index]
if len(searchReq.GetCollectionName()) == 0 {
searchReq.CollectionName = t.request.GetCollectionName()
} else if searchReq.GetCollectionName() != t.request.GetCollectionName() {
return errors.New(fmt.Sprintf("inconsistent collection name in hybrid search request, "+
"expect %s, actual %s", searchReq.GetCollectionName(), t.request.GetCollectionName()))
}
searchReq.PartitionNames = t.request.GetPartitionNames()
searchReq.ConsistencyLevel = consistencyLevel
searchReq.GuaranteeTimestamp = guaranteeTs
searchReq.UseDefaultConsistency = useDefaultConsistency
searchReq.OutputFields = nil
t.searchTasks[index] = &searchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
collectionName: collectionName,
SearchRequest: &internalpb.SearchRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
DbID: 0, // todo
CollectionID: collID,
},
request: searchReq,
schema: t.schema,
tr: timerecord.NewTimeRecorder("hybrid search"),
qc: t.qc,
node: t.node,
lb: t.lb,
partitionKeyMode: partitionKeyMode,
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
}
err := initSearchRequest(ctx, t.searchTasks[index])
if err != nil {
log.Debug("init hybrid search request failed", zap.Error(err))
return err
}
}
log.Debug("hybrid search preExecute done.",
zap.Uint64("guarantee_ts", t.request.GetGuaranteeTimestamp()),
zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()),
@ -121,56 +202,65 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
return nil
}
func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
for _, searchTask := range t.searchTasks {
t.HybridSearchRequest.Reqs = append(t.HybridSearchRequest.Reqs, searchTask.SearchRequest)
}
hybridSearchReq := typeutil.Clone(t.HybridSearchRequest)
hybridSearchReq.GetBase().TargetID = nodeID
req := &querypb.HybridSearchRequest{
Req: hybridSearchReq,
DmlChannels: []string{channel},
TotalChannelNum: int32(1),
}
log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.Int64("nodeID", nodeID),
zap.String("channel", channel))
var result *querypb.HybridSearchResult
var err error
result, err = qn.HybridSearch(ctx, req)
if err != nil {
log.Warn("QueryNode hybrid search return error", zap.Error(err))
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader")
return errInvalidShardLeaders
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode hybrid search result error",
zap.String("reason", result.GetStatus().GetReason()))
return errors.Wrapf(merr.Error(result.GetStatus()), "fail to hybrid search on QueryNode %d", nodeID)
}
t.resultBuf.Insert(result)
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
return nil
}
func (t *hybridSearchTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID()))
defer tr.CtxElapse(ctx, "done")
futures := make([]*conc.Future[*milvuspb.SearchResults], len(t.request.Requests))
for index := range t.request.Requests {
searchReq := t.request.Requests[index]
future := conc.Go(func() (*milvuspb.SearchResults, error) {
searchReq.TravelTimestamp = t.request.GetTravelTimestamp()
searchReq.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
searchReq.NotReturnAllMeta = t.request.GetNotReturnAllMeta()
searchReq.ConsistencyLevel = t.request.GetConsistencyLevel()
searchReq.UseDefaultConsistency = t.request.GetUseDefaultConsistency()
searchReq.OutputFields = nil
return t.node.Search(ctx, searchReq)
})
futures[index] = future
}
err := conc.AwaitAll(futures...)
t.resultBuf = typeutil.NewConcurrentSet[*querypb.HybridSearchResult]()
err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collectionID: t.CollectionID,
collectionName: t.request.GetCollectionName(),
nq: 1,
exec: t.hybridSearchShard,
})
if err != nil {
return err
}
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
if err != nil {
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
return err
}
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
for i, future := range futures {
err = future.Err()
if err != nil {
log.Debug("QueryNode search result error", zap.Error(err))
return err
}
result := futures[i].Value()
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Debug("QueryNode search result error",
zap.String("reason", result.GetStatus().GetReason()))
return merr.Error(result.GetStatus())
}
t.reScorers[i].reScore(result)
t.multipleRecallResults.Insert(result)
log.Warn("hybrid search execute failed", zap.Error(err))
return errors.Wrap(err, "failed to hybrid search")
}
log.Debug("hybrid search execute done.")
@ -194,7 +284,7 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
if err != nil {
return nil, errors.New(LimitKey + " not found in search_params")
return nil, errors.New(LimitKey + " not found in rank_params")
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
@ -235,16 +325,59 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
}, nil
}
func (t *hybridSearchTask) collectHybridSearchResults(ctx context.Context) error {
select {
case <-t.TraceCtx().Done():
log.Ctx(ctx).Warn("hybrid search task wait to finish timeout!")
return fmt.Errorf("hybrid search task wait to finish timeout, msgID=%d", t.ID())
default:
log.Ctx(ctx).Debug("all hybrid searches are finished or canceled")
t.resultBuf.Range(func(res *querypb.HybridSearchResult) bool {
for index, searchResult := range res.GetResults() {
t.searchTasks[index].resultBuf.Insert(searchResult)
}
log.Ctx(ctx).Debug("proxy receives one hybrid search result",
zap.Int64("sourceID", res.GetBase().GetSourceID()))
return true
})
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
for i, searchTask := range t.searchTasks {
err := searchTask.PostExecute(ctx)
if err != nil {
return err
}
t.reScorers[i].reScore(searchTask.result)
t.multipleRecallResults.Insert(searchTask.result)
}
return nil
}
}
func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID()))
defer func() {
tr.CtxElapse(ctx, "done")
}()
err := t.collectHybridSearchResults(ctx)
if err != nil {
log.Warn("failed to collect hybrid search results", zap.Error(err))
return err
}
t.queryChannelsTs = make(map[string]uint64)
for _, r := range t.resultBuf.Collect() {
for ch, ts := range r.GetChannelsMvcc() {
t.queryChannelsTs[ch] = ts
}
}
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
@ -304,9 +437,8 @@ func (t *hybridSearchTask) Requery() error {
},
}
// TODO:Xige-16 refine the mvcc functionality of hybrid search
// TODO:silverxia move partitionIDs to hybrid search level
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
}
func rankSearchResultData(ctx context.Context,
@ -436,11 +568,11 @@ func (t *hybridSearchTask) TraceCtx() context.Context {
}
func (t *hybridSearchTask) ID() UniqueID {
return t.request.Base.MsgID
return t.Base.MsgID
}
func (t *hybridSearchTask) SetID(uid UniqueID) {
t.request.Base.MsgID = uid
t.Base.MsgID = uid
}
func (t *hybridSearchTask) Name() string {
@ -448,24 +580,24 @@ func (t *hybridSearchTask) Name() string {
}
func (t *hybridSearchTask) Type() commonpb.MsgType {
return t.request.Base.MsgType
return t.Base.MsgType
}
func (t *hybridSearchTask) BeginTs() Timestamp {
return t.request.Base.Timestamp
return t.Base.Timestamp
}
func (t *hybridSearchTask) EndTs() Timestamp {
return t.request.Base.Timestamp
return t.Base.Timestamp
}
func (t *hybridSearchTask) SetTs(ts Timestamp) {
t.request.Base.Timestamp = ts
t.Base.Timestamp = ts
}
func (t *hybridSearchTask) OnEnqueue() error {
t.request.Base = commonpbutil.NewMsgBase()
t.request.Base.MsgType = commonpb.MsgType_Search
t.request.Base.SourceID = paramtable.GetNodeID()
t.Base = commonpbutil.NewMsgBase()
t.Base.MsgType = commonpb.MsgType_Search
t.Base.SourceID = paramtable.GetNodeID()
return nil
}

View File

@ -20,6 +20,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -67,8 +68,9 @@ func TestHybridSearchTask_PreExecute(t *testing.T) {
genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask {
task := &hybridSearchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
ctx: ctx,
Condition: NewTaskCondition(ctx),
HybridSearchRequest: &internalpb.HybridSearchRequest{},
request: &milvuspb.HybridSearchRequest{
CollectionName: collName,
Requests: reqs,
@ -225,6 +227,7 @@ func TestHybridSearchTask_ErrExecute(t *testing.T) {
result: &milvuspb.SearchResults{
Status: merr.Success(),
},
HybridSearchRequest: &internalpb.HybridSearchRequest{},
request: &milvuspb.HybridSearchRequest{
CollectionName: collectionName,
Requests: []*milvuspb.SearchRequest{
@ -266,12 +269,12 @@ func TestHybridSearchTask_ErrExecute(t *testing.T) {
task.ctx = ctx
assert.NoError(t, task.PreExecute(ctx))
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
@ -291,6 +294,10 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
mgr := NewMockShardClientManager(t)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
Base: commonpbutil.NewMsgBase(),
Status: merr.Success(),
}, nil)
t.Run("Test empty result", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
@ -313,6 +320,9 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
qc: nil,
tr: timerecord.NewTimeRecorder("search"),
schema: schema,
HybridSearchRequest: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
},
request: &milvuspb.HybridSearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
@ -320,6 +330,7 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
CollectionName: collectionName,
RankParams: rankParams,
},
resultBuf: typeutil.NewConcurrentSet[*querypb.HybridSearchResult](),
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
}

View File

@ -30,7 +30,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -53,10 +52,11 @@ type searchTask struct {
result *milvuspb.SearchResults
request *milvuspb.SearchRequest
tr *timerecord.TimeRecorder
collectionName string
schema *schemaInfo
requery bool
tr *timerecord.TimeRecorder
collectionName string
schema *schemaInfo
requery bool
partitionKeyMode bool
userOutputFields []string
@ -250,22 +250,21 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
return err
}
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
t.SearchRequest.DbID = 0 // todo
t.SearchRequest.CollectionID = collID
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
if err != nil {
log.Warn("get collection schema failed", zap.Error(err))
return err
}
partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil {
log.Warn("is partition key mode failed", zap.Error(err))
return err
}
if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
if t.partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
return errors.New("not support manually specifying the partition names if partition key mode is used")
}
@ -277,123 +276,9 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
log.Debug("translate output fields",
zap.Strings("output fields", t.request.GetOutputFields()))
// fetch search_growing from search param
var ignoreGrowing bool
for i, kv := range t.request.GetSearchParams() {
if kv.GetKey() == IgnoreGrowingKey {
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
if err != nil {
return errors.New("parse search growing failed")
}
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
break
}
}
t.SearchRequest.IgnoreGrowing = ignoreGrowing
// Manually update nq if not set.
nq, err := getNq(t.request)
err = initSearchRequest(ctx, t)
if err != nil {
log.Warn("failed to get nq", zap.Error(err))
return err
}
// Check if nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateNQLimit(nq); err != nil {
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
}
t.SearchRequest.Nq = nq
log = log.With(zap.Int64("nq", nq))
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
if err != nil {
log.Warn("fail to get output field ids", zap.Error(err))
return err
}
t.SearchRequest.OutputFieldsId = outputFieldIDs
partitionNames := t.request.GetPartitionNames()
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
if err != nil || len(annsField) == 0 {
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
if len(vecFields) == 0 {
return errors.New(AnnsFieldKey + " not found in schema")
}
if enableMultipleVectorFields && len(vecFields) > 1 {
return errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
}
annsField = vecFields[0].Name
}
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
if err != nil {
return err
}
if queryInfo.GroupByFieldId != 0 {
t.SearchRequest.IgnoreGrowing = true
// for group by operation, currently, we ignore growing segments
}
t.offset = offset
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
if err != nil {
log.Warn("failed to create query plan", zap.Error(err),
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)
}
log.Debug("create query plan",
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
if partitionKeyMode {
expr, err := ParseExprFromPlan(plan)
if err != nil {
log.Warn("failed to parse expr", zap.Error(err))
return err
}
partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), collectionName, partitionKeys)
if err != nil {
log.Warn("failed to assign partition keys", zap.Error(err))
return err
}
partitionNames = append(partitionNames, hashedPartitionNames...)
}
plan.OutputFieldIds = outputFieldIDs
t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
if err != nil {
log.Warn("failed to estimate result size", zap.Error(err))
return err
}
if estimateSize >= requeryThreshold {
t.requery = true
plan.OutputFieldIds = nil
}
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}
log.Debug("Proxy::searchTask::PreExecute",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.
}
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, partitionNames)
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
log.Debug("init search request failed", zap.Error(err))
return err
}
@ -421,17 +306,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
if deadline, ok := t.TraceCtx().Deadline(); ok {
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
// Set username of this search request for feature like task scheduling.
if username, _ := GetCurUserFromContext(ctx); username != "" {
t.SearchRequest.Username = username
}
log.Debug("search PreExecute done.",
zap.Uint64("guarantee_ts", guaranteeTs),
zap.Bool("use_default_consistency", useDefaultConsistency),

View File

@ -469,6 +469,61 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) RunAndReturn(run func(con
return _c
}
// HybridSearch provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
ret := _m.Called(_a0, _a1)
var r0 *querypb.HybridSearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.HybridSearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNodeServer_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
type MockQueryNodeServer_HybridSearch_Call struct {
*mock.Call
}
// HybridSearch is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.HybridSearchRequest
func (_e *MockQueryNodeServer_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_HybridSearch_Call {
return &MockQueryNodeServer_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
}
func (_c *MockQueryNodeServer_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNodeServer_HybridSearch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
})
return _c
}
func (_c *MockQueryNodeServer_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeServer_HybridSearch_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryNodeServer_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNodeServer_HybridSearch_Call {
_c.Call.Return(run)
return _c
}
// LoadPartitions provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)

View File

@ -62,6 +62,7 @@ type ShardDelegator interface {
GetSegmentInfo(readable bool) (sealed []SnapshotItem, growing []SegmentEntry)
SyncDistribution(ctx context.Context, entries ...SegmentEntry)
Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error)
HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)
Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)
QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)
@ -184,6 +185,44 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
return nodeReq
}
// Search preforms search operation on shard.
func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest, sealed []SnapshotItem, growing []SegmentEntry) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
}
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
log.Debug("search segments...",
zap.Int("sealedNum", sealedNum),
zap.Int("growingNum", len(growing)),
)
req, err := optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
if err != nil {
log.Warn("failed to optimize search params", zap.Error(err))
return nil, err
}
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
if err != nil {
log.Warn("Search organizeSubTask failed", zap.Error(err))
return nil, err
}
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
return worker.SearchSegments(ctx, req)
}, "Search", log)
if err != nil {
log.Warn("Delegator search failed", zap.Error(err))
return nil, err
}
log.Debug("Delegator search done")
return results, nil
}
// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
@ -229,39 +268,113 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
return funcutil.SliceContain(existPartitions, segment.PartitionID)
})
if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
return sd.search(ctx, req, sealed, growing)
}
// HybridSearch preforms hybrid search operation on shard.
func (sd *shardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
log := sd.getLogger(ctx)
if err := sd.lifetime.Add(lifetime.IsWorking); err != nil {
return nil, err
}
defer sd.lifetime.Done()
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
log.Warn("deletgator received hybrid search request not belongs to it",
zap.Strings("reqChannels", req.GetDmlChannels()),
)
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
log.Debug("search segments...",
zap.Int("sealedNum", sealedNum),
zap.Int("growingNum", len(growing)),
)
partitions := req.GetReq().GetPartitionIDs()
if !sd.collection.ExistPartition(partitions...) {
return nil, merr.WrapErrPartitionNotLoaded(partitions)
}
req, err = optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
if err != nil {
log.Warn("delegator hybrid search failed to wait tsafe", zap.Error(err))
return nil, err
}
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
}
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to hybrid search, current distribution is not serviceable")
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
existPartitions := sd.collection.GetPartitions()
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
})
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetReqs()))
for index := range req.GetReq().GetReqs() {
request := req.GetReq().Reqs[index]
future := conc.Go(func() (*internalpb.SearchResults, error) {
searchReq := &querypb.SearchRequest{
Req: request,
DmlChannels: req.GetDmlChannels(),
TotalChannelNum: req.GetTotalChannelNum(),
FromShardLeader: true,
}
searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
if searchReq.GetReq().GetMvccTimestamp() == 0 {
searchReq.GetReq().MvccTimestamp = tSafe
}
results, err := sd.search(ctx, searchReq, sealed, growing)
if err != nil {
return nil, err
}
return segments.ReduceSearchResults(ctx,
results,
searchReq.Req.GetNq(),
searchReq.Req.GetTopk(),
searchReq.Req.GetMetricType())
})
futures[index] = future
}
err = conc.AwaitAll(futures...)
if err != nil {
log.Warn("failed to optimize search params", zap.Error(err))
return nil, err
}
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
if err != nil {
log.Warn("Search organizeSubTask failed", zap.Error(err))
return nil, err
ret := &querypb.HybridSearchResult{
Status: merr.Success(),
Results: make([]*internalpb.SearchResults, len(futures)),
}
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
return worker.SearchSegments(ctx, req)
}, "Search", log)
if err != nil {
log.Warn("Delegator search failed", zap.Error(err))
return nil, err
channelsMvcc := make(map[string]uint64)
for i, future := range futures {
result := future.Value()
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Debug("delegator hybrid search failed",
zap.String("reason", result.GetStatus().GetReason()))
return nil, merr.Error(result.GetStatus())
}
ret.Results[i] = result
for ch, ts := range result.GetChannelsMvcc() {
channelsMvcc[ch] = ts
}
}
ret.ChannelsMvcc = channelsMvcc
log.Debug("Delegator search done")
log.Debug("Delegator hybrid search done")
return results, nil
return ret, nil
}
func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error {

View File

@ -469,6 +469,251 @@ func (s *DelegatorSuite) TestSearch() {
})
}
func (s *DelegatorSuite) TestHybridSearch() {
s.delegator.Start()
paramtable.SetNodeID(1)
s.initSegments()
s.Run("normal", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(1, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
if req.GetScope() == querypb.DataScope_Streaming {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
}
if req.GetScope() == querypb.DataScope_Historical {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
}
}).Return(&internalpb.SearchResults{}, nil)
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
results, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
Reqs: []*internalpb.SearchRequest{
{Base: commonpbutil.NewMsgBase()},
{Base: commonpbutil.NewMsgBase()},
},
},
DmlChannels: []string{s.vchannelName},
})
s.NoError(err)
s.Equal(2, len(results.Results))
})
s.Run("partition_not_loaded", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
// not load partation -1,will return error
PartitionIDs: []int64{-1},
},
DmlChannels: []string{s.vchannelName},
})
s.True(errors.Is(err, merr.ErrPartitionNotLoaded))
})
s.Run("worker_return_error", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(nil, errors.New("mock error"))
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
Reqs: []*internalpb.SearchRequest{
{
Base: commonpbutil.NewMsgBase(),
},
},
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("worker_return_failure_code", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(&internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mocked error",
},
}, nil)
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
Reqs: []*internalpb.SearchRequest{
{
Base: commonpbutil.NewMsgBase(),
},
},
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("wrong_channel", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
},
DmlChannels: []string{"non_exist_channel"},
})
s.Error(err)
})
s.Run("wait_tsafe_timeout", func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
GuaranteeTimestamp: 10100,
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("tsafe_behind_max_lag", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)) + 10001,
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("distribution_not_serviceable", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sd, ok := s.delegator.(*shardDelegator)
s.Require().True(ok)
sd.distribution.AddOfflines(1001)
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("cluster_not_serviceable", func() {
s.delegator.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
}
func (s *DelegatorSuite) TestQuery() {
s.delegator.Start()
paramtable.SetNodeID(1)

View File

@ -253,6 +253,61 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int6
return _c
}
// HybridSearch provides a mock function with given fields: ctx, req
func (_m *MockShardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
ret := _m.Called(ctx, req)
var r0 *querypb.HybridSearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.HybridSearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardDelegator_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
type MockShardDelegator_HybridSearch_Call struct {
*mock.Call
}
// HybridSearch is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.HybridSearchRequest
func (_e *MockShardDelegator_Expecter) HybridSearch(ctx interface{}, req interface{}) *MockShardDelegator_HybridSearch_Call {
return &MockShardDelegator_HybridSearch_Call{Call: _e.mock.On("HybridSearch", ctx, req)}
}
func (_c *MockShardDelegator_HybridSearch_Call) Run(run func(ctx context.Context, req *querypb.HybridSearchRequest)) *MockShardDelegator_HybridSearch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
})
return _c
}
func (_c *MockShardDelegator_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockShardDelegator_HybridSearch_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockShardDelegator_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockShardDelegator_HybridSearch_Call {
_c.Call.Return(run)
return _c
}
// LoadGrowing provides a mock function with given fields: ctx, infos, version
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
ret := _m.Called(ctx, infos, version)

View File

@ -401,6 +401,63 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
return resp, nil
}
func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.HybridSearchRequest, channel string) (*querypb.HybridSearchResult, error) {
log := log.Ctx(ctx).With(
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
zap.Int64("collectionID", req.Req.GetCollectionID()),
zap.String("channel", channel),
)
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
return nil, err
}
defer node.lifetime.Done()
var err error
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
defer func() {
if err != nil {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc()
}
}()
log.Debug("start to search channel")
searchCtx, cancel := context.WithCancel(ctx)
defer cancel()
// From Proxy
tr := timerecord.NewTimeRecorder("hybridSearchDelegator")
// get delegator
sd, ok := node.delegators.Get(channel)
if !ok {
err := merr.WrapErrChannelNotFound(channel)
log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err))
return nil, err
}
// do hybrid search
result, err := sd.HybridSearch(searchCtx, req)
if err != nil {
log.Warn("failed to hybrid search on delegator", zap.Error(err))
return nil, err
}
tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , traceID = %s, vChannel = %s",
traceID,
channel,
))
// update metric to prometheus
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
for _, searchReq := range req.GetReq().GetReqs() {
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetTopk()))
}
return result, nil
}
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.Req.GetCollectionID()),

View File

@ -821,6 +821,114 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
return result, nil
}
// HybridSearch performs replica search tasks.
func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetReq().GetCollectionID()),
zap.Strings("channels", req.GetDmlChannels()))
log.Debug("Received HybridSearchRequest",
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp()))
tr := timerecord.NewTimeRecorderWithTrace(ctx, "HybridSearchRequest")
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
return &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
},
Status: merr.Status(err),
}, nil
}
defer node.lifetime.Done()
err := merr.CheckTargetID(req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
},
Status: merr.Status(err),
}, nil
}
resp := &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
},
Status: merr.Success(),
}
collection := node.manager.Collection.Get(req.GetReq().GetCollectionID())
if collection == nil {
resp.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID()))
return resp, nil
}
MultipleResults := make([]*querypb.HybridSearchResult, len(req.GetDmlChannels()))
runningGp, runningCtx := errgroup.WithContext(ctx)
for i, ch := range req.GetDmlChannels() {
ch := ch
req := &querypb.HybridSearchRequest{
Req: req.Req,
DmlChannels: []string{ch},
TotalChannelNum: 1,
}
i := i
runningGp.Go(func() error {
ret, err := node.hybridSearchChannel(runningCtx, req, ch)
if err != nil {
return err
}
if err := merr.Error(ret.GetStatus()); err != nil {
return err
}
MultipleResults[i] = ret
return nil
})
}
if err := runningGp.Wait(); err != nil {
resp.Status = merr.Status(err)
return resp, nil
}
tr.RecordSpan()
channelsMvcc := make(map[string]uint64)
for i, searchReq := range req.GetReq().GetReqs() {
toReduceResults := make([]*internalpb.SearchResults, len(MultipleResults))
for index, hs := range MultipleResults {
toReduceResults[index] = hs.Results[i]
}
result, err := segments.ReduceSearchResults(ctx, toReduceResults, searchReq.GetNq(), searchReq.GetTopk(), searchReq.GetMetricType())
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
resp.Status = merr.Status(err)
return resp, nil
}
for ch, ts := range result.GetChannelsMvcc() {
channelsMvcc[ch] = ts
}
resp.Results = append(resp.Results, result)
}
resp.ChannelsMvcc = channelsMvcc
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards).
Observe(float64(reduceLatency.Milliseconds()))
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.HybridSearchLabel).
Add(float64(proto.Size(req)))
if resp.GetCostAggregation() != nil {
resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
}
return resp, nil
}
// only used for delegator query segments from worker
func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
resp := &internalpb.RetrieveResults{

View File

@ -1323,6 +1323,47 @@ func (suite *ServiceSuite) TestSearchSegments_Failed() {
suite.Equal(commonpb.ErrorCode_UnexpectedError, rsp.GetStatus().GetErrorCode())
}
func (suite *ServiceSuite) TestHybridSearch_Concurrent() {
ctx := context.Background()
// pre
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
concurrency := 16
futures := make([]*conc.Future[*querypb.HybridSearchResult], 0, concurrency)
for i := 0; i < concurrency; i++ {
future := conc.Go(func() (*querypb.HybridSearchResult, error) {
creq1, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
suite.NoError(err)
creq2, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
suite.NoError(err)
req := &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: &commonpb.MsgBase{
MsgID: rand.Int63(),
TargetID: suite.node.session.ServerID,
},
CollectionID: suite.collectionID,
PartitionIDs: suite.partitionIDs,
MvccTimestamp: typeutil.MaxTimestamp,
Reqs: []*internalpb.SearchRequest{creq1, creq2},
},
DmlChannels: []string{suite.vchannel},
}
return suite.node.HybridSearch(ctx, req)
})
futures = append(futures, future)
}
err := conc.AwaitAll(futures...)
suite.NoError(err)
for i := range futures {
suite.True(merr.Ok(futures[i].Value().GetStatus()))
}
}
func (suite *ServiceSuite) TestSearchSegments_Normal() {
ctx := context.Background()
// pre

View File

@ -82,6 +82,10 @@ func (m *GrpcQueryNodeClient) Search(ctx context.Context, in *querypb.SearchRequ
return &internalpb.SearchResults{}, m.Err
}
func (m *GrpcQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
return &querypb.HybridSearchResult{}, m.Err
}
func (m *GrpcQueryNodeClient) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
return &internalpb.SearchResults{}, m.Err
}

View File

@ -93,6 +93,10 @@ func (qn *qnServerWrapper) Search(ctx context.Context, in *querypb.SearchRequest
return qn.QueryNode.Search(ctx, in)
}
func (qn *qnServerWrapper) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
return qn.QueryNode.HybridSearch(ctx, in)
}
func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
return qn.QueryNode.SearchSegments(ctx, in)
}