diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index 64c0ceec5d..863df81a90 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -329,3 +329,10 @@ func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...gr return client.Delete(ctx, req) }) } + +// HybridSearch performs replica hybrid search tasks in QueryNode. +func (c *Client) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest, _ ...grpc.CallOption) (*querypb.HybridSearchResult, error) { + return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.HybridSearchResult, error) { + return client.HybridSearch(ctx, req) + }) +} diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index c904cc4e12..2c02285d1b 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -374,3 +374,8 @@ func (s *Server) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) { return s.querynode.Delete(ctx, req) } + +// HybridSearch performs hybrid search of streaming/historical replica on QueryNode. +func (s *Server) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) { + return s.querynode.HybridSearch(ctx, req) +} diff --git a/internal/mocks/mock_querynode.go b/internal/mocks/mock_querynode.go index 9723cf21f1..c96f9b3e2b 100644 --- a/internal/mocks/mock_querynode.go +++ b/internal/mocks/mock_querynode.go @@ -511,6 +511,61 @@ func (_c *MockQueryNode_GetTimeTickChannel_Call) RunAndReturn(run func(context.C return _c } +// HybridSearch provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) { + ret := _m.Called(_a0, _a1) + + var r0 *querypb.HybridSearchResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.HybridSearchResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNode_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch' +type MockQueryNode_HybridSearch_Call struct { + *mock.Call +} + +// HybridSearch is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.HybridSearchRequest +func (_e *MockQueryNode_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNode_HybridSearch_Call { + return &MockQueryNode_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)} +} + +func (_c *MockQueryNode_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNode_HybridSearch_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest)) + }) + return _c +} + +func (_c *MockQueryNode_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNode_HybridSearch_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNode_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNode_HybridSearch_Call { + _c.Call.Return(run) + return _c +} + // Init provides a mock function with given fields: func (_m *MockQueryNode) Init() error { ret := _m.Called() diff --git a/internal/mocks/mock_querynode_client.go b/internal/mocks/mock_querynode_client.go index 3621a87884..51f4d9ba3f 100644 --- a/internal/mocks/mock_querynode_client.go +++ b/internal/mocks/mock_querynode_client.go @@ -632,6 +632,76 @@ func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) RunAndReturn(run func(con return _c } +// HybridSearch provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.HybridSearchResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) *querypb.HybridSearchResult); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.HybridSearchResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch' +type MockQueryNodeClient_HybridSearch_Call struct { + *mock.Call +} + +// HybridSearch is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.HybridSearchRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) HybridSearch(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_HybridSearch_Call { + return &MockQueryNodeClient_HybridSearch_Call{Call: _e.mock.On("HybridSearch", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_HybridSearch_Call) Run(run func(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_HybridSearch_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeClient_HybridSearch_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)) *MockQueryNodeClient_HybridSearch_Call { + _c.Call.Return(run) + return _c +} + // LoadPartitions provides a mock function with given fields: ctx, in, opts func (_m *MockQueryNodeClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { _va := make([]interface{}, len(opts)) diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index ab6d4a62b5..62582597ea 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -104,6 +104,18 @@ message SearchRequest { string username = 18; } +message HybridSearchRequest { + common.MsgBase base = 1; + int64 reqID = 2; + int64 dbID = 3; + int64 collectionID = 4; + repeated int64 partitionIDs = 5; + repeated SearchRequest reqs = 6; + uint64 mvcc_timestamp = 11; + uint64 guarantee_timestamp = 12; + uint64 timeout_timestamp = 13; +} + message SearchResults { common.MsgBase base = 1; common.Status status = 2; diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index 5e3410c9e0..5cafca497a 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -71,6 +71,7 @@ service QueryNode { rpc GetStatistics(GetStatisticsRequest) returns (internal.GetStatisticsResponse) {} rpc Search(SearchRequest) returns (internal.SearchResults) {} + rpc HybridSearch(HybridSearchRequest) returns (HybridSearchResult) {} rpc SearchSegments(SearchRequest) returns (internal.SearchResults) {} rpc Query(QueryRequest) returns (internal.RetrieveResults) {} rpc QueryStream(QueryRequest) returns (stream internal.RetrieveResults){} @@ -328,6 +329,20 @@ message SearchRequest { int32 total_channel_num = 6; } +message HybridSearchRequest { + internal.HybridSearchRequest req = 1; + repeated string dml_channels = 2; + int32 total_channel_num = 3; +} + +message HybridSearchResult { + common.MsgBase base = 1; + common.Status status = 2; + repeated internal.SearchResults results = 3; + internal.CostAggregation costAggregation = 4; + map channels_mvcc = 5; +} + message QueryRequest { internal.RetrieveRequest req = 1; repeated string dml_channels = 2; diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 7c12abe812..e9719a088e 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2784,11 +2784,18 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea qt := &hybridSearchTask{ ctx: ctx, Condition: NewTaskCondition(ctx), - request: request, - tr: timerecord.NewTimeRecorder(method), - qc: node.queryCoord, - node: node, - lb: node.lbPolicy, + HybridSearchRequest: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Search), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + ReqID: paramtable.GetNodeID(), + }, + request: request, + tr: timerecord.NewTimeRecorder(method), + qc: node.queryCoord, + node: node, + lb: node.lbPolicy, } guaranteeTs := request.GuaranteeTimestamp @@ -2831,7 +2838,7 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea log.Debug( rpcEnqueued(method), - zap.Uint64("timestamp", qt.request.Base.Timestamp), + zap.Uint64("timestamp", qt.Base.Timestamp), ) if err := qt.WaitToFinish(); err != nil { diff --git a/internal/proxy/reScorer.go b/internal/proxy/reScorer.go index 67b0f4fdee..07ea6de484 100644 --- a/internal/proxy/reScorer.go +++ b/internal/proxy/reScorer.go @@ -120,7 +120,7 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue return nil, errors.New("The type of rank param k should be float") } if k <= 0 || k >= maxRRFParamsValue { - return nil, errors.New("The rank params k should be in range (0, 16384)") + return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue)) } log.Debug("rrf params", zap.Float64("k", k)) for i := range reqs { diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go new file mode 100644 index 0000000000..b247f7ee18 --- /dev/null +++ b/internal/proxy/search_util.go @@ -0,0 +1,160 @@ +package proxy + +import ( + "context" + "fmt" + "strconv" + + "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" + "go.opentelemetry.io/otel" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func initSearchRequest(ctx context.Context, t *searchTask) error { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request") + defer sp.End() + + log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName)) + // fetch search_growing from search param + var ignoreGrowing bool + var err error + for i, kv := range t.request.GetSearchParams() { + if kv.GetKey() == IgnoreGrowingKey { + ignoreGrowing, err = strconv.ParseBool(kv.GetValue()) + if err != nil { + return errors.New("parse search growing failed") + } + t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...) + break + } + } + t.SearchRequest.IgnoreGrowing = ignoreGrowing + + // Manually update nq if not set. + nq, err := getNq(t.request) + if err != nil { + log.Warn("failed to get nq", zap.Error(err)) + return err + } + // Check if nq is valid: + // https://milvus.io/docs/limitations.md + if err := validateNQLimit(nq); err != nil { + return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err) + } + t.SearchRequest.Nq = nq + log = log.With(zap.Int64("nq", nq)) + + outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields()) + if err != nil { + log.Warn("fail to get output field ids", zap.Error(err)) + return err + } + t.SearchRequest.OutputFieldsId = outputFieldIDs + + partitionNames := t.request.GetPartitionNames() + if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { + annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) + if err != nil || len(annsField) == 0 { + vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema) + if len(vecFields) == 0 { + return errors.New(AnnsFieldKey + " not found in schema") + } + + if enableMultipleVectorFields && len(vecFields) > 1 { + return errors.New("multiple anns_fields exist, please specify a anns_field in search_params") + } + + annsField = vecFields[0].Name + } + queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema) + if err != nil { + return err + } + if queryInfo.GroupByFieldId != 0 { + t.SearchRequest.IgnoreGrowing = true + // for group by operation, currently, we ignore growing segments + } + t.offset = offset + + plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo) + if err != nil { + log.Warn("failed to create query plan", zap.Error(err), + zap.String("dsl", t.request.Dsl), // may be very large if large term passed. + zap.String("anns field", annsField), zap.Any("query info", queryInfo)) + return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err) + } + log.Debug("create query plan", + zap.String("dsl", t.request.Dsl), // may be very large if large term passed. + zap.String("anns field", annsField), zap.Any("query info", queryInfo)) + + if t.partitionKeyMode { + expr, err := ParseExprFromPlan(plan) + if err != nil { + log.Warn("failed to parse expr", zap.Error(err)) + return err + } + partitionKeys := ParsePartitionKeys(expr) + hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys) + if err != nil { + log.Warn("failed to assign partition keys", zap.Error(err)) + return err + } + + partitionNames = append(partitionNames, hashedPartitionNames...) + } + + plan.OutputFieldIds = outputFieldIDs + + t.SearchRequest.Topk = queryInfo.GetTopk() + t.SearchRequest.MetricType = queryInfo.GetMetricType() + t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 + + estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk) + if err != nil { + log.Warn("failed to estimate result size", zap.Error(err)) + return err + } + if estimateSize >= requeryThreshold { + t.requery = true + plan.OutputFieldIds = nil + } + + t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan) + if err != nil { + return err + } + + log.Debug("proxy init search request", + zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), + zap.Stringer("plan", plan)) // may be very large if large term passed. + } + + // translate partition name to partition ids. Use regex-pattern to match partition name. + t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, partitionNames) + if err != nil { + log.Warn("failed to get partition ids", zap.Error(err)) + return err + } + + if deadline, ok := t.TraceCtx().Deadline(); ok { + t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0) + } + + t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup + + // Set username of this search request for feature like task scheduling. + if username, _ := GetCurUserFromContext(ctx); username != "" { + t.SearchRequest.Username = username + } + + return nil +} diff --git a/internal/proxy/task_hybrid_search.go b/internal/proxy/task_hybrid_search.go index 662a0fc07a..8bb98c71c7 100644 --- a/internal/proxy/task_hybrid_search.go +++ b/internal/proxy/task_hybrid_search.go @@ -14,10 +14,11 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" - "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -32,9 +33,11 @@ const ( type hybridSearchTask struct { Condition ctx context.Context + *internalpb.HybridSearchRequest - result *milvuspb.SearchResults - request *milvuspb.HybridSearchRequest + result *milvuspb.SearchResults + request *milvuspb.HybridSearchRequest + searchTasks []*searchTask tr *timerecord.TimeRecorder schema *schemaInfo @@ -42,15 +45,14 @@ type hybridSearchTask struct { userOutputFields []string - qc types.QueryCoordClient - node types.ProxyComponent - lb LBPolicy - queryChannelsTs map[string]Timestamp - - collectionID UniqueID + qc types.QueryCoordClient + node types.ProxyComponent + lb LBPolicy + resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult] multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults] reScorers []reScorer + queryChannelsTs map[string]Timestamp rankParams *rankParams } @@ -63,7 +65,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error { } if len(t.request.Requests) > defaultMaxSearchRequest { - return errors.New("maximum of ann search requests is 1024") + return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest)) } for _, req := range t.request.GetRequests() { nq, err := getNq(req) @@ -78,12 +80,15 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error { } } + t.Base.MsgType = commonpb.MsgType_Search + t.Base.SourceID = paramtable.GetNodeID() + collectionName := t.request.CollectionName collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName) if err != nil { return err } - t.collectionID = collID + t.CollectionID = collID log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName)) t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName) @@ -113,6 +118,82 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error { t.requery = true } + collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID) + if err2 != nil { + log.Warn("Proxy::hybridSearchTask::PreExecute failed to GetCollectionInfo from cache", + zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2)) + return err2 + } + guaranteeTs := t.request.GetGuaranteeTimestamp() + var consistencyLevel commonpb.ConsistencyLevel + useDefaultConsistency := t.request.GetUseDefaultConsistency() + if useDefaultConsistency { + consistencyLevel = collectionInfo.consistencyLevel + guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel) + } else { + consistencyLevel = t.request.GetConsistencyLevel() + // Compatibility logic, parse guarantee timestamp + if consistencyLevel == 0 && guaranteeTs > 0 { + guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs()) + } else { + // parse from guarantee timestamp and user input consistency level + guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel) + } + } + + t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams()) + if err != nil { + log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err)) + return err + } + + t.searchTasks = make([]*searchTask, len(t.request.GetRequests())) + for index := range t.request.Requests { + searchReq := t.request.Requests[index] + + if len(searchReq.GetCollectionName()) == 0 { + searchReq.CollectionName = t.request.GetCollectionName() + } else if searchReq.GetCollectionName() != t.request.GetCollectionName() { + return errors.New(fmt.Sprintf("inconsistent collection name in hybrid search request, "+ + "expect %s, actual %s", searchReq.GetCollectionName(), t.request.GetCollectionName())) + } + + searchReq.PartitionNames = t.request.GetPartitionNames() + searchReq.ConsistencyLevel = consistencyLevel + searchReq.GuaranteeTimestamp = guaranteeTs + searchReq.UseDefaultConsistency = useDefaultConsistency + searchReq.OutputFields = nil + + t.searchTasks[index] = &searchTask{ + ctx: ctx, + Condition: NewTaskCondition(ctx), + collectionName: collectionName, + SearchRequest: &internalpb.SearchRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Search), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + ReqID: paramtable.GetNodeID(), + DbID: 0, // todo + CollectionID: collID, + }, + request: searchReq, + schema: t.schema, + tr: timerecord.NewTimeRecorder("hybrid search"), + qc: t.qc, + node: t.node, + lb: t.lb, + + partitionKeyMode: partitionKeyMode, + resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](), + } + err := initSearchRequest(ctx, t.searchTasks[index]) + if err != nil { + log.Debug("init hybrid search request failed", zap.Error(err)) + return err + } + } + log.Debug("hybrid search preExecute done.", zap.Uint64("guarantee_ts", t.request.GetGuaranteeTimestamp()), zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()), @@ -121,56 +202,65 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error { return nil } +func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error { + for _, searchTask := range t.searchTasks { + t.HybridSearchRequest.Reqs = append(t.HybridSearchRequest.Reqs, searchTask.SearchRequest) + } + hybridSearchReq := typeutil.Clone(t.HybridSearchRequest) + hybridSearchReq.GetBase().TargetID = nodeID + req := &querypb.HybridSearchRequest{ + Req: hybridSearchReq, + DmlChannels: []string{channel}, + TotalChannelNum: int32(1), + } + + log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()), + zap.Int64s("partitionIDs", t.GetPartitionIDs()), + zap.Int64("nodeID", nodeID), + zap.String("channel", channel)) + + var result *querypb.HybridSearchResult + var err error + + result, err = qn.HybridSearch(ctx, req) + if err != nil { + log.Warn("QueryNode hybrid search return error", zap.Error(err)) + return err + } + if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { + log.Warn("QueryNode is not shardLeader") + return errInvalidShardLeaders + } + if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("QueryNode hybrid search result error", + zap.String("reason", result.GetStatus().GetReason())) + return errors.Wrapf(merr.Error(result.GetStatus()), "fail to hybrid search on QueryNode %d", nodeID) + } + t.resultBuf.Insert(result) + t.lb.UpdateCostMetrics(nodeID, result.CostAggregation) + + return nil +} + func (t *hybridSearchTask) Execute(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute") defer sp.End() - log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName())) + log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID())) defer tr.CtxElapse(ctx, "done") - futures := make([]*conc.Future[*milvuspb.SearchResults], len(t.request.Requests)) - for index := range t.request.Requests { - searchReq := t.request.Requests[index] - future := conc.Go(func() (*milvuspb.SearchResults, error) { - searchReq.TravelTimestamp = t.request.GetTravelTimestamp() - searchReq.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp() - searchReq.NotReturnAllMeta = t.request.GetNotReturnAllMeta() - searchReq.ConsistencyLevel = t.request.GetConsistencyLevel() - searchReq.UseDefaultConsistency = t.request.GetUseDefaultConsistency() - searchReq.OutputFields = nil - - return t.node.Search(ctx, searchReq) - }) - futures[index] = future - } - - err := conc.AwaitAll(futures...) + t.resultBuf = typeutil.NewConcurrentSet[*querypb.HybridSearchResult]() + err := t.lb.Execute(ctx, CollectionWorkLoad{ + db: t.request.GetDbName(), + collectionID: t.CollectionID, + collectionName: t.request.GetCollectionName(), + nq: 1, + exec: t.hybridSearchShard, + }) if err != nil { - return err - } - - t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams()) - if err != nil { - log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err)) - return err - } - t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]() - for i, future := range futures { - err = future.Err() - if err != nil { - log.Debug("QueryNode search result error", zap.Error(err)) - return err - } - result := futures[i].Value() - if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Debug("QueryNode search result error", - zap.String("reason", result.GetStatus().GetReason())) - return merr.Error(result.GetStatus()) - } - - t.reScorers[i].reScore(result) - t.multipleRecallResults.Insert(result) + log.Warn("hybrid search execute failed", zap.Error(err)) + return errors.Wrap(err, "failed to hybrid search") } log.Debug("hybrid search execute done.") @@ -194,7 +284,7 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair) if err != nil { - return nil, errors.New(LimitKey + " not found in search_params") + return nil, errors.New(LimitKey + " not found in rank_params") } limit, err = strconv.ParseInt(limitStr, 0, 64) if err != nil { @@ -235,16 +325,59 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro }, nil } +func (t *hybridSearchTask) collectHybridSearchResults(ctx context.Context) error { + select { + case <-t.TraceCtx().Done(): + log.Ctx(ctx).Warn("hybrid search task wait to finish timeout!") + return fmt.Errorf("hybrid search task wait to finish timeout, msgID=%d", t.ID()) + default: + log.Ctx(ctx).Debug("all hybrid searches are finished or canceled") + t.resultBuf.Range(func(res *querypb.HybridSearchResult) bool { + for index, searchResult := range res.GetResults() { + t.searchTasks[index].resultBuf.Insert(searchResult) + } + log.Ctx(ctx).Debug("proxy receives one hybrid search result", + zap.Int64("sourceID", res.GetBase().GetSourceID())) + return true + }) + + t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]() + for i, searchTask := range t.searchTasks { + err := searchTask.PostExecute(ctx) + if err != nil { + return err + } + t.reScorers[i].reScore(searchTask.result) + t.multipleRecallResults.Insert(searchTask.result) + } + + return nil + } +} + func (t *hybridSearchTask) PostExecute(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute") defer sp.End() - log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName())) + log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID())) defer func() { tr.CtxElapse(ctx, "done") }() + err := t.collectHybridSearchResults(ctx) + if err != nil { + log.Warn("failed to collect hybrid search results", zap.Error(err)) + return err + } + + t.queryChannelsTs = make(map[string]uint64) + for _, r := range t.resultBuf.Collect() { + for ch, ts := range r.GetChannelsMvcc() { + t.queryChannelsTs[ch] = ts + } + } + primaryFieldSchema, err := t.schema.GetPkField() if err != nil { log.Warn("failed to get primary field schema", zap.Error(err)) @@ -304,9 +437,8 @@ func (t *hybridSearchTask) Requery() error { }, } - // TODO:Xige-16 refine the mvcc functionality of hybrid search // TODO:silverxia move partitionIDs to hybrid search level - return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{}) + return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{}) } func rankSearchResultData(ctx context.Context, @@ -436,11 +568,11 @@ func (t *hybridSearchTask) TraceCtx() context.Context { } func (t *hybridSearchTask) ID() UniqueID { - return t.request.Base.MsgID + return t.Base.MsgID } func (t *hybridSearchTask) SetID(uid UniqueID) { - t.request.Base.MsgID = uid + t.Base.MsgID = uid } func (t *hybridSearchTask) Name() string { @@ -448,24 +580,24 @@ func (t *hybridSearchTask) Name() string { } func (t *hybridSearchTask) Type() commonpb.MsgType { - return t.request.Base.MsgType + return t.Base.MsgType } func (t *hybridSearchTask) BeginTs() Timestamp { - return t.request.Base.Timestamp + return t.Base.Timestamp } func (t *hybridSearchTask) EndTs() Timestamp { - return t.request.Base.Timestamp + return t.Base.Timestamp } func (t *hybridSearchTask) SetTs(ts Timestamp) { - t.request.Base.Timestamp = ts + t.Base.Timestamp = ts } func (t *hybridSearchTask) OnEnqueue() error { - t.request.Base = commonpbutil.NewMsgBase() - t.request.Base.MsgType = commonpb.MsgType_Search - t.request.Base.SourceID = paramtable.GetNodeID() + t.Base = commonpbutil.NewMsgBase() + t.Base.MsgType = commonpb.MsgType_Search + t.Base.SourceID = paramtable.GetNodeID() return nil } diff --git a/internal/proxy/task_hybrid_search_test.go b/internal/proxy/task_hybrid_search_test.go index 0cee1f89db..1ea332891a 100644 --- a/internal/proxy/task_hybrid_search_test.go +++ b/internal/proxy/task_hybrid_search_test.go @@ -20,6 +20,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -67,8 +68,9 @@ func TestHybridSearchTask_PreExecute(t *testing.T) { genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask { task := &hybridSearchTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), + ctx: ctx, + Condition: NewTaskCondition(ctx), + HybridSearchRequest: &internalpb.HybridSearchRequest{}, request: &milvuspb.HybridSearchRequest{ CollectionName: collName, Requests: reqs, @@ -225,6 +227,7 @@ func TestHybridSearchTask_ErrExecute(t *testing.T) { result: &milvuspb.SearchResults{ Status: merr.Success(), }, + HybridSearchRequest: &internalpb.HybridSearchRequest{}, request: &milvuspb.HybridSearchRequest{ CollectionName: collectionName, Requests: []*milvuspb.SearchRequest{ @@ -266,12 +269,12 @@ func TestHybridSearchTask_ErrExecute(t *testing.T) { task.ctx = ctx assert.NoError(t, task.PreExecute(ctx)) - qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) + qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() - qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ + qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, }, @@ -291,6 +294,10 @@ func TestHybridSearchTask_PostExecute(t *testing.T) { mgr := NewMockShardClientManager(t) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() + qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{ + Base: commonpbutil.NewMsgBase(), + Status: merr.Success(), + }, nil) t.Run("Test empty result", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -313,6 +320,9 @@ func TestHybridSearchTask_PostExecute(t *testing.T) { qc: nil, tr: timerecord.NewTimeRecorder("search"), schema: schema, + HybridSearchRequest: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + }, request: &milvuspb.HybridSearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, @@ -320,6 +330,7 @@ func TestHybridSearchTask_PostExecute(t *testing.T) { CollectionName: collectionName, RankParams: rankParams, }, + resultBuf: typeutil.NewConcurrentSet[*querypb.HybridSearchResult](), multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](), } diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 60af81cbed..5af8a19d79 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -30,7 +30,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -53,10 +52,11 @@ type searchTask struct { result *milvuspb.SearchResults request *milvuspb.SearchRequest - tr *timerecord.TimeRecorder - collectionName string - schema *schemaInfo - requery bool + tr *timerecord.TimeRecorder + collectionName string + schema *schemaInfo + requery bool + partitionKeyMode bool userOutputFields []string @@ -250,22 +250,21 @@ func (t *searchTask) PreExecute(ctx context.Context) error { return err } - log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName)) - t.SearchRequest.DbID = 0 // todo t.SearchRequest.CollectionID = collID + log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName)) t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName) if err != nil { log.Warn("get collection schema failed", zap.Error(err)) return err } - partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName) + t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName) if err != nil { log.Warn("is partition key mode failed", zap.Error(err)) return err } - if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 { + if t.partitionKeyMode && len(t.request.GetPartitionNames()) != 0 { return errors.New("not support manually specifying the partition names if partition key mode is used") } @@ -277,123 +276,9 @@ func (t *searchTask) PreExecute(ctx context.Context) error { log.Debug("translate output fields", zap.Strings("output fields", t.request.GetOutputFields())) - // fetch search_growing from search param - var ignoreGrowing bool - for i, kv := range t.request.GetSearchParams() { - if kv.GetKey() == IgnoreGrowingKey { - ignoreGrowing, err = strconv.ParseBool(kv.GetValue()) - if err != nil { - return errors.New("parse search growing failed") - } - t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...) - break - } - } - t.SearchRequest.IgnoreGrowing = ignoreGrowing - - // Manually update nq if not set. - nq, err := getNq(t.request) + err = initSearchRequest(ctx, t) if err != nil { - log.Warn("failed to get nq", zap.Error(err)) - return err - } - // Check if nq is valid: - // https://milvus.io/docs/limitations.md - if err := validateNQLimit(nq); err != nil { - return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err) - } - t.SearchRequest.Nq = nq - log = log.With(zap.Int64("nq", nq)) - - outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields()) - if err != nil { - log.Warn("fail to get output field ids", zap.Error(err)) - return err - } - t.SearchRequest.OutputFieldsId = outputFieldIDs - - partitionNames := t.request.GetPartitionNames() - if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { - annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) - if err != nil || len(annsField) == 0 { - vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema) - if len(vecFields) == 0 { - return errors.New(AnnsFieldKey + " not found in schema") - } - - if enableMultipleVectorFields && len(vecFields) > 1 { - return errors.New("multiple anns_fields exist, please specify a anns_field in search_params") - } - - annsField = vecFields[0].Name - } - queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema) - if err != nil { - return err - } - if queryInfo.GroupByFieldId != 0 { - t.SearchRequest.IgnoreGrowing = true - // for group by operation, currently, we ignore growing segments - } - t.offset = offset - - plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo) - if err != nil { - log.Warn("failed to create query plan", zap.Error(err), - zap.String("dsl", t.request.Dsl), // may be very large if large term passed. - zap.String("anns field", annsField), zap.Any("query info", queryInfo)) - return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err) - } - log.Debug("create query plan", - zap.String("dsl", t.request.Dsl), // may be very large if large term passed. - zap.String("anns field", annsField), zap.Any("query info", queryInfo)) - - if partitionKeyMode { - expr, err := ParseExprFromPlan(plan) - if err != nil { - log.Warn("failed to parse expr", zap.Error(err)) - return err - } - partitionKeys := ParsePartitionKeys(expr) - hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), collectionName, partitionKeys) - if err != nil { - log.Warn("failed to assign partition keys", zap.Error(err)) - return err - } - - partitionNames = append(partitionNames, hashedPartitionNames...) - } - - plan.OutputFieldIds = outputFieldIDs - - t.SearchRequest.Topk = queryInfo.GetTopk() - t.SearchRequest.MetricType = queryInfo.GetMetricType() - t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 - - estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk) - if err != nil { - log.Warn("failed to estimate result size", zap.Error(err)) - return err - } - if estimateSize >= requeryThreshold { - t.requery = true - plan.OutputFieldIds = nil - } - - t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan) - if err != nil { - return err - } - - log.Debug("Proxy::searchTask::PreExecute", - zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), - zap.Stringer("plan", plan)) // may be very large if large term passed. - } - - // translate partition name to partition ids. Use regex-pattern to match partition name. - t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, partitionNames) - if err != nil { - log.Warn("failed to get partition ids", zap.Error(err)) + log.Debug("init search request failed", zap.Error(err)) return err } @@ -421,17 +306,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error { } t.SearchRequest.GuaranteeTimestamp = guaranteeTs - if deadline, ok := t.TraceCtx().Deadline(); ok { - t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0) - } - - t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup - - // Set username of this search request for feature like task scheduling. - if username, _ := GetCurUserFromContext(ctx); username != "" { - t.SearchRequest.Username = username - } - log.Debug("search PreExecute done.", zap.Uint64("guarantee_ts", guaranteeTs), zap.Bool("use_default_consistency", useDefaultConsistency), diff --git a/internal/querycoordv2/mocks/mock_querynode.go b/internal/querycoordv2/mocks/mock_querynode.go index 039d03ee6b..7227f53451 100644 --- a/internal/querycoordv2/mocks/mock_querynode.go +++ b/internal/querycoordv2/mocks/mock_querynode.go @@ -469,6 +469,61 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) RunAndReturn(run func(con return _c } +// HybridSearch provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNodeServer) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) { + ret := _m.Called(_a0, _a1) + + var r0 *querypb.HybridSearchResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.HybridSearchResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeServer_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch' +type MockQueryNodeServer_HybridSearch_Call struct { + *mock.Call +} + +// HybridSearch is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.HybridSearchRequest +func (_e *MockQueryNodeServer_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_HybridSearch_Call { + return &MockQueryNodeServer_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)} +} + +func (_c *MockQueryNodeServer_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNodeServer_HybridSearch_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest)) + }) + return _c +} + +func (_c *MockQueryNodeServer_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeServer_HybridSearch_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeServer_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNodeServer_HybridSearch_Call { + _c.Call.Return(run) + return _c +} + // LoadPartitions provides a mock function with given fields: _a0, _a1 func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index fde526b677..00c35c720f 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -62,6 +62,7 @@ type ShardDelegator interface { GetSegmentInfo(readable bool) (sealed []SnapshotItem, growing []SegmentEntry) SyncDistribution(ctx context.Context, entries ...SegmentEntry) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) + HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) @@ -184,6 +185,44 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu return nodeReq } +// Search preforms search operation on shard. +func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest, sealed []SnapshotItem, growing []SegmentEntry) ([]*internalpb.SearchResults, error) { + log := sd.getLogger(ctx) + if req.Req.IgnoreGrowing { + growing = []SegmentEntry{} + } + + sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) }) + log.Debug("search segments...", + zap.Int("sealedNum", sealedNum), + zap.Int("growingNum", len(growing)), + ) + + req, err := optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum) + if err != nil { + log.Warn("failed to optimize search params", zap.Error(err)) + return nil, err + } + + tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest) + if err != nil { + log.Warn("Search organizeSubTask failed", zap.Error(err)) + return nil, err + } + + results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) { + return worker.SearchSegments(ctx, req) + }, "Search", log) + if err != nil { + log.Warn("Delegator search failed", zap.Error(err)) + return nil, err + } + + log.Debug("Delegator search done") + + return results, nil +} + // Search preforms search operation on shard. func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) { log := sd.getLogger(ctx) @@ -229,39 +268,113 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest return funcutil.SliceContain(existPartitions, segment.PartitionID) }) - if req.Req.IgnoreGrowing { - growing = []SegmentEntry{} + return sd.search(ctx, req, sealed, growing) +} + +// HybridSearch preforms hybrid search operation on shard. +func (sd *shardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) { + log := sd.getLogger(ctx) + if err := sd.lifetime.Add(lifetime.IsWorking); err != nil { + return nil, err + } + defer sd.lifetime.Done() + + if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { + log.Warn("deletgator received hybrid search request not belongs to it", + zap.Strings("reqChannels", req.GetDmlChannels()), + ) + return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels()) } - sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) }) - log.Debug("search segments...", - zap.Int("sealedNum", sealedNum), - zap.Int("growingNum", len(growing)), - ) + partitions := req.GetReq().GetPartitionIDs() + if !sd.collection.ExistPartition(partitions...) { + return nil, merr.WrapErrPartitionNotLoaded(partitions) + } - req, err = optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum) + // wait tsafe + waitTr := timerecord.NewTimeRecorder("wait tSafe") + tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) + if err != nil { + log.Warn("delegator hybrid search failed to wait tsafe", zap.Error(err)) + return nil, err + } + if req.GetReq().GetMvccTimestamp() == 0 { + req.Req.MvccTimestamp = tSafe + } + metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues( + fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel). + Observe(float64(waitTr.ElapseSpan().Milliseconds())) + + sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) + if err != nil { + log.Warn("delegator failed to hybrid search, current distribution is not serviceable") + return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") + } + defer sd.distribution.Unpin(version) + existPartitions := sd.collection.GetPartitions() + growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { + return funcutil.SliceContain(existPartitions, segment.PartitionID) + }) + + futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetReqs())) + for index := range req.GetReq().GetReqs() { + request := req.GetReq().Reqs[index] + future := conc.Go(func() (*internalpb.SearchResults, error) { + searchReq := &querypb.SearchRequest{ + Req: request, + DmlChannels: req.GetDmlChannels(), + TotalChannelNum: req.GetTotalChannelNum(), + FromShardLeader: true, + } + searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp() + searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp() + if searchReq.GetReq().GetMvccTimestamp() == 0 { + searchReq.GetReq().MvccTimestamp = tSafe + } + + results, err := sd.search(ctx, searchReq, sealed, growing) + if err != nil { + return nil, err + } + + return segments.ReduceSearchResults(ctx, + results, + searchReq.Req.GetNq(), + searchReq.Req.GetTopk(), + searchReq.Req.GetMetricType()) + }) + futures[index] = future + } + + err = conc.AwaitAll(futures...) if err != nil { - log.Warn("failed to optimize search params", zap.Error(err)) return nil, err } - tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest) - if err != nil { - log.Warn("Search organizeSubTask failed", zap.Error(err)) - return nil, err + ret := &querypb.HybridSearchResult{ + Status: merr.Success(), + Results: make([]*internalpb.SearchResults, len(futures)), } - results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) { - return worker.SearchSegments(ctx, req) - }, "Search", log) - if err != nil { - log.Warn("Delegator search failed", zap.Error(err)) - return nil, err + channelsMvcc := make(map[string]uint64) + for i, future := range futures { + result := future.Value() + if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Debug("delegator hybrid search failed", + zap.String("reason", result.GetStatus().GetReason())) + return nil, merr.Error(result.GetStatus()) + } + + ret.Results[i] = result + for ch, ts := range result.GetChannelsMvcc() { + channelsMvcc[ch] = ts + } } + ret.ChannelsMvcc = channelsMvcc - log.Debug("Delegator search done") + log.Debug("Delegator hybrid search done") - return results, nil + return ret, nil } func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index a793d3b0d2..2cfcd1115a 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -469,6 +469,251 @@ func (s *DelegatorSuite) TestSearch() { }) } +func (s *DelegatorSuite) TestHybridSearch() { + s.delegator.Start() + paramtable.SetNodeID(1) + s.initSegments() + s.Run("normal", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + }() + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + worker2 := &cluster.MockWorker{} + + workers[1] = worker1 + workers[2] = worker2 + + worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). + Run(func(_ context.Context, req *querypb.SearchRequest) { + s.EqualValues(1, req.Req.GetBase().GetTargetID()) + s.True(req.GetFromShardLeader()) + if req.GetScope() == querypb.DataScope_Streaming { + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1004}, req.GetSegmentIDs()) + } + if req.GetScope() == querypb.DataScope_Historical { + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs()) + } + }).Return(&internalpb.SearchResults{}, nil) + worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). + Run(func(_ context.Context, req *querypb.SearchRequest) { + s.EqualValues(2, req.Req.GetBase().GetTargetID()) + s.True(req.GetFromShardLeader()) + s.Equal(querypb.DataScope_Historical, req.GetScope()) + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs()) + }).Return(&internalpb.SearchResults{}, nil) + + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + results, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + Reqs: []*internalpb.SearchRequest{ + {Base: commonpbutil.NewMsgBase()}, + {Base: commonpbutil.NewMsgBase()}, + }, + }, + DmlChannels: []string{s.vchannelName}, + }) + + s.NoError(err) + s.Equal(2, len(results.Results)) + }) + + s.Run("partition_not_loaded", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + // not load partation -1,will return error + PartitionIDs: []int64{-1}, + }, + DmlChannels: []string{s.vchannelName}, + }) + + s.True(errors.Is(err, merr.ErrPartitionNotLoaded)) + }) + + s.Run("worker_return_error", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + }() + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + worker2 := &cluster.MockWorker{} + + workers[1] = worker1 + workers[2] = worker2 + + worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(nil, errors.New("mock error")) + worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). + Run(func(_ context.Context, req *querypb.SearchRequest) { + s.EqualValues(2, req.Req.GetBase().GetTargetID()) + s.True(req.GetFromShardLeader()) + s.Equal(querypb.DataScope_Historical, req.GetScope()) + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs()) + }).Return(&internalpb.SearchResults{}, nil) + + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + Reqs: []*internalpb.SearchRequest{ + { + Base: commonpbutil.NewMsgBase(), + }, + }, + }, + DmlChannels: []string{s.vchannelName}, + }) + + s.Error(err) + }) + + s.Run("worker_return_failure_code", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + }() + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + worker2 := &cluster.MockWorker{} + + workers[1] = worker1 + workers[2] = worker2 + + worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(&internalpb.SearchResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mocked error", + }, + }, nil) + worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). + Run(func(_ context.Context, req *querypb.SearchRequest) { + s.EqualValues(2, req.Req.GetBase().GetTargetID()) + s.True(req.GetFromShardLeader()) + s.Equal(querypb.DataScope_Historical, req.GetScope()) + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs()) + }).Return(&internalpb.SearchResults{}, nil) + + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + Reqs: []*internalpb.SearchRequest{ + { + Base: commonpbutil.NewMsgBase(), + }, + }, + }, + DmlChannels: []string{s.vchannelName}, + }) + + s.Error(err) + }) + + s.Run("wrong_channel", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + }, + DmlChannels: []string{"non_exist_channel"}, + }) + + s.Error(err) + }) + + s.Run("wait_tsafe_timeout", func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + _, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + GuaranteeTimestamp: 10100, + }, + DmlChannels: []string{s.vchannelName}, + }) + + s.Error(err) + }) + + s.Run("tsafe_behind_max_lag", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)) + 10001, + }, + DmlChannels: []string{s.vchannelName}, + }) + + s.Error(err) + }) + + s.Run("distribution_not_serviceable", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sd, ok := s.delegator.(*shardDelegator) + s.Require().True(ok) + sd.distribution.AddOfflines(1001) + + _, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + }, + DmlChannels: []string{s.vchannelName}, + }) + + s.Error(err) + }) + + s.Run("cluster_not_serviceable", func() { + s.delegator.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: commonpbutil.NewMsgBase(), + }, + DmlChannels: []string{s.vchannelName}, + }) + + s.Error(err) + }) +} + func (s *DelegatorSuite) TestQuery() { s.delegator.Start() paramtable.SetNodeID(1) diff --git a/internal/querynodev2/delegator/mock_delegator.go b/internal/querynodev2/delegator/mock_delegator.go index ae0191b443..1e2ddc2960 100644 --- a/internal/querynodev2/delegator/mock_delegator.go +++ b/internal/querynodev2/delegator/mock_delegator.go @@ -253,6 +253,61 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int6 return _c } +// HybridSearch provides a mock function with given fields: ctx, req +func (_m *MockShardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) { + ret := _m.Called(ctx, req) + + var r0 *querypb.HybridSearchResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.HybridSearchResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockShardDelegator_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch' +type MockShardDelegator_HybridSearch_Call struct { + *mock.Call +} + +// HybridSearch is a helper method to define mock.On call +// - ctx context.Context +// - req *querypb.HybridSearchRequest +func (_e *MockShardDelegator_Expecter) HybridSearch(ctx interface{}, req interface{}) *MockShardDelegator_HybridSearch_Call { + return &MockShardDelegator_HybridSearch_Call{Call: _e.mock.On("HybridSearch", ctx, req)} +} + +func (_c *MockShardDelegator_HybridSearch_Call) Run(run func(ctx context.Context, req *querypb.HybridSearchRequest)) *MockShardDelegator_HybridSearch_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest)) + }) + return _c +} + +func (_c *MockShardDelegator_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockShardDelegator_HybridSearch_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockShardDelegator_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockShardDelegator_HybridSearch_Call { + _c.Call.Return(run) + return _c +} + // LoadGrowing provides a mock function with given fields: ctx, infos, version func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error { ret := _m.Called(ctx, infos, version) diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index c27dd7a64e..d4afc8df1c 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -401,6 +401,63 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq return resp, nil } +func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.HybridSearchRequest, channel string) (*querypb.HybridSearchResult, error) { + log := log.Ctx(ctx).With( + zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()), + zap.Int64("collectionID", req.Req.GetCollectionID()), + zap.String("channel", channel), + ) + traceID := trace.SpanFromContext(ctx).SpanContext().TraceID() + + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + return nil, err + } + defer node.lifetime.Done() + + var err error + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc() + defer func() { + if err != nil { + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc() + } + }() + + log.Debug("start to search channel") + searchCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // From Proxy + tr := timerecord.NewTimeRecorder("hybridSearchDelegator") + // get delegator + sd, ok := node.delegators.Get(channel) + if !ok { + err := merr.WrapErrChannelNotFound(channel) + log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err)) + return nil, err + } + // do hybrid search + result, err := sd.HybridSearch(searchCtx, req) + if err != nil { + log.Warn("failed to hybrid search on delegator", zap.Error(err)) + return nil, err + } + + tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , traceID = %s, vChannel = %s", + traceID, + channel, + )) + + // update metric to prometheus + latency := tr.ElapseSpan() + metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc() + for _, searchReq := range req.GetReq().GetReqs() { + metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetNq())) + metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetTopk())) + } + return result, nil +} + func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) { log := log.Ctx(ctx).With( zap.Int64("collectionID", req.Req.GetCollectionID()), diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 5254c42464..11de702bfc 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -821,6 +821,114 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( return result, nil } +// HybridSearch performs replica search tasks. +func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) { + log := log.Ctx(ctx).With( + zap.Int64("collectionID", req.GetReq().GetCollectionID()), + zap.Strings("channels", req.GetDmlChannels())) + + log.Debug("Received HybridSearchRequest", + zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()), + zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp())) + + tr := timerecord.NewTimeRecorderWithTrace(ctx, "HybridSearchRequest") + + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + return &querypb.HybridSearchResult{ + Base: &commonpb.MsgBase{ + SourceID: paramtable.GetNodeID(), + }, + Status: merr.Status(err), + }, nil + } + defer node.lifetime.Done() + + err := merr.CheckTargetID(req.GetReq().GetBase()) + if err != nil { + log.Warn("target ID check failed", zap.Error(err)) + return &querypb.HybridSearchResult{ + Base: &commonpb.MsgBase{ + SourceID: paramtable.GetNodeID(), + }, + Status: merr.Status(err), + }, nil + } + + resp := &querypb.HybridSearchResult{ + Base: &commonpb.MsgBase{ + SourceID: paramtable.GetNodeID(), + }, + Status: merr.Success(), + } + collection := node.manager.Collection.Get(req.GetReq().GetCollectionID()) + if collection == nil { + resp.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID())) + return resp, nil + } + + MultipleResults := make([]*querypb.HybridSearchResult, len(req.GetDmlChannels())) + runningGp, runningCtx := errgroup.WithContext(ctx) + + for i, ch := range req.GetDmlChannels() { + ch := ch + req := &querypb.HybridSearchRequest{ + Req: req.Req, + DmlChannels: []string{ch}, + TotalChannelNum: 1, + } + + i := i + runningGp.Go(func() error { + ret, err := node.hybridSearchChannel(runningCtx, req, ch) + if err != nil { + return err + } + if err := merr.Error(ret.GetStatus()); err != nil { + return err + } + MultipleResults[i] = ret + return nil + }) + } + if err := runningGp.Wait(); err != nil { + resp.Status = merr.Status(err) + return resp, nil + } + + tr.RecordSpan() + channelsMvcc := make(map[string]uint64) + for i, searchReq := range req.GetReq().GetReqs() { + toReduceResults := make([]*internalpb.SearchResults, len(MultipleResults)) + for index, hs := range MultipleResults { + toReduceResults[index] = hs.Results[i] + } + result, err := segments.ReduceSearchResults(ctx, toReduceResults, searchReq.GetNq(), searchReq.GetTopk(), searchReq.GetMetricType()) + if err != nil { + log.Warn("failed to reduce search results", zap.Error(err)) + resp.Status = merr.Status(err) + return resp, nil + } + for ch, ts := range result.GetChannelsMvcc() { + channelsMvcc[ch] = ts + } + resp.Results = append(resp.Results, result) + } + resp.ChannelsMvcc = channelsMvcc + + reduceLatency := tr.RecordSpan() + metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards). + Observe(float64(reduceLatency.Milliseconds())) + + collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req))) + metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.HybridSearchLabel). + Add(float64(proto.Size(req))) + + if resp.GetCostAggregation() != nil { + resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds() + } + return resp, nil +} + // only used for delegator query segments from worker func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { resp := &internalpb.RetrieveResults{ diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index c8048ebca5..308d729433 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -1323,6 +1323,47 @@ func (suite *ServiceSuite) TestSearchSegments_Failed() { suite.Equal(commonpb.ErrorCode_UnexpectedError, rsp.GetStatus().GetErrorCode()) } +func (suite *ServiceSuite) TestHybridSearch_Concurrent() { + ctx := context.Background() + // pre + suite.TestWatchDmChannelsInt64() + suite.TestLoadSegments_Int64() + + concurrency := 16 + futures := make([]*conc.Future[*querypb.HybridSearchResult], 0, concurrency) + for i := 0; i < concurrency; i++ { + future := conc.Go(func() (*querypb.HybridSearchResult, error) { + creq1, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType) + suite.NoError(err) + creq2, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType) + suite.NoError(err) + req := &querypb.HybridSearchRequest{ + Req: &internalpb.HybridSearchRequest{ + Base: &commonpb.MsgBase{ + MsgID: rand.Int63(), + TargetID: suite.node.session.ServerID, + }, + CollectionID: suite.collectionID, + PartitionIDs: suite.partitionIDs, + MvccTimestamp: typeutil.MaxTimestamp, + Reqs: []*internalpb.SearchRequest{creq1, creq2}, + }, + DmlChannels: []string{suite.vchannel}, + } + + return suite.node.HybridSearch(ctx, req) + }) + futures = append(futures, future) + } + + err := conc.AwaitAll(futures...) + suite.NoError(err) + + for i := range futures { + suite.True(merr.Ok(futures[i].Value().GetStatus())) + } +} + func (suite *ServiceSuite) TestSearchSegments_Normal() { ctx := context.Background() // pre diff --git a/internal/util/mock/grpc_querynode_client.go b/internal/util/mock/grpc_querynode_client.go index e20dc0d635..5eaeb1fa2f 100644 --- a/internal/util/mock/grpc_querynode_client.go +++ b/internal/util/mock/grpc_querynode_client.go @@ -82,6 +82,10 @@ func (m *GrpcQueryNodeClient) Search(ctx context.Context, in *querypb.SearchRequ return &internalpb.SearchResults{}, m.Err } +func (m *GrpcQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) { + return &querypb.HybridSearchResult{}, m.Err +} + func (m *GrpcQueryNodeClient) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) { return &internalpb.SearchResults{}, m.Err } diff --git a/internal/util/wrappers/qn_wrapper.go b/internal/util/wrappers/qn_wrapper.go index 63147c0116..90bbffab15 100644 --- a/internal/util/wrappers/qn_wrapper.go +++ b/internal/util/wrappers/qn_wrapper.go @@ -93,6 +93,10 @@ func (qn *qnServerWrapper) Search(ctx context.Context, in *querypb.SearchRequest return qn.QueryNode.Search(ctx, in) } +func (qn *qnServerWrapper) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) { + return qn.QueryNode.HybridSearch(ctx, in) +} + func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) { return qn.QueryNode.SearchSegments(ctx, in) }