mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
Optimize range search result sort in segcore (#24837)
Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
028cbee519
commit
1b3c4b26f1
@ -12,37 +12,43 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
#include "common/Utils.h"
|
||||
#include "common/RangeSearchHelper.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
namespace {
|
||||
using ResultPair = std::pair<float, int64_t>;
|
||||
}
|
||||
|
||||
/* Sort and return TOPK items as final range search result */
|
||||
DatasetPtr
|
||||
SortRangeSearchResult(DatasetPtr data_set,
|
||||
int64_t topk,
|
||||
int64_t nq,
|
||||
const std::string_view metric_type) {
|
||||
ReGenRangeSearchResult(DatasetPtr data_set,
|
||||
int64_t topk,
|
||||
int64_t nq,
|
||||
const std::string& metric_type) {
|
||||
/**
|
||||
* nq: number of queries;
|
||||
* lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result queries[i]
|
||||
* for example, the nq is 5. In the selected range,
|
||||
* the size of RangeSearch result for each nq is [1, 2, 3, 4, 5],
|
||||
* the lims will be [0, 1, 3, 6, 10, 15];
|
||||
* for example, the nq is 5. In the selected range,
|
||||
* the size of RangeSearch result for each nq is [1, 2, 3, 4, 5],
|
||||
* the lims will be [0, 1, 3, 6, 10, 15];
|
||||
* ids: the size of ids is lim[nq],
|
||||
* { i(0,0), i(0,1), …, i(0,k0-1),
|
||||
* {
|
||||
* i(0,0), i(0,1), …, i(0,k0-1),
|
||||
* i(1,0), i(1,1), …, i(1,k1-1),
|
||||
* …,
|
||||
* i(n-1,0), i(n-1,1), …, i(n-1,kn-1)},
|
||||
* ... ...
|
||||
* i(n-1,0), i(n-1,1), …, i(n-1,kn-1)
|
||||
* }
|
||||
* i(0,0), i(0,1), …, i(0,k0-1) means the ids of RangeSearch result queries[0], k0 equals lim[1] - lim[0];
|
||||
* dist: the size of ids is lim[nq],
|
||||
* { d(0,0), d(0,1), …, d(0,k0-1),
|
||||
* {
|
||||
* d(0,0), d(0,1), …, d(0,k0-1),
|
||||
* d(1,0), d(1,1), …, d(1,k1-1),
|
||||
* …,
|
||||
* d(n-1,0), d(n-1,1), …, d(n-1,kn-1)},
|
||||
* ... ...
|
||||
* d(n-1,0), d(n-1,1), …, d(n-1,kn-1)
|
||||
* }
|
||||
* d(0,0), d(0,1), …, d(0,k0-1) means the distances of RangeSearch result queries[0], k0 equals lim[1] - lim[0];
|
||||
*/
|
||||
auto lims = GetDatasetLims(data_set);
|
||||
@ -51,50 +57,53 @@ SortRangeSearchResult(DatasetPtr data_set,
|
||||
|
||||
// use p_id and p_dist to GenResultDataset after sorted
|
||||
auto p_id = new int64_t[topk * nq];
|
||||
memset(p_id, -1, sizeof(int64_t) * topk * nq);
|
||||
auto p_dist = new float[topk * nq];
|
||||
std::fill_n(p_id, topk * nq, -1);
|
||||
std::fill_n(p_dist, topk * nq, std::numeric_limits<float>::max());
|
||||
|
||||
/*
|
||||
* get result for one nq
|
||||
* IP: 1.0 range_filter radius
|
||||
* |------------+---------------| min_heap descending_order
|
||||
* L2: 0.0 range_filter radius
|
||||
* |------------+---------------| max_heap ascending_order
|
||||
*
|
||||
*/
|
||||
* get result for one nq
|
||||
* IP: 1.0 range_filter radius
|
||||
* |------------+---------------| min_heap descending_order
|
||||
* |___ ___|
|
||||
* V
|
||||
* topk
|
||||
*
|
||||
* L2: 0.0 range_filter radius
|
||||
* |------------+---------------| max_heap ascending_order
|
||||
* |___ ___|
|
||||
* V
|
||||
* topk
|
||||
*/
|
||||
std::function<bool(const ResultPair&, const ResultPair&)> cmp =
|
||||
std::less<>();
|
||||
if (IsMetricType(metric_type, knowhere::metric::IP)) {
|
||||
if (PositivelyRelated(metric_type)) {
|
||||
cmp = std::greater<>();
|
||||
}
|
||||
std::priority_queue<ResultPair, std::vector<ResultPair>, decltype(cmp)>
|
||||
sub_result(cmp);
|
||||
|
||||
// The subscript of p_id and p_dist
|
||||
int cnt = 0;
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < nq; i++) {
|
||||
// if RangeSearch answer size of one nq is less than topk, set the capacity to size
|
||||
int size = lims[i + 1] - lims[i];
|
||||
int capacity = topk > size ? size : topk;
|
||||
std::priority_queue<ResultPair, std::vector<ResultPair>, decltype(cmp)>
|
||||
pq(cmp);
|
||||
auto capacity = std::min<int64_t>(lims[i + 1] - lims[i], topk);
|
||||
|
||||
for (int j = lims[i]; j < lims[i + 1]; j++) {
|
||||
auto current = ResultPair(dist[j], id[j]);
|
||||
if (sub_result.size() == capacity) {
|
||||
if (cmp(sub_result.top(), current)) {
|
||||
current = sub_result.top();
|
||||
}
|
||||
sub_result.pop();
|
||||
auto curr = ResultPair(dist[j], id[j]);
|
||||
if (pq.size() < capacity) {
|
||||
pq.push(curr);
|
||||
} else if (cmp(curr, pq.top())) {
|
||||
pq.pop();
|
||||
pq.push(curr);
|
||||
}
|
||||
sub_result.push(current);
|
||||
}
|
||||
|
||||
for (int i = capacity + cnt - 1; i > cnt - 1; i--) {
|
||||
p_dist[i] = sub_result.top().first;
|
||||
p_id[i] = sub_result.top().second;
|
||||
sub_result.pop();
|
||||
for (int j = capacity - 1; j >= 0; j--) {
|
||||
auto& node = pq.top();
|
||||
p_dist[i * topk + j] = node.first;
|
||||
p_id[i * topk + j] = node.second;
|
||||
pq.pop();
|
||||
}
|
||||
cnt += topk;
|
||||
}
|
||||
return GenResultDataset(nq, topk, p_id, p_dist);
|
||||
}
|
||||
@ -102,22 +111,27 @@ SortRangeSearchResult(DatasetPtr data_set,
|
||||
void
|
||||
CheckRangeSearchParam(float radius,
|
||||
float range_filter,
|
||||
const std::string_view metric_type) {
|
||||
const std::string& metric_type) {
|
||||
/*
|
||||
* IP: 1.0 range_filter radius
|
||||
* |------------+---------------| min_heap descending_order
|
||||
* L2: 1.0 radius range_filter
|
||||
* |------------+---------------| max_heap ascending_order
|
||||
* |------------+---------------| range_filter > radius
|
||||
* L2: 0.0 range_filter radius
|
||||
* |------------+---------------| range_filter < radius
|
||||
*
|
||||
*/
|
||||
if (metric_type == knowhere::metric::IP) {
|
||||
if (range_filter < radius) {
|
||||
PanicInfo("range_filter must more than radius when IP");
|
||||
if (PositivelyRelated(metric_type)) {
|
||||
if (range_filter <= radius) {
|
||||
PanicInfo(
|
||||
"range_filter must be greater than or equal to radius for IP "
|
||||
"and COSINE");
|
||||
}
|
||||
} else {
|
||||
if (range_filter > radius) {
|
||||
PanicInfo("range_filter must less than radius except IP");
|
||||
if (range_filter >= radius) {
|
||||
PanicInfo(
|
||||
"range_filter must be less than or equal to radius for "
|
||||
"L2/HAMMING/JACCARD/TANIMOTO");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus
|
||||
|
@ -17,13 +17,14 @@
|
||||
namespace milvus {
|
||||
|
||||
DatasetPtr
|
||||
SortRangeSearchResult(DatasetPtr data_set,
|
||||
int64_t topk,
|
||||
int64_t nq,
|
||||
const std::string_view metric_type);
|
||||
ReGenRangeSearchResult(DatasetPtr data_set,
|
||||
int64_t topk,
|
||||
int64_t nq,
|
||||
const std::string& metric_type);
|
||||
|
||||
void
|
||||
CheckRangeSearchParam(float radius,
|
||||
float range_filter,
|
||||
const std::string_view metric_type);
|
||||
const std::string& metric_type);
|
||||
|
||||
} // namespace milvus
|
||||
|
@ -192,7 +192,7 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
|
||||
"failed to range search, " +
|
||||
MatchKnowhereError(res.error()));
|
||||
}
|
||||
return SortRangeSearchResult(
|
||||
return ReGenRangeSearchResult(
|
||||
res.value(), topk, num_queries, GetMetricType());
|
||||
} else {
|
||||
auto res = index_.Search(*dataset, search_config, bitset);
|
||||
|
@ -122,7 +122,7 @@ VectorMemIndex::Query(const DatasetPtr dataset,
|
||||
"failed to range search, " +
|
||||
MatchKnowhereError(res.error()));
|
||||
}
|
||||
return SortRangeSearchResult(
|
||||
return ReGenRangeSearchResult(
|
||||
res.value(), topk, num_queries, GetMetricType());
|
||||
} else {
|
||||
auto res = index_.Search(*dataset, search_conf, bitset);
|
||||
|
@ -77,7 +77,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
"failed to range search, " +
|
||||
MatchKnowhereError(res.error()));
|
||||
}
|
||||
auto result = SortRangeSearchResult(
|
||||
auto result = ReGenRangeSearchResult(
|
||||
res.value(), topk, nq, dataset.metric_type);
|
||||
std::copy_n(
|
||||
GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
|
||||
|
@ -22,12 +22,12 @@
|
||||
#include "test_utils/indexbuilder_test_utils.h"
|
||||
|
||||
bool
|
||||
cmp1(std::pair<float, int64_t> a, std::pair<float, int64_t> b) {
|
||||
greater(std::pair<float, int64_t> a, std::pair<float, int64_t> b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
bool
|
||||
cmp2(std::pair<float, int64_t> a, std::pair<float, int64_t> b) {
|
||||
less(std::pair<float, int64_t> a, std::pair<float, int64_t> b) {
|
||||
return a.first < b.first;
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ auto
|
||||
RangeSearchSortResultBF(milvus::DatasetPtr data_set,
|
||||
int64_t topk,
|
||||
size_t nq,
|
||||
std::string metric_type) {
|
||||
std::string& metric_type) {
|
||||
auto lims = milvus::GetDatasetLims(data_set);
|
||||
auto id = milvus::GetDatasetIDs(data_set);
|
||||
auto dist = milvus::GetDatasetDistance(data_set);
|
||||
@ -43,32 +43,26 @@ RangeSearchSortResultBF(milvus::DatasetPtr data_set,
|
||||
memset(p_id, -1, sizeof(int64_t) * topk * nq);
|
||||
auto p_dist = new float[topk * nq];
|
||||
std::fill_n(p_dist, topk * nq, std::numeric_limits<float>::max());
|
||||
|
||||
auto cmp_func = (milvus::PositivelyRelated(metric_type)) ? greater : less;
|
||||
|
||||
// cnt means the subscript of p_id and p_dist
|
||||
int cnt = 0;
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto size = lims[i + 1] - lims[i];
|
||||
int capacity = topk > size ? size : topk;
|
||||
auto capacity = std::min<int64_t>(lims[i + 1] - lims[i], topk);
|
||||
|
||||
// sort each layer
|
||||
std::vector<std::pair<float, int64_t>> list;
|
||||
if (milvus::IsMetricType(metric_type, knowhere::metric::IP)) {
|
||||
for (int j = lims[i]; j < lims[i + 1]; j++) {
|
||||
list.push_back(std::pair<float, int64_t>(dist[j], id[j]));
|
||||
}
|
||||
std::sort(list.begin(), list.end(), cmp1);
|
||||
for (int j = lims[i]; j < lims[i + 1]; j++) {
|
||||
list.emplace_back(dist[j], id[j]);
|
||||
}
|
||||
std::sort(list.begin(), list.end(), cmp_func);
|
||||
|
||||
} else {
|
||||
for (int j = lims[i]; j < lims[i + 1]; j++) {
|
||||
list.push_back(std::pair<float, int64_t>(dist[j], id[j]));
|
||||
}
|
||||
std::sort(list.begin(), list.end(), cmp2);
|
||||
for (int k = 0; k < capacity; k++) {
|
||||
p_dist[i * topk + k] = list[k].first;
|
||||
p_id[i * topk + k] = list[k].second;
|
||||
}
|
||||
for (int k = cnt; k < capacity + cnt; k++) {
|
||||
p_dist[k] = list[k - cnt].first;
|
||||
p_id[k] = list[k - cnt].second;
|
||||
}
|
||||
cnt += topk;
|
||||
}
|
||||
return std::make_tuple(cnt, p_id, p_dist);
|
||||
return std::make_tuple(p_id, p_dist);
|
||||
}
|
||||
|
||||
milvus::DatasetPtr
|
||||
@ -93,10 +87,8 @@ CheckRangeSearchSortResult(int64_t* p_id,
|
||||
auto id = milvus::GetDatasetIDs(dataset);
|
||||
auto dist = milvus::GetDatasetDistance(dataset);
|
||||
for (int i = 0; i < n; i++) {
|
||||
AssertInfo(id[i] == p_id[i],
|
||||
"id of range search result are not the same");
|
||||
AssertInfo(dist[i] == p_dist[i],
|
||||
"distance of range search result are not the same");
|
||||
AssertInfo(id[i] == p_id[i], "id of range search result not same");
|
||||
AssertInfo(dist[i] == p_dist[i], "distance of range search result not same");
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,10 +165,9 @@ INSTANTIATE_TEST_CASE_P(RangeSearchSortParameters,
|
||||
knowhere::metric::HAMMING));
|
||||
|
||||
TEST_P(RangeSearchSortTest, CheckRangeSearchSort) {
|
||||
auto res = milvus::SortRangeSearchResult(dataset, TOPK, N, metric_type);
|
||||
auto [real_num, p_id, p_dist] =
|
||||
RangeSearchSortResultBF(dataset, TOPK, N, metric_type);
|
||||
CheckRangeSearchSortResult(p_id, p_dist, res, real_num);
|
||||
auto res = milvus::ReGenRangeSearchResult(dataset, TOPK, N, metric_type);
|
||||
auto [p_id, p_dist] = RangeSearchSortResultBF(dataset, TOPK, N, metric_type);
|
||||
CheckRangeSearchSortResult(p_id, p_dist, res, N * TOPK);
|
||||
delete[] p_id;
|
||||
delete[] p_dist;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user