mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 03:18:29 +08:00
Optimize GetResultData in reduce_c.cpp (#10797)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
9a0d1d1e74
commit
4186e785ab
@ -58,69 +58,72 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
|
|||||||
void
|
void
|
||||||
GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
||||||
std::vector<SearchResult*>& search_results,
|
std::vector<SearchResult*>& search_results,
|
||||||
int64_t query_idx,
|
int64_t nq,
|
||||||
int64_t topk) {
|
int64_t topk) {
|
||||||
|
AssertInfo(topk > 0, "topk must greater than 0");
|
||||||
auto num_segments = search_results.size();
|
auto num_segments = search_results.size();
|
||||||
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
||||||
std::vector<SearchResultPair> result_pairs;
|
|
||||||
int64_t query_offset = query_idx * topk;
|
int64_t skip_dup_cnt = 0;
|
||||||
for (int j = 0; j < num_segments; ++j) {
|
for (int64_t qi = 0; qi < nq; qi++) {
|
||||||
auto search_result = search_results[j];
|
std::vector<SearchResultPair> result_pairs;
|
||||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
int64_t base_offset = qi * topk;
|
||||||
auto distance = search_result->result_distances_[query_offset];
|
for (int j = 0; j < num_segments; ++j) {
|
||||||
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j));
|
auto search_result = search_results[j];
|
||||||
}
|
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||||
int64_t loc_offset = query_offset;
|
auto distance = search_result->result_distances_[base_offset];
|
||||||
AssertInfo(topk > 0, "topk must greater than 0");
|
result_pairs.push_back(SearchResultPair(distance, search_result, base_offset, j));
|
||||||
|
}
|
||||||
|
int64_t curr_offset = base_offset;
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
for (int i = 0; i < topk; ++i) {
|
for (int i = 0; i < topk; ++i) {
|
||||||
result_pairs[0].reset_distance();
|
result_pairs[0].reset_distance();
|
||||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||||
auto& result_pair = result_pairs[0];
|
auto& result_pair = result_pairs[0];
|
||||||
auto index = result_pair.index_;
|
auto index = result_pair.index_;
|
||||||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
|
||||||
search_records[index].push_back(result_pair.offset_++);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
int64_t skip_dup_cnt = 0;
|
|
||||||
float prev_dis = MAXFLOAT;
|
|
||||||
std::unordered_set<int64_t> prev_pk_set;
|
|
||||||
while (loc_offset - query_offset < topk) {
|
|
||||||
result_pairs[0].reset_distance();
|
|
||||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
|
||||||
auto& result_pair = result_pairs[0];
|
|
||||||
auto index = result_pair.index_;
|
|
||||||
int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_];
|
|
||||||
float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_];
|
|
||||||
// remove duplicates
|
|
||||||
if (curr_pk == INVALID_ID || std::abs(curr_dis - prev_dis) > 0.00001) {
|
|
||||||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
||||||
search_records[index].push_back(result_pair.offset_);
|
search_records[index].push_back(result_pair.offset_++);
|
||||||
prev_dis = curr_dis;
|
}
|
||||||
prev_pk_set.clear();
|
#else
|
||||||
prev_pk_set.insert(curr_pk);
|
float prev_dis = MAXFLOAT;
|
||||||
} else {
|
std::unordered_set<int64_t> prev_pk_set;
|
||||||
// To handle this case:
|
while (curr_offset - base_offset < topk) {
|
||||||
// e1: [100, 0.99]
|
result_pairs[0].reset_distance();
|
||||||
// e2: [101, 0.99] ==> not duplicated, should keep
|
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||||
// e3: [100, 0.99] ==> duplicated, should remove
|
auto& result_pair = result_pairs[0];
|
||||||
if (prev_pk_set.count(curr_pk) == 0) {
|
auto index = result_pair.index_;
|
||||||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_];
|
||||||
|
float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_];
|
||||||
|
// remove duplicates
|
||||||
|
if (curr_pk == INVALID_ID || std::abs(curr_dis - prev_dis) > 0.00001) {
|
||||||
|
result_pair.search_result_->result_offsets_.push_back(curr_offset++);
|
||||||
search_records[index].push_back(result_pair.offset_);
|
search_records[index].push_back(result_pair.offset_);
|
||||||
// prev_pk_set keeps all primary keys with same distance
|
prev_dis = curr_dis;
|
||||||
|
prev_pk_set.clear();
|
||||||
prev_pk_set.insert(curr_pk);
|
prev_pk_set.insert(curr_pk);
|
||||||
} else {
|
} else {
|
||||||
// the entity with same distance and same primary key must be duplicated
|
// To handle this case:
|
||||||
skip_dup_cnt++;
|
// e1: [100, 0.99]
|
||||||
|
// e2: [101, 0.99] ==> not duplicated, should keep
|
||||||
|
// e3: [100, 0.99] ==> duplicated, should remove
|
||||||
|
if (prev_pk_set.count(curr_pk) == 0) {
|
||||||
|
result_pair.search_result_->result_offsets_.push_back(curr_offset++);
|
||||||
|
search_records[index].push_back(result_pair.offset_);
|
||||||
|
// prev_pk_set keeps all primary keys with same distance
|
||||||
|
prev_pk_set.insert(curr_pk);
|
||||||
|
} else {
|
||||||
|
// the entity with same distance and same primary key must be duplicated
|
||||||
|
skip_dup_cnt++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
result_pair.offset_++;
|
||||||
}
|
}
|
||||||
result_pair.offset_++;
|
#endif
|
||||||
}
|
}
|
||||||
if (skip_dup_cnt > 0) {
|
if (skip_dup_cnt > 0) {
|
||||||
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
|
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -172,9 +175,7 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
|
|||||||
segment->FillPrimaryKeys(plan, *search_result);
|
segment->FillPrimaryKeys(plan, *search_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < num_queries; ++i) {
|
GetResultData(search_records, search_results, num_queries, topk);
|
||||||
GetResultData(search_records, search_results, i, topk);
|
|
||||||
}
|
|
||||||
ResetSearchResult(search_records, search_results);
|
ResetSearchResult(search_records, search_results);
|
||||||
|
|
||||||
// fill in other entities
|
// fill in other entities
|
||||||
|
Loading…
Reference in New Issue
Block a user