From 1b743e5c6cfcfeadbeeb52fbf7b9e60e82c9778f Mon Sep 17 00:00:00 2001 From: dragondriver Date: Thu, 25 Mar 2021 10:14:09 +0800 Subject: [PATCH] Parallelize the processing of SearchTask.PostExecute Signed-off-by: dragondriver --- internal/proxynode/task.go | 467 ++++++++++++++++++++++++++++++------- internal/proxynode/util.go | 14 ++ 2 files changed, 397 insertions(+), 84 deletions(-) diff --git a/internal/proxynode/task.go b/internal/proxynode/task.go index 6a433104f4..0fa481466c 100644 --- a/internal/proxynode/task.go +++ b/internal/proxynode/task.go @@ -6,7 +6,9 @@ import ( "fmt" "math" "regexp" + "runtime" "strconv" + "sync" "go.uber.org/zap" @@ -613,6 +615,378 @@ func (st *SearchTask) Execute(ctx context.Context) error { return err } +func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) { + hits := make([][]*milvuspb.Hits, 0) + for _, partialSearchResult := range searchResults { + if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 { + continue + } + + partialHits := make([]*milvuspb.Hits, len(partialSearchResult.Hits)) + + var err error + for i := range partialSearchResult.Hits { + j := i + + func(idx int) { + partialHit := &milvuspb.Hits{} + err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit) + if err != nil { + log.Debug("proxynode", zap.Any("unmarshal search result error", err)) + } + partialHits[idx] = partialHit + }(j) + } + + if err != nil { + return nil, err + } + + hits = append(hits, partialHits) + } + + return hits, nil +} + +// TODO: add benchmark to compare with serial implementation +func decodeSearchResultsParallel(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) { + hits := make([][]*milvuspb.Hits, 0) + // necessary to parallel this? + for _, partialSearchResult := range searchResults { + if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 { + continue + } + + // necessary to check nq (len(partialSearchResult.Hits) here)? + partialHits := make([]*milvuspb.Hits, len(partialSearchResult.Hits)) + + var wg sync.WaitGroup + var err error + for i := range partialSearchResult.Hits { + j := i + wg.Add(1) + + go func(idx int) { + defer wg.Done() + partialHit := &milvuspb.Hits{} + err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit) + if err != nil { + log.Debug("proxynode", zap.Any("unmarshal search result error", err)) + } + partialHits[idx] = partialHit + }(j) + } + + if err != nil { + return nil, err + } + + wg.Wait() + + hits = append(hits, partialHits) + } + + return hits, nil +} + +// TODO: add benchmark to compare with serial implementation +func decodeSearchResultsParallelByCPU(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) { + log.Debug("ProcessSearchResultParallel", zap.Any("runtime.NumCPU", runtime.NumCPU())) + hits := make([][]*milvuspb.Hits, 0) + // necessary to parallel this? + for _, partialSearchResult := range searchResults { + if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 { + continue + } + + nq := len(partialSearchResult.Hits) + maxParallel := runtime.NumCPU() + nqPerBatch := (nq + maxParallel - 1) / maxParallel + partialHits := make([]*milvuspb.Hits, nq) + + var wg sync.WaitGroup + var err error + for i := 0; i < nq; i = i + nqPerBatch { + j := i + wg.Add(1) + + go func(begin int) { + defer wg.Done() + end := getMin(nq, begin+nqPerBatch) + for idx := begin; idx < end; idx++ { + partialHit := &milvuspb.Hits{} + err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit) + if err != nil { + log.Debug("proxynode", zap.Any("unmarshal search result error", err)) + } + partialHits[idx] = partialHit + } + }(j) + } + + if err != nil { + return nil, err + } + + wg.Wait() + + hits = append(hits, partialHits) + } + + return hits, nil +} + +func decodeSearchResults(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) { + return decodeSearchResultsParallel(searchResults) +} + +func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults { + ret := &milvuspb.SearchResults{ + Status: &commonpb.Status{ + ErrorCode: 0, + }, + Hits: make([][]byte, nq), + } + + const minFloat32 = -1 * float32(math.MaxFloat32) + + for i := 0; i < nq; i++ { + j := i + func(idx int) { + locs := make([]int, availableQueryNodeNum) + reducedHits := &milvuspb.Hits{ + IDs: make([]int64, 0), + RowData: make([][]byte, 0), + Scores: make([]float32, 0), + } + + for j := 0; j < topk; j++ { + valid := false + choice, maxDistance := 0, minFloat32 + for q, loc := range locs { // query num, the number of ways to merge + if loc >= len(hits[q][idx].IDs) { + continue + } + distance := hits[q][idx].Scores[loc] + if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) { + choice = q + maxDistance = distance + valid = true + } + } + if !valid { + break + } + choiceOffset := locs[choice] + // check if distance is valid, `invalid` here means very very big, + // in this process, distance here is the smallest, so the rest of distance are all invalid + if hits[choice][idx].Scores[choiceOffset] <= minFloat32 { + break + } + reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset]) + if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 { + reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset]) + } + reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset]) + locs[choice]++ + } + + if metricType != "IP" { + for k := range reducedHits.Scores { + reducedHits.Scores[k] *= -1 + } + } + + reducedHitsBs, err := proto.Marshal(reducedHits) + if err != nil { + log.Debug("proxynode", zap.String("error", "marshal error")) + } + + ret.Hits[idx] = reducedHitsBs + }(j) + } + + return ret +} + +// TODO: add benchmark to compare with simple serial implementation +func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults { + ret := &milvuspb.SearchResults{ + Status: &commonpb.Status{ + ErrorCode: 0, + }, + Hits: make([][]byte, nq), + } + + const minFloat32 = -1 * float32(math.MaxFloat32) + + var wg sync.WaitGroup + for i := 0; i < nq; i++ { + j := i + wg.Add(1) + go func(idx int) { + defer wg.Done() + + locs := make([]int, availableQueryNodeNum) + reducedHits := &milvuspb.Hits{ + IDs: make([]int64, 0), + RowData: make([][]byte, 0), + Scores: make([]float32, 0), + } + + for j := 0; j < topk; j++ { + valid := false + choice, maxDistance := 0, minFloat32 + for q, loc := range locs { // query num, the number of ways to merge + if loc >= len(hits[q][idx].IDs) { + continue + } + distance := hits[q][idx].Scores[loc] + if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) { + choice = q + maxDistance = distance + valid = true + } + } + if !valid { + break + } + choiceOffset := locs[choice] + // check if distance is valid, `invalid` here means very very big, + // in this process, distance here is the smallest, so the rest of distance are all invalid + if hits[choice][idx].Scores[choiceOffset] <= minFloat32 { + break + } + reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset]) + if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 { + reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset]) + } + reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset]) + locs[choice]++ + } + + if metricType != "IP" { + for k := range reducedHits.Scores { + reducedHits.Scores[k] *= -1 + } + } + + reducedHitsBs, err := proto.Marshal(reducedHits) + if err != nil { + log.Debug("proxynode", zap.String("error", "marshal error")) + } + + ret.Hits[idx] = reducedHitsBs + }(j) + + } + + wg.Wait() + + return ret +} + +// TODO: add benchmark to compare with simple serial implementation +func reduceSearchResultsParallelByCPU(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults { + maxParallel := runtime.NumCPU() + nqPerBatch := (nq + maxParallel - 1) / maxParallel + + ret := &milvuspb.SearchResults{ + Status: &commonpb.Status{ + ErrorCode: 0, + }, + Hits: make([][]byte, nq), + } + + const minFloat32 = -1 * float32(math.MaxFloat32) + + var wg sync.WaitGroup + for begin := 0; begin < nq; begin = begin + nqPerBatch { + j := begin + + wg.Add(1) + go func(begin int) { + defer wg.Done() + + end := getMin(nq, begin+nqPerBatch) + + for idx := begin; idx < end; idx++ { + locs := make([]int, availableQueryNodeNum) + reducedHits := &milvuspb.Hits{ + IDs: make([]int64, 0), + RowData: make([][]byte, 0), + Scores: make([]float32, 0), + } + + for j := 0; j < topk; j++ { + valid := false + choice, maxDistance := 0, minFloat32 + for q, loc := range locs { // query num, the number of ways to merge + if loc >= len(hits[q][idx].IDs) { + continue + } + distance := hits[q][idx].Scores[loc] + if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) { + choice = q + maxDistance = distance + valid = true + } + } + if !valid { + break + } + choiceOffset := locs[choice] + // check if distance is valid, `invalid` here means very very big, + // in this process, distance here is the smallest, so the rest of distance are all invalid + if hits[choice][idx].Scores[choiceOffset] <= minFloat32 { + break + } + reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset]) + if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 { + reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset]) + } + reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset]) + locs[choice]++ + } + + if metricType != "IP" { + for k := range reducedHits.Scores { + reducedHits.Scores[k] *= -1 + } + } + + reducedHitsBs, err := proto.Marshal(reducedHits) + if err != nil { + log.Debug("proxynode", zap.String("error", "marshal error")) + } + + ret.Hits[idx] = reducedHitsBs + } + }(j) + + } + + wg.Wait() + + return ret +} + +func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults { + return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType) +} + +func printSearchResult(partialSearchResult *internalpb.SearchResults) { + for i := 0; i < len(partialSearchResult.Hits); i++ { + testHits := milvuspb.Hits{} + err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits) + if err != nil { + panic(err) + } + fmt.Println(testHits.IDs) + fmt.Println(testHits.Scores) + } +} + func (st *SearchTask) PostExecute(ctx context.Context) error { for { select { @@ -627,15 +1001,7 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success { filterSearchResult = append(filterSearchResult, partialSearchResult) // For debugging, please don't delete. - //for i := 0; i < len(partialSearchResult.Hits); i++ { - // testHits := milvuspb.Hits{} - // err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits) - // if err != nil { - // panic(err) - // } - // fmt.Println(testHits.IDs) - // fmt.Println(testHits.Scores) - //} + // printSearchResult(partialSearchResult) } else { filterReason += partialSearchResult.Status.Reason + "\n" } @@ -652,26 +1018,15 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { return errors.New(filterReason) } - hits := make([][]*milvuspb.Hits, 0) + availableQueryNodeNum = 0 for _, partialSearchResult := range filterSearchResult { if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 { filterReason += "nq is zero\n" continue } - partialHits := make([]*milvuspb.Hits, 0) - for _, bs := range partialSearchResult.Hits { - partialHit := &milvuspb.Hits{} - err := proto.Unmarshal(bs, partialHit) - if err != nil { - log.Debug("proxynode", zap.String("error", "unmarshal error")) - return err - } - partialHits = append(partialHits, partialHit) - } - hits = append(hits, partialHits) + availableQueryNodeNum++ } - availableQueryNodeNum = len(hits) if availableQueryNodeNum <= 0 { st.result = &milvuspb.SearchResults{ Status: &commonpb.Status{ @@ -682,6 +1037,11 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { return nil } + hits, err := decodeSearchResults(filterSearchResult) + if err != nil { + return err + } + nq := len(hits[0]) if nq <= 0 { st.result = &milvuspb.SearchResults{ @@ -694,73 +1054,12 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { } topk := 0 - getMax := func(a, b int) int { - if a > b { - return a - } - return b - } for _, hit := range hits { topk = getMax(topk, len(hit[0].IDs)) } - st.result = &milvuspb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: 0, - }, - Hits: make([][]byte, 0), - } - const minFloat32 = -1 * float32(math.MaxFloat32) - for i := 0; i < nq; i++ { - locs := make([]int, availableQueryNodeNum) - reducedHits := &milvuspb.Hits{ - IDs: make([]int64, 0), - RowData: make([][]byte, 0), - Scores: make([]float32, 0), - } + st.result = reduceSearchResults(hits, nq, availableQueryNodeNum, topk, searchResults[0].MetricType) - for j := 0; j < topk; j++ { - valid := false - choice, maxDistance := 0, minFloat32 - for q, loc := range locs { // query num, the number of ways to merge - if loc >= len(hits[q][i].IDs) { - continue - } - distance := hits[q][i].Scores[loc] - if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) { - choice = q - maxDistance = distance - valid = true - } - } - if !valid { - break - } - choiceOffset := locs[choice] - // check if distance is valid, `invalid` here means very very big, - // in this process, distance here is the smallest, so the rest of distance are all invalid - if hits[choice][i].Scores[choiceOffset] <= minFloat32 { - break - } - reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset]) - if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 { - reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset]) - } - reducedHits.Scores = append(reducedHits.Scores, hits[choice][i].Scores[choiceOffset]) - locs[choice]++ - } - if searchResults[0].MetricType != "IP" { - for k := range reducedHits.Scores { - reducedHits.Scores[k] *= -1 - } - } - reducedHitsBs, err := proto.Marshal(reducedHits) - if err != nil { - log.Debug("proxynode", zap.String("error", "marshal error")) - return err - } - st.result.Hits = append(st.result.Hits, reducedHitsBs) - } return nil } } diff --git a/internal/proxynode/util.go b/internal/proxynode/util.go index 5b9ab718b3..91b31dca14 100644 --- a/internal/proxynode/util.go +++ b/internal/proxynode/util.go @@ -97,3 +97,17 @@ func SortedSliceEqual(s1 interface{}, s2 interface{}) bool { } return true } + +func getMax(a, b int) int { + if a > b { + return a + } + return b +} + +func getMin(a, b int) int { + if a < b { + return a + } + return b +}