mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 19:08:30 +08:00
Fix reduce panic (#11325)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
6024346195
commit
db2a0a3bd3
@ -13,5 +13,6 @@
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
const int64_t INVALID_SEG_OFFSET = -1;
|
||||
const int64_t INVALID_ID = -1;
|
||||
const int64_t INVALID_OFFSET = -1;
|
||||
const int64_t INVALID_SEG_OFFSET = -1;
|
||||
|
@ -9,32 +9,58 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <cmath> // std::isnan
|
||||
#include <common/Types.h>
|
||||
#include <cmath>
|
||||
|
||||
#include "common/Consts.h"
|
||||
#include "common/Types.h"
|
||||
#include "segcore/Reduce.h"
|
||||
|
||||
using milvus::SearchResult;
|
||||
|
||||
struct SearchResultPair {
|
||||
int64_t primary_key_;
|
||||
float distance_;
|
||||
milvus::SearchResult* search_result_;
|
||||
int64_t offset_;
|
||||
int64_t index_;
|
||||
int64_t offset_;
|
||||
int64_t offset_rb_; // right bound
|
||||
|
||||
SearchResultPair(float distance, milvus::SearchResult* search_result, int64_t offset, int64_t index)
|
||||
: distance_(distance), search_result_(search_result), offset_(offset), index_(index) {
|
||||
SearchResultPair(int64_t primary_key, float distance, SearchResult* result, int64_t index, int64_t lb, int64_t rb)
|
||||
: primary_key_(primary_key),
|
||||
distance_(distance),
|
||||
search_result_(result),
|
||||
index_(index),
|
||||
offset_(lb),
|
||||
offset_rb_(rb) {
|
||||
}
|
||||
|
||||
bool
|
||||
operator<(const SearchResultPair& pair) const {
|
||||
return std::isnan(pair.distance_) || (!std::isnan(distance_) && (distance_ < pair.distance_));
|
||||
}
|
||||
|
||||
bool
|
||||
operator>(const SearchResultPair& pair) const {
|
||||
return std::isnan(pair.distance_) || (!std::isnan(distance_) && (distance_ > pair.distance_));
|
||||
operator>(const SearchResultPair& other) const {
|
||||
if (this->primary_key_ == INVALID_ID) {
|
||||
return false;
|
||||
} else {
|
||||
if (other.primary_key_ == INVALID_ID) {
|
||||
return true;
|
||||
} else {
|
||||
return (distance_ > other.distance_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
reset_distance() {
|
||||
distance_ = search_result_->result_distances_[offset_];
|
||||
reset() {
|
||||
if (offset_ < offset_rb_) {
|
||||
offset_++;
|
||||
if (offset_ < offset_rb_) {
|
||||
primary_key_ = search_result_->primary_keys_.at(offset_);
|
||||
distance_ = search_result_->result_distances_.at(offset_);
|
||||
} else {
|
||||
primary_key_ = INVALID_ID;
|
||||
distance_ = MAXFLOAT;
|
||||
}
|
||||
} else {
|
||||
primary_key_ = INVALID_ID;
|
||||
distance_ = MAXFLOAT;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -22,11 +22,9 @@ SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan, SearchResult&
|
||||
AssertInfo(results.internal_seg_offsets_.size() == size,
|
||||
"Size of result distances is not equal to size of segment offsets");
|
||||
Assert(results.primary_keys_.size() == 0);
|
||||
|
||||
results.primary_keys_.resize(size);
|
||||
|
||||
auto element_sizeof = sizeof(int64_t);
|
||||
|
||||
aligned_vector<char> blob(size * element_sizeof);
|
||||
if (plan->schema_.get_is_auto_id()) {
|
||||
bulk_subscript(SystemFieldType::RowId, results.internal_seg_offsets_.data(), size, blob.data());
|
||||
@ -38,9 +36,7 @@ SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan, SearchResult&
|
||||
bulk_subscript(key_offset, results.internal_seg_offsets_.data(), size, blob.data());
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
results.primary_keys_[i] = *(int64_t*)(blob.data() + element_sizeof * i);
|
||||
}
|
||||
memcpy(results.primary_keys_.data(), blob.data(), element_sizeof * size);
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -59,8 +59,6 @@ class SegmentInterface {
|
||||
Delete(int64_t reserved_offset, int64_t size, const int64_t* row_ids, const Timestamp* timestamps) = 0;
|
||||
|
||||
virtual ~SegmentInterface() = default;
|
||||
|
||||
protected:
|
||||
};
|
||||
|
||||
// internal API for DSL calculation
|
||||
|
@ -55,6 +55,17 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
|
||||
delete hits;
|
||||
}
|
||||
|
||||
// void
|
||||
// PrintSearchResult(char* buf, const milvus::SearchResult* result, int64_t seg_idx, int64_t from, int64_t to) {
|
||||
// const int64_t MAXLEN = 32;
|
||||
// snprintf(buf + strlen(buf), MAXLEN, "{ seg No.%ld ", seg_idx);
|
||||
// for (int64_t i = from; i < to; i++) {
|
||||
// snprintf(buf + strlen(buf), MAXLEN, "(%ld, %ld, %f), ", i, result->primary_keys_[i],
|
||||
// result->result_distances_[i]);
|
||||
// }
|
||||
// snprintf(buf + strlen(buf), MAXLEN, "} ");
|
||||
//}
|
||||
|
||||
void
|
||||
GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
||||
std::vector<SearchResult*>& search_results,
|
||||
@ -72,8 +83,12 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
||||
for (int j = 0; j < num_segments; ++j) {
|
||||
auto search_result = search_results[j];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
AssertInfo(search_result->primary_keys_.size() == nq * topk, "incorrect search result primary key size");
|
||||
AssertInfo(search_result->result_distances_.size() == nq * topk, "incorrect search result distance size");
|
||||
auto primary_key = search_result->primary_keys_[base_offset];
|
||||
auto distance = search_result->result_distances_[base_offset];
|
||||
result_pairs.push_back(SearchResultPair(distance, search_result, base_offset, j));
|
||||
result_pairs.push_back(
|
||||
SearchResultPair(primary_key, distance, search_result, j, base_offset, base_offset + topk));
|
||||
}
|
||||
int64_t curr_offset = base_offset;
|
||||
|
||||
@ -89,23 +104,24 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
||||
#else
|
||||
pk_set.clear();
|
||||
while (curr_offset - base_offset < topk) {
|
||||
result_pairs[0].reset_distance();
|
||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||
auto& result_pair = result_pairs[0];
|
||||
auto index = result_pair.index_;
|
||||
int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_];
|
||||
auto& pilot = result_pairs[0];
|
||||
auto index = pilot.index_;
|
||||
int64_t curr_pk = pilot.primary_key_;
|
||||
// remove duplicates
|
||||
if (curr_pk == INVALID_ID || pk_set.count(curr_pk) == 0) {
|
||||
result_pair.search_result_->result_offsets_.push_back(curr_offset++);
|
||||
search_records[index].push_back(result_pair.offset_++);
|
||||
pilot.search_result_->result_offsets_.push_back(curr_offset++);
|
||||
// when inserted data are dirty, it's possible that primary keys are duplicated,
|
||||
// in this case, "offset_" may be greater than "offset_rb_" (#10530)
|
||||
search_records[index].push_back(pilot.offset_ < pilot.offset_rb_ ? pilot.offset_ : INVALID_OFFSET);
|
||||
if (curr_pk != INVALID_ID) {
|
||||
pk_set.insert(curr_pk);
|
||||
}
|
||||
} else {
|
||||
// skip entity with same primary key
|
||||
result_pair.offset_++;
|
||||
skip_dup_cnt++;
|
||||
}
|
||||
pilot.reset();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -128,12 +144,20 @@ ResetSearchResult(std::vector<std::vector<int64_t>>& search_records, std::vector
|
||||
std::vector<int64_t> primary_keys;
|
||||
std::vector<float> result_distances;
|
||||
std::vector<int64_t> internal_seg_offsets;
|
||||
|
||||
int64_t primary_key;
|
||||
float distance;
|
||||
int64_t internal_seg_offset;
|
||||
for (int j = 0; j < search_records[i].size(); j++) {
|
||||
auto& offset = search_records[i][j];
|
||||
auto primary_key = search_result->primary_keys_[offset];
|
||||
auto distance = search_result->result_distances_[offset];
|
||||
auto internal_seg_offset = search_result->internal_seg_offsets_[offset];
|
||||
if (offset != INVALID_OFFSET) {
|
||||
primary_key = search_result->primary_keys_[offset];
|
||||
distance = search_result->result_distances_[offset];
|
||||
internal_seg_offset = search_result->internal_seg_offsets_[offset];
|
||||
} else {
|
||||
primary_key = INVALID_ID;
|
||||
distance = MAXFLOAT;
|
||||
internal_seg_offset = INVALID_SEG_OFFSET;
|
||||
}
|
||||
primary_keys.push_back(primary_key);
|
||||
result_distances.push_back(distance);
|
||||
internal_seg_offsets.push_back(internal_seg_offset);
|
||||
@ -152,8 +176,6 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
|
||||
std::vector<SearchResult*> search_results;
|
||||
for (int i = 0; i < num_segments; ++i) {
|
||||
search_results.push_back((SearchResult*)c_search_results[i]);
|
||||
LOG_SEGCORE_DEBUG_ << "No." << i << ": search result addr " << c_search_results[i] << ", segment addr "
|
||||
<< search_results[i]->segment_;
|
||||
}
|
||||
auto topk = search_results[0]->topk_;
|
||||
auto num_queries = search_results[0]->num_queries_;
|
||||
@ -164,18 +186,15 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
|
||||
auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_);
|
||||
segment->FillPrimaryKeys(plan, *search_result);
|
||||
}
|
||||
LOG_SEGCORE_DEBUG_ << "Fill primary key done";
|
||||
|
||||
GetResultData(search_records, search_results, num_queries, topk);
|
||||
ResetSearchResult(search_records, search_results);
|
||||
LOG_SEGCORE_DEBUG_ << "Search result reduce done";
|
||||
|
||||
// fill in other entities
|
||||
for (auto& search_result : search_results) {
|
||||
auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_);
|
||||
segment->FillTargetEntry(plan, *search_result);
|
||||
}
|
||||
LOG_SEGCORE_DEBUG_ << "Fill target entry done";
|
||||
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
|
@ -53,14 +53,12 @@ NewSegment(CCollection collection, uint64_t segment_id, SegmentType seg_type) {
|
||||
void
|
||||
DeleteSegment(CSegmentInterface c_segment) {
|
||||
// TODO: use dynamic cast, and return c status
|
||||
LOG_SEGCORE_DEBUG_ << "delete segment " << c_segment;
|
||||
auto s = (milvus::segcore::SegmentInterface*)c_segment;
|
||||
delete s;
|
||||
}
|
||||
|
||||
void
|
||||
DeleteSearchResult(CSearchResult search_result) {
|
||||
LOG_SEGCORE_DEBUG_ << "delete search result " << search_result;
|
||||
auto res = (milvus::SearchResult*)search_result;
|
||||
delete res;
|
||||
}
|
||||
|
@ -10,98 +10,24 @@
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common/Consts.h"
|
||||
#include "segcore/ReduceStructure.h"
|
||||
|
||||
TEST(SearchResultPair, Less) {
|
||||
auto pair1 = SearchResultPair(1.0, nullptr, 0, 0);
|
||||
auto pair2 = SearchResultPair(1.0, nullptr, 0, 0);
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = 1.0;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 < pair2, true);
|
||||
ASSERT_EQ(pair1.operator<(pair2), true);
|
||||
|
||||
pair1.distance_ = 1.0;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 < pair2, true);
|
||||
ASSERT_EQ(pair1.operator<(pair2), true);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = 1.0;
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 < pair2, true);
|
||||
ASSERT_EQ(pair1.operator<(pair2), true);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = 1.0;
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 < pair2, false);
|
||||
ASSERT_EQ(pair1.operator<(pair2), false);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 < pair2, true);
|
||||
ASSERT_EQ(pair1.operator<(pair2), true);
|
||||
}
|
||||
|
||||
TEST(SearchResultPair, Greater) {
|
||||
auto pair1 = SearchResultPair(1.0, nullptr, 0, 0);
|
||||
auto pair2 = SearchResultPair(1.0, nullptr, 0, 0);
|
||||
auto pair1 = SearchResultPair(0, 1.0, nullptr, 0, 0, 10);
|
||||
auto pair2 = SearchResultPair(1, 2.0, nullptr, 1, 0, 10);
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = 1.0;
|
||||
pair2.distance_ = 2.0;
|
||||
pair1.primary_key_ = INVALID_ID;
|
||||
pair2.primary_key_ = 1;
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = 1.0;
|
||||
pair2.distance_ = NAN;
|
||||
pair1.primary_key_ = 0;
|
||||
pair2.primary_key_ = INVALID_ID;
|
||||
ASSERT_EQ(pair1 > pair2, true);
|
||||
ASSERT_EQ(pair1.operator>(pair2), true);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = 1.0;
|
||||
ASSERT_EQ(pair1 > pair2, true);
|
||||
ASSERT_EQ(pair1.operator>(pair2), true);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = 2.0;
|
||||
pair1.primary_key_ = INVALID_ID;
|
||||
pair2.primary_key_ = INVALID_ID;
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = 2.0;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 > pair2, true);
|
||||
ASSERT_EQ(pair1.operator>(pair2), true);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = 1.0;
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = 2.0;
|
||||
ASSERT_EQ(pair1 > pair2, false);
|
||||
ASSERT_EQ(pair1.operator>(pair2), false);
|
||||
|
||||
pair1.distance_ = NAN;
|
||||
pair2.distance_ = NAN;
|
||||
ASSERT_EQ(pair1 > pair2, true);
|
||||
ASSERT_EQ(pair1.operator>(pair2), true);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user