diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index d2954cb8af..008dd58e5e 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -11,12 +11,14 @@ #pragma once -#include -#include "exceptions/EasyAssert.h" -#include "config/ConfigChunkManager.h" -#include "common/Consts.h" #include +#include + +#include "common/Consts.h" +#include "config/ConfigChunkManager.h" +#include "exceptions/EasyAssert.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" namespace milvus { @@ -96,4 +98,14 @@ upper_div(int64_t value, int64_t align) { return groups; } +inline bool +IsMetricType(const std::string& str, const knowhere::MetricType& metric_type) { + return !strcasecmp(str.c_str(), metric_type.c_str()); +} + +inline bool +PositivelyRelated(const knowhere::MetricType& metric_type) { + return IsMetricType(metric_type, knowhere::metric::IP); +} + } // namespace milvus diff --git a/internal/core/src/query/SubSearchResult.cpp b/internal/core/src/query/SubSearchResult.cpp index 19a88d877e..f46b831b6c 100644 --- a/internal/core/src/query/SubSearchResult.cpp +++ b/internal/core/src/query/SubSearchResult.cpp @@ -22,7 +22,7 @@ SubSearchResult::merge_impl(const SubSearchResult& right) { AssertInfo(num_queries_ == right.num_queries_, "[SubSearchResult]Nq check failed"); AssertInfo(topk_ == right.topk_, "[SubSearchResult]Topk check failed"); AssertInfo(metric_type_ == right.metric_type_, "[SubSearchResult]Metric type check failed"); - AssertInfo(is_desc == is_descending(metric_type_), "[SubSearchResult]Metric type isn't desc"); + AssertInfo(is_desc == PositivelyRelated(metric_type_), "[SubSearchResult]Metric type isn't desc"); for (int64_t qn = 0; qn < num_queries_; ++qn) { auto offset = qn * topk_; @@ -61,7 +61,7 @@ SubSearchResult::merge_impl(const SubSearchResult& right) { void SubSearchResult::merge(const SubSearchResult& sub_result) { AssertInfo(metric_type_ == sub_result.metric_type_, "[SubSearchResult]Metric type check failed when merge"); - if (is_descending(metric_type_)) { + if (PositivelyRelated(metric_type_)) { this->merge_impl(sub_result); } else { this->merge_impl(sub_result); diff --git a/internal/core/src/query/SubSearchResult.h b/internal/core/src/query/SubSearchResult.h index cac498cb7d..0cedf870d8 100644 --- a/internal/core/src/query/SubSearchResult.h +++ b/internal/core/src/query/SubSearchResult.h @@ -14,7 +14,9 @@ #include #include #include + #include "common/Types.h" +#include "common/Utils.h" namespace milvus::query { @@ -41,17 +43,7 @@ class SubSearchResult { public: static float init_value(const MetricType& metric_type) { - return (is_descending(metric_type) ? -1 : 1) * std::numeric_limits::max(); - } - - static bool - is_descending(const MetricType& metric_type) { - // TODO(dog): more types - if (metric_type == knowhere::metric::IP) { - return true; - } else { - return false; - } + return (PositivelyRelated(metric_type) ? -1 : 1) * std::numeric_limits::max(); } public: diff --git a/internal/core/src/segcore/SimilarityCorelation.h b/internal/core/src/segcore/SimilarityCorelation.h deleted file mode 100644 index 8ccf84e6c0..0000000000 --- a/internal/core/src/segcore/SimilarityCorelation.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once - -#include "common/Types.h" - -namespace milvus::segcore { -static inline bool -PositivelyRelated(const MetricType& metric_type) { - return metric_type == knowhere::metric::IP; -} -} // namespace milvus::segcore diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index 8ad40ecdf3..bdd87bb582 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -18,7 +18,6 @@ #include "segcore/Collection.h" #include "segcore/SegmentGrowingImpl.h" #include "segcore/SegmentSealedImpl.h" -#include "segcore/SimilarityCorelation.h" #include "segcore/segment_c.h" #include "index/IndexInfo.h" #include "google/protobuf/text_format.h" @@ -72,7 +71,7 @@ Search(CSegmentInterface c_segment, auto plan = (milvus::query::Plan*)c_plan; auto phg_ptr = reinterpret_cast(c_placeholder_group); auto search_result = segment->Search(plan, phg_ptr, timestamp); - if (!milvus::segcore::PositivelyRelated(plan->plan_node_->search_info_.metric_type_)) { + if (!milvus::PositivelyRelated(plan->plan_node_->search_info_.metric_type_)) { for (auto& dis : search_result->distances_) { dis *= -1; } diff --git a/internal/core/unittest/test_bf.cpp b/internal/core/unittest/test_bf.cpp index 181fd00fdc..c7d157ee02 100644 --- a/internal/core/unittest/test_bf.cpp +++ b/internal/core/unittest/test_bf.cpp @@ -11,7 +11,9 @@ #include #include -#include + +#include "common/Utils.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "query/SearchBruteForce.h" #include "test_utils/Distance.h" @@ -38,13 +40,13 @@ Distances(const float* base, int nb, int dim, const knowhere::MetricType& metric) { - if (metric == knowhere::metric::L2) { + if (milvus::IsMetricType(metric, knowhere::metric::L2)) { std::vector> res; for (int i = 0; i < nb; i++) { res.emplace_back(i, L2(base + i * dim, query, dim)); } return res; - } else if (metric == knowhere::metric::IP) { + } else if (milvus::IsMetricType(metric, knowhere::metric::IP)) { std::vector> res; for (int i = 0; i < nb; i++) { res.emplace_back(i, IP(base + i * dim, query, dim)); @@ -75,8 +77,9 @@ Ref(const float* base, const knowhere::MetricType& metric) { auto res = Distances(base, query, nb, dim, metric); std::sort(res.begin(), res.end()); - if (metric == knowhere::metric::L2) { - } else if (metric == knowhere::metric::IP) { + if (milvus::IsMetricType(metric, knowhere::metric::L2)) { + // do nothing + } else if (milvus::IsMetricType(metric, knowhere::metric::IP)) { std::reverse(res.begin(), res.end()); } else { PanicInfo("invalid metric type"); @@ -95,8 +98,8 @@ AssertMatch(const std::vector& ref, const int64_t* ans) { } bool -is_supported_float_metric(const knowhere::MetricType& metric) { - return metric == knowhere::metric::L2 || metric == knowhere::metric::IP; +is_supported_float_metric(const std::string& metric) { + return milvus::IsMetricType(metric, knowhere::metric::L2) || milvus::IsMetricType(metric, knowhere::metric::IP); } } // namespace @@ -127,11 +130,13 @@ class TestFloatSearchBruteForce : public ::testing::Test { }; TEST_F(TestFloatSearchBruteForce, L2) { - Run(100, 10, 5, 128, knowhere::metric::L2); + Run(100, 10, 5, 128, "L2"); + Run(100, 10, 5, 128, "l2"); } TEST_F(TestFloatSearchBruteForce, IP) { - Run(100, 10, 5, 128, knowhere::metric::IP); + Run(100, 10, 5, 128, "IP"); + Run(100, 10, 5, 128, "ip"); } TEST_F(TestFloatSearchBruteForce, NotSupported) { diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index d3c552169f..1d50258f97 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -19,6 +19,7 @@ #include "query/generated/ExprVisitor.h" #include "query/generated/ShowPlanNodeVisitor.h" #include "segcore/SegmentSealed.h" +#include "test_utils/AssertUtils.h" #include "test_utils/DataGen.h" using namespace milvus; @@ -504,7 +505,7 @@ TEST(Query, ExecWithoutPredicate) { { "vector": { "fakevec": { - "metric_type": "L2", + "metric_type": "l2", "params": { "nprobe": 10 }, @@ -530,6 +531,7 @@ TEST(Query, ExecWithoutPredicate) { Timestamp time = 1000000; auto sr = segment->Search(plan.get(), ph_group.get(), time); + assert_order(*sr, "l2"); std::vector> results; int topk = 5; auto json = SearchResultToJson(*sr); @@ -572,7 +574,7 @@ TEST(Indexing, InnerProduct) { { "vector": { "normalized": { - "metric_type": "IP", + "metric_type": "ip", "params": { "nprobe": 10 }, @@ -599,6 +601,7 @@ TEST(Indexing, InnerProduct) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); Timestamp ts = N * 2; auto sr = segment->Search(plan.get(), ph_group.get(), ts); + assert_order(*sr, "ip"); std::cout << SearchResultToJson(*sr).dump(2); } diff --git a/internal/core/unittest/test_similarity_corelation.cpp b/internal/core/unittest/test_similarity_corelation.cpp index 3081d3c11a..e2122999e2 100644 --- a/internal/core/unittest/test_similarity_corelation.cpp +++ b/internal/core/unittest/test_similarity_corelation.cpp @@ -11,15 +11,15 @@ #include -#include "segcore/SimilarityCorelation.h" +#include "common/Utils.h" TEST(SimilarityCorelation, Naive) { - ASSERT_TRUE(milvus::segcore::PositivelyRelated(knowhere::metric::IP)); + ASSERT_TRUE(milvus::PositivelyRelated(knowhere::metric::IP)); - ASSERT_FALSE(milvus::segcore::PositivelyRelated(knowhere::metric::L2)); - ASSERT_FALSE(milvus::segcore::PositivelyRelated(knowhere::metric::HAMMING)); - ASSERT_FALSE(milvus::segcore::PositivelyRelated(knowhere::metric::JACCARD)); - ASSERT_FALSE(milvus::segcore::PositivelyRelated(knowhere::metric::TANIMOTO)); - ASSERT_FALSE(milvus::segcore::PositivelyRelated(knowhere::metric::SUBSTRUCTURE)); - ASSERT_FALSE(milvus::segcore::PositivelyRelated(knowhere::metric::SUPERSTRUCTURE)); + ASSERT_FALSE(milvus::PositivelyRelated(knowhere::metric::L2)); + ASSERT_FALSE(milvus::PositivelyRelated(knowhere::metric::HAMMING)); + ASSERT_FALSE(milvus::PositivelyRelated(knowhere::metric::JACCARD)); + ASSERT_FALSE(milvus::PositivelyRelated(knowhere::metric::TANIMOTO)); + ASSERT_FALSE(milvus::PositivelyRelated(knowhere::metric::SUBSTRUCTURE)); + ASSERT_FALSE(milvus::PositivelyRelated(knowhere::metric::SUPERSTRUCTURE)); } diff --git a/internal/core/unittest/test_utils/AssertUtils.h b/internal/core/unittest/test_utils/AssertUtils.h index 1bfeae3ff6..0502963a5f 100644 --- a/internal/core/unittest/test_utils/AssertUtils.h +++ b/internal/core/unittest/test_utils/AssertUtils.h @@ -25,6 +25,7 @@ compare_float(float x, float y, float epsilon = 0.000001f) { return true; return false; } + bool compare_double(double x, double y, double epsilon = 0.000001f) { if (fabs(x - y) < epsilon) @@ -32,6 +33,34 @@ compare_double(double x, double y, double epsilon = 0.000001f) { return false; } +inline void +assert_order(const milvus::SearchResult& result, const knowhere::MetricType& metric_type) { + bool dsc = milvus::PositivelyRelated(metric_type); + auto& ids = result.seg_offsets_; + auto& dist = result.distances_; + auto nq = result.total_nq_; + auto topk = result.unity_topK_; + if (dsc) { + for (int i = 0; i < nq; i++) { + for (int j = 1; j < topk; j++) { + auto idx = i * topk + j; + if (ids[idx] != -1) { + ASSERT_GE(dist[idx - 1], dist[idx]); + } + } + } + } else { + for (int i = 0; i < nq; i++) { + for (int j = 1; j < topk; j++) { + auto idx = i * topk + j; + if (ids[idx] != -1) { + ASSERT_LE(dist[idx - 1], dist[idx]); + } + } + } + } +} + template inline void assert_in(ScalarIndex* index, const std::vector& arr) { diff --git a/internal/core/unittest/test_utils/indexbuilder_test_utils.h b/internal/core/unittest/test_utils/indexbuilder_test_utils.h index 4a8ed25ade..80d65fe758 100644 --- a/internal/core/unittest/test_utils/indexbuilder_test_utils.h +++ b/internal/core/unittest/test_utils/indexbuilder_test_utils.h @@ -240,9 +240,9 @@ CountDistance( if (point_a == nullptr || point_b == nullptr) { return std::numeric_limits::max(); } - if (metric == knowhere::metric::L2) { + if (milvus::IsMetricType(metric, knowhere::metric::L2)) { return L2(static_cast(point_a), static_cast(point_b), dim); - } else if (metric == knowhere::metric::JACCARD) { + } else if (milvus::IsMetricType(metric, knowhere::metric::JACCARD)) { return Jaccard(static_cast(point_a), static_cast(point_b), dim); } else { return std::numeric_limits::max();