From 4186e785ab4ea40f09931535e5ac9d64092e3d6c Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Thu, 28 Oct 2021 14:24:22 +0800 Subject: [PATCH] Optimize GetResultData in reduce_c.cpp (#10797) Signed-off-by: yudong.cai --- internal/core/src/segcore/reduce_c.cpp | 103 +++++++++++++------------ 1 file changed, 52 insertions(+), 51 deletions(-) diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index 614963e783..c84bf29697 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -58,69 +58,72 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) { void GetResultData(std::vector>& search_records, std::vector& search_results, - int64_t query_idx, + int64_t nq, int64_t topk) { + AssertInfo(topk > 0, "topk must greater than 0"); auto num_segments = search_results.size(); AssertInfo(num_segments > 0, "num segment must greater than 0"); - std::vector result_pairs; - int64_t query_offset = query_idx * topk; - for (int j = 0; j < num_segments; ++j) { - auto search_result = search_results[j]; - AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); - auto distance = search_result->result_distances_[query_offset]; - result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j)); - } - int64_t loc_offset = query_offset; - AssertInfo(topk > 0, "topk must greater than 0"); + + int64_t skip_dup_cnt = 0; + for (int64_t qi = 0; qi < nq; qi++) { + std::vector result_pairs; + int64_t base_offset = qi * topk; + for (int j = 0; j < num_segments; ++j) { + auto search_result = search_results[j]; + AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); + auto distance = search_result->result_distances_[base_offset]; + result_pairs.push_back(SearchResultPair(distance, search_result, base_offset, j)); + } + int64_t curr_offset = base_offset; #if 0 - for (int i = 0; i < topk; ++i) { - result_pairs[0].reset_distance(); - std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); - auto& result_pair = result_pairs[0]; - auto index = result_pair.index_; - result_pair.search_result_->result_offsets_.push_back(loc_offset++); - search_records[index].push_back(result_pair.offset_++); - } -#else - int64_t skip_dup_cnt = 0; - float prev_dis = MAXFLOAT; - std::unordered_set prev_pk_set; - while (loc_offset - query_offset < topk) { - result_pairs[0].reset_distance(); - std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); - auto& result_pair = result_pairs[0]; - auto index = result_pair.index_; - int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_]; - float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_]; - // remove duplicates - if (curr_pk == INVALID_ID || std::abs(curr_dis - prev_dis) > 0.00001) { + for (int i = 0; i < topk; ++i) { + result_pairs[0].reset_distance(); + std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); + auto& result_pair = result_pairs[0]; + auto index = result_pair.index_; result_pair.search_result_->result_offsets_.push_back(loc_offset++); - search_records[index].push_back(result_pair.offset_); - prev_dis = curr_dis; - prev_pk_set.clear(); - prev_pk_set.insert(curr_pk); - } else { - // To handle this case: - // e1: [100, 0.99] - // e2: [101, 0.99] ==> not duplicated, should keep - // e3: [100, 0.99] ==> duplicated, should remove - if (prev_pk_set.count(curr_pk) == 0) { - result_pair.search_result_->result_offsets_.push_back(loc_offset++); + search_records[index].push_back(result_pair.offset_++); + } +#else + float prev_dis = MAXFLOAT; + std::unordered_set prev_pk_set; + while (curr_offset - base_offset < topk) { + result_pairs[0].reset_distance(); + std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); + auto& result_pair = result_pairs[0]; + auto index = result_pair.index_; + int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_]; + float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_]; + // remove duplicates + if (curr_pk == INVALID_ID || std::abs(curr_dis - prev_dis) > 0.00001) { + result_pair.search_result_->result_offsets_.push_back(curr_offset++); search_records[index].push_back(result_pair.offset_); - // prev_pk_set keeps all primary keys with same distance + prev_dis = curr_dis; + prev_pk_set.clear(); prev_pk_set.insert(curr_pk); } else { - // the entity with same distance and same primary key must be duplicated - skip_dup_cnt++; + // To handle this case: + // e1: [100, 0.99] + // e2: [101, 0.99] ==> not duplicated, should keep + // e3: [100, 0.99] ==> duplicated, should remove + if (prev_pk_set.count(curr_pk) == 0) { + result_pair.search_result_->result_offsets_.push_back(curr_offset++); + search_records[index].push_back(result_pair.offset_); + // prev_pk_set keeps all primary keys with same distance + prev_pk_set.insert(curr_pk); + } else { + // the entity with same distance and same primary key must be duplicated + skip_dup_cnt++; + } } + result_pair.offset_++; } - result_pair.offset_++; +#endif } if (skip_dup_cnt > 0) { LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt; } -#endif } void @@ -172,9 +175,7 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul segment->FillPrimaryKeys(plan, *search_result); } - for (int i = 0; i < num_queries; ++i) { - GetResultData(search_records, search_results, i, topk); - } + GetResultData(search_records, search_results, num_queries, topk); ResetSearchResult(search_records, search_results); // fill in other entities