Support range search (#21652)

Signed-off-by: smellthemoon <xinguo.li@zilliz.com>
Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: jaime <yun.zhang@zilliz.com>
This commit is contained in:
smellthemoon 2023-02-21 09:48:32 +08:00 committed by GitHub
parent 0a9a9058b9
commit 9e0ec15436
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 746 additions and 103 deletions

View File

@ -18,6 +18,7 @@ set(COMMON_SRC
binary_set_c.cpp
init_c.cpp
Common.cpp
RangeSearchHelper.cpp
)
add_library(milvus_common SHARED ${COMMON_SRC})

View File

@ -43,3 +43,6 @@ const int64_t DEFAULT_THREAD_CORE_COEFFICIENT = 50;
const int64_t DEFAULT_INDEX_FILE_SLICE_SIZE = 4; // megabytes
const int DEFAULT_CPU_NUM = 1;
constexpr const char* RADIUS = "radius";
constexpr const char* RANGE_FILTER = "range_filter";

View File

@ -0,0 +1,112 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <queue>
#include <vector>
#include "common/Utils.h"
#include <functional>
#include <iostream>
namespace milvus {
namespace {
using ResultPair = std::pair<float, int64_t>;
}
DatasetPtr
SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string metric_type) {
/**
* nq: number of querys;
* lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result querys[i]
* for example, the nq is 5. In the seleted range,
* the size of RangeSearch result for each nq is [1, 2, 3, 4, 5],
* the lims will be [0, 1, 3, 6, 10, 15];
* ids: the size of ids is lim[nq],
* { i(0,0), i(0,1), , i(0,k0-1),
* i(1,0), i(1,1), , i(1,k1-1),
* ,
* i(n-1,0), i(n-1,1), , i(n-1,kn-1)},
* i(0,0), i(0,1), , i(0,k0-1) means the ids of RangeSearch result querys[0], k0 equals lim[1] - lim[0];
* dist: the size of ids is lim[nq],
* { d(0,0), d(0,1), , d(0,k0-1),
* d(1,0), d(1,1), , d(1,k1-1),
* ,
* d(n-1,0), d(n-1,1), , d(n-1,kn-1)},
* d(0,0), d(0,1), , d(0,k0-1) means the distances of RangeSearch result querys[0], k0 equals lim[1] - lim[0];
*/
auto lims = GetDatasetLims(data_set);
auto id = GetDatasetIDs(data_set);
auto dist = GetDatasetDistance(data_set);
// use p_id and p_dist to GenResultDataset after sorted
auto p_id = new int64_t[topk * nq];
auto p_dist = new float[topk * nq];
// cnt means the subscript of p_id and p_dist
int cnt = 0;
for (int i = 0; i < nq; i++) {
// if RangeSearch answer size of one nq is less than topk, set the capacity to size
int size = lims[i + 1] - lims[i];
int capacity = topk > size ? size : topk;
/*
* get result for one nq
* IP: 1.0 range_filter radius
* |------------+---------------| min_heap descending_order
* L2: 0.0 range_filter radius
* |------------+---------------| max_heap ascending_order
*
*/
std::function<bool(const ResultPair&, const ResultPair&)> cmp = std::less<std::pair<float, int64_t>>();
if (IsMetricType(metric_type, knowhere::metric::IP)) {
cmp = std::greater<std::pair<float, int64_t>>();
}
std::priority_queue<std::pair<float, int64_t>, std::vector<std::pair<float, int64_t>>, decltype(cmp)>
sub_result(cmp);
for (int j = lims[i]; j < lims[i + 1]; j++) {
auto current = ResultPair(dist[j], id[j]);
if (sub_result.size() == capacity) {
if (cmp(sub_result.top(), current)) {
current = sub_result.top();
}
sub_result.pop();
}
sub_result.push(current);
}
for (int i = capacity + cnt - 1; i > cnt - 1; i--) {
p_dist[i] = sub_result.top().first;
p_id[i] = sub_result.top().second;
sub_result.pop();
}
cnt += capacity;
}
return GenResultDataset(nq, topk, p_id, p_dist);
}
void
CheckRangeSearchParam(float radius, float range_filter, std::string metric_type) {
/*
* IP: 1.0 range_filter radius
* |------------+---------------| min_heap descending_order
* L2: 1.0 radius range_filter
* |------------+---------------| max_heap ascending_order
*
*/
if (metric_type == knowhere::metric::IP) {
if (range_filter < radius) {
PanicInfo("range_filter must more than radius when IP");
}
} else {
if (range_filter > radius) {
PanicInfo("range_filter must less than radius except IP");
}
}
}
} // namespace milvus

View File

@ -0,0 +1,24 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <string>
#include <common/Types.h>
namespace milvus {
DatasetPtr
SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string metric_type);
void
CheckRangeSearchParam(float radius, float range_filter, std::string metric_type);
} // namespace milvus

View File

@ -51,6 +51,11 @@ GetDatasetDim(const DatasetPtr& dataset) {
return dataset->GetDim();
}
inline const size_t*
GetDatasetLims(const DatasetPtr& dataset) {
return dataset->GetLims();
}
inline bool
PrefixMatch(const std::string& str, const std::string& prefix) {
auto ret = strncmp(str.c_str(), prefix.c_str(), prefix.length());
@ -61,6 +66,17 @@ PrefixMatch(const std::string& str, const std::string& prefix) {
return true;
}
inline DatasetPtr
GenResultDataset(const int64_t nq, const int64_t topk, const int64_t* ids, const float* distance) {
auto ret_ds = std::make_shared<Dataset>();
ret_ds->SetRows(nq);
ret_ds->SetDim(topk);
ret_ds->SetIds(ids);
ret_ds->SetDistance(distance);
ret_ds->SetIsOwner(true);
return ret_ds;
}
inline bool
PostfixMatch(const std::string& str, const std::string& postfix) {
if (postfix.length() > str.length()) {

View File

@ -21,7 +21,9 @@
#include "storage/LocalChunkManager.h"
#include "config/ConfigKnowhere.h"
#include "storage/Util.h"
#include "common/Consts.h"
#include "common/Utils.h"
#include "common/RangeSearchHelper.h"
namespace milvus::index {
@ -145,14 +147,29 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_
// set json reset field, will be removed later
search_config[DISK_ANN_PQ_CODE_BUDGET] = 0.0;
auto final = index_.Search(*dataset, search_config, bitset);
if (!final.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search");
}
auto final = [&] {
auto radius = GetValueFromConfig<float>(search_info.search_params_, RADIUS);
if (radius.has_value()) {
search_config[RADIUS] = radius.value();
auto range_filter = GetValueFromConfig<float>(search_info.search_params_, RANGE_FILTER);
if (range_filter.has_value()) {
search_config[RANGE_FILTER] = range_filter.value();
CheckRangeSearchParam(search_config[RADIUS], search_config[RANGE_FILTER], GetMetricType());
}
auto res = index_.RangeSearch(*dataset, search_config, bitset);
return SortRangeSearchResult(res.value(), topk, num_queries, GetMetricType());
} else {
auto res = index_.Search(*dataset, search_config, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search");
}
return res.value();
}
}();
auto ids = final.value()->GetIds();
float* distances = const_cast<float*>(final.value()->GetDistance());
final.value()->SetIsOwner(true);
auto ids = final->GetIds();
float* distances = const_cast<float*>(final->GetDistance());
final->SetIsOwner(true);
auto round_decimal = search_info.round_decimal_;
auto total_num = num_queries * topk;

View File

@ -24,6 +24,8 @@
#include "knowhere/comp/Timer.h"
#include "common/BitsetView.h"
#include "common/Slice.h"
#include "common/Consts.h"
#include "common/RangeSearchHelper.h"
namespace milvus::index {
@ -82,13 +84,24 @@ VectorMemIndex::Query(const DatasetPtr dataset, const SearchInfo& search_info, c
search_conf[knowhere::meta::TOPK] = topk;
search_conf[knowhere::meta::METRIC_TYPE] = GetMetricType();
auto index_type = GetIndexType();
return index_.Search(*dataset, search_conf, bitset);
if (CheckKeyInConfig(search_conf, RADIUS)) {
if (CheckKeyInConfig(search_conf, RANGE_FILTER)) {
CheckRangeSearchParam(search_conf[RADIUS], search_conf[RANGE_FILTER], GetMetricType());
}
auto res = index_.RangeSearch(*dataset, search_conf, bitset);
return SortRangeSearchResult(res.value(), topk, num_queries, GetMetricType());
} else {
auto res = index_.Search(*dataset, search_conf, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search");
}
return res.value();
}
}();
if (!final.has_value())
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search");
auto ids = final.value()->GetIds();
float* distances = const_cast<float*>(final.value()->GetDistance());
final.value()->SetIsOwner(true);
auto ids = final->GetIds();
float* distances = const_cast<float*>(final->GetDistance());
final->SetIsOwner(true);
auto round_decimal = search_info.round_decimal_;
auto total_num = num_queries * topk;

View File

@ -12,6 +12,8 @@
#include <string>
#include <vector>
#include "common/Consts.h"
#include "common/RangeSearchHelper.h"
#include "SearchBruteForce.h"
#include "SubSearchResult.h"
#include "knowhere/comp/brute_force.h"
@ -34,6 +36,7 @@ SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const knowhere::Json& conf,
const BitsetView& bitset) {
SubSearchResult sub_result(dataset.num_queries, dataset.topk, dataset.metric_type, dataset.round_decimal);
try {
@ -48,15 +51,29 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, topk},
};
sub_result.mutable_seg_offsets().resize(nq * topk);
sub_result.mutable_distances().resize(nq * topk);
auto stat =
knowhere::BruteForce::SearchWithBuf(base_dataset, query_dataset, sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(), config, bitset);
if (conf.contains(RADIUS)) {
config[RADIUS] = conf[RADIUS];
if (conf.contains(RANGE_FILTER)) {
config[RANGE_FILTER] = conf[RANGE_FILTER];
CheckRangeSearchParam(config[RADIUS], config[RANGE_FILTER], dataset.metric_type);
}
auto result = SortRangeSearchResult(
knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, config, bitset).value(), topk, nq,
dataset.metric_type);
std::copy_n(GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
std::copy_n(GetDatasetDistance(result), nq * topk, sub_result.get_distances());
} else {
auto stat = knowhere::BruteForce::SearchWithBuf(base_dataset, query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(), config, bitset);
if (stat != knowhere::Status::success) {
throw std::invalid_argument("invalid metric type");
if (stat != knowhere::Status::success) {
throw std::invalid_argument("invalid metric type");
}
}
} catch (std::exception& e) {
PanicInfo(e.what());

View File

@ -26,6 +26,7 @@ SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const knowhere::Json& conf,
const BitsetView& bitset);
} // namespace milvus::query

View File

@ -125,7 +125,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
auto size_per_chunk = element_end - element_begin;
auto sub_view = bitset.subview(element_begin, size_per_chunk);
auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, sub_view);
auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, info.search_params_, sub_view);
// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_seg_offsets()) {

View File

@ -86,9 +86,8 @@ SearchOnSealed(const Schema& schema,
auto vec_data = record.get_field_data_base(field_id);
AssertInfo(vec_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
auto chunk_data = vec_data->get_chunk_data(0);
CheckBruteForceSearchParam(field, search_info);
auto sub_qr = BruteForceSearch(dataset, chunk_data, row_count, bitset);
auto sub_qr = BruteForceSearch(dataset, chunk_data, row_count, search_info.search_params_, bitset);
result.distances_ = std::move(sub_qr.mutable_distances());
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());

View File

@ -46,6 +46,7 @@ set(MILVUS_TEST_FILES
test_timestamp_index.cpp
test_utils.cpp
test_data_codec.cpp
test_range_search_sort.cpp
)
if ( BUILD_DISK_ANN STREQUAL "ON" )

View File

@ -120,7 +120,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
// ASSERT_ANY_THROW(BruteForceSearch(dataset, base.data(), nb, bitset_view));
return;
}
auto result = BruteForceSearch(dataset, base.data(), nb, bitset_view);
auto result = BruteForceSearch(dataset, base.data(), nb, knowhere::Json(), bitset_view);
for (int i = 0; i < nq; i++) {
auto ref = Ref(base.data(), query.data() + i * dim, nb, dim, topk, metric_type);
auto ans = result.get_seg_offsets() + i * topk;

View File

@ -3522,3 +3522,245 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) {
DeleteSegment(segment);
}
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_WHEN_IP) {
auto c_collection = NewCollection(get_default_schema_config());
auto segment = NewSegment(c_collection, Growing, -1);
auto col = (milvus::segcore::Collection*)c_collection;
int N = 10000;
auto dataset = DataGen(col->get_schema(), N);
int64_t ts_offset = 1000;
int64_t offset;
PreInsert(segment, N, &offset);
auto insert_data = serialize(dataset.raw_);
auto ins_res = Insert(segment, offset, N, dataset.row_ids_.data(), dataset.timestamps_.data(), insert_data.data(),
insert_data.size());
ASSERT_EQ(ins_res.error_code, Success);
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "IP",
"params": {
"nprobe": 10,
"radius": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
})";
int num_queries = 10;
auto blob = generate_query_data(num_queries);
void* plan = nullptr;
auto status = CreateSearchPlan(c_collection, dsl_string, &plan);
ASSERT_EQ(status.error_code, Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
ASSERT_EQ(status.error_code, Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
CSearchResult search_result;
auto res = Search(segment, plan, placeholderGroup, ts_offset, &search_result);
ASSERT_EQ(res.error_code, Success);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteSearchResult(search_result);
DeleteCollection(c_collection);
DeleteSegment(segment);
}
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP) {
auto c_collection = NewCollection(get_default_schema_config());
auto segment = NewSegment(c_collection, Growing, -1);
auto col = (milvus::segcore::Collection*)c_collection;
int N = 10000;
auto dataset = DataGen(col->get_schema(), N);
int64_t ts_offset = 1000;
int64_t offset;
PreInsert(segment, N, &offset);
auto insert_data = serialize(dataset.raw_);
auto ins_res = Insert(segment, offset, N, dataset.row_ids_.data(), dataset.timestamps_.data(), insert_data.data(),
insert_data.size());
ASSERT_EQ(ins_res.error_code, Success);
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "IP",
"params": {
"nprobe": 10,
"radius": 10,
"range_filter": 20
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
})";
int num_queries = 10;
auto blob = generate_query_data(num_queries);
void* plan = nullptr;
auto status = CreateSearchPlan(c_collection, dsl_string, &plan);
ASSERT_EQ(status.error_code, Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
ASSERT_EQ(status.error_code, Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
CSearchResult search_result;
auto res = Search(segment, plan, placeholderGroup, ts_offset, &search_result);
ASSERT_EQ(res.error_code, Success);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteSearchResult(search_result);
DeleteCollection(c_collection);
DeleteSegment(segment);
}
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_WHEN_L2) {
auto c_collection = NewCollection(get_default_schema_config());
auto segment = NewSegment(c_collection, Growing, -1);
auto col = (milvus::segcore::Collection*)c_collection;
int N = 10000;
auto dataset = DataGen(col->get_schema(), N);
int64_t ts_offset = 1000;
int64_t offset;
PreInsert(segment, N, &offset);
auto insert_data = serialize(dataset.raw_);
auto ins_res = Insert(segment, offset, N, dataset.row_ids_.data(), dataset.timestamps_.data(), insert_data.data(),
insert_data.size());
ASSERT_EQ(ins_res.error_code, Success);
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10,
"radius": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
})";
int num_queries = 10;
auto blob = generate_query_data(num_queries);
void* plan = nullptr;
auto status = CreateSearchPlan(c_collection, dsl_string, &plan);
ASSERT_EQ(status.error_code, Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
ASSERT_EQ(status.error_code, Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
CSearchResult search_result;
auto res = Search(segment, plan, placeholderGroup, ts_offset, &search_result);
ASSERT_EQ(res.error_code, Success);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteSearchResult(search_result);
DeleteCollection(c_collection);
DeleteSegment(segment);
}
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_L2) {
auto c_collection = NewCollection(get_default_schema_config());
auto segment = NewSegment(c_collection, Growing, -1);
auto col = (milvus::segcore::Collection*)c_collection;
int N = 10000;
auto dataset = DataGen(col->get_schema(), N);
int64_t ts_offset = 1000;
int64_t offset;
PreInsert(segment, N, &offset);
auto insert_data = serialize(dataset.raw_);
auto ins_res = Insert(segment, offset, N, dataset.row_ids_.data(), dataset.timestamps_.data(), insert_data.data(),
insert_data.size());
ASSERT_EQ(ins_res.error_code, Success);
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10,
"radius": 20,
"range_filter": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
})";
int num_queries = 10;
auto blob = generate_query_data(num_queries);
void* plan = nullptr;
auto status = CreateSearchPlan(c_collection, dsl_string, &plan);
ASSERT_EQ(status.error_code, Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
ASSERT_EQ(status.error_code, Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
CSearchResult search_result;
auto res = Search(segment, plan, placeholderGroup, ts_offset, &search_result);
ASSERT_EQ(res.error_code, Success);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteSearchResult(search_result);
DeleteCollection(c_collection);
DeleteSegment(segment);
}

View File

@ -145,6 +145,7 @@ TEST(Indexing, BinaryBruteForce) {
int64_t topk = 5;
int64_t round_decimal = 3;
int64_t dim = 8192;
Config search_params_ = {};
auto metric_type = knowhere::metric::JACCARD;
auto result_count = topk * num_queries;
auto schema = std::make_shared<Schema>();
@ -162,7 +163,7 @@ TEST(Indexing, BinaryBruteForce) {
query_data //
};
auto sub_result = query::BruteForceSearch(search_dataset, bin_vec.data(), N, nullptr);
auto sub_result = query::BruteForceSearch(search_dataset, bin_vec.data(), N, knowhere::Json(), nullptr);
SearchResult sr;
sr.total_nq_ = num_queries;
@ -293,6 +294,7 @@ class IndexTest : public ::testing::TestWithParam<Param> {
build_conf = generate_build_conf(index_type, metric_type);
load_conf = generate_load_conf(index_type, metric_type, NB);
search_conf = generate_search_conf(index_type, metric_type);
range_search_conf = generate_range_search_conf(index_type, metric_type);
std::map<knowhere::MetricType, bool> is_binary_map = {
{knowhere::IndexEnum::INDEX_FAISS_IDMAP, false},
@ -335,6 +337,7 @@ class IndexTest : public ::testing::TestWithParam<Param> {
milvus::Config build_conf;
milvus::Config load_conf;
milvus::Config search_conf;
milvus::Config range_search_conf;
milvus::DataType vec_field_data_type;
knowhere::DataSetPtr xb_dataset;
std::vector<float> xb_data;
@ -357,9 +360,9 @@ INSTANTIATE_TEST_CASE_P(
std::pair(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, knowhere::metric::JACCARD),
std::pair(knowhere::IndexEnum::INDEX_HNSW, knowhere::metric::L2),
// ci ut not start minio, so not run ut about diskann index for now
//#ifdef BUILD_DISK_ANN
// #ifdef BUILD_DISK_ANN
// std::pair(knowhere::IndexEnum::INDEX_DISKANN, knowhere::metric::L2),
//#endif
// #endif
std::pair(knowhere::IndexEnum::INDEX_ANNOY, knowhere::metric::L2)));
TEST_P(IndexTest, BuildAndQuery) {
@ -422,76 +425,79 @@ TEST_P(IndexTest, BuildAndQuery) {
if (!is_binary) {
EXPECT_EQ(result->seg_offsets_[0], query_offset);
}
if (index_type != knowhere::IndexEnum::INDEX_ANNOY) {
search_info.search_params_ = range_search_conf;
vec_index->Query(xq_dataset, search_info, nullptr);
}
}
//#ifdef BUILD_DISK_ANN
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
// int64_t NB = 10000;
// IndexType index_type = knowhere::IndexEnum::INDEX_DISKANN;
// MetricType metric_type = knowhere::metric::L2;
// milvus::index::CreateIndexInfo create_index_info;
// create_index_info.index_type = index_type;
// create_index_info.metric_type = metric_type;
// create_index_info.field_type = milvus::DataType::VECTOR_FLOAT;
// #ifdef BUILD_DISK_ANN
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
// int64_t NB = 10000;
// IndexType index_type = knowhere::IndexEnum::INDEX_DISKANN;
// MetricType metric_type = knowhere::metric::L2;
// milvus::index::CreateIndexInfo create_index_info;
// create_index_info.index_type = index_type;
// create_index_info.metric_type = metric_type;
// create_index_info.field_type = milvus::DataType::VECTOR_FLOAT;
//
// StorageConfig storage_config = get_default_storage_config();
// auto rcm = std::make_shared<storage::MinioChunkManager>(storage_config);
// if (!rcm->BucketExists(storage_config.bucket_name)) {
// rcm->CreateBucket(storage_config.bucket_name);
// }
// milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
// milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
// auto file_manager =
// std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config);
// auto index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, file_manager);
// StorageConfig storage_config = get_default_storage_config();
// auto rcm = std::make_shared<storage::MinioChunkManager>(storage_config);
// if (!rcm->BucketExists(storage_config.bucket_name)) {
// rcm->CreateBucket(storage_config.bucket_name);
// }
// milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
// milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
// auto file_manager =
// std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config);
// auto index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, file_manager);
//
// auto build_conf = knowhere::Config{
// {knowhere::meta::METRIC_TYPE, metric_type},
// {knowhere::meta::DIM, std::to_string(DIM)},
// {milvus::index::DISK_ANN_MAX_DEGREE, std::to_string(48)},
// {milvus::index::DISK_ANN_SEARCH_LIST_SIZE, std::to_string(128)},
// {milvus::index::DISK_ANN_PQ_CODE_BUDGET, std::to_string(0.001)},
// {milvus::index::DISK_ANN_BUILD_DRAM_BUDGET, std::to_string(2)},
// };
// auto build_conf = knowhere::Config{
// {knowhere::meta::METRIC_TYPE, metric_type},
// {knowhere::meta::DIM, std::to_string(DIM)},
// {milvus::index::DISK_ANN_MAX_DEGREE, std::to_string(48)},
// {milvus::index::DISK_ANN_SEARCH_LIST_SIZE, std::to_string(128)},
// {milvus::index::DISK_ANN_PQ_CODE_BUDGET, std::to_string(0.001)},
// {milvus::index::DISK_ANN_BUILD_DRAM_BUDGET, std::to_string(2)},
// };
//
// // build disk ann index
// auto dataset = GenDataset(NB, metric_type, false);
// std::vector<float> xb_data = dataset.get_col<float>(milvus::FieldId(100));
// knowhere::DatasetPtr xb_dataset = knowhere::GenDataset(NB, DIM, xb_data.data());
// ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
// // build disk ann index
// auto dataset = GenDataset(NB, metric_type, false);
// std::vector<float> xb_data = dataset.get_col<float>(milvus::FieldId(100));
// knowhere::DatasetPtr xb_dataset = knowhere::GenDataset(NB, DIM, xb_data.data());
// ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
//
// // serialize and load disk index, disk index can only be search after loading for now
// auto binary_set = index->Serialize(milvus::Config{});
// index.reset();
// // clean local file dir
// file_manager.reset();
// // serialize and load disk index, disk index can only be search after loading for now
// auto binary_set = index->Serialize(milvus::Config{});
// index.reset();
// // clean local file dir
// file_manager.reset();
//
// auto new_file_manager =
// std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config);
// auto new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, new_file_manager);
// auto vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
// std::vector<std::string> index_files;
// for (auto& binary : binary_set.binary_map_) {
// index_files.emplace_back(binary.first);
// }
// auto load_conf = generate_load_conf(index_type, metric_type, NB);
// load_conf["index_files"] = index_files;
// vec_index->Load(binary_set, load_conf);
// EXPECT_EQ(vec_index->Count(), NB);
// auto new_file_manager =
// std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config);
// auto new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, new_file_manager);
// auto vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
// std::vector<std::string> index_files;
// for (auto& binary : binary_set.binary_map_) {
// index_files.emplace_back(binary.first);
// }
// auto load_conf = generate_load_conf(index_type, metric_type, NB);
// load_conf["index_files"] = index_files;
// vec_index->Load(binary_set, load_conf);
// EXPECT_EQ(vec_index->Count(), NB);
//
// // search disk index with search_list == limit
// int query_offset = 100;
// knowhere::DatasetPtr xq_dataset = knowhere::GenDataset(NQ, DIM, xb_data.data() + DIM * query_offset);
// // search disk index with search_list == limit
// int query_offset = 100;
// knowhere::DatasetPtr xq_dataset = knowhere::GenDataset(NQ, DIM, xb_data.data() + DIM * query_offset);
//
// milvus::SearchInfo search_info;
// search_info.topk_ = K;
// search_info.metric_type_ = metric_type;
// search_info.search_params_ = milvus::Config{
// {knowhere::meta::METRIC_TYPE, metric_type},
// {milvus::index::DISK_ANN_QUERY_LIST, K - 1},
// };
// EXPECT_THROW(vec_index->Query(xq_dataset, search_info, nullptr), std::runtime_error);
// // vec_index->Query(xq_dataset, search_info, nullptr);
//}
//#endif
// milvus::SearchInfo search_info;
// search_info.topk_ = K;
// search_info.metric_type_ = metric_type;
// search_info.search_params_ = milvus::Config{
// {knowhere::meta::METRIC_TYPE, metric_type},
// {milvus::index::DISK_ANN_QUERY_LIST, K - 1},
// };
// EXPECT_THROW(vec_index->Query(xq_dataset, search_info, nullptr), std::runtime_error);
// // vec_index->Query(xq_dataset, search_info, nullptr);
// }
// #endif

View File

@ -0,0 +1,165 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <queue>
#include <random>
#include <vector>
#include <iostream>
#include "common/RangeSearchHelper.h"
#include "common/Types.h"
#include "common/Utils.h"
#include "common/Schema.h"
#include "test_utils/indexbuilder_test_utils.h"
bool
cmp1(std::pair<float, int64_t> a, std::pair<float, int64_t> b) {
return a.first > b.first;
}
bool
cmp2(std::pair<float, int64_t> a, std::pair<float, int64_t> b) {
return a.first < b.first;
}
auto
RangeSearchSortResultBF(milvus::DatasetPtr data_set, int64_t topk, size_t nq, std::string metric_type) {
auto lims = milvus::GetDatasetLims(data_set);
auto id = milvus::GetDatasetIDs(data_set);
auto dist = milvus::GetDatasetDistance(data_set);
auto p_id = new int64_t[topk * nq];
auto p_dist = new float[topk * nq];
// cnt means the subscript of p_id and p_dist
int cnt = 0;
for (int i = 0; i < nq; i++) {
auto size = lims[i + 1] - lims[i];
int capacity = topk > size ? size : topk;
// sort each layer
std::vector<std::pair<float, int64_t>> list;
if (milvus::IsMetricType(metric_type, knowhere::metric::IP)) {
for (int j = lims[i]; j < lims[i + 1]; j++) {
list.push_back(std::pair<float, int64_t>(dist[j], id[j]));
}
std::sort(list.begin(), list.end(), cmp1);
} else {
for (int j = lims[i]; j < lims[i + 1]; j++) {
list.push_back(std::pair<float, int64_t>(dist[j], id[j]));
}
std::sort(list.begin(), list.end(), cmp2);
}
for (int k = 0; k < capacity; k++) {
p_dist[cnt] = list[k].first;
p_id[cnt] = list[k].second;
cnt++;
}
}
return std::make_tuple(cnt, p_id, p_dist);
}
milvus::DatasetPtr
genResultDataset(const int64_t nq, const int64_t* ids, const float* distance, const size_t* lims) {
auto ret_ds = std::make_shared<milvus::Dataset>();
ret_ds->SetRows(nq);
ret_ds->SetIds(ids);
ret_ds->SetDistance(distance);
ret_ds->SetLims(lims);
ret_ds->SetIsOwner(true);
return ret_ds;
}
void
CheckRangeSearchSortResult(int64_t* p_id, float* p_dist, milvus::DatasetPtr dataset, int64_t n) {
auto id = milvus::GetDatasetIDs(dataset);
auto dist = milvus::GetDatasetDistance(dataset);
for (int i = 0; i < n; i++) {
AssertInfo(id[i] == p_id[i], "id of range search result are not the same");
AssertInfo(dist[i] == p_dist[i], "distance of range search result are not the same");
}
}
auto
GenRangeSearchResult(int64_t* ids,
float* distances,
size_t* lims,
int64_t N,
int64_t id_min,
int64_t id_max,
float distance_min,
float distance_max,
int seed = 42) {
std::mt19937 e(seed);
std::uniform_int_distribution<> uniform_num(0, N);
std::uniform_int_distribution<> uniform_ids(id_min, id_max);
std::uniform_real_distribution<> uniform_distance(distance_min, distance_max);
lims = new size_t[N + 1];
// alloc max memory
distances = new float[N * N];
ids = new int64_t[N * N];
lims[0] = 0;
for (int64_t i = 0; i < N; i++) {
int64_t num = uniform_num(e);
for (int64_t j = 0; j < num; j++) {
auto id = uniform_ids(e);
auto dis = uniform_distance(e);
ids[lims[i] + j] = id;
distances[lims[i] + j] = dis;
}
lims[i + 1] = lims[i] + num;
}
return genResultDataset(N, ids, distances, lims);
}
class RangeSearchSortTest : public ::testing::TestWithParam<knowhere::MetricType> {
protected:
void
SetUp() override {
metric_type = GetParam();
dataset = GenRangeSearchResult(ids, distances, lims, N, id_min, id_max, dist_min, dist_max);
}
void
TearDown() override {
delete[] ids;
delete[] distances;
delete[] lims;
}
protected:
knowhere::MetricType metric_type;
milvus::DatasetPtr dataset = nullptr;
int64_t N = 100;
int64_t TOPK = 10;
int64_t DIM = 16;
int64_t* ids = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
int64_t id_min = 0, id_max = 10000;
float dist_min = 0.0, dist_max = 100.0;
};
INSTANTIATE_TEST_CASE_P(RangeSearchSortParameters,
RangeSearchSortTest,
::testing::Values(knowhere::metric::L2,
knowhere::metric::IP,
knowhere::metric::JACCARD,
knowhere::metric::TANIMOTO,
knowhere::metric::HAMMING));
TEST_P(RangeSearchSortTest, CheckRangeSearchSort) {
auto res = milvus::SortRangeSearchResult(dataset, TOPK, N, metric_type);
auto [real_num, p_id, p_dist] = RangeSearchSortResultBF(dataset, TOPK, N, metric_type);
CheckRangeSearchSortResult(p_id, p_dist, res, real_num);
delete[] p_id;
delete[] p_dist;
}

View File

@ -84,8 +84,8 @@ TEST(Sealed, without_predicate) {
auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, nullptr);
auto build_conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, std::to_string(dim)},
{knowhere::indexparam::NLIST, "100"}};
{knowhere::meta::DIM, std::to_string(dim)},
{knowhere::indexparam::NLIST, "100"}};
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
@ -190,8 +190,8 @@ TEST(Sealed, with_predicate) {
auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, nullptr);
auto build_conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, std::to_string(dim)},
{knowhere::indexparam::NLIST, "100"}};
{knowhere::meta::DIM, std::to_string(dim)},
{knowhere::indexparam::NLIST, "100"}};
auto database = knowhere::GenDataSet(N, dim, vec_col.data());
indexing->BuildWithDataset(database, build_conf);
@ -288,8 +288,8 @@ TEST(Sealed, with_predicate_filter_all) {
auto ivf_indexing = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, nullptr);
auto ivf_build_conf = knowhere::Json{{knowhere::meta::DIM, std::to_string(dim)},
{knowhere::indexparam::NLIST, "100"},
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}};
{knowhere::indexparam::NLIST, "100"},
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}};
auto database = knowhere::GenDataSet(N, dim, vec_col.data());
ivf_indexing->BuildWithDataset(database, ivf_build_conf);
@ -312,10 +312,10 @@ TEST(Sealed, with_predicate_filter_all) {
EXPECT_EQ(sr->get_total_result_count(), 0);
auto hnsw_conf = knowhere::Json{{knowhere::meta::DIM, std::to_string(dim)},
{knowhere::indexparam::HNSW_M, "16"},
{knowhere::indexparam::EFCONSTRUCTION, "200"},
{knowhere::indexparam::EF, "200"},
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}};
{knowhere::indexparam::HNSW_M, "16"},
{knowhere::indexparam::EFCONSTRUCTION, "200"},
{knowhere::indexparam::EF, "200"},
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}};
create_index_info.field_type = DataType::VECTOR_FLOAT;
create_index_info.metric_type = knowhere::metric::L2;
@ -437,7 +437,7 @@ TEST(Sealed, LoadFieldData) {
// ASSERT_EQ(json.dump(-2), json2.dump(-2));
// segment->DropFieldData(double_id);
// ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
//#ifdef __linux__
// #ifdef __linux__
// auto std_json = Json::parse(R"(
//[
// [
@ -448,7 +448,7 @@ TEST(Sealed, LoadFieldData) {
// ["66353->5.696000", "30664->5.881000", "41087->5.917000", "10393->6.633000", "90215->7.202000"]
// ]
//])");
//#else // for mac
// #else // for mac
// auto std_json = Json::parse(R"(
//[
// [
@ -459,7 +459,7 @@ TEST(Sealed, LoadFieldData) {
// ["37759->3.581000", "31292->5.780000", "98124->6.216000", "63535->6.439000", "11707->6.553000"]
// ]
//])");
//#endif
// #endif
// ASSERT_EQ(std_json.dump(-2), json.dump(-2));
}

View File

@ -538,7 +538,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
dim, //
query_ptr //
};
auto sub_result = BruteForceSearch(search_dataset, vec_col.data(), N, nullptr);
auto sub_result = BruteForceSearch(search_dataset, vec_col.data(), N, knowhere::Json(), nullptr);
auto sr = segment->Search(plan.get(), ph_group.get(), time);
segment->FillPrimaryKeys(plan.get(), *sr);

View File

@ -292,6 +292,32 @@ generate_search_conf(const milvus::IndexType& index_type, const milvus::MetricTy
return conf;
}
auto
generate_range_search_conf(const milvus::IndexType& index_type, const milvus::MetricType& metric_type) {
auto conf = milvus::Config{
{knowhere::meta::METRIC_TYPE, metric_type},
};
if (metric_type == knowhere::metric::IP) {
conf[knowhere::meta::RADIUS] = 0.1;
conf[knowhere::meta::RANGE_FILTER] = 0.2;
} else {
conf[knowhere::meta::RADIUS] = 0.2;
conf[knowhere::meta::RANGE_FILTER] = 0.1;
}
if (milvus::index::is_in_list<milvus::IndexType>(index_type, search_with_nprobe_list)) {
conf[knowhere::indexparam::NPROBE] = 4;
} else if (index_type == knowhere::IndexEnum::INDEX_HNSW) {
conf[knowhere::indexparam::EF] = 200;
} else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) {
conf[knowhere::indexparam::SEARCH_K] = 100;
} else if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
conf[milvus::index::DISK_ANN_QUERY_LIST] = K * 2;
}
return conf;
}
auto
generate_params(const knowhere::IndexType& index_type, const knowhere::MetricType& metric_type) {
namespace indexcgo = milvus::proto::indexcgo;