milvus/internal/querynodev2/handlers.go
yah01 081572d31c
Refactor QueryNode (#21625)
Signed-off-by: yah01 <yang.cen@zilliz.com>
Co-authored-by: Congqi Xia <congqi.xia@zilliz.com>
Co-authored-by: aoiasd <zhicheng.yue@zilliz.com>
2023-03-27 00:42:00 +08:00

456 lines
16 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynodev2
import (
"context"
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/timerecord"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
func loadGrowingSegments(ctx context.Context, delegator delegator.ShardDelegator, req *querypb.WatchDmChannelsRequest) error {
// load growing segments
growingSegments := make([]*querypb.SegmentLoadInfo, 0, len(req.Infos))
for _, info := range req.Infos {
for _, segmentID := range info.GetUnflushedSegmentIds() {
// unFlushed segment may not have binLogs, skip loading
segmentInfo := req.GetSegmentInfos()[segmentID]
if segmentInfo == nil {
log.Warn("an unflushed segment is not found in segment infos", zap.Int64("segment ID", segmentID))
continue
}
if len(segmentInfo.GetBinlogs()) > 0 {
growingSegments = append(growingSegments, &querypb.SegmentLoadInfo{
SegmentID: segmentInfo.ID,
PartitionID: segmentInfo.PartitionID,
CollectionID: segmentInfo.CollectionID,
BinlogPaths: segmentInfo.Binlogs,
NumOfRows: segmentInfo.NumOfRows,
Statslogs: segmentInfo.Statslogs,
Deltalogs: segmentInfo.Deltalogs,
InsertChannel: segmentInfo.InsertChannel,
})
} else {
log.Info("skip segment which binlog is empty", zap.Int64("segmentID", segmentInfo.ID))
}
}
}
return delegator.LoadGrowing(ctx, growingSegments, req.GetVersion())
}
func (node *QueryNode) loadDeltaLogs(ctx context.Context, req *querypb.LoadSegmentsRequest) *commonpb.Status {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
loadedSegments := make([]int64, 0, len(req.GetInfos()))
var finalErr error
for _, info := range req.GetInfos() {
segment := node.manager.Segment.GetSealed(info.GetSegmentID())
if segment == nil {
continue
}
local := segment.(*segments.LocalSegment)
err := node.loader.LoadDeltaLogs(ctx, local, info.GetDeltalogs())
if err != nil {
if finalErr == nil {
finalErr = err
}
continue
}
loadedSegments = append(loadedSegments, info.GetSegmentID())
}
if finalErr != nil {
log.Warn("failed to load delta logs", zap.Error(finalErr))
return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, "failed to load delta logs", finalErr)
}
return util.SuccessStatus()
}
func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryRequest, channel string) (*internalpb.RetrieveResults, error) {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel).Inc()
failRet := WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "")
msgID := req.Req.Base.GetMsgID()
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
log := log.Ctx(ctx).With(
zap.Int64("msgID", msgID),
zap.Int64("collectionID", req.GetReq().GetCollectionID()),
zap.String("channel", channel),
zap.String("scope", req.GetScope().String()),
)
defer func() {
if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc()
}
}()
if !node.lifetime.Add(commonpbutil.IsHealthy) {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
return failRet, nil
}
defer node.lifetime.Done()
log.Debug("start do query with channel",
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
)
// add cancel when error occurs
queryCtx, cancel := context.WithCancel(ctx)
defer cancel()
// TODO From Shard Delegator
if req.FromShardLeader {
tr := timerecord.NewTimeRecorder("queryChannel")
results, err := node.querySegments(queryCtx, req)
if err != nil {
log.Warn("failed to query channel", zap.Error(err))
failRet.Status.Reason = err.Error()
return failRet, nil
}
tr.CtxElapse(ctx, fmt.Sprintf("do query done, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
traceID,
req.GetFromShardLeader(),
channel,
req.GetSegmentIDs(),
))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
// TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
return results, nil
}
// From Proxy
tr := timerecord.NewTimeRecorder("queryDelegator")
// get delegator
sd, ok := node.delegators.Get(channel)
if !ok {
log.Warn("Query failed, failed to get query shard delegator", zap.Error(ErrGetDelegatorFailed))
failRet.Status.Reason = ErrGetDelegatorFailed.Error()
return failRet, nil
}
// do query
results, err := sd.Query(queryCtx, req)
if err != nil {
log.Warn("failed to query on delegator", zap.Error(err))
failRet.Status.Reason = err.Error()
return failRet, nil
}
// reduce result
tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
traceID,
req.GetFromShardLeader(),
channel,
req.GetSegmentIDs(),
))
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
log.Warn("Query failed, failed to get collection")
failRet.Status.Reason = segments.WrapCollectionNotFound(req.Req.CollectionID).Error()
return failRet, nil
}
ret, err := segments.MergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.GetReq().GetOutputFieldsId(), collection.Schema())
if err != nil {
failRet.Status.Reason = err.Error()
return failRet, nil
}
tr.CtxElapse(ctx, fmt.Sprintf("do query with channel done , vChannel = %s, segmentIDs = %v",
channel,
req.GetSegmentIDs(),
))
//
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
return ret, nil
}
func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
return nil, segments.ErrCollectionNotFound
}
// build plan
retrievePlan, err := segments.NewRetrievePlan(
collection,
req.Req.GetSerializedExprPlan(),
req.Req.GetTravelTimestamp(),
req.Req.Base.GetMsgID(),
)
if err != nil {
return nil, err
}
defer retrievePlan.Delete()
var results []*segcorepb.RetrieveResults
if req.GetScope() == querypb.DataScope_Historical {
results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
} else {
results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
}
if err != nil {
return nil, err
}
reducedResult, err := segments.MergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.Req.GetOutputFieldsId(), collection.Schema())
if err != nil {
return nil, err
}
return &internalpb.RetrieveResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Ids: reducedResult.Ids,
FieldsData: reducedResult.FieldsData,
}, nil
}
func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) {
log := log.Ctx(ctx).With(
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
zap.Int64("collectionID", req.Req.GetCollectionID()),
zap.String("channel", channel),
zap.String("scope", req.GetScope().String()),
)
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
if !node.lifetime.Add(commonpbutil.IsHealthy) {
return nil, WrapErrNodeUnhealthy(paramtable.GetNodeID())
}
defer node.lifetime.Done()
log.Debug("start to search channel",
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
)
searchCtx, cancel := context.WithCancel(ctx)
defer cancel()
// TODO From Shard Delegator
if req.GetFromShardLeader() {
tr := timerecord.NewTimeRecorder("searchChannel")
log.Debug("search channel...")
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
log.Warn("failed to search channel", zap.Error(segments.ErrCollectionNotFound))
return nil, segments.WrapCollectionNotFound(req.GetReq().GetCollectionID())
}
task := tasks.NewSearchTask(searchCtx, collection, node.manager, req)
if !node.scheduler.Add(task) {
log.Warn("failed to search channel", zap.Error(tasks.ErrTaskQueueFull))
return nil, tasks.ErrTaskQueueFull
}
err := task.Wait()
if err != nil {
log.Warn("failed to search channel", zap.Error(err))
return nil, err
}
tr.CtxElapse(ctx, fmt.Sprintf("search channel done, channel = %s, segmentIDs = %v",
channel,
req.GetSegmentIDs(),
))
// TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
return task.Result(), nil
}
// From Proxy
tr := timerecord.NewTimeRecorder("searchDelegator")
// get delegator
sd, ok := node.delegators.Get(channel)
if !ok {
log.Warn("Query failed, failed to get query shard delegator", zap.Error(ErrGetDelegatorFailed))
return nil, ErrGetDelegatorFailed
}
// do search
results, err := sd.Search(searchCtx, req)
if err != nil {
log.Warn("failed to search on delegator", zap.Error(err))
return nil, err
}
// reduce result
tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
traceID,
req.GetFromShardLeader(),
channel,
req.GetSegmentIDs(),
))
ret, err := segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err != nil {
return nil, err
}
tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , vChannel = %s, segmentIDs = %v",
channel,
req.GetSegmentIDs(),
))
// update metric to prometheus
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk()))
return ret, nil
}
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.Req.GetCollectionID()),
zap.String("channel", channel),
zap.String("scope", req.GetScope().String()),
)
failRet := &internalpb.GetStatisticsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}
if req.GetFromShardLeader() {
var results []segments.SegmentStats
var err error
switch req.GetScope() {
case querypb.DataScope_Historical:
results, _, _, err = segments.StatisticsHistorical(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
case querypb.DataScope_Streaming:
results, _, _, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
}
if err != nil {
log.Warn("get segments statistics failed", zap.Error(err))
return nil, err
}
return segmentStatsResponse(results), nil
}
sd, ok := node.delegators.Get(channel)
if !ok {
log.Warn("GetStatistics failed, failed to get query shard delegator")
return failRet, nil
}
results, err := sd.GetStatistics(ctx, req)
if err != nil {
log.Warn("failed to get statistics from delegator", zap.Error(err))
failRet.Status.Reason = err.Error()
return failRet, nil
}
ret, err := reduceStatisticResponse(results)
if err != nil {
failRet.Status.Reason = err.Error()
return failRet, nil
}
return ret, nil
}
func segmentStatsResponse(segStats []segments.SegmentStats) *internalpb.GetStatisticsResponse {
var totalRowNum int64
for _, stats := range segStats {
totalRowNum += stats.RowCount
}
resultMap := make(map[string]string)
resultMap["row_count"] = strconv.FormatInt(totalRowNum, 10)
ret := &internalpb.GetStatisticsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Stats: funcutil.Map2KeyValuePair(resultMap),
}
return ret
}
func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*internalpb.GetStatisticsResponse, error) {
mergedResults := map[string]interface{}{
"row_count": int64(0),
}
fieldMethod := map[string]func(string) error{
"row_count": func(str string) error {
count, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return err
}
mergedResults["row_count"] = mergedResults["row_count"].(int64) + count
return nil
},
}
for _, partialResult := range results {
for _, pair := range partialResult.Stats {
fn, ok := fieldMethod[pair.Key]
if !ok {
return nil, fmt.Errorf("unknown statistic field: %s", pair.Key)
}
if err := fn(pair.Value); err != nil {
return nil, err
}
}
}
stringMap := make(map[string]string)
for k, v := range mergedResults {
stringMap[k] = fmt.Sprint(v)
}
ret := &internalpb.GetStatisticsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Stats: funcutil.Map2KeyValuePair(stringMap),
}
return ret, nil
}