proxy add illegal check for search result (#7227)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
Cai Yudong 2021-08-23 17:03:51 +08:00 committed by GitHub
parent 6f33214ad3
commit 8405d90f5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 51 deletions

View File

@ -1659,23 +1659,23 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb
// return decodeSearchResultsParallelByCPU(searchResults)
}
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfsearchResultData", len(searchResultData)),
zap.Any("nq", nq), zap.Any("availableQueryNodeNum", availableQueryNodeNum),
zap.Any("topk", topk), zap.Any("metricType", metricType),
zap.Any("maxParallel", maxParallel))
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
nq := searchResultData[0].NumQueries
topk := searchResultData[0].TopK
for i, sData := range searchResultData {
log.Debug("reduceSearchResultDataParallel", zap.Any("i", i), zap.Any("len(FieldsData)", len(sData.FieldsData)))
}
log.Debug("reduceSearchResultDataParallel",
zap.Int("len(searchResultData)", len(searchResultData)),
zap.Int64("availableQueryNodeNum", availableQueryNodeNum),
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType),
zap.Int("maxParallel", maxParallel))
ret := &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: 0,
},
Results: &schemapb.SearchResultData{
NumQueries: int64(nq),
TopK: int64(topk),
NumQueries: nq,
TopK: topk,
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
Scores: make([]float32, 0),
Ids: &schemapb.IDs{
@ -1689,14 +1689,36 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
},
}
for i, sData := range searchResultData {
log.Debug("reduceSearchResultDataParallel",
zap.Int("i", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Any("len(FieldsData)", len(sData.FieldsData)))
if sData.NumQueries != nq {
return ret, fmt.Errorf("search result's nq(%d) mis-match with %d", sData.NumQueries, nq)
}
if sData.TopK != topk {
return ret, fmt.Errorf("search result's topk(%d) mis-match with %d", sData.TopK, topk)
}
if len(sData.Ids.GetIntId().Data) != (int)(nq*topk) {
return ret, fmt.Errorf("search result's id length %d invalid", len(sData.Ids.GetIntId().Data))
}
if len(sData.Scores) != (int)(nq*topk) {
return ret, fmt.Errorf("search result's score length %d invalid", len(sData.Scores))
}
}
const minFloat32 = -1 * float32(math.MaxFloat32)
// TODO(yukun): Use parallel function
realTopK := -1
for idx := 0; idx < nq; idx++ {
locs := make([]int, availableQueryNodeNum)
var realTopK int64 = -1
var idx int64
var j int64
for idx = 0; idx < nq; idx++ {
locs := make([]int64, availableQueryNodeNum)
j := 0
j = 0
for ; j < topk; j++ {
valid := true
choice, maxDistance := 0, minFloat32
@ -1823,22 +1845,22 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
case *schemapb.VectorField_BinaryVector:
if ret.Results.FieldsData[k].GetVectors().GetBinaryVector() == nil {
bvec := &schemapb.VectorField_BinaryVector{
BinaryVector: vectorType.BinaryVector[curIdx*int((dim/8)) : (curIdx+1)*int((dim/8))],
BinaryVector: vectorType.BinaryVector[curIdx*(dim/8) : (curIdx+1)*(dim/8)],
}
ret.Results.FieldsData[k].GetVectors().Data = bvec
} else {
ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[curIdx*int((dim/8)):(curIdx+1)*int((dim/8))]...)
ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[curIdx*(dim/8):(curIdx+1)*(dim/8)]...)
}
case *schemapb.VectorField_FloatVector:
if ret.Results.FieldsData[k].GetVectors().GetFloatVector() == nil {
fvec := &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: vectorType.FloatVector.Data[curIdx*int(dim) : (curIdx+1)*int(dim)],
Data: vectorType.FloatVector.Data[curIdx*dim : (curIdx+1)*dim],
},
}
ret.Results.FieldsData[k].GetVectors().Data = fvec
} else {
ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data = append(ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[curIdx*int(dim):(curIdx+1)*int(dim)]...)
ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data = append(ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[curIdx*dim:(curIdx+1)*dim]...)
}
}
}
@ -1851,10 +1873,10 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, int64(realTopK))
ret.Results.Topks = append(ret.Results.Topks, realTopK)
}
ret.Results.TopK = int64(realTopK)
ret.Results.TopK = realTopK
if metricType != "IP" {
for k := range ret.Results.Scores {
@ -1865,12 +1887,12 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
return ret, nil
}
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) (*milvuspb.SearchResults, error) {
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, metricType string) (*milvuspb.SearchResults, error) {
t := time.Now()
defer func() {
log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t)))
}()
return reduceSearchResultDataParallel(searchResultData, nq, availableQueryNodeNum, topk, metricType, runtime.NumCPU())
return reduceSearchResultDataParallel(searchResultData, availableQueryNodeNum, metricType, runtime.NumCPU())
}
func printSearchResult(partialSearchResult *internalpb.SearchResults) {
@ -1950,22 +1972,7 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
return err
}
nq := results[0].NumQueries
topk := 0
for _, partialResult := range results {
topk = getMax(topk, int(partialResult.TopK))
}
if nq <= 0 {
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: filterReason,
},
}
return nil
}
st.result, err = reduceSearchResultData(results, int(nq), availableQueryNodeNum, topk, searchResults[0].MetricType)
st.result, err = reduceSearchResultData(results, int64(availableQueryNodeNum), searchResults[0].MetricType)
if err != nil {
return err
}

View File

@ -470,10 +470,10 @@ func (c *Core) setMsgStreams() error {
Timestamps: pt,
DefaultTimestamp: t,
}
log.Debug("update timetick",
zap.Any("DefaultTs", t),
zap.Any("sourceID", c.session.ServerID),
zap.Any("reason", reason))
//log.Debug("update timetick",
// zap.Any("DefaultTs", t),
// zap.Any("sourceID", c.session.ServerID),
// zap.Any("reason", reason))
return c.chanTimeTick.UpdateTimeTick(&ttMsg, reason)
}

View File

@ -83,13 +83,13 @@ func newTimeTickSync(core *Core) *timetickSync {
// sendToChannel send all channels' timetick to sendChan
// lock is needed by the invoker
func (t *timetickSync) sendToChannel() error {
func (t *timetickSync) sendToChannel() {
if len(t.proxyTimeTick) == 0 {
return fmt.Errorf("proxyTimeTick empty")
return
}
for _, v := range t.proxyTimeTick {
if v == nil {
return fmt.Errorf("proxyTimeTick has not been fulfilled")
return
}
}
// clear proxyTimeTick and send a clone
@ -99,7 +99,6 @@ func (t *timetickSync) sendToChannel() error {
t.proxyTimeTick[k] = nil
}
t.sendChan <- ptt
return nil
}
// AddDmlTimeTick add ts into ddlTimetickInfos[sourceID],
@ -191,12 +190,10 @@ func (t *timetickSync) UpdateTimeTick(in *internalpb.ChannelTimeTickMsg, reason
}
t.proxyTimeTick[in.Base.SourceID] = newChannelTimeTickMsg(in)
log.Debug("update proxyTimeTick", zap.Int64("source id", in.Base.SourceID),
zap.Uint64("inTs", in.DefaultTimestamp), zap.String("reason", reason))
//log.Debug("update proxyTimeTick", zap.Int64("source id", in.Base.SourceID),
// zap.Uint64("inTs", in.DefaultTimestamp), zap.String("reason", reason))
if err := t.sendToChannel(); err != nil {
log.Debug("sendToChannel fail", zap.Any("err", err.Error()))
}
t.sendToChannel()
return nil
}