mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 19:39:21 +08:00
Parallelize the processing of SearchTask.PostExecute
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
2842e929ba
commit
1b743e5c6c
@ -6,7 +6,9 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
@ -613,6 +615,378 @@ func (st *SearchTask) Execute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
hits := make([][]*milvuspb.Hits, 0)
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
partialHits := make([]*milvuspb.Hits, len(partialSearchResult.Hits))
|
||||
|
||||
var err error
|
||||
for i := range partialSearchResult.Hits {
|
||||
j := i
|
||||
|
||||
func(idx int) {
|
||||
partialHit := &milvuspb.Hits{}
|
||||
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
|
||||
}
|
||||
partialHits[idx] = partialHit
|
||||
}(j)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hits = append(hits, partialHits)
|
||||
}
|
||||
|
||||
return hits, nil
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with serial implementation
|
||||
func decodeSearchResultsParallel(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
hits := make([][]*milvuspb.Hits, 0)
|
||||
// necessary to parallel this?
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// necessary to check nq (len(partialSearchResult.Hits) here)?
|
||||
partialHits := make([]*milvuspb.Hits, len(partialSearchResult.Hits))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var err error
|
||||
for i := range partialSearchResult.Hits {
|
||||
j := i
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
partialHit := &milvuspb.Hits{}
|
||||
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
|
||||
}
|
||||
partialHits[idx] = partialHit
|
||||
}(j)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
hits = append(hits, partialHits)
|
||||
}
|
||||
|
||||
return hits, nil
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with serial implementation
|
||||
func decodeSearchResultsParallelByCPU(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
log.Debug("ProcessSearchResultParallel", zap.Any("runtime.NumCPU", runtime.NumCPU()))
|
||||
hits := make([][]*milvuspb.Hits, 0)
|
||||
// necessary to parallel this?
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
nq := len(partialSearchResult.Hits)
|
||||
maxParallel := runtime.NumCPU()
|
||||
nqPerBatch := (nq + maxParallel - 1) / maxParallel
|
||||
partialHits := make([]*milvuspb.Hits, nq)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var err error
|
||||
for i := 0; i < nq; i = i + nqPerBatch {
|
||||
j := i
|
||||
wg.Add(1)
|
||||
|
||||
go func(begin int) {
|
||||
defer wg.Done()
|
||||
end := getMin(nq, begin+nqPerBatch)
|
||||
for idx := begin; idx < end; idx++ {
|
||||
partialHit := &milvuspb.Hits{}
|
||||
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
|
||||
}
|
||||
partialHits[idx] = partialHit
|
||||
}
|
||||
}(j)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
hits = append(hits, partialHits)
|
||||
}
|
||||
|
||||
return hits, nil
|
||||
}
|
||||
|
||||
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
return decodeSearchResultsParallel(searchResults)
|
||||
}
|
||||
|
||||
func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Hits: make([][]byte, nq),
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
for i := 0; i < nq; i++ {
|
||||
j := i
|
||||
func(idx int) {
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
reducedHits := &milvuspb.Hits{
|
||||
IDs: make([]int64, 0),
|
||||
RowData: make([][]byte, 0),
|
||||
Scores: make([]float32, 0),
|
||||
}
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][idx].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][idx].Scores[loc]
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
|
||||
break
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
|
||||
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
|
||||
}
|
||||
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range reducedHits.Scores {
|
||||
reducedHits.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
|
||||
reducedHitsBs, err := proto.Marshal(reducedHits)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.String("error", "marshal error"))
|
||||
}
|
||||
|
||||
ret.Hits[idx] = reducedHitsBs
|
||||
}(j)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with simple serial implementation
|
||||
func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Hits: make([][]byte, nq),
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nq; i++ {
|
||||
j := i
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
reducedHits := &milvuspb.Hits{
|
||||
IDs: make([]int64, 0),
|
||||
RowData: make([][]byte, 0),
|
||||
Scores: make([]float32, 0),
|
||||
}
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][idx].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][idx].Scores[loc]
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
|
||||
break
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
|
||||
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
|
||||
}
|
||||
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range reducedHits.Scores {
|
||||
reducedHits.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
|
||||
reducedHitsBs, err := proto.Marshal(reducedHits)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.String("error", "marshal error"))
|
||||
}
|
||||
|
||||
ret.Hits[idx] = reducedHitsBs
|
||||
}(j)
|
||||
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with simple serial implementation
|
||||
func reduceSearchResultsParallelByCPU(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
maxParallel := runtime.NumCPU()
|
||||
nqPerBatch := (nq + maxParallel - 1) / maxParallel
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Hits: make([][]byte, nq),
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for begin := 0; begin < nq; begin = begin + nqPerBatch {
|
||||
j := begin
|
||||
|
||||
wg.Add(1)
|
||||
go func(begin int) {
|
||||
defer wg.Done()
|
||||
|
||||
end := getMin(nq, begin+nqPerBatch)
|
||||
|
||||
for idx := begin; idx < end; idx++ {
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
reducedHits := &milvuspb.Hits{
|
||||
IDs: make([]int64, 0),
|
||||
RowData: make([][]byte, 0),
|
||||
Scores: make([]float32, 0),
|
||||
}
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][idx].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][idx].Scores[loc]
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
|
||||
break
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
|
||||
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
|
||||
}
|
||||
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range reducedHits.Scores {
|
||||
reducedHits.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
|
||||
reducedHitsBs, err := proto.Marshal(reducedHits)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.String("error", "marshal error"))
|
||||
}
|
||||
|
||||
ret.Hits[idx] = reducedHitsBs
|
||||
}
|
||||
}(j)
|
||||
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType)
|
||||
}
|
||||
|
||||
func printSearchResult(partialSearchResult *internalpb.SearchResults) {
|
||||
for i := 0; i < len(partialSearchResult.Hits); i++ {
|
||||
testHits := milvuspb.Hits{}
|
||||
err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(testHits.IDs)
|
||||
fmt.Println(testHits.Scores)
|
||||
}
|
||||
}
|
||||
|
||||
func (st *SearchTask) PostExecute(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
@ -627,15 +1001,7 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
|
||||
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success {
|
||||
filterSearchResult = append(filterSearchResult, partialSearchResult)
|
||||
// For debugging, please don't delete.
|
||||
//for i := 0; i < len(partialSearchResult.Hits); i++ {
|
||||
// testHits := milvuspb.Hits{}
|
||||
// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// fmt.Println(testHits.IDs)
|
||||
// fmt.Println(testHits.Scores)
|
||||
//}
|
||||
// printSearchResult(partialSearchResult)
|
||||
} else {
|
||||
filterReason += partialSearchResult.Status.Reason + "\n"
|
||||
}
|
||||
@ -652,26 +1018,15 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
|
||||
return errors.New(filterReason)
|
||||
}
|
||||
|
||||
hits := make([][]*milvuspb.Hits, 0)
|
||||
availableQueryNodeNum = 0
|
||||
for _, partialSearchResult := range filterSearchResult {
|
||||
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
||||
filterReason += "nq is zero\n"
|
||||
continue
|
||||
}
|
||||
partialHits := make([]*milvuspb.Hits, 0)
|
||||
for _, bs := range partialSearchResult.Hits {
|
||||
partialHit := &milvuspb.Hits{}
|
||||
err := proto.Unmarshal(bs, partialHit)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.String("error", "unmarshal error"))
|
||||
return err
|
||||
}
|
||||
partialHits = append(partialHits, partialHit)
|
||||
}
|
||||
hits = append(hits, partialHits)
|
||||
availableQueryNodeNum++
|
||||
}
|
||||
|
||||
availableQueryNodeNum = len(hits)
|
||||
if availableQueryNodeNum <= 0 {
|
||||
st.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
@ -682,6 +1037,11 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
hits, err := decodeSearchResults(filterSearchResult)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nq := len(hits[0])
|
||||
if nq <= 0 {
|
||||
st.result = &milvuspb.SearchResults{
|
||||
@ -694,73 +1054,12 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
|
||||
}
|
||||
|
||||
topk := 0
|
||||
getMax := func(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
for _, hit := range hits {
|
||||
topk = getMax(topk, len(hit[0].IDs))
|
||||
}
|
||||
st.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Hits: make([][]byte, 0),
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
for i := 0; i < nq; i++ {
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
reducedHits := &milvuspb.Hits{
|
||||
IDs: make([]int64, 0),
|
||||
RowData: make([][]byte, 0),
|
||||
Scores: make([]float32, 0),
|
||||
}
|
||||
st.result = reduceSearchResults(hits, nq, availableQueryNodeNum, topk, searchResults[0].MetricType)
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][i].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][i].Scores[loc]
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][i].Scores[choiceOffset] <= minFloat32 {
|
||||
break
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
|
||||
if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])
|
||||
}
|
||||
reducedHits.Scores = append(reducedHits.Scores, hits[choice][i].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
if searchResults[0].MetricType != "IP" {
|
||||
for k := range reducedHits.Scores {
|
||||
reducedHits.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
reducedHitsBs, err := proto.Marshal(reducedHits)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.String("error", "marshal error"))
|
||||
return err
|
||||
}
|
||||
st.result.Hits = append(st.result.Hits, reducedHitsBs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -97,3 +97,17 @@ func SortedSliceEqual(s1 interface{}, s2 interface{}) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func getMax(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func getMin(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user