Fix reduce panic (#11325)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
Cai Yudong 2021-11-05 18:17:00 +08:00 committed by GitHub
parent 6024346195
commit db2a0a3bd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 89 additions and 125 deletions

View File

@ -13,5 +13,6 @@
#include <stdint.h>
const int64_t INVALID_SEG_OFFSET = -1;
const int64_t INVALID_ID = -1;
const int64_t INVALID_OFFSET = -1;
const int64_t INVALID_SEG_OFFSET = -1;

View File

@ -9,32 +9,58 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <cmath> // std::isnan
#include <common/Types.h>
#include <cmath>
#include "common/Consts.h"
#include "common/Types.h"
#include "segcore/Reduce.h"
using milvus::SearchResult;
struct SearchResultPair {
int64_t primary_key_;
float distance_;
milvus::SearchResult* search_result_;
int64_t offset_;
int64_t index_;
int64_t offset_;
int64_t offset_rb_; // right bound
SearchResultPair(float distance, milvus::SearchResult* search_result, int64_t offset, int64_t index)
: distance_(distance), search_result_(search_result), offset_(offset), index_(index) {
SearchResultPair(int64_t primary_key, float distance, SearchResult* result, int64_t index, int64_t lb, int64_t rb)
: primary_key_(primary_key),
distance_(distance),
search_result_(result),
index_(index),
offset_(lb),
offset_rb_(rb) {
}
bool
operator<(const SearchResultPair& pair) const {
return std::isnan(pair.distance_) || (!std::isnan(distance_) && (distance_ < pair.distance_));
}
bool
operator>(const SearchResultPair& pair) const {
return std::isnan(pair.distance_) || (!std::isnan(distance_) && (distance_ > pair.distance_));
operator>(const SearchResultPair& other) const {
if (this->primary_key_ == INVALID_ID) {
return false;
} else {
if (other.primary_key_ == INVALID_ID) {
return true;
} else {
return (distance_ > other.distance_);
}
}
}
void
reset_distance() {
distance_ = search_result_->result_distances_[offset_];
reset() {
if (offset_ < offset_rb_) {
offset_++;
if (offset_ < offset_rb_) {
primary_key_ = search_result_->primary_keys_.at(offset_);
distance_ = search_result_->result_distances_.at(offset_);
} else {
primary_key_ = INVALID_ID;
distance_ = MAXFLOAT;
}
} else {
primary_key_ = INVALID_ID;
distance_ = MAXFLOAT;
}
}
};

View File

@ -22,11 +22,9 @@ SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan, SearchResult&
AssertInfo(results.internal_seg_offsets_.size() == size,
"Size of result distances is not equal to size of segment offsets");
Assert(results.primary_keys_.size() == 0);
results.primary_keys_.resize(size);
auto element_sizeof = sizeof(int64_t);
aligned_vector<char> blob(size * element_sizeof);
if (plan->schema_.get_is_auto_id()) {
bulk_subscript(SystemFieldType::RowId, results.internal_seg_offsets_.data(), size, blob.data());
@ -38,9 +36,7 @@ SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan, SearchResult&
bulk_subscript(key_offset, results.internal_seg_offsets_.data(), size, blob.data());
}
for (int64_t i = 0; i < size; ++i) {
results.primary_keys_[i] = *(int64_t*)(blob.data() + element_sizeof * i);
}
memcpy(results.primary_keys_.data(), blob.data(), element_sizeof * size);
}
void

View File

@ -59,8 +59,6 @@ class SegmentInterface {
Delete(int64_t reserved_offset, int64_t size, const int64_t* row_ids, const Timestamp* timestamps) = 0;
virtual ~SegmentInterface() = default;
protected:
};
// internal API for DSL calculation

View File

@ -55,6 +55,17 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
delete hits;
}
// void
// PrintSearchResult(char* buf, const milvus::SearchResult* result, int64_t seg_idx, int64_t from, int64_t to) {
// const int64_t MAXLEN = 32;
// snprintf(buf + strlen(buf), MAXLEN, "{ seg No.%ld ", seg_idx);
// for (int64_t i = from; i < to; i++) {
// snprintf(buf + strlen(buf), MAXLEN, "(%ld, %ld, %f), ", i, result->primary_keys_[i],
// result->result_distances_[i]);
// }
// snprintf(buf + strlen(buf), MAXLEN, "} ");
//}
void
GetResultData(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results,
@ -72,8 +83,12 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
for (int j = 0; j < num_segments; ++j) {
auto search_result = search_results[j];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
AssertInfo(search_result->primary_keys_.size() == nq * topk, "incorrect search result primary key size");
AssertInfo(search_result->result_distances_.size() == nq * topk, "incorrect search result distance size");
auto primary_key = search_result->primary_keys_[base_offset];
auto distance = search_result->result_distances_[base_offset];
result_pairs.push_back(SearchResultPair(distance, search_result, base_offset, j));
result_pairs.push_back(
SearchResultPair(primary_key, distance, search_result, j, base_offset, base_offset + topk));
}
int64_t curr_offset = base_offset;
@ -89,23 +104,24 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
#else
pk_set.clear();
while (curr_offset - base_offset < topk) {
result_pairs[0].reset_distance();
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
auto& result_pair = result_pairs[0];
auto index = result_pair.index_;
int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_];
auto& pilot = result_pairs[0];
auto index = pilot.index_;
int64_t curr_pk = pilot.primary_key_;
// remove duplicates
if (curr_pk == INVALID_ID || pk_set.count(curr_pk) == 0) {
result_pair.search_result_->result_offsets_.push_back(curr_offset++);
search_records[index].push_back(result_pair.offset_++);
pilot.search_result_->result_offsets_.push_back(curr_offset++);
// when inserted data are dirty, it's possible that primary keys are duplicated,
// in this case, "offset_" may be greater than "offset_rb_" (#10530)
search_records[index].push_back(pilot.offset_ < pilot.offset_rb_ ? pilot.offset_ : INVALID_OFFSET);
if (curr_pk != INVALID_ID) {
pk_set.insert(curr_pk);
}
} else {
// skip entity with same primary key
result_pair.offset_++;
skip_dup_cnt++;
}
pilot.reset();
}
#endif
}
@ -128,12 +144,20 @@ ResetSearchResult(std::vector<std::vector<int64_t>>& search_records, std::vector
std::vector<int64_t> primary_keys;
std::vector<float> result_distances;
std::vector<int64_t> internal_seg_offsets;
int64_t primary_key;
float distance;
int64_t internal_seg_offset;
for (int j = 0; j < search_records[i].size(); j++) {
auto& offset = search_records[i][j];
auto primary_key = search_result->primary_keys_[offset];
auto distance = search_result->result_distances_[offset];
auto internal_seg_offset = search_result->internal_seg_offsets_[offset];
if (offset != INVALID_OFFSET) {
primary_key = search_result->primary_keys_[offset];
distance = search_result->result_distances_[offset];
internal_seg_offset = search_result->internal_seg_offsets_[offset];
} else {
primary_key = INVALID_ID;
distance = MAXFLOAT;
internal_seg_offset = INVALID_SEG_OFFSET;
}
primary_keys.push_back(primary_key);
result_distances.push_back(distance);
internal_seg_offsets.push_back(internal_seg_offset);
@ -152,8 +176,6 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
std::vector<SearchResult*> search_results;
for (int i = 0; i < num_segments; ++i) {
search_results.push_back((SearchResult*)c_search_results[i]);
LOG_SEGCORE_DEBUG_ << "No." << i << ": search result addr " << c_search_results[i] << ", segment addr "
<< search_results[i]->segment_;
}
auto topk = search_results[0]->topk_;
auto num_queries = search_results[0]->num_queries_;
@ -164,18 +186,15 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_);
segment->FillPrimaryKeys(plan, *search_result);
}
LOG_SEGCORE_DEBUG_ << "Fill primary key done";
GetResultData(search_records, search_results, num_queries, topk);
ResetSearchResult(search_records, search_results);
LOG_SEGCORE_DEBUG_ << "Search result reduce done";
// fill in other entities
for (auto& search_result : search_results) {
auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_);
segment->FillTargetEntry(plan, *search_result);
}
LOG_SEGCORE_DEBUG_ << "Fill target entry done";
auto status = CStatus();
status.error_code = Success;

View File

@ -53,14 +53,12 @@ NewSegment(CCollection collection, uint64_t segment_id, SegmentType seg_type) {
void
DeleteSegment(CSegmentInterface c_segment) {
// TODO: use dynamic cast, and return c status
LOG_SEGCORE_DEBUG_ << "delete segment " << c_segment;
auto s = (milvus::segcore::SegmentInterface*)c_segment;
delete s;
}
void
DeleteSearchResult(CSearchResult search_result) {
LOG_SEGCORE_DEBUG_ << "delete search result " << search_result;
auto res = (milvus::SearchResult*)search_result;
delete res;
}

View File

@ -10,98 +10,24 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include "common/Consts.h"
#include "segcore/ReduceStructure.h"
TEST(SearchResultPair, Less) {
auto pair1 = SearchResultPair(1.0, nullptr, 0, 0);
auto pair2 = SearchResultPair(1.0, nullptr, 0, 0);
ASSERT_EQ(pair1 < pair2, false);
ASSERT_EQ(pair1.operator<(pair2), false);
pair1.distance_ = 1.0;
pair2.distance_ = 2.0;
ASSERT_EQ(pair1 < pair2, true);
ASSERT_EQ(pair1.operator<(pair2), true);
pair1.distance_ = 1.0;
pair2.distance_ = NAN;
ASSERT_EQ(pair1 < pair2, true);
ASSERT_EQ(pair1.operator<(pair2), true);
pair1.distance_ = 2.0;
pair2.distance_ = 1.0;
ASSERT_EQ(pair1 < pair2, false);
ASSERT_EQ(pair1.operator<(pair2), false);
pair1.distance_ = 2.0;
pair2.distance_ = 2.0;
ASSERT_EQ(pair1 < pair2, false);
ASSERT_EQ(pair1.operator<(pair2), false);
pair1.distance_ = 2.0;
pair2.distance_ = NAN;
ASSERT_EQ(pair1 < pair2, true);
ASSERT_EQ(pair1.operator<(pair2), true);
pair1.distance_ = NAN;
pair2.distance_ = 1.0;
ASSERT_EQ(pair1 < pair2, false);
ASSERT_EQ(pair1.operator<(pair2), false);
pair1.distance_ = NAN;
pair2.distance_ = 2.0;
ASSERT_EQ(pair1 < pair2, false);
ASSERT_EQ(pair1.operator<(pair2), false);
pair1.distance_ = NAN;
pair2.distance_ = NAN;
ASSERT_EQ(pair1 < pair2, true);
ASSERT_EQ(pair1.operator<(pair2), true);
}
TEST(SearchResultPair, Greater) {
auto pair1 = SearchResultPair(1.0, nullptr, 0, 0);
auto pair2 = SearchResultPair(1.0, nullptr, 0, 0);
auto pair1 = SearchResultPair(0, 1.0, nullptr, 0, 0, 10);
auto pair2 = SearchResultPair(1, 2.0, nullptr, 1, 0, 10);
ASSERT_EQ(pair1 > pair2, false);
ASSERT_EQ(pair1.operator>(pair2), false);
pair1.distance_ = 1.0;
pair2.distance_ = 2.0;
pair1.primary_key_ = INVALID_ID;
pair2.primary_key_ = 1;
ASSERT_EQ(pair1 > pair2, false);
ASSERT_EQ(pair1.operator>(pair2), false);
pair1.distance_ = 1.0;
pair2.distance_ = NAN;
pair1.primary_key_ = 0;
pair2.primary_key_ = INVALID_ID;
ASSERT_EQ(pair1 > pair2, true);
ASSERT_EQ(pair1.operator>(pair2), true);
pair1.distance_ = 2.0;
pair2.distance_ = 1.0;
ASSERT_EQ(pair1 > pair2, true);
ASSERT_EQ(pair1.operator>(pair2), true);
pair1.distance_ = 2.0;
pair2.distance_ = 2.0;
pair1.primary_key_ = INVALID_ID;
pair2.primary_key_ = INVALID_ID;
ASSERT_EQ(pair1 > pair2, false);
ASSERT_EQ(pair1.operator>(pair2), false);
pair1.distance_ = 2.0;
pair2.distance_ = NAN;
ASSERT_EQ(pair1 > pair2, true);
ASSERT_EQ(pair1.operator>(pair2), true);
pair1.distance_ = NAN;
pair2.distance_ = 1.0;
ASSERT_EQ(pair1 > pair2, false);
ASSERT_EQ(pair1.operator>(pair2), false);
pair1.distance_ = NAN;
pair2.distance_ = 2.0;
ASSERT_EQ(pair1 > pair2, false);
ASSERT_EQ(pair1.operator>(pair2), false);
pair1.distance_ = NAN;
pair2.distance_ = NAN;
ASSERT_EQ(pair1 > pair2, true);
ASSERT_EQ(pair1.operator>(pair2), true);
}