milvus/internal/querynodev2/tasks/search_task.go
SimFG 90bed1caf9
enhance: add the related data size for the read apis (#31816)
issue: #30436
origin pr: #30438
related pr: #31772

---------

Signed-off-by: SimFG <bang.fu@zilliz.com>
2024-04-10 15:07:17 +08:00

385 lines
10 KiB
Go

package tasks
// TODO: rename this file into search_task.go
import (
"bytes"
"context"
"fmt"
"strconv"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/collector"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
var (
_ Task = &SearchTask{}
_ MergeTask = &SearchTask{}
)
type SearchTask struct {
ctx context.Context
collection *segments.Collection
segmentManager *segments.Manager
req *querypb.SearchRequest
result *internalpb.SearchResults
merged bool
groupSize int64
topk int64
nq int64
placeholderGroup []byte
originTopks []int64
originNqs []int64
others []*SearchTask
notifier chan error
serverID int64
tr *timerecord.TimeRecorder
scheduleSpan trace.Span
}
func NewSearchTask(ctx context.Context,
collection *segments.Collection,
manager *segments.Manager,
req *querypb.SearchRequest,
serverID int64,
) *SearchTask {
ctx, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "schedule")
return &SearchTask{
ctx: ctx,
collection: collection,
segmentManager: manager,
req: req,
merged: false,
groupSize: 1,
topk: req.GetReq().GetTopk(),
nq: req.GetReq().GetNq(),
placeholderGroup: req.GetReq().GetPlaceholderGroup(),
originTopks: []int64{req.GetReq().GetTopk()},
originNqs: []int64{req.GetReq().GetNq()},
notifier: make(chan error, 1),
tr: timerecord.NewTimeRecorderWithTrace(ctx, "searchTask"),
scheduleSpan: span,
serverID: serverID,
}
}
// Return the username which task is belong to.
// Return "" if the task do not contain any user info.
func (t *SearchTask) Username() string {
return t.req.Req.GetUsername()
}
func (t *SearchTask) GetNodeID() int64 {
return t.serverID
}
func (t *SearchTask) IsGpuIndex() bool {
return t.collection.IsGpuIndex()
}
func (t *SearchTask) PreExecute() error {
// Update task wait time metric before execute
nodeID := strconv.FormatInt(t.GetNodeID(), 10)
inQueueDuration := t.tr.ElapseSpan()
inQueueDurationMS := inQueueDuration.Seconds() * 1000
// Update in queue metric for prometheus.
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(
nodeID,
metrics.SearchLabel,
t.collection.GetDBName(),
t.collection.GetResourceGroup(),
// TODO: resource group and db name may be removed at runtime,
// should be refactor into metricsutil.observer in the future.
).Observe(inQueueDurationMS)
username := t.Username()
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
nodeID,
metrics.SearchLabel,
username).
Observe(inQueueDurationMS)
// Update collector for query node quota.
collector.Average.Add(metricsinfo.SearchQueueMetric, float64(inQueueDuration.Microseconds()))
// Execute merged task's PreExecute.
for _, subTask := range t.others {
err := subTask.PreExecute()
if err != nil {
return err
}
}
return nil
}
func (t *SearchTask) Execute() error {
log := log.Ctx(t.ctx).With(
zap.Int64("collectionID", t.collection.ID()),
zap.String("shard", t.req.GetDmlChannels()[0]),
)
if t.scheduleSpan != nil {
t.scheduleSpan.End()
}
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask")
req := t.req
err := t.combinePlaceHolderGroups()
if err != nil {
return err
}
searchReq, err := segments.NewSearchRequest(t.ctx, t.collection, req, t.placeholderGroup)
if err != nil {
return err
}
defer searchReq.Delete()
var (
results []*segments.SearchResult
searchedSegments []segments.Segment
)
if req.GetScope() == querypb.DataScope_Historical {
results, searchedSegments, err = segments.SearchHistorical(
t.ctx,
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetSegmentIDs(),
)
} else if req.GetScope() == querypb.DataScope_Streaming {
results, searchedSegments, err = segments.SearchStreaming(
t.ctx,
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetSegmentIDs(),
)
}
defer t.segmentManager.Segment.Unpin(searchedSegments)
if err != nil {
return err
}
defer segments.DeleteSearchResults(results)
// plan.MetricType is accurate, though req.MetricType may be empty
metricType := searchReq.Plan().GetMetricType()
if len(results) == 0 {
for i := range t.originNqs {
var task *SearchTask
if i == 0 {
task = t
} else {
task = t.others[i-1]
}
task.result = &internalpb.SearchResults{
Base: &commonpb.MsgBase{
SourceID: t.GetNodeID(),
},
Status: merr.Success(),
MetricType: metricType,
NumQueries: t.originNqs[i],
TopK: t.originTopks[i],
SlicedOffset: 1,
SlicedNumCount: 1,
CostAggregation: &internalpb.CostAggregation{
ServiceTime: tr.ElapseSpan().Milliseconds(),
},
}
}
return nil
}
relatedDataSize := lo.Reduce(searchedSegments, func(acc int64, seg segments.Segment, _ int) int64 {
return acc + seg.MemSize()
}, 0)
tr.RecordSpan()
blobs, err := segments.ReduceSearchResultsAndFillData(
t.ctx,
searchReq.Plan(),
results,
int64(len(results)),
t.originNqs,
t.originTopks,
)
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return err
}
defer segments.DeleteSearchResultDataBlobs(blobs)
metrics.QueryNodeReduceLatency.WithLabelValues(
fmt.Sprint(t.GetNodeID()),
metrics.SearchLabel,
metrics.ReduceSegments).
Observe(float64(tr.RecordSpan().Milliseconds()))
for i := range t.originNqs {
blob, err := segments.GetSearchResultDataBlob(t.ctx, blobs, i)
if err != nil {
return err
}
var task *SearchTask
if i == 0 {
task = t
} else {
task = t.others[i-1]
}
// Note: blob is unsafe because get from C
bs := make([]byte, len(blob))
copy(bs, blob)
task.result = &internalpb.SearchResults{
Base: &commonpb.MsgBase{
SourceID: t.GetNodeID(),
},
Status: merr.Success(),
MetricType: metricType,
NumQueries: t.originNqs[i],
TopK: t.originTopks[i],
SlicedBlob: bs,
SlicedOffset: 1,
SlicedNumCount: 1,
CostAggregation: &internalpb.CostAggregation{
ServiceTime: tr.ElapseSpan().Milliseconds(),
TotalRelatedDataSize: relatedDataSize,
},
}
}
return nil
}
func (t *SearchTask) Merge(other *SearchTask) bool {
var (
nq = t.nq
topk = t.topk
otherNq = other.nq
otherTopk = other.topk
)
diffTopk := topk != otherTopk
pre := funcutil.Min(nq*topk, otherNq*otherTopk)
maxTopk := funcutil.Max(topk, otherTopk)
after := (nq + otherNq) * maxTopk
ratio := float64(after) / float64(pre)
// Check mergeable
if t.req.GetReq().GetDbID() != other.req.GetReq().GetDbID() ||
t.req.GetReq().GetCollectionID() != other.req.GetReq().GetCollectionID() ||
t.req.GetReq().GetMvccTimestamp() != other.req.GetReq().GetMvccTimestamp() ||
t.req.GetReq().GetDslType() != other.req.GetReq().GetDslType() ||
t.req.GetDmlChannels()[0] != other.req.GetDmlChannels()[0] ||
nq+otherNq > paramtable.Get().QueryNodeCfg.MaxGroupNQ.GetAsInt64() ||
diffTopk && ratio > paramtable.Get().QueryNodeCfg.TopKMergeRatio.GetAsFloat() ||
!funcutil.SliceSetEqual(t.req.GetReq().GetPartitionIDs(), other.req.GetReq().GetPartitionIDs()) ||
!funcutil.SliceSetEqual(t.req.GetSegmentIDs(), other.req.GetSegmentIDs()) ||
!bytes.Equal(t.req.GetReq().GetSerializedExprPlan(), other.req.GetReq().GetSerializedExprPlan()) {
return false
}
// Merge
t.groupSize += other.groupSize
t.topk = maxTopk
t.nq += otherNq
t.originTopks = append(t.originTopks, other.originTopks...)
t.originNqs = append(t.originNqs, other.originNqs...)
t.others = append(t.others, other)
other.merged = true
return true
}
func (t *SearchTask) Done(err error) {
if !t.merged {
metrics.QueryNodeSearchGroupSize.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.groupSize))
metrics.QueryNodeSearchGroupNQ.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.nq))
metrics.QueryNodeSearchGroupTopK.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.topk))
}
t.notifier <- err
for _, other := range t.others {
other.Done(err)
}
}
func (t *SearchTask) Canceled() error {
return t.ctx.Err()
}
func (t *SearchTask) Wait() error {
return <-t.notifier
}
func (t *SearchTask) Result() *internalpb.SearchResults {
if t.result != nil {
channelsMvcc := make(map[string]uint64)
for _, ch := range t.req.GetDmlChannels() {
channelsMvcc[ch] = t.req.GetReq().GetMvccTimestamp()
}
t.result.ChannelsMvcc = channelsMvcc
}
return t.result
}
func (t *SearchTask) NQ() int64 {
return t.nq
}
func (t *SearchTask) MergeWith(other Task) bool {
switch other := other.(type) {
case *SearchTask:
return t.Merge(other)
}
return false
}
// combinePlaceHolderGroups combine all the placeholder groups.
func (t *SearchTask) combinePlaceHolderGroups() error {
if len(t.others) == 0 {
return nil
}
ret := &commonpb.PlaceholderGroup{}
if err := proto.Unmarshal(t.placeholderGroup, ret); err != nil {
return merr.WrapErrParameterInvalidMsg("invalid search vector placeholder: %v", err)
}
if len(ret.GetPlaceholders()) == 0 {
return merr.WrapErrParameterInvalidMsg("empty search vector is not allowed")
}
for _, t := range t.others {
x := &commonpb.PlaceholderGroup{}
if err := proto.Unmarshal(t.placeholderGroup, x); err != nil {
return merr.WrapErrParameterInvalidMsg("invalid search vector placeholder: %v", err)
}
if len(x.GetPlaceholders()) == 0 {
return merr.WrapErrParameterInvalidMsg("empty search vector is not allowed")
}
ret.Placeholders[0].Values = append(ret.Placeholders[0].Values, x.Placeholders[0].Values...)
}
t.placeholderGroup, _ = proto.Marshal(ret)
return nil
}