From 1b3c4b26f12b5180e995dc2a478c082178938246 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Tue, 13 Jun 2023 19:22:38 +0800 Subject: [PATCH] Optimize range search result sort in segcore (#24837) Signed-off-by: Yudong Cai --- .../core/src/common/RangeSearchHelper.cpp | 114 ++++++++++-------- internal/core/src/common/RangeSearchHelper.h | 11 +- internal/core/src/index/VectorDiskIndex.cpp | 2 +- internal/core/src/index/VectorMemIndex.cpp | 2 +- internal/core/src/query/SearchBruteForce.cpp | 2 +- .../core/unittest/test_range_search_sort.cpp | 51 ++++---- 6 files changed, 94 insertions(+), 88 deletions(-) diff --git a/internal/core/src/common/RangeSearchHelper.cpp b/internal/core/src/common/RangeSearchHelper.cpp index 126de23d6a..4fb0014cc0 100644 --- a/internal/core/src/common/RangeSearchHelper.cpp +++ b/internal/core/src/common/RangeSearchHelper.cpp @@ -12,37 +12,43 @@ #include #include #include -#include #include "common/Utils.h" #include "common/RangeSearchHelper.h" namespace milvus { + namespace { using ResultPair = std::pair; } + +/* Sort and return TOPK items as final range search result */ DatasetPtr -SortRangeSearchResult(DatasetPtr data_set, - int64_t topk, - int64_t nq, - const std::string_view metric_type) { +ReGenRangeSearchResult(DatasetPtr data_set, + int64_t topk, + int64_t nq, + const std::string& metric_type) { /** * nq: number of queries; * lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result queries[i] - * for example, the nq is 5. In the selected range, - * the size of RangeSearch result for each nq is [1, 2, 3, 4, 5], - * the lims will be [0, 1, 3, 6, 10, 15]; + * for example, the nq is 5. In the selected range, + * the size of RangeSearch result for each nq is [1, 2, 3, 4, 5], + * the lims will be [0, 1, 3, 6, 10, 15]; * ids: the size of ids is lim[nq], - * { i(0,0), i(0,1), …, i(0,k0-1), + * { + * i(0,0), i(0,1), …, i(0,k0-1), * i(1,0), i(1,1), …, i(1,k1-1), - * …, - * i(n-1,0), i(n-1,1), …, i(n-1,kn-1)}, + * ... ... + * i(n-1,0), i(n-1,1), …, i(n-1,kn-1) + * } * i(0,0), i(0,1), …, i(0,k0-1) means the ids of RangeSearch result queries[0], k0 equals lim[1] - lim[0]; * dist: the size of ids is lim[nq], - * { d(0,0), d(0,1), …, d(0,k0-1), + * { + * d(0,0), d(0,1), …, d(0,k0-1), * d(1,0), d(1,1), …, d(1,k1-1), - * …, - * d(n-1,0), d(n-1,1), …, d(n-1,kn-1)}, + * ... ... + * d(n-1,0), d(n-1,1), …, d(n-1,kn-1) + * } * d(0,0), d(0,1), …, d(0,k0-1) means the distances of RangeSearch result queries[0], k0 equals lim[1] - lim[0]; */ auto lims = GetDatasetLims(data_set); @@ -51,50 +57,53 @@ SortRangeSearchResult(DatasetPtr data_set, // use p_id and p_dist to GenResultDataset after sorted auto p_id = new int64_t[topk * nq]; - memset(p_id, -1, sizeof(int64_t) * topk * nq); auto p_dist = new float[topk * nq]; + std::fill_n(p_id, topk * nq, -1); std::fill_n(p_dist, topk * nq, std::numeric_limits::max()); /* - * get result for one nq - * IP: 1.0 range_filter radius - * |------------+---------------| min_heap descending_order - * L2: 0.0 range_filter radius - * |------------+---------------| max_heap ascending_order - * - */ + * get result for one nq + * IP: 1.0 range_filter radius + * |------------+---------------| min_heap descending_order + * |___ ___| + * V + * topk + * + * L2: 0.0 range_filter radius + * |------------+---------------| max_heap ascending_order + * |___ ___| + * V + * topk + */ std::function cmp = std::less<>(); - if (IsMetricType(metric_type, knowhere::metric::IP)) { + if (PositivelyRelated(metric_type)) { cmp = std::greater<>(); } - std::priority_queue, decltype(cmp)> - sub_result(cmp); // The subscript of p_id and p_dist - int cnt = 0; +#pragma omp parallel for for (int i = 0; i < nq; i++) { - // if RangeSearch answer size of one nq is less than topk, set the capacity to size - int size = lims[i + 1] - lims[i]; - int capacity = topk > size ? size : topk; + std::priority_queue, decltype(cmp)> + pq(cmp); + auto capacity = std::min(lims[i + 1] - lims[i], topk); for (int j = lims[i]; j < lims[i + 1]; j++) { - auto current = ResultPair(dist[j], id[j]); - if (sub_result.size() == capacity) { - if (cmp(sub_result.top(), current)) { - current = sub_result.top(); - } - sub_result.pop(); + auto curr = ResultPair(dist[j], id[j]); + if (pq.size() < capacity) { + pq.push(curr); + } else if (cmp(curr, pq.top())) { + pq.pop(); + pq.push(curr); } - sub_result.push(current); } - for (int i = capacity + cnt - 1; i > cnt - 1; i--) { - p_dist[i] = sub_result.top().first; - p_id[i] = sub_result.top().second; - sub_result.pop(); + for (int j = capacity - 1; j >= 0; j--) { + auto& node = pq.top(); + p_dist[i * topk + j] = node.first; + p_id[i * topk + j] = node.second; + pq.pop(); } - cnt += topk; } return GenResultDataset(nq, topk, p_id, p_dist); } @@ -102,22 +111,27 @@ SortRangeSearchResult(DatasetPtr data_set, void CheckRangeSearchParam(float radius, float range_filter, - const std::string_view metric_type) { + const std::string& metric_type) { /* * IP: 1.0 range_filter radius - * |------------+---------------| min_heap descending_order - * L2: 1.0 radius range_filter - * |------------+---------------| max_heap ascending_order + * |------------+---------------| range_filter > radius + * L2: 0.0 range_filter radius + * |------------+---------------| range_filter < radius * */ - if (metric_type == knowhere::metric::IP) { - if (range_filter < radius) { - PanicInfo("range_filter must more than radius when IP"); + if (PositivelyRelated(metric_type)) { + if (range_filter <= radius) { + PanicInfo( + "range_filter must be greater than or equal to radius for IP " + "and COSINE"); } } else { - if (range_filter > radius) { - PanicInfo("range_filter must less than radius except IP"); + if (range_filter >= radius) { + PanicInfo( + "range_filter must be less than or equal to radius for " + "L2/HAMMING/JACCARD/TANIMOTO"); } } } + } // namespace milvus diff --git a/internal/core/src/common/RangeSearchHelper.h b/internal/core/src/common/RangeSearchHelper.h index b401be04c3..c2ac044e76 100644 --- a/internal/core/src/common/RangeSearchHelper.h +++ b/internal/core/src/common/RangeSearchHelper.h @@ -17,13 +17,14 @@ namespace milvus { DatasetPtr -SortRangeSearchResult(DatasetPtr data_set, - int64_t topk, - int64_t nq, - const std::string_view metric_type); +ReGenRangeSearchResult(DatasetPtr data_set, + int64_t topk, + int64_t nq, + const std::string& metric_type); void CheckRangeSearchParam(float radius, float range_filter, - const std::string_view metric_type); + const std::string& metric_type); + } // namespace milvus diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index ff75111864..8fd6d73636 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -192,7 +192,7 @@ VectorDiskAnnIndex::Query(const DatasetPtr dataset, "failed to range search, " + MatchKnowhereError(res.error())); } - return SortRangeSearchResult( + return ReGenRangeSearchResult( res.value(), topk, num_queries, GetMetricType()); } else { auto res = index_.Search(*dataset, search_config, bitset); diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index ed263bb51b..c9f5964209 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -122,7 +122,7 @@ VectorMemIndex::Query(const DatasetPtr dataset, "failed to range search, " + MatchKnowhereError(res.error())); } - return SortRangeSearchResult( + return ReGenRangeSearchResult( res.value(), topk, num_queries, GetMetricType()); } else { auto res = index_.Search(*dataset, search_conf, bitset); diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 8afde0e032..fb55590b0f 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -77,7 +77,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset, "failed to range search, " + MatchKnowhereError(res.error())); } - auto result = SortRangeSearchResult( + auto result = ReGenRangeSearchResult( res.value(), topk, nq, dataset.metric_type); std::copy_n( GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets()); diff --git a/internal/core/unittest/test_range_search_sort.cpp b/internal/core/unittest/test_range_search_sort.cpp index 7a6e6ce946..c81221af80 100644 --- a/internal/core/unittest/test_range_search_sort.cpp +++ b/internal/core/unittest/test_range_search_sort.cpp @@ -22,12 +22,12 @@ #include "test_utils/indexbuilder_test_utils.h" bool -cmp1(std::pair a, std::pair b) { +greater(std::pair a, std::pair b) { return a.first > b.first; } bool -cmp2(std::pair a, std::pair b) { +less(std::pair a, std::pair b) { return a.first < b.first; } @@ -35,7 +35,7 @@ auto RangeSearchSortResultBF(milvus::DatasetPtr data_set, int64_t topk, size_t nq, - std::string metric_type) { + std::string& metric_type) { auto lims = milvus::GetDatasetLims(data_set); auto id = milvus::GetDatasetIDs(data_set); auto dist = milvus::GetDatasetDistance(data_set); @@ -43,32 +43,26 @@ RangeSearchSortResultBF(milvus::DatasetPtr data_set, memset(p_id, -1, sizeof(int64_t) * topk * nq); auto p_dist = new float[topk * nq]; std::fill_n(p_dist, topk * nq, std::numeric_limits::max()); + + auto cmp_func = (milvus::PositivelyRelated(metric_type)) ? greater : less; + // cnt means the subscript of p_id and p_dist - int cnt = 0; for (int i = 0; i < nq; i++) { - auto size = lims[i + 1] - lims[i]; - int capacity = topk > size ? size : topk; + auto capacity = std::min(lims[i + 1] - lims[i], topk); + // sort each layer std::vector> list; - if (milvus::IsMetricType(metric_type, knowhere::metric::IP)) { - for (int j = lims[i]; j < lims[i + 1]; j++) { - list.push_back(std::pair(dist[j], id[j])); - } - std::sort(list.begin(), list.end(), cmp1); + for (int j = lims[i]; j < lims[i + 1]; j++) { + list.emplace_back(dist[j], id[j]); + } + std::sort(list.begin(), list.end(), cmp_func); - } else { - for (int j = lims[i]; j < lims[i + 1]; j++) { - list.push_back(std::pair(dist[j], id[j])); - } - std::sort(list.begin(), list.end(), cmp2); + for (int k = 0; k < capacity; k++) { + p_dist[i * topk + k] = list[k].first; + p_id[i * topk + k] = list[k].second; } - for (int k = cnt; k < capacity + cnt; k++) { - p_dist[k] = list[k - cnt].first; - p_id[k] = list[k - cnt].second; - } - cnt += topk; } - return std::make_tuple(cnt, p_id, p_dist); + return std::make_tuple(p_id, p_dist); } milvus::DatasetPtr @@ -93,10 +87,8 @@ CheckRangeSearchSortResult(int64_t* p_id, auto id = milvus::GetDatasetIDs(dataset); auto dist = milvus::GetDatasetDistance(dataset); for (int i = 0; i < n; i++) { - AssertInfo(id[i] == p_id[i], - "id of range search result are not the same"); - AssertInfo(dist[i] == p_dist[i], - "distance of range search result are not the same"); + AssertInfo(id[i] == p_id[i], "id of range search result not same"); + AssertInfo(dist[i] == p_dist[i], "distance of range search result not same"); } } @@ -173,10 +165,9 @@ INSTANTIATE_TEST_CASE_P(RangeSearchSortParameters, knowhere::metric::HAMMING)); TEST_P(RangeSearchSortTest, CheckRangeSearchSort) { - auto res = milvus::SortRangeSearchResult(dataset, TOPK, N, metric_type); - auto [real_num, p_id, p_dist] = - RangeSearchSortResultBF(dataset, TOPK, N, metric_type); - CheckRangeSearchSortResult(p_id, p_dist, res, real_num); + auto res = milvus::ReGenRangeSearchResult(dataset, TOPK, N, metric_type); + auto [p_id, p_dist] = RangeSearchSortResultBF(dataset, TOPK, N, metric_type); + CheckRangeSearchSortResult(p_id, p_dist, res, N * TOPK); delete[] p_id; delete[] p_dist; }