mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-03 12:29:36 +08:00
dad060af2e
Signed-off-by: aoiasd <zhicheng.yue@zilliz.com> Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
700 lines
23 KiB
Go
700 lines
23 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
|
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"github.com/milvus-io/milvus/internal/metrics"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
|
|
"github.com/milvus-io/milvus/internal/util/distance"
|
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
|
"github.com/milvus-io/milvus/internal/util/grpcclient"
|
|
"github.com/milvus-io/milvus/internal/util/timerecord"
|
|
"github.com/milvus-io/milvus/internal/util/trace"
|
|
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
|
|
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
|
)
|
|
|
|
type searchTask struct {
|
|
Condition
|
|
*internalpb.SearchRequest
|
|
ctx context.Context
|
|
|
|
result *milvuspb.SearchResults
|
|
request *milvuspb.SearchRequest
|
|
qc types.QueryCoord
|
|
tr *timerecord.TimeRecorder
|
|
collectionName string
|
|
schema *schemapb.CollectionSchema
|
|
|
|
resultBuf chan *internalpb.SearchResults
|
|
toReduceResults []*internalpb.SearchResults
|
|
|
|
searchShardPolicy pickShardPolicy
|
|
shardMgr *shardClientMgr
|
|
}
|
|
|
|
func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
|
for _, tag := range partitionNames {
|
|
if err := validatePartitionTag(tag, false); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
partitionsRecord := make(map[UniqueID]bool)
|
|
partitionIDs = make([]UniqueID, 0, len(partitionNames))
|
|
for _, partitionName := range partitionNames {
|
|
pattern := fmt.Sprintf("^%s$", partitionName)
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid partition: %s", partitionName)
|
|
}
|
|
found := false
|
|
for name, pID := range partitionsMap {
|
|
if re.MatchString(name) {
|
|
if _, exist := partitionsRecord[pID]; !exist {
|
|
partitionIDs = append(partitionIDs, pID)
|
|
partitionsRecord[pID] = true
|
|
}
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
|
}
|
|
}
|
|
return partitionIDs, nil
|
|
}
|
|
|
|
func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInfo, error) {
|
|
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
|
if err != nil {
|
|
return nil, errors.New(TopKKey + " not found in search_params")
|
|
}
|
|
topK, err := strconv.Atoi(topKStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
|
}
|
|
|
|
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(MetricTypeKey, searchParamsPair)
|
|
if err != nil {
|
|
return nil, errors.New(MetricTypeKey + " not found in search_params")
|
|
}
|
|
|
|
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
|
|
if err != nil {
|
|
return nil, errors.New(SearchParamsKey + " not found in search_params")
|
|
}
|
|
|
|
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
|
|
if err != nil {
|
|
roundDecimalStr = "-1"
|
|
}
|
|
roundDecimal, err := strconv.Atoi(roundDecimalStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
return &planpb.QueryInfo{
|
|
Topk: int64(topK),
|
|
MetricType: metricType,
|
|
SearchParams: searchParams,
|
|
RoundDecimal: int64(roundDecimal),
|
|
}, nil
|
|
}
|
|
|
|
func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
|
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
|
|
for _, name := range outputFields {
|
|
hitField := false
|
|
for _, field := range schema.GetFields() {
|
|
if field.Name == name {
|
|
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
|
|
return nil, errors.New("search doesn't support vector field as output_fields")
|
|
}
|
|
outputFieldIDs = append(outputFieldIDs, field.GetFieldID())
|
|
|
|
hitField = true
|
|
break
|
|
}
|
|
}
|
|
if !hitField {
|
|
errMsg := "Field " + name + " not exist"
|
|
return nil, errors.New(errMsg)
|
|
}
|
|
}
|
|
return outputFieldIDs, nil
|
|
}
|
|
|
|
func getNq(req *milvuspb.SearchRequest) (int64, error) {
|
|
if req.GetNq() == 0 {
|
|
// keep compatible with older client version.
|
|
x := &commonpb.PlaceholderGroup{}
|
|
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
total := int64(0)
|
|
for _, h := range x.GetPlaceholders() {
|
|
total += int64(len(h.Values))
|
|
}
|
|
return total, nil
|
|
}
|
|
return req.GetNq(), nil
|
|
}
|
|
|
|
func (t *searchTask) PreExecute(ctx context.Context) error {
|
|
sp, ctx := trace.StartSpanFromContextWithOperationName(t.TraceCtx(), "Proxy-Search-PreExecute")
|
|
defer sp.Finish()
|
|
|
|
if t.searchShardPolicy == nil {
|
|
t.searchShardPolicy = mergeRoundRobinPolicy
|
|
}
|
|
|
|
t.Base.MsgType = commonpb.MsgType_Search
|
|
t.Base.SourceID = Params.ProxyCfg.GetNodeID()
|
|
|
|
collectionName := t.request.CollectionName
|
|
t.collectionName = collectionName
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
|
if err != nil { // err is not nil if collection not exists
|
|
return err
|
|
}
|
|
|
|
t.SearchRequest.DbID = 0 // todo
|
|
t.SearchRequest.CollectionID = collID
|
|
t.schema, _ = globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
|
|
|
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
|
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, t.request.GetPartitionNames())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// check if collection/partitions are loaded into query node
|
|
loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.SearchRequest.GetPartitionIDs())
|
|
if err != nil {
|
|
return fmt.Errorf("checkIfLoaded failed when search, collection:%v, partitions:%v, err = %s", collectionName, t.request.GetPartitionNames(), err)
|
|
}
|
|
if !loaded {
|
|
return fmt.Errorf("collection:%v or partition:%v not loaded into memory when search", collectionName, t.request.GetPartitionNames())
|
|
}
|
|
|
|
t.request.OutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Ctx(ctx).Debug("translate output fields", zap.Int64("msgID", t.ID()),
|
|
zap.Strings("output fields", t.request.GetOutputFields()))
|
|
|
|
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
|
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
|
if err != nil {
|
|
return errors.New(AnnsFieldKey + " not found in search_params")
|
|
}
|
|
|
|
queryInfo, err := parseQueryInfo(t.request.GetSearchParams())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
plan, err := planparserv2.CreateSearchPlan(t.schema, t.request.Dsl, annsField, queryInfo)
|
|
if err != nil {
|
|
log.Ctx(ctx).Warn("failed to create query plan", zap.Error(err), zap.Int64("msgID", t.ID()),
|
|
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
|
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
|
return fmt.Errorf("failed to create query plan: %v", err)
|
|
}
|
|
log.Ctx(ctx).Debug("create query plan", zap.Int64("msgID", t.ID()),
|
|
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
|
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
|
|
|
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
|
plan.OutputFieldIds = outputFieldIDs
|
|
|
|
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
|
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
|
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
t.SearchRequest.Topk = queryInfo.GetTopk()
|
|
if err := validateTopK(queryInfo.GetTopk()); err != nil {
|
|
return err
|
|
}
|
|
log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()),
|
|
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
|
zap.String("plan", plan.String())) // may be very large if large term passed.
|
|
}
|
|
|
|
travelTimestamp := t.request.TravelTimestamp
|
|
if travelTimestamp == 0 {
|
|
travelTimestamp = typeutil.MaxTimestamp
|
|
}
|
|
err = validateTravelTimestamp(travelTimestamp, t.BeginTs())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t.SearchRequest.TravelTimestamp = travelTimestamp
|
|
|
|
guaranteeTs := t.request.GetGuaranteeTimestamp()
|
|
guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs())
|
|
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
|
|
|
|
deadline, ok := t.TraceCtx().Deadline()
|
|
if ok {
|
|
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
|
}
|
|
|
|
t.SearchRequest.Dsl = t.request.Dsl
|
|
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
|
if t.SearchRequest.Nq, err = getNq(t.request); err != nil {
|
|
return err
|
|
}
|
|
log.Ctx(ctx).Debug("search PreExecute done.", zap.Int64("msgID", t.ID()),
|
|
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
|
|
zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp()))
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *searchTask) Execute(ctx context.Context) error {
|
|
sp, ctx := trace.StartSpanFromContextWithOperationName(t.TraceCtx(), "Proxy-Search-Execute")
|
|
defer sp.Finish()
|
|
|
|
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID()))
|
|
defer tr.CtxElapse(ctx, "done")
|
|
|
|
executeSearch := func(withCache bool) error {
|
|
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
|
|
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
|
|
if err := t.searchShardPolicy(ctx, t.shardMgr, t.searchShard, shard2Leaders); err != nil {
|
|
log.Ctx(ctx).Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders)))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
err := executeSearch(WithCache)
|
|
if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) {
|
|
log.Ctx(ctx).Warn("first search failed, updating shardleader caches and retry search",
|
|
zap.Int64("msgID", t.ID()), zap.Error(err))
|
|
return executeSearch(WithoutCache)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("fail to search on all shard leaders, err=%v", err)
|
|
}
|
|
|
|
log.Ctx(ctx).Debug("Search Execute done.", zap.Int64("msgID", t.ID()))
|
|
return nil
|
|
}
|
|
|
|
func (t *searchTask) PostExecute(ctx context.Context) error {
|
|
sp, ctx := trace.StartSpanFromContextWithOperationName(t.TraceCtx(), "Proxy-Search-PostExecute")
|
|
defer sp.Finish()
|
|
tr := timerecord.NewTimeRecorder("searchTask PostExecute")
|
|
defer func() {
|
|
tr.CtxElapse(ctx, "done")
|
|
}()
|
|
|
|
select {
|
|
// in case timeout happened
|
|
case <-t.TraceCtx().Done():
|
|
log.Ctx(ctx).Debug("wait to finish timeout!", zap.Int64("msgID", t.ID()))
|
|
return nil
|
|
default:
|
|
log.Ctx(ctx).Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID()))
|
|
close(t.resultBuf)
|
|
for res := range t.resultBuf {
|
|
t.toReduceResults = append(t.toReduceResults, res)
|
|
log.Ctx(ctx).Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID()))
|
|
}
|
|
}
|
|
tr.CtxRecord(ctx, "decodeResultStart")
|
|
validSearchResults, err := decodeSearchResults(ctx, t.toReduceResults)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
|
|
metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
|
|
|
log.Ctx(ctx).Debug("proxy search post execute stage 2", zap.Int64("msgID", t.ID()),
|
|
zap.Int("len(validSearchResults)", len(validSearchResults)))
|
|
if len(validSearchResults) <= 0 {
|
|
log.Ctx(ctx).Warn("search result is empty", zap.Int64("msgID", t.ID()))
|
|
|
|
t.result = &milvuspb.SearchResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
Reason: "search result is empty",
|
|
},
|
|
CollectionName: t.collectionName,
|
|
}
|
|
// add information if any
|
|
if len(t.toReduceResults) > 0 {
|
|
t.result.Results = &schemapb.SearchResultData{
|
|
NumQueries: t.toReduceResults[0].NumQueries,
|
|
Topks: make([]int64, t.toReduceResults[0].NumQueries),
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
tr.CtxRecord(ctx, "reduceResultStart")
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(t.schema)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t.result, err = reduceSearchResultData(ctx, validSearchResults, t.toReduceResults[0].NumQueries,
|
|
t.toReduceResults[0].TopK, t.toReduceResults[0].MetricType, primaryFieldSchema.DataType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
|
t.result.CollectionName = t.collectionName
|
|
|
|
schema, err := globalMetaCache.GetCollectionSchema(ctx, t.request.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
|
|
for k, fieldName := range t.request.OutputFields {
|
|
for _, field := range schema.Fields {
|
|
if t.result.Results.FieldsData[k] != nil && field.Name == fieldName {
|
|
t.result.Results.FieldsData[k].FieldName = field.Name
|
|
t.result.Results.FieldsData[k].FieldId = field.FieldID
|
|
t.result.Results.FieldsData[k].Type = field.DataType
|
|
}
|
|
}
|
|
}
|
|
}
|
|
log.Ctx(ctx).Debug("Search post execute done", zap.Int64("msgID", t.ID()))
|
|
return nil
|
|
}
|
|
|
|
func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
|
|
req := &querypb.SearchRequest{
|
|
Req: t.SearchRequest,
|
|
DmlChannels: channelIDs,
|
|
Scope: querypb.DataScope_All,
|
|
}
|
|
result, err := qn.Search(ctx, req)
|
|
if err != nil {
|
|
log.Ctx(ctx).Warn("QueryNode search return error", zap.Int64("msgID", t.ID()),
|
|
zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err))
|
|
return err
|
|
}
|
|
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
|
log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()),
|
|
zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs))
|
|
return errInvalidShardLeaders
|
|
}
|
|
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
|
log.Ctx(ctx).Warn("QueryNode search result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID),
|
|
zap.String("reason", result.GetStatus().GetReason()))
|
|
return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
|
|
}
|
|
t.resultBuf <- result
|
|
|
|
return nil
|
|
}
|
|
|
|
// checkIfLoaded check if collection was loaded into QueryNode
|
|
func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, searchPartitionIDs []UniqueID) (bool, error) {
|
|
info, err := globalMetaCache.GetCollectionInfo(ctx, collectionName)
|
|
if err != nil {
|
|
return false, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err)
|
|
}
|
|
if info.isLoaded {
|
|
return true, nil
|
|
}
|
|
if len(searchPartitionIDs) == 0 {
|
|
return false, nil
|
|
}
|
|
|
|
// If request to search partitions
|
|
resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_ShowPartitions,
|
|
SourceID: Params.ProxyCfg.GetNodeID(),
|
|
},
|
|
CollectionID: info.collID,
|
|
PartitionIDs: searchPartitionIDs,
|
|
})
|
|
if err != nil {
|
|
return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err)
|
|
}
|
|
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason())
|
|
}
|
|
|
|
for _, persent := range resp.InMemoryPercentages {
|
|
if persent < 100 {
|
|
return false, nil
|
|
}
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
|
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
|
results := make([]*schemapb.SearchResultData, 0)
|
|
for _, partialSearchResult := range searchResults {
|
|
if partialSearchResult.SlicedBlob == nil {
|
|
continue
|
|
}
|
|
|
|
var partialResultData schemapb.SearchResultData
|
|
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
results = append(results, &partialResultData)
|
|
}
|
|
tr.CtxElapse(ctx, "decodeSearchResults done")
|
|
return results, nil
|
|
}
|
|
|
|
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
|
|
if data.NumQueries != nq {
|
|
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
|
}
|
|
if data.TopK != topk {
|
|
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
|
}
|
|
|
|
pkHitNum := typeutil.GetSizeOfIDs(data.GetIds())
|
|
if len(data.Scores) != pkHitNum {
|
|
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
|
|
len(data.Scores), pkHitNum)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int {
|
|
sel := -1
|
|
maxDistance := minFloat32 // distance here means score :)
|
|
for i, offset := range offsets { // query num, the number of ways to merge
|
|
if offset >= dataArray[i].Topks[qi] {
|
|
continue
|
|
}
|
|
idx := resultOffsets[i][qi] + offset
|
|
distance := dataArray[i].Scores[idx]
|
|
if distance > maxDistance {
|
|
sel = i
|
|
maxDistance = distance
|
|
}
|
|
}
|
|
return sel
|
|
}
|
|
|
|
func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType) (*milvuspb.SearchResults, error) {
|
|
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
|
defer func() {
|
|
tr.CtxElapse(ctx, "done")
|
|
}()
|
|
|
|
log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
|
|
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: 0,
|
|
},
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: nq,
|
|
TopK: topk,
|
|
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
|
Scores: make([]float32, 0),
|
|
Ids: &schemapb.IDs{},
|
|
Topks: make([]int64, 0),
|
|
},
|
|
}
|
|
switch pkType {
|
|
case schemapb.DataType_Int64:
|
|
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: make([]int64, 0),
|
|
},
|
|
}
|
|
case schemapb.DataType_VarChar:
|
|
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
|
StrId: &schemapb.StringArray{
|
|
Data: make([]string, 0),
|
|
},
|
|
}
|
|
default:
|
|
return nil, errors.New("unsupported pk type")
|
|
}
|
|
|
|
for i, sData := range searchResultData {
|
|
log.Ctx(ctx).Debug("reduceSearchResultData",
|
|
zap.Int("result No.", i),
|
|
zap.Int64("nq", sData.NumQueries),
|
|
zap.Int64("topk", sData.TopK),
|
|
zap.Any("len(topks)", len(sData.Topks)),
|
|
zap.Any("len(FieldsData)", len(sData.FieldsData)))
|
|
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return ret, err
|
|
}
|
|
//printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
|
}
|
|
|
|
resultOffsets := make([][]int64, len(searchResultData))
|
|
for i := 0; i < len(searchResultData); i++ {
|
|
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
|
for j := int64(1); j < nq; j++ {
|
|
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
|
}
|
|
}
|
|
|
|
var skipDupCnt int64
|
|
var realTopK int64 = -1
|
|
for i := int64(0); i < nq; i++ {
|
|
offsets := make([]int64, len(searchResultData))
|
|
|
|
var idSet = make(map[interface{}]struct{})
|
|
var j int64
|
|
for j = 0; j < topk; {
|
|
sel := selectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
|
if sel == -1 {
|
|
break
|
|
}
|
|
idx := resultOffsets[sel][i] + offsets[sel]
|
|
|
|
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
|
score := searchResultData[sel].Scores[idx]
|
|
|
|
// remove duplicates
|
|
if _, ok := idSet[id]; !ok {
|
|
typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx)
|
|
typeutil.AppendPKs(ret.Results.Ids, id)
|
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
|
idSet[id] = struct{}{}
|
|
j++
|
|
} else {
|
|
// skip entity with same id
|
|
skipDupCnt++
|
|
}
|
|
offsets[sel]++
|
|
}
|
|
if realTopK != -1 && realTopK != j {
|
|
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
|
// return nil, errors.New("the length (topk) between all result of query is different")
|
|
}
|
|
realTopK = j
|
|
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
|
}
|
|
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
|
ret.Results.TopK = realTopK
|
|
|
|
if !distance.PositivelyRelated(metricType) {
|
|
for k := range ret.Results.Scores {
|
|
ret.Results.Scores[k] *= -1
|
|
}
|
|
}
|
|
// printSearchResultData(ret.Results, "proxy reduce result")
|
|
return ret, nil
|
|
}
|
|
|
|
//func printSearchResultData(data *schemapb.SearchResultData, header string) {
|
|
// size := len(data.Ids.GetIntId().Data)
|
|
// if size != len(data.Scores) {
|
|
// log.Error("SearchResultData length mis-match")
|
|
// }
|
|
// log.Debug("==== SearchResultData ====",
|
|
// zap.String("header", header), zap.Int64("nq", data.NumQueries), zap.Int64("topk", data.TopK))
|
|
// for i := 0; i < size; i++ {
|
|
// log.Debug("", zap.Int("i", i), zap.Int64("id", data.Ids.GetIntId().Data[i]), zap.Float32("score", data.Scores[i]))
|
|
// }
|
|
//}
|
|
|
|
// func printSearchResult(partialSearchResult *internalpb.SearchResults) {
|
|
// for i := 0; i < len(partialSearchResult.Hits); i++ {
|
|
// testHits := milvuspb.Hits{}
|
|
// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
|
|
// if err != nil {
|
|
// panic(err)
|
|
// }
|
|
// fmt.Println(testHits.IDs)
|
|
// fmt.Println(testHits.Scores)
|
|
// }
|
|
// }
|
|
|
|
func (t *searchTask) TraceCtx() context.Context {
|
|
return t.ctx
|
|
}
|
|
|
|
func (t *searchTask) ID() UniqueID {
|
|
return t.Base.MsgID
|
|
}
|
|
|
|
func (t *searchTask) SetID(uid UniqueID) {
|
|
t.Base.MsgID = uid
|
|
}
|
|
|
|
func (t *searchTask) Name() string {
|
|
return SearchTaskName
|
|
}
|
|
|
|
func (t *searchTask) Type() commonpb.MsgType {
|
|
return t.Base.MsgType
|
|
}
|
|
|
|
func (t *searchTask) BeginTs() Timestamp {
|
|
return t.Base.Timestamp
|
|
}
|
|
|
|
func (t *searchTask) EndTs() Timestamp {
|
|
return t.Base.Timestamp
|
|
}
|
|
|
|
func (t *searchTask) SetTs(ts Timestamp) {
|
|
t.Base.Timestamp = ts
|
|
}
|
|
|
|
func (t *searchTask) OnEnqueue() error {
|
|
t.Base = &commonpb.MsgBase{}
|
|
t.Base.MsgType = commonpb.MsgType_Search
|
|
t.Base.SourceID = Params.ProxyCfg.GetNodeID()
|
|
return nil
|
|
}
|