mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
e480b103bd
related: #35096 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>
217 lines
7.1 KiB
Go
217 lines
7.1 KiB
Go
package segments
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"go.opentelemetry.io/otel"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
|
"github.com/milvus-io/milvus/pkg/log"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
|
)
|
|
|
|
type SearchReduce interface {
|
|
ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error)
|
|
}
|
|
|
|
type SearchCommonReduce struct{}
|
|
|
|
func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) {
|
|
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
|
defer sp.End()
|
|
log := log.Ctx(ctx)
|
|
|
|
if len(searchResultData) == 0 {
|
|
return &schemapb.SearchResultData{
|
|
NumQueries: info.GetNq(),
|
|
TopK: info.GetTopK(),
|
|
FieldsData: make([]*schemapb.FieldData, 0),
|
|
Scores: make([]float32, 0),
|
|
Ids: &schemapb.IDs{},
|
|
Topks: make([]int64, 0),
|
|
}, nil
|
|
}
|
|
ret := &schemapb.SearchResultData{
|
|
NumQueries: info.GetNq(),
|
|
TopK: info.GetTopK(),
|
|
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
|
Scores: make([]float32, 0),
|
|
Ids: &schemapb.IDs{},
|
|
Topks: make([]int64, 0),
|
|
}
|
|
|
|
resultOffsets := make([][]int64, len(searchResultData))
|
|
for i := 0; i < len(searchResultData); i++ {
|
|
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
|
for j := int64(1); j < info.GetNq(); j++ {
|
|
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
|
}
|
|
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
|
}
|
|
|
|
var skipDupCnt int64
|
|
var retSize int64
|
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
|
for i := int64(0); i < info.GetNq(); i++ {
|
|
offsets := make([]int64, len(searchResultData))
|
|
idSet := make(map[interface{}]struct{})
|
|
var j int64
|
|
for j = 0; j < info.GetTopK(); {
|
|
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
|
if sel == -1 {
|
|
break
|
|
}
|
|
idx := resultOffsets[sel][i] + offsets[sel]
|
|
|
|
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
|
score := searchResultData[sel].Scores[idx]
|
|
|
|
// remove duplicates
|
|
if _, ok := idSet[id]; !ok {
|
|
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
|
typeutil.AppendPKs(ret.Ids, id)
|
|
ret.Scores = append(ret.Scores, score)
|
|
idSet[id] = struct{}{}
|
|
j++
|
|
} else {
|
|
// skip entity with same id
|
|
skipDupCnt++
|
|
}
|
|
offsets[sel]++
|
|
}
|
|
|
|
// if realTopK != -1 && realTopK != j {
|
|
// log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
|
// // return nil, errors.New("the length (topk) between all result of query is different")
|
|
// }
|
|
ret.Topks = append(ret.Topks, j)
|
|
|
|
// limit search result to avoid oom
|
|
if retSize > maxOutputSize {
|
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
|
}
|
|
}
|
|
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
|
return ret, nil
|
|
}
|
|
|
|
type SearchGroupByReduce struct{}
|
|
|
|
func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) {
|
|
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
|
defer sp.End()
|
|
log := log.Ctx(ctx)
|
|
|
|
if len(searchResultData) == 0 {
|
|
return &schemapb.SearchResultData{
|
|
NumQueries: info.GetNq(),
|
|
TopK: info.GetTopK(),
|
|
FieldsData: make([]*schemapb.FieldData, 0),
|
|
Scores: make([]float32, 0),
|
|
Ids: &schemapb.IDs{},
|
|
Topks: make([]int64, 0),
|
|
}, nil
|
|
}
|
|
ret := &schemapb.SearchResultData{
|
|
NumQueries: info.GetNq(),
|
|
TopK: info.GetTopK(),
|
|
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
|
Scores: make([]float32, 0),
|
|
Ids: &schemapb.IDs{},
|
|
Topks: make([]int64, 0),
|
|
}
|
|
|
|
resultOffsets := make([][]int64, len(searchResultData))
|
|
for i := 0; i < len(searchResultData); i++ {
|
|
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
|
for j := int64(1); j < info.GetNq(); j++ {
|
|
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
|
}
|
|
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
|
}
|
|
|
|
var filteredCount int64
|
|
var retSize int64
|
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
|
groupSize := info.GetGroupSize()
|
|
if groupSize <= 0 {
|
|
groupSize = 1
|
|
}
|
|
groupBound := info.GetTopK() * groupSize
|
|
|
|
for i := int64(0); i < info.GetNq(); i++ {
|
|
offsets := make([]int64, len(searchResultData))
|
|
|
|
idSet := make(map[interface{}]struct{})
|
|
groupByValueMap := make(map[interface{}]int64)
|
|
|
|
var j int64
|
|
for j = 0; j < groupBound; {
|
|
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
|
if sel == -1 {
|
|
break
|
|
}
|
|
idx := resultOffsets[sel][i] + offsets[sel]
|
|
|
|
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
|
groupByVal := typeutil.GetData(searchResultData[sel].GetGroupByFieldValue(), int(idx))
|
|
score := searchResultData[sel].Scores[idx]
|
|
if _, ok := idSet[id]; !ok {
|
|
if groupByVal == nil {
|
|
return ret, merr.WrapErrParameterMissing("GroupByVal returned from segment cannot be null")
|
|
}
|
|
|
|
groupCount := groupByValueMap[groupByVal]
|
|
if groupCount == 0 && int64(len(groupByValueMap)) >= info.GetTopK() {
|
|
// exceed the limit for group count, filter this entity
|
|
filteredCount++
|
|
} else if groupCount >= groupSize {
|
|
// exceed the limit for each group, filter this entity
|
|
filteredCount++
|
|
} else {
|
|
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
|
typeutil.AppendPKs(ret.Ids, id)
|
|
ret.Scores = append(ret.Scores, score)
|
|
if err := typeutil.AppendGroupByValue(ret, groupByVal, searchResultData[sel].GetGroupByFieldValue().GetType()); err != nil {
|
|
log.Error("Failed to append groupByValues", zap.Error(err))
|
|
return ret, err
|
|
}
|
|
groupByValueMap[groupByVal] += 1
|
|
idSet[id] = struct{}{}
|
|
j++
|
|
}
|
|
} else {
|
|
// skip entity with same pk
|
|
filteredCount++
|
|
}
|
|
offsets[sel]++
|
|
}
|
|
ret.Topks = append(ret.Topks, j)
|
|
|
|
// limit search result to avoid oom
|
|
if retSize > maxOutputSize {
|
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
|
}
|
|
}
|
|
if float64(filteredCount) >= 0.3*float64(groupBound) {
|
|
log.Warn("GroupBy reduce filtered too many results, "+
|
|
"this may influence the final result seriously",
|
|
zap.Int64("filteredCount", filteredCount),
|
|
zap.Int64("groupBound", groupBound))
|
|
}
|
|
log.Debug("skip duplicated search result", zap.Int64("count", filteredCount))
|
|
return ret, nil
|
|
}
|
|
|
|
func InitSearchReducer(info *reduce.ResultInfo) SearchReduce {
|
|
if info.GetGroupByFieldId() > 0 {
|
|
return &SearchGroupByReduce{}
|
|
}
|
|
return &SearchCommonReduce{}
|
|
}
|