mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 03:18:29 +08:00
proxy add illegal check for search result (#7227)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
6f33214ad3
commit
8405d90f5e
@ -1659,23 +1659,23 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb
|
||||
// return decodeSearchResultsParallelByCPU(searchResults)
|
||||
}
|
||||
|
||||
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
|
||||
log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfsearchResultData", len(searchResultData)),
|
||||
zap.Any("nq", nq), zap.Any("availableQueryNodeNum", availableQueryNodeNum),
|
||||
zap.Any("topk", topk), zap.Any("metricType", metricType),
|
||||
zap.Any("maxParallel", maxParallel))
|
||||
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
|
||||
nq := searchResultData[0].NumQueries
|
||||
topk := searchResultData[0].TopK
|
||||
|
||||
for i, sData := range searchResultData {
|
||||
log.Debug("reduceSearchResultDataParallel", zap.Any("i", i), zap.Any("len(FieldsData)", len(sData.FieldsData)))
|
||||
}
|
||||
log.Debug("reduceSearchResultDataParallel",
|
||||
zap.Int("len(searchResultData)", len(searchResultData)),
|
||||
zap.Int64("availableQueryNodeNum", availableQueryNodeNum),
|
||||
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType),
|
||||
zap.Int("maxParallel", maxParallel))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{
|
||||
@ -1689,14 +1689,36 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
||||
},
|
||||
}
|
||||
|
||||
for i, sData := range searchResultData {
|
||||
log.Debug("reduceSearchResultDataParallel",
|
||||
zap.Int("i", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Any("len(FieldsData)", len(sData.FieldsData)))
|
||||
if sData.NumQueries != nq {
|
||||
return ret, fmt.Errorf("search result's nq(%d) mis-match with %d", sData.NumQueries, nq)
|
||||
}
|
||||
if sData.TopK != topk {
|
||||
return ret, fmt.Errorf("search result's topk(%d) mis-match with %d", sData.TopK, topk)
|
||||
}
|
||||
if len(sData.Ids.GetIntId().Data) != (int)(nq*topk) {
|
||||
return ret, fmt.Errorf("search result's id length %d invalid", len(sData.Ids.GetIntId().Data))
|
||||
}
|
||||
if len(sData.Scores) != (int)(nq*topk) {
|
||||
return ret, fmt.Errorf("search result's score length %d invalid", len(sData.Scores))
|
||||
}
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
// TODO(yukun): Use parallel function
|
||||
realTopK := -1
|
||||
for idx := 0; idx < nq; idx++ {
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
var realTopK int64 = -1
|
||||
var idx int64
|
||||
var j int64
|
||||
for idx = 0; idx < nq; idx++ {
|
||||
locs := make([]int64, availableQueryNodeNum)
|
||||
|
||||
j := 0
|
||||
j = 0
|
||||
for ; j < topk; j++ {
|
||||
valid := true
|
||||
choice, maxDistance := 0, minFloat32
|
||||
@ -1823,22 +1845,22 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
if ret.Results.FieldsData[k].GetVectors().GetBinaryVector() == nil {
|
||||
bvec := &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: vectorType.BinaryVector[curIdx*int((dim/8)) : (curIdx+1)*int((dim/8))],
|
||||
BinaryVector: vectorType.BinaryVector[curIdx*(dim/8) : (curIdx+1)*(dim/8)],
|
||||
}
|
||||
ret.Results.FieldsData[k].GetVectors().Data = bvec
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[curIdx*int((dim/8)):(curIdx+1)*int((dim/8))]...)
|
||||
ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[curIdx*(dim/8):(curIdx+1)*(dim/8)]...)
|
||||
}
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
if ret.Results.FieldsData[k].GetVectors().GetFloatVector() == nil {
|
||||
fvec := &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: vectorType.FloatVector.Data[curIdx*int(dim) : (curIdx+1)*int(dim)],
|
||||
Data: vectorType.FloatVector.Data[curIdx*dim : (curIdx+1)*dim],
|
||||
},
|
||||
}
|
||||
ret.Results.FieldsData[k].GetVectors().Data = fvec
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data = append(ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[curIdx*int(dim):(curIdx+1)*int(dim)]...)
|
||||
ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data = append(ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[curIdx*dim:(curIdx+1)*dim]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1851,10 +1873,10 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, int64(realTopK))
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
}
|
||||
|
||||
ret.Results.TopK = int64(realTopK)
|
||||
ret.Results.TopK = realTopK
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range ret.Results.Scores {
|
||||
@ -1865,12 +1887,12 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) (*milvuspb.SearchResults, error) {
|
||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, metricType string) (*milvuspb.SearchResults, error) {
|
||||
t := time.Now()
|
||||
defer func() {
|
||||
log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t)))
|
||||
}()
|
||||
return reduceSearchResultDataParallel(searchResultData, nq, availableQueryNodeNum, topk, metricType, runtime.NumCPU())
|
||||
return reduceSearchResultDataParallel(searchResultData, availableQueryNodeNum, metricType, runtime.NumCPU())
|
||||
}
|
||||
|
||||
func printSearchResult(partialSearchResult *internalpb.SearchResults) {
|
||||
@ -1950,22 +1972,7 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
nq := results[0].NumQueries
|
||||
topk := 0
|
||||
for _, partialResult := range results {
|
||||
topk = getMax(topk, int(partialResult.TopK))
|
||||
}
|
||||
if nq <= 0 {
|
||||
st.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: filterReason,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
st.result, err = reduceSearchResultData(results, int(nq), availableQueryNodeNum, topk, searchResults[0].MetricType)
|
||||
st.result, err = reduceSearchResultData(results, int64(availableQueryNodeNum), searchResults[0].MetricType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -470,10 +470,10 @@ func (c *Core) setMsgStreams() error {
|
||||
Timestamps: pt,
|
||||
DefaultTimestamp: t,
|
||||
}
|
||||
log.Debug("update timetick",
|
||||
zap.Any("DefaultTs", t),
|
||||
zap.Any("sourceID", c.session.ServerID),
|
||||
zap.Any("reason", reason))
|
||||
//log.Debug("update timetick",
|
||||
// zap.Any("DefaultTs", t),
|
||||
// zap.Any("sourceID", c.session.ServerID),
|
||||
// zap.Any("reason", reason))
|
||||
return c.chanTimeTick.UpdateTimeTick(&ttMsg, reason)
|
||||
}
|
||||
|
||||
|
@ -83,13 +83,13 @@ func newTimeTickSync(core *Core) *timetickSync {
|
||||
|
||||
// sendToChannel send all channels' timetick to sendChan
|
||||
// lock is needed by the invoker
|
||||
func (t *timetickSync) sendToChannel() error {
|
||||
func (t *timetickSync) sendToChannel() {
|
||||
if len(t.proxyTimeTick) == 0 {
|
||||
return fmt.Errorf("proxyTimeTick empty")
|
||||
return
|
||||
}
|
||||
for _, v := range t.proxyTimeTick {
|
||||
if v == nil {
|
||||
return fmt.Errorf("proxyTimeTick has not been fulfilled")
|
||||
return
|
||||
}
|
||||
}
|
||||
// clear proxyTimeTick and send a clone
|
||||
@ -99,7 +99,6 @@ func (t *timetickSync) sendToChannel() error {
|
||||
t.proxyTimeTick[k] = nil
|
||||
}
|
||||
t.sendChan <- ptt
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDmlTimeTick add ts into ddlTimetickInfos[sourceID],
|
||||
@ -191,12 +190,10 @@ func (t *timetickSync) UpdateTimeTick(in *internalpb.ChannelTimeTickMsg, reason
|
||||
}
|
||||
|
||||
t.proxyTimeTick[in.Base.SourceID] = newChannelTimeTickMsg(in)
|
||||
log.Debug("update proxyTimeTick", zap.Int64("source id", in.Base.SourceID),
|
||||
zap.Uint64("inTs", in.DefaultTimestamp), zap.String("reason", reason))
|
||||
//log.Debug("update proxyTimeTick", zap.Int64("source id", in.Base.SourceID),
|
||||
// zap.Uint64("inTs", in.DefaultTimestamp), zap.String("reason", reason))
|
||||
|
||||
if err := t.sendToChannel(); err != nil {
|
||||
log.Debug("sendToChannel fail", zap.Any("err", err.Error()))
|
||||
}
|
||||
t.sendToChannel()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user