Optimize rounding distances, avoid promoting to double (#22846)

Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
yah01 2023-03-20 14:09:56 +08:00 committed by GitHub
parent 8cf748a375
commit 3202eb0d9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 13 additions and 11 deletions

View File

@ -20,7 +20,7 @@ Checks: >
-*, clang-diagnostic-*, -clang-diagnostic-error,
clang-analyzer-*, -clang-analyzer-alpha*,
google-*, -google-runtime-references, -google-readability-todo,
modernize-*, -modernize-use-trailing-return-type,
modernize-*, -modernize-use-trailing-return-type, -modernize-use-nodiscard,
performance-*,
bugprone-bool-pointer-implicit-conversion,
bugprone-branch-clone,

View File

@ -216,7 +216,7 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
if (round_decimal != -1) {
const float multiplier = pow(10.0, round_decimal);
for (int i = 0; i < total_num; i++) {
distances[i] = round(distances[i] * multiplier) / multiplier;
distances[i] = std::round(distances[i] * multiplier) / multiplier;
}
}
auto result = std::make_unique<SearchResult>();

View File

@ -15,6 +15,8 @@
// limitations under the License.
#include "index/VectorMemIndex.h"
#include <cmath>
#include "index/Meta.h"
#include "index/Utils.h"
#include "exceptions/EasyAssert.h"
@ -129,7 +131,7 @@ VectorMemIndex::Query(const DatasetPtr dataset,
if (round_decimal != -1) {
const float multiplier = pow(10.0, round_decimal);
for (int i = 0; i < total_num; i++) {
distances[i] = round(distances[i] * multiplier) / multiplier;
distances[i] = std::round(distances[i] * multiplier) / multiplier;
}
}
auto result = std::make_unique<SearchResult>();

View File

@ -58,7 +58,7 @@ SearchOnSealedIndex(const Schema& schema,
if (round_decimal != -1) {
const float multiplier = pow(10.0, round_decimal);
for (int i = 0; i < total_num; i++) {
distances[i] = round(distances[i] * multiplier) / multiplier;
distances[i] = std::round(distances[i] * multiplier) / multiplier;
}
}
result.seg_offsets_.resize(total_num);

View File

@ -89,9 +89,8 @@ SubSearchResult::round_values() {
if (round_decimal_ == -1)
return;
const float multiplier = pow(10.0, round_decimal_);
for (auto it = this->distances_.begin(); it != this->distances_.end();
it++) {
*it = round(*it * multiplier) / multiplier;
for (float& distance : this->distances_) {
distance = std::round(distance * multiplier) / multiplier;
}
}

View File

@ -34,11 +34,11 @@ class SubSearchResult {
distances_(num_queries * topk, init_value(metric_type)) {
}
SubSearchResult(SubSearchResult&& other)
SubSearchResult(SubSearchResult&& other) noexcept
: num_queries_(other.num_queries_),
topk_(other.topk_),
round_decimal_(other.round_decimal_),
metric_type_(other.metric_type_),
metric_type_(std::move(other.metric_type_)),
seg_offsets_(std::move(other.seg_offsets_)),
distances_(std::move(other.distances_)) {
}

View File

@ -12,6 +12,7 @@
#pragma once
#include <limits>
#include <utility>
#include "common/Consts.h"
#include "common/Types.h"
@ -33,7 +34,7 @@ struct SearchResultPair {
int64_t index,
int64_t lb,
int64_t rb)
: primary_key_(primary_key),
: primary_key_(std::move(primary_key)),
distance_(distance),
search_result_(result),
segment_index_(index),
@ -43,7 +44,7 @@ struct SearchResultPair {
bool
operator>(const SearchResultPair& other) const {
if (fabs(distance_ - other.distance_) < 0.000001f) {
if (std::fabs(distance_ - other.distance_) < 0.000001f) {
return primary_key_ < other.primary_key_;
}
return distance_ > other.distance_;