Optimize GetResultData in reduce_c.cpp (#10797)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
Cai Yudong 2021-10-28 14:24:22 +08:00 committed by GitHub
parent 9a0d1d1e74
commit 4186e785ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,69 +58,72 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
void void
GetResultData(std::vector<std::vector<int64_t>>& search_records, GetResultData(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results, std::vector<SearchResult*>& search_results,
int64_t query_idx, int64_t nq,
int64_t topk) { int64_t topk) {
AssertInfo(topk > 0, "topk must greater than 0");
auto num_segments = search_results.size(); auto num_segments = search_results.size();
AssertInfo(num_segments > 0, "num segment must greater than 0"); AssertInfo(num_segments > 0, "num segment must greater than 0");
std::vector<SearchResultPair> result_pairs;
int64_t query_offset = query_idx * topk; int64_t skip_dup_cnt = 0;
for (int j = 0; j < num_segments; ++j) { for (int64_t qi = 0; qi < nq; qi++) {
auto search_result = search_results[j]; std::vector<SearchResultPair> result_pairs;
AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); int64_t base_offset = qi * topk;
auto distance = search_result->result_distances_[query_offset]; for (int j = 0; j < num_segments; ++j) {
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j)); auto search_result = search_results[j];
} AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
int64_t loc_offset = query_offset; auto distance = search_result->result_distances_[base_offset];
AssertInfo(topk > 0, "topk must greater than 0"); result_pairs.push_back(SearchResultPair(distance, search_result, base_offset, j));
}
int64_t curr_offset = base_offset;
#if 0 #if 0
for (int i = 0; i < topk; ++i) { for (int i = 0; i < topk; ++i) {
result_pairs[0].reset_distance(); result_pairs[0].reset_distance();
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
auto& result_pair = result_pairs[0]; auto& result_pair = result_pairs[0];
auto index = result_pair.index_; auto index = result_pair.index_;
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
search_records[index].push_back(result_pair.offset_++);
}
#else
int64_t skip_dup_cnt = 0;
float prev_dis = MAXFLOAT;
std::unordered_set<int64_t> prev_pk_set;
while (loc_offset - query_offset < topk) {
result_pairs[0].reset_distance();
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
auto& result_pair = result_pairs[0];
auto index = result_pair.index_;
int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_];
float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_];
// remove duplicates
if (curr_pk == INVALID_ID || std::abs(curr_dis - prev_dis) > 0.00001) {
result_pair.search_result_->result_offsets_.push_back(loc_offset++); result_pair.search_result_->result_offsets_.push_back(loc_offset++);
search_records[index].push_back(result_pair.offset_); search_records[index].push_back(result_pair.offset_++);
prev_dis = curr_dis; }
prev_pk_set.clear(); #else
prev_pk_set.insert(curr_pk); float prev_dis = MAXFLOAT;
} else { std::unordered_set<int64_t> prev_pk_set;
// To handle this case: while (curr_offset - base_offset < topk) {
// e1: [100, 0.99] result_pairs[0].reset_distance();
// e2: [101, 0.99] ==> not duplicated, should keep std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
// e3: [100, 0.99] ==> duplicated, should remove auto& result_pair = result_pairs[0];
if (prev_pk_set.count(curr_pk) == 0) { auto index = result_pair.index_;
result_pair.search_result_->result_offsets_.push_back(loc_offset++); int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_];
float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_];
// remove duplicates
if (curr_pk == INVALID_ID || std::abs(curr_dis - prev_dis) > 0.00001) {
result_pair.search_result_->result_offsets_.push_back(curr_offset++);
search_records[index].push_back(result_pair.offset_); search_records[index].push_back(result_pair.offset_);
// prev_pk_set keeps all primary keys with same distance prev_dis = curr_dis;
prev_pk_set.clear();
prev_pk_set.insert(curr_pk); prev_pk_set.insert(curr_pk);
} else { } else {
// the entity with same distance and same primary key must be duplicated // To handle this case:
skip_dup_cnt++; // e1: [100, 0.99]
// e2: [101, 0.99] ==> not duplicated, should keep
// e3: [100, 0.99] ==> duplicated, should remove
if (prev_pk_set.count(curr_pk) == 0) {
result_pair.search_result_->result_offsets_.push_back(curr_offset++);
search_records[index].push_back(result_pair.offset_);
// prev_pk_set keeps all primary keys with same distance
prev_pk_set.insert(curr_pk);
} else {
// the entity with same distance and same primary key must be duplicated
skip_dup_cnt++;
}
} }
result_pair.offset_++;
} }
result_pair.offset_++; #endif
} }
if (skip_dup_cnt > 0) { if (skip_dup_cnt > 0) {
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt; LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
} }
#endif
} }
void void
@ -172,9 +175,7 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
segment->FillPrimaryKeys(plan, *search_result); segment->FillPrimaryKeys(plan, *search_result);
} }
for (int i = 0; i < num_queries; ++i) { GetResultData(search_records, search_results, num_queries, topk);
GetResultData(search_records, search_results, i, topk);
}
ResetSearchResult(search_records, search_results); ResetSearchResult(search_records, search_results);
// fill in other entities // fill in other entities