Parallelize the processing of SearchTask.PostExecute

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
dragondriver 2021-03-25 10:14:09 +08:00 committed by yefu.chen
parent 2842e929ba
commit 1b743e5c6c
2 changed files with 397 additions and 84 deletions

View File

@ -6,7 +6,9 @@ import (
"fmt"
"math"
"regexp"
"runtime"
"strconv"
"sync"
"go.uber.org/zap"
@ -613,6 +615,378 @@ func (st *SearchTask) Execute(ctx context.Context) error {
return err
}
func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
hits := make([][]*milvuspb.Hits, 0)
for _, partialSearchResult := range searchResults {
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
continue
}
partialHits := make([]*milvuspb.Hits, len(partialSearchResult.Hits))
var err error
for i := range partialSearchResult.Hits {
j := i
func(idx int) {
partialHit := &milvuspb.Hits{}
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
if err != nil {
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
}
partialHits[idx] = partialHit
}(j)
}
if err != nil {
return nil, err
}
hits = append(hits, partialHits)
}
return hits, nil
}
// TODO: add benchmark to compare with serial implementation
func decodeSearchResultsParallel(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
hits := make([][]*milvuspb.Hits, 0)
// necessary to parallel this?
for _, partialSearchResult := range searchResults {
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
continue
}
// necessary to check nq (len(partialSearchResult.Hits) here)?
partialHits := make([]*milvuspb.Hits, len(partialSearchResult.Hits))
var wg sync.WaitGroup
var err error
for i := range partialSearchResult.Hits {
j := i
wg.Add(1)
go func(idx int) {
defer wg.Done()
partialHit := &milvuspb.Hits{}
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
if err != nil {
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
}
partialHits[idx] = partialHit
}(j)
}
if err != nil {
return nil, err
}
wg.Wait()
hits = append(hits, partialHits)
}
return hits, nil
}
// TODO: add benchmark to compare with serial implementation
func decodeSearchResultsParallelByCPU(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
log.Debug("ProcessSearchResultParallel", zap.Any("runtime.NumCPU", runtime.NumCPU()))
hits := make([][]*milvuspb.Hits, 0)
// necessary to parallel this?
for _, partialSearchResult := range searchResults {
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
continue
}
nq := len(partialSearchResult.Hits)
maxParallel := runtime.NumCPU()
nqPerBatch := (nq + maxParallel - 1) / maxParallel
partialHits := make([]*milvuspb.Hits, nq)
var wg sync.WaitGroup
var err error
for i := 0; i < nq; i = i + nqPerBatch {
j := i
wg.Add(1)
go func(begin int) {
defer wg.Done()
end := getMin(nq, begin+nqPerBatch)
for idx := begin; idx < end; idx++ {
partialHit := &milvuspb.Hits{}
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
if err != nil {
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
}
partialHits[idx] = partialHit
}
}(j)
}
if err != nil {
return nil, err
}
wg.Wait()
hits = append(hits, partialHits)
}
return hits, nil
}
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
return decodeSearchResultsParallel(searchResults)
}
func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
ret := &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: 0,
},
Hits: make([][]byte, nq),
}
const minFloat32 = -1 * float32(math.MaxFloat32)
for i := 0; i < nq; i++ {
j := i
func(idx int) {
locs := make([]int, availableQueryNodeNum)
reducedHits := &milvuspb.Hits{
IDs: make([]int64, 0),
RowData: make([][]byte, 0),
Scores: make([]float32, 0),
}
for j := 0; j < topk; j++ {
valid := false
choice, maxDistance := 0, minFloat32
for q, loc := range locs { // query num, the number of ways to merge
if loc >= len(hits[q][idx].IDs) {
continue
}
distance := hits[q][idx].Scores[loc]
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
choice = q
maxDistance = distance
valid = true
}
}
if !valid {
break
}
choiceOffset := locs[choice]
// check if distance is valid, `invalid` here means very very big,
// in this process, distance here is the smallest, so the rest of distance are all invalid
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
break
}
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
}
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
locs[choice]++
}
if metricType != "IP" {
for k := range reducedHits.Scores {
reducedHits.Scores[k] *= -1
}
}
reducedHitsBs, err := proto.Marshal(reducedHits)
if err != nil {
log.Debug("proxynode", zap.String("error", "marshal error"))
}
ret.Hits[idx] = reducedHitsBs
}(j)
}
return ret
}
// TODO: add benchmark to compare with simple serial implementation
func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
ret := &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: 0,
},
Hits: make([][]byte, nq),
}
const minFloat32 = -1 * float32(math.MaxFloat32)
var wg sync.WaitGroup
for i := 0; i < nq; i++ {
j := i
wg.Add(1)
go func(idx int) {
defer wg.Done()
locs := make([]int, availableQueryNodeNum)
reducedHits := &milvuspb.Hits{
IDs: make([]int64, 0),
RowData: make([][]byte, 0),
Scores: make([]float32, 0),
}
for j := 0; j < topk; j++ {
valid := false
choice, maxDistance := 0, minFloat32
for q, loc := range locs { // query num, the number of ways to merge
if loc >= len(hits[q][idx].IDs) {
continue
}
distance := hits[q][idx].Scores[loc]
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
choice = q
maxDistance = distance
valid = true
}
}
if !valid {
break
}
choiceOffset := locs[choice]
// check if distance is valid, `invalid` here means very very big,
// in this process, distance here is the smallest, so the rest of distance are all invalid
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
break
}
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
}
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
locs[choice]++
}
if metricType != "IP" {
for k := range reducedHits.Scores {
reducedHits.Scores[k] *= -1
}
}
reducedHitsBs, err := proto.Marshal(reducedHits)
if err != nil {
log.Debug("proxynode", zap.String("error", "marshal error"))
}
ret.Hits[idx] = reducedHitsBs
}(j)
}
wg.Wait()
return ret
}
// TODO: add benchmark to compare with simple serial implementation
func reduceSearchResultsParallelByCPU(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
maxParallel := runtime.NumCPU()
nqPerBatch := (nq + maxParallel - 1) / maxParallel
ret := &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: 0,
},
Hits: make([][]byte, nq),
}
const minFloat32 = -1 * float32(math.MaxFloat32)
var wg sync.WaitGroup
for begin := 0; begin < nq; begin = begin + nqPerBatch {
j := begin
wg.Add(1)
go func(begin int) {
defer wg.Done()
end := getMin(nq, begin+nqPerBatch)
for idx := begin; idx < end; idx++ {
locs := make([]int, availableQueryNodeNum)
reducedHits := &milvuspb.Hits{
IDs: make([]int64, 0),
RowData: make([][]byte, 0),
Scores: make([]float32, 0),
}
for j := 0; j < topk; j++ {
valid := false
choice, maxDistance := 0, minFloat32
for q, loc := range locs { // query num, the number of ways to merge
if loc >= len(hits[q][idx].IDs) {
continue
}
distance := hits[q][idx].Scores[loc]
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
choice = q
maxDistance = distance
valid = true
}
}
if !valid {
break
}
choiceOffset := locs[choice]
// check if distance is valid, `invalid` here means very very big,
// in this process, distance here is the smallest, so the rest of distance are all invalid
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
break
}
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
}
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
locs[choice]++
}
if metricType != "IP" {
for k := range reducedHits.Scores {
reducedHits.Scores[k] *= -1
}
}
reducedHitsBs, err := proto.Marshal(reducedHits)
if err != nil {
log.Debug("proxynode", zap.String("error", "marshal error"))
}
ret.Hits[idx] = reducedHitsBs
}
}(j)
}
wg.Wait()
return ret
}
func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType)
}
func printSearchResult(partialSearchResult *internalpb.SearchResults) {
for i := 0; i < len(partialSearchResult.Hits); i++ {
testHits := milvuspb.Hits{}
err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
if err != nil {
panic(err)
}
fmt.Println(testHits.IDs)
fmt.Println(testHits.Scores)
}
}
func (st *SearchTask) PostExecute(ctx context.Context) error {
for {
select {
@ -627,15 +1001,7 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success {
filterSearchResult = append(filterSearchResult, partialSearchResult)
// For debugging, please don't delete.
//for i := 0; i < len(partialSearchResult.Hits); i++ {
// testHits := milvuspb.Hits{}
// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
// if err != nil {
// panic(err)
// }
// fmt.Println(testHits.IDs)
// fmt.Println(testHits.Scores)
//}
// printSearchResult(partialSearchResult)
} else {
filterReason += partialSearchResult.Status.Reason + "\n"
}
@ -652,26 +1018,15 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
return errors.New(filterReason)
}
hits := make([][]*milvuspb.Hits, 0)
availableQueryNodeNum = 0
for _, partialSearchResult := range filterSearchResult {
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
filterReason += "nq is zero\n"
continue
}
partialHits := make([]*milvuspb.Hits, 0)
for _, bs := range partialSearchResult.Hits {
partialHit := &milvuspb.Hits{}
err := proto.Unmarshal(bs, partialHit)
if err != nil {
log.Debug("proxynode", zap.String("error", "unmarshal error"))
return err
}
partialHits = append(partialHits, partialHit)
}
hits = append(hits, partialHits)
availableQueryNodeNum++
}
availableQueryNodeNum = len(hits)
if availableQueryNodeNum <= 0 {
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
@ -682,6 +1037,11 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
return nil
}
hits, err := decodeSearchResults(filterSearchResult)
if err != nil {
return err
}
nq := len(hits[0])
if nq <= 0 {
st.result = &milvuspb.SearchResults{
@ -694,73 +1054,12 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
}
topk := 0
getMax := func(a, b int) int {
if a > b {
return a
}
return b
}
for _, hit := range hits {
topk = getMax(topk, len(hit[0].IDs))
}
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: 0,
},
Hits: make([][]byte, 0),
}
const minFloat32 = -1 * float32(math.MaxFloat32)
for i := 0; i < nq; i++ {
locs := make([]int, availableQueryNodeNum)
reducedHits := &milvuspb.Hits{
IDs: make([]int64, 0),
RowData: make([][]byte, 0),
Scores: make([]float32, 0),
}
st.result = reduceSearchResults(hits, nq, availableQueryNodeNum, topk, searchResults[0].MetricType)
for j := 0; j < topk; j++ {
valid := false
choice, maxDistance := 0, minFloat32
for q, loc := range locs { // query num, the number of ways to merge
if loc >= len(hits[q][i].IDs) {
continue
}
distance := hits[q][i].Scores[loc]
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
choice = q
maxDistance = distance
valid = true
}
}
if !valid {
break
}
choiceOffset := locs[choice]
// check if distance is valid, `invalid` here means very very big,
// in this process, distance here is the smallest, so the rest of distance are all invalid
if hits[choice][i].Scores[choiceOffset] <= minFloat32 {
break
}
reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 {
reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])
}
reducedHits.Scores = append(reducedHits.Scores, hits[choice][i].Scores[choiceOffset])
locs[choice]++
}
if searchResults[0].MetricType != "IP" {
for k := range reducedHits.Scores {
reducedHits.Scores[k] *= -1
}
}
reducedHitsBs, err := proto.Marshal(reducedHits)
if err != nil {
log.Debug("proxynode", zap.String("error", "marshal error"))
return err
}
st.result.Hits = append(st.result.Hits, reducedHitsBs)
}
return nil
}
}

View File

@ -97,3 +97,17 @@ func SortedSliceEqual(s1 interface{}, s2 interface{}) bool {
}
return true
}
func getMax(a, b int) int {
if a > b {
return a
}
return b
}
func getMin(a, b int) int {
if a < b {
return a
}
return b
}