mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
Support range search (#21652)
Signed-off-by: smellthemoon <xinguo.li@zilliz.com> Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: jaime <yun.zhang@zilliz.com>
This commit is contained in:
parent
0a9a9058b9
commit
9e0ec15436
@ -18,6 +18,7 @@ set(COMMON_SRC
|
||||
binary_set_c.cpp
|
||||
init_c.cpp
|
||||
Common.cpp
|
||||
RangeSearchHelper.cpp
|
||||
)
|
||||
|
||||
add_library(milvus_common SHARED ${COMMON_SRC})
|
||||
|
@ -43,3 +43,6 @@ const int64_t DEFAULT_THREAD_CORE_COEFFICIENT = 50;
|
||||
const int64_t DEFAULT_INDEX_FILE_SLICE_SIZE = 4; // megabytes
|
||||
|
||||
const int DEFAULT_CPU_NUM = 1;
|
||||
|
||||
constexpr const char* RADIUS = "radius";
|
||||
constexpr const char* RANGE_FILTER = "range_filter";
|
||||
|
112
internal/core/src/common/RangeSearchHelper.cpp
Normal file
112
internal/core/src/common/RangeSearchHelper.cpp
Normal file
@ -0,0 +1,112 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include "common/Utils.h"
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
namespace milvus {
|
||||
namespace {
|
||||
using ResultPair = std::pair<float, int64_t>;
|
||||
}
|
||||
DatasetPtr
|
||||
SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string metric_type) {
|
||||
/**
|
||||
* nq: number of querys;
|
||||
* lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result querys[i]
|
||||
* for example, the nq is 5. In the seleted range,
|
||||
* the size of RangeSearch result for each nq is [1, 2, 3, 4, 5],
|
||||
* the lims will be [0, 1, 3, 6, 10, 15];
|
||||
* ids: the size of ids is lim[nq],
|
||||
* { i(0,0), i(0,1), …, i(0,k0-1),
|
||||
* i(1,0), i(1,1), …, i(1,k1-1),
|
||||
* …,
|
||||
* i(n-1,0), i(n-1,1), …, i(n-1,kn-1)},
|
||||
* i(0,0), i(0,1), …, i(0,k0-1) means the ids of RangeSearch result querys[0], k0 equals lim[1] - lim[0];
|
||||
* dist: the size of ids is lim[nq],
|
||||
* { d(0,0), d(0,1), …, d(0,k0-1),
|
||||
* d(1,0), d(1,1), …, d(1,k1-1),
|
||||
* …,
|
||||
* d(n-1,0), d(n-1,1), …, d(n-1,kn-1)},
|
||||
* d(0,0), d(0,1), …, d(0,k0-1) means the distances of RangeSearch result querys[0], k0 equals lim[1] - lim[0];
|
||||
*/
|
||||
auto lims = GetDatasetLims(data_set);
|
||||
auto id = GetDatasetIDs(data_set);
|
||||
auto dist = GetDatasetDistance(data_set);
|
||||
|
||||
// use p_id and p_dist to GenResultDataset after sorted
|
||||
auto p_id = new int64_t[topk * nq];
|
||||
auto p_dist = new float[topk * nq];
|
||||
// cnt means the subscript of p_id and p_dist
|
||||
int cnt = 0;
|
||||
|
||||
for (int i = 0; i < nq; i++) {
|
||||
// if RangeSearch answer size of one nq is less than topk, set the capacity to size
|
||||
int size = lims[i + 1] - lims[i];
|
||||
int capacity = topk > size ? size : topk;
|
||||
/*
|
||||
* get result for one nq
|
||||
* IP: 1.0 range_filter radius
|
||||
* |------------+---------------| min_heap descending_order
|
||||
* L2: 0.0 range_filter radius
|
||||
* |------------+---------------| max_heap ascending_order
|
||||
*
|
||||
*/
|
||||
std::function<bool(const ResultPair&, const ResultPair&)> cmp = std::less<std::pair<float, int64_t>>();
|
||||
if (IsMetricType(metric_type, knowhere::metric::IP)) {
|
||||
cmp = std::greater<std::pair<float, int64_t>>();
|
||||
}
|
||||
std::priority_queue<std::pair<float, int64_t>, std::vector<std::pair<float, int64_t>>, decltype(cmp)>
|
||||
sub_result(cmp);
|
||||
|
||||
for (int j = lims[i]; j < lims[i + 1]; j++) {
|
||||
auto current = ResultPair(dist[j], id[j]);
|
||||
if (sub_result.size() == capacity) {
|
||||
if (cmp(sub_result.top(), current)) {
|
||||
current = sub_result.top();
|
||||
}
|
||||
sub_result.pop();
|
||||
}
|
||||
sub_result.push(current);
|
||||
}
|
||||
|
||||
for (int i = capacity + cnt - 1; i > cnt - 1; i--) {
|
||||
p_dist[i] = sub_result.top().first;
|
||||
p_id[i] = sub_result.top().second;
|
||||
sub_result.pop();
|
||||
}
|
||||
cnt += capacity;
|
||||
}
|
||||
return GenResultDataset(nq, topk, p_id, p_dist);
|
||||
}
|
||||
|
||||
void
|
||||
CheckRangeSearchParam(float radius, float range_filter, std::string metric_type) {
|
||||
/*
|
||||
* IP: 1.0 range_filter radius
|
||||
* |------------+---------------| min_heap descending_order
|
||||
* L2: 1.0 radius range_filter
|
||||
* |------------+---------------| max_heap ascending_order
|
||||
*
|
||||
*/
|
||||
if (metric_type == knowhere::metric::IP) {
|
||||
if (range_filter < radius) {
|
||||
PanicInfo("range_filter must more than radius when IP");
|
||||
}
|
||||
} else {
|
||||
if (range_filter > radius) {
|
||||
PanicInfo("range_filter must less than radius except IP");
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace milvus
|
24
internal/core/src/common/RangeSearchHelper.h
Normal file
24
internal/core/src/common/RangeSearchHelper.h
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <common/Types.h>
|
||||
|
||||
namespace milvus {
|
||||
|
||||
DatasetPtr
|
||||
SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string metric_type);
|
||||
|
||||
void
|
||||
CheckRangeSearchParam(float radius, float range_filter, std::string metric_type);
|
||||
} // namespace milvus
|
@ -51,6 +51,11 @@ GetDatasetDim(const DatasetPtr& dataset) {
|
||||
return dataset->GetDim();
|
||||
}
|
||||
|
||||
inline const size_t*
|
||||
GetDatasetLims(const DatasetPtr& dataset) {
|
||||
return dataset->GetLims();
|
||||
}
|
||||
|
||||
inline bool
|
||||
PrefixMatch(const std::string& str, const std::string& prefix) {
|
||||
auto ret = strncmp(str.c_str(), prefix.c_str(), prefix.length());
|
||||
@ -61,6 +66,17 @@ PrefixMatch(const std::string& str, const std::string& prefix) {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline DatasetPtr
|
||||
GenResultDataset(const int64_t nq, const int64_t topk, const int64_t* ids, const float* distance) {
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->SetRows(nq);
|
||||
ret_ds->SetDim(topk);
|
||||
ret_ds->SetIds(ids);
|
||||
ret_ds->SetDistance(distance);
|
||||
ret_ds->SetIsOwner(true);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
inline bool
|
||||
PostfixMatch(const std::string& str, const std::string& postfix) {
|
||||
if (postfix.length() > str.length()) {
|
||||
|
@ -21,7 +21,9 @@
|
||||
#include "storage/LocalChunkManager.h"
|
||||
#include "config/ConfigKnowhere.h"
|
||||
#include "storage/Util.h"
|
||||
#include "common/Consts.h"
|
||||
#include "common/Utils.h"
|
||||
#include "common/RangeSearchHelper.h"
|
||||
|
||||
namespace milvus::index {
|
||||
|
||||
@ -145,14 +147,29 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_
|
||||
// set json reset field, will be removed later
|
||||
search_config[DISK_ANN_PQ_CODE_BUDGET] = 0.0;
|
||||
|
||||
auto final = index_.Search(*dataset, search_config, bitset);
|
||||
if (!final.has_value()) {
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search");
|
||||
}
|
||||
auto final = [&] {
|
||||
auto radius = GetValueFromConfig<float>(search_info.search_params_, RADIUS);
|
||||
if (radius.has_value()) {
|
||||
search_config[RADIUS] = radius.value();
|
||||
auto range_filter = GetValueFromConfig<float>(search_info.search_params_, RANGE_FILTER);
|
||||
if (range_filter.has_value()) {
|
||||
search_config[RANGE_FILTER] = range_filter.value();
|
||||
CheckRangeSearchParam(search_config[RADIUS], search_config[RANGE_FILTER], GetMetricType());
|
||||
}
|
||||
auto res = index_.RangeSearch(*dataset, search_config, bitset);
|
||||
return SortRangeSearchResult(res.value(), topk, num_queries, GetMetricType());
|
||||
} else {
|
||||
auto res = index_.Search(*dataset, search_config, bitset);
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search");
|
||||
}
|
||||
return res.value();
|
||||
}
|
||||
}();
|
||||
|
||||
auto ids = final.value()->GetIds();
|
||||
float* distances = const_cast<float*>(final.value()->GetDistance());
|
||||
final.value()->SetIsOwner(true);
|
||||
auto ids = final->GetIds();
|
||||
float* distances = const_cast<float*>(final->GetDistance());
|
||||
final->SetIsOwner(true);
|
||||
|
||||
auto round_decimal = search_info.round_decimal_;
|
||||
auto total_num = num_queries * topk;
|
||||
|
@ -24,6 +24,8 @@
|
||||
#include "knowhere/comp/Timer.h"
|
||||
#include "common/BitsetView.h"
|
||||
#include "common/Slice.h"
|
||||
#include "common/Consts.h"
|
||||
#include "common/RangeSearchHelper.h"
|
||||
|
||||
namespace milvus::index {
|
||||
|
||||
@ -82,13 +84,24 @@ VectorMemIndex::Query(const DatasetPtr dataset, const SearchInfo& search_info, c
|
||||
search_conf[knowhere::meta::TOPK] = topk;
|
||||
search_conf[knowhere::meta::METRIC_TYPE] = GetMetricType();
|
||||
auto index_type = GetIndexType();
|
||||
return index_.Search(*dataset, search_conf, bitset);
|
||||
if (CheckKeyInConfig(search_conf, RADIUS)) {
|
||||
if (CheckKeyInConfig(search_conf, RANGE_FILTER)) {
|
||||
CheckRangeSearchParam(search_conf[RADIUS], search_conf[RANGE_FILTER], GetMetricType());
|
||||
}
|
||||
auto res = index_.RangeSearch(*dataset, search_conf, bitset);
|
||||
return SortRangeSearchResult(res.value(), topk, num_queries, GetMetricType());
|
||||
} else {
|
||||
auto res = index_.Search(*dataset, search_conf, bitset);
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search");
|
||||
}
|
||||
return res.value();
|
||||
}
|
||||
}();
|
||||
if (!final.has_value())
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search");
|
||||
auto ids = final.value()->GetIds();
|
||||
float* distances = const_cast<float*>(final.value()->GetDistance());
|
||||
final.value()->SetIsOwner(true);
|
||||
|
||||
auto ids = final->GetIds();
|
||||
float* distances = const_cast<float*>(final->GetDistance());
|
||||
final->SetIsOwner(true);
|
||||
auto round_decimal = search_info.round_decimal_;
|
||||
auto total_num = num_queries * topk;
|
||||
|
||||
|
@ -12,6 +12,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/Consts.h"
|
||||
#include "common/RangeSearchHelper.h"
|
||||
#include "SearchBruteForce.h"
|
||||
#include "SubSearchResult.h"
|
||||
#include "knowhere/comp/brute_force.h"
|
||||
@ -34,6 +36,7 @@ SubSearchResult
|
||||
BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t chunk_rows,
|
||||
const knowhere::Json& conf,
|
||||
const BitsetView& bitset) {
|
||||
SubSearchResult sub_result(dataset.num_queries, dataset.topk, dataset.metric_type, dataset.round_decimal);
|
||||
try {
|
||||
@ -48,15 +51,29 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
{knowhere::meta::DIM, dim},
|
||||
{knowhere::meta::TOPK, topk},
|
||||
};
|
||||
|
||||
sub_result.mutable_seg_offsets().resize(nq * topk);
|
||||
sub_result.mutable_distances().resize(nq * topk);
|
||||
|
||||
auto stat =
|
||||
knowhere::BruteForce::SearchWithBuf(base_dataset, query_dataset, sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_distances().data(), config, bitset);
|
||||
if (conf.contains(RADIUS)) {
|
||||
config[RADIUS] = conf[RADIUS];
|
||||
if (conf.contains(RANGE_FILTER)) {
|
||||
config[RANGE_FILTER] = conf[RANGE_FILTER];
|
||||
CheckRangeSearchParam(config[RADIUS], config[RANGE_FILTER], dataset.metric_type);
|
||||
}
|
||||
auto result = SortRangeSearchResult(
|
||||
knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, config, bitset).value(), topk, nq,
|
||||
dataset.metric_type);
|
||||
std::copy_n(GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
|
||||
std::copy_n(GetDatasetDistance(result), nq * topk, sub_result.get_distances());
|
||||
} else {
|
||||
auto stat = knowhere::BruteForce::SearchWithBuf(base_dataset, query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_distances().data(), config, bitset);
|
||||
|
||||
if (stat != knowhere::Status::success) {
|
||||
throw std::invalid_argument("invalid metric type");
|
||||
if (stat != knowhere::Status::success) {
|
||||
throw std::invalid_argument("invalid metric type");
|
||||
}
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
PanicInfo(e.what());
|
||||
|
@ -26,6 +26,7 @@ SubSearchResult
|
||||
BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t chunk_rows,
|
||||
const knowhere::Json& conf,
|
||||
const BitsetView& bitset);
|
||||
|
||||
} // namespace milvus::query
|
||||
|
@ -125,7 +125,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
auto size_per_chunk = element_end - element_begin;
|
||||
|
||||
auto sub_view = bitset.subview(element_begin, size_per_chunk);
|
||||
auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, sub_view);
|
||||
auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, info.search_params_, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_qr.mutable_seg_offsets()) {
|
||||
|
@ -86,9 +86,8 @@ SearchOnSealed(const Schema& schema,
|
||||
auto vec_data = record.get_field_data_base(field_id);
|
||||
AssertInfo(vec_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
|
||||
auto chunk_data = vec_data->get_chunk_data(0);
|
||||
|
||||
CheckBruteForceSearchParam(field, search_info);
|
||||
auto sub_qr = BruteForceSearch(dataset, chunk_data, row_count, bitset);
|
||||
auto sub_qr = BruteForceSearch(dataset, chunk_data, row_count, search_info.search_params_, bitset);
|
||||
|
||||
result.distances_ = std::move(sub_qr.mutable_distances());
|
||||
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
|
||||
|
@ -46,6 +46,7 @@ set(MILVUS_TEST_FILES
|
||||
test_timestamp_index.cpp
|
||||
test_utils.cpp
|
||||
test_data_codec.cpp
|
||||
test_range_search_sort.cpp
|
||||
)
|
||||
|
||||
if ( BUILD_DISK_ANN STREQUAL "ON" )
|
||||
|
@ -120,7 +120,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
|
||||
// ASSERT_ANY_THROW(BruteForceSearch(dataset, base.data(), nb, bitset_view));
|
||||
return;
|
||||
}
|
||||
auto result = BruteForceSearch(dataset, base.data(), nb, bitset_view);
|
||||
auto result = BruteForceSearch(dataset, base.data(), nb, knowhere::Json(), bitset_view);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref = Ref(base.data(), query.data() + i * dim, nb, dim, topk, metric_type);
|
||||
auto ans = result.get_seg_offsets() + i * topk;
|
||||
|
@ -3522,3 +3522,245 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) {
|
||||
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_WHEN_IP) {
|
||||
auto c_collection = NewCollection(get_default_schema_config());
|
||||
auto segment = NewSegment(c_collection, Growing, -1);
|
||||
auto col = (milvus::segcore::Collection*)c_collection;
|
||||
|
||||
int N = 10000;
|
||||
auto dataset = DataGen(col->get_schema(), N);
|
||||
int64_t ts_offset = 1000;
|
||||
|
||||
int64_t offset;
|
||||
PreInsert(segment, N, &offset);
|
||||
|
||||
auto insert_data = serialize(dataset.raw_);
|
||||
auto ins_res = Insert(segment, offset, N, dataset.row_ids_.data(), dataset.timestamps_.data(), insert_data.data(),
|
||||
insert_data.size());
|
||||
ASSERT_EQ(ins_res.error_code, Success);
|
||||
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "IP",
|
||||
"params": {
|
||||
"nprobe": 10,
|
||||
"radius": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
int num_queries = 10;
|
||||
auto blob = generate_query_data(num_queries);
|
||||
|
||||
void* plan = nullptr;
|
||||
auto status = CreateSearchPlan(c_collection, dsl_string, &plan);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
void* placeholderGroup = nullptr;
|
||||
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
|
||||
CSearchResult search_result;
|
||||
auto res = Search(segment, plan, placeholderGroup, ts_offset, &search_result);
|
||||
ASSERT_EQ(res.error_code, Success);
|
||||
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteSearchResult(search_result);
|
||||
DeleteCollection(c_collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP) {
|
||||
auto c_collection = NewCollection(get_default_schema_config());
|
||||
auto segment = NewSegment(c_collection, Growing, -1);
|
||||
auto col = (milvus::segcore::Collection*)c_collection;
|
||||
|
||||
int N = 10000;
|
||||
auto dataset = DataGen(col->get_schema(), N);
|
||||
int64_t ts_offset = 1000;
|
||||
|
||||
int64_t offset;
|
||||
PreInsert(segment, N, &offset);
|
||||
|
||||
auto insert_data = serialize(dataset.raw_);
|
||||
auto ins_res = Insert(segment, offset, N, dataset.row_ids_.data(), dataset.timestamps_.data(), insert_data.data(),
|
||||
insert_data.size());
|
||||
ASSERT_EQ(ins_res.error_code, Success);
|
||||
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "IP",
|
||||
"params": {
|
||||
"nprobe": 10,
|
||||
"radius": 10,
|
||||
"range_filter": 20
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
int num_queries = 10;
|
||||
auto blob = generate_query_data(num_queries);
|
||||
|
||||
void* plan = nullptr;
|
||||
auto status = CreateSearchPlan(c_collection, dsl_string, &plan);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
void* placeholderGroup = nullptr;
|
||||
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
|
||||
CSearchResult search_result;
|
||||
auto res = Search(segment, plan, placeholderGroup, ts_offset, &search_result);
|
||||
ASSERT_EQ(res.error_code, Success);
|
||||
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteSearchResult(search_result);
|
||||
DeleteCollection(c_collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_WHEN_L2) {
|
||||
auto c_collection = NewCollection(get_default_schema_config());
|
||||
auto segment = NewSegment(c_collection, Growing, -1);
|
||||
auto col = (milvus::segcore::Collection*)c_collection;
|
||||
|
||||
int N = 10000;
|
||||
auto dataset = DataGen(col->get_schema(), N);
|
||||
int64_t ts_offset = 1000;
|
||||
|
||||
int64_t offset;
|
||||
PreInsert(segment, N, &offset);
|
||||
|
||||
auto insert_data = serialize(dataset.raw_);
|
||||
auto ins_res = Insert(segment, offset, N, dataset.row_ids_.data(), dataset.timestamps_.data(), insert_data.data(),
|
||||
insert_data.size());
|
||||
ASSERT_EQ(ins_res.error_code, Success);
|
||||
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10,
|
||||
"radius": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
int num_queries = 10;
|
||||
auto blob = generate_query_data(num_queries);
|
||||
|
||||
void* plan = nullptr;
|
||||
auto status = CreateSearchPlan(c_collection, dsl_string, &plan);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
void* placeholderGroup = nullptr;
|
||||
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
|
||||
CSearchResult search_result;
|
||||
auto res = Search(segment, plan, placeholderGroup, ts_offset, &search_result);
|
||||
ASSERT_EQ(res.error_code, Success);
|
||||
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteSearchResult(search_result);
|
||||
DeleteCollection(c_collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_L2) {
|
||||
auto c_collection = NewCollection(get_default_schema_config());
|
||||
auto segment = NewSegment(c_collection, Growing, -1);
|
||||
auto col = (milvus::segcore::Collection*)c_collection;
|
||||
|
||||
int N = 10000;
|
||||
auto dataset = DataGen(col->get_schema(), N);
|
||||
int64_t ts_offset = 1000;
|
||||
|
||||
int64_t offset;
|
||||
PreInsert(segment, N, &offset);
|
||||
|
||||
auto insert_data = serialize(dataset.raw_);
|
||||
auto ins_res = Insert(segment, offset, N, dataset.row_ids_.data(), dataset.timestamps_.data(), insert_data.data(),
|
||||
insert_data.size());
|
||||
ASSERT_EQ(ins_res.error_code, Success);
|
||||
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10,
|
||||
"radius": 20,
|
||||
"range_filter": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
int num_queries = 10;
|
||||
auto blob = generate_query_data(num_queries);
|
||||
|
||||
void* plan = nullptr;
|
||||
auto status = CreateSearchPlan(c_collection, dsl_string, &plan);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
void* placeholderGroup = nullptr;
|
||||
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
|
||||
CSearchResult search_result;
|
||||
auto res = Search(segment, plan, placeholderGroup, ts_offset, &search_result);
|
||||
ASSERT_EQ(res.error_code, Success);
|
||||
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteSearchResult(search_result);
|
||||
DeleteCollection(c_collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
@ -145,6 +145,7 @@ TEST(Indexing, BinaryBruteForce) {
|
||||
int64_t topk = 5;
|
||||
int64_t round_decimal = 3;
|
||||
int64_t dim = 8192;
|
||||
Config search_params_ = {};
|
||||
auto metric_type = knowhere::metric::JACCARD;
|
||||
auto result_count = topk * num_queries;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
@ -162,7 +163,7 @@ TEST(Indexing, BinaryBruteForce) {
|
||||
query_data //
|
||||
};
|
||||
|
||||
auto sub_result = query::BruteForceSearch(search_dataset, bin_vec.data(), N, nullptr);
|
||||
auto sub_result = query::BruteForceSearch(search_dataset, bin_vec.data(), N, knowhere::Json(), nullptr);
|
||||
|
||||
SearchResult sr;
|
||||
sr.total_nq_ = num_queries;
|
||||
@ -293,6 +294,7 @@ class IndexTest : public ::testing::TestWithParam<Param> {
|
||||
build_conf = generate_build_conf(index_type, metric_type);
|
||||
load_conf = generate_load_conf(index_type, metric_type, NB);
|
||||
search_conf = generate_search_conf(index_type, metric_type);
|
||||
range_search_conf = generate_range_search_conf(index_type, metric_type);
|
||||
|
||||
std::map<knowhere::MetricType, bool> is_binary_map = {
|
||||
{knowhere::IndexEnum::INDEX_FAISS_IDMAP, false},
|
||||
@ -335,6 +337,7 @@ class IndexTest : public ::testing::TestWithParam<Param> {
|
||||
milvus::Config build_conf;
|
||||
milvus::Config load_conf;
|
||||
milvus::Config search_conf;
|
||||
milvus::Config range_search_conf;
|
||||
milvus::DataType vec_field_data_type;
|
||||
knowhere::DataSetPtr xb_dataset;
|
||||
std::vector<float> xb_data;
|
||||
@ -357,9 +360,9 @@ INSTANTIATE_TEST_CASE_P(
|
||||
std::pair(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, knowhere::metric::JACCARD),
|
||||
std::pair(knowhere::IndexEnum::INDEX_HNSW, knowhere::metric::L2),
|
||||
// ci ut not start minio, so not run ut about diskann index for now
|
||||
//#ifdef BUILD_DISK_ANN
|
||||
// #ifdef BUILD_DISK_ANN
|
||||
// std::pair(knowhere::IndexEnum::INDEX_DISKANN, knowhere::metric::L2),
|
||||
//#endif
|
||||
// #endif
|
||||
std::pair(knowhere::IndexEnum::INDEX_ANNOY, knowhere::metric::L2)));
|
||||
|
||||
TEST_P(IndexTest, BuildAndQuery) {
|
||||
@ -422,76 +425,79 @@ TEST_P(IndexTest, BuildAndQuery) {
|
||||
if (!is_binary) {
|
||||
EXPECT_EQ(result->seg_offsets_[0], query_offset);
|
||||
}
|
||||
if (index_type != knowhere::IndexEnum::INDEX_ANNOY) {
|
||||
search_info.search_params_ = range_search_conf;
|
||||
vec_index->Query(xq_dataset, search_info, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
//#ifdef BUILD_DISK_ANN
|
||||
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
|
||||
// int64_t NB = 10000;
|
||||
// IndexType index_type = knowhere::IndexEnum::INDEX_DISKANN;
|
||||
// MetricType metric_type = knowhere::metric::L2;
|
||||
// milvus::index::CreateIndexInfo create_index_info;
|
||||
// create_index_info.index_type = index_type;
|
||||
// create_index_info.metric_type = metric_type;
|
||||
// create_index_info.field_type = milvus::DataType::VECTOR_FLOAT;
|
||||
// #ifdef BUILD_DISK_ANN
|
||||
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
|
||||
// int64_t NB = 10000;
|
||||
// IndexType index_type = knowhere::IndexEnum::INDEX_DISKANN;
|
||||
// MetricType metric_type = knowhere::metric::L2;
|
||||
// milvus::index::CreateIndexInfo create_index_info;
|
||||
// create_index_info.index_type = index_type;
|
||||
// create_index_info.metric_type = metric_type;
|
||||
// create_index_info.field_type = milvus::DataType::VECTOR_FLOAT;
|
||||
//
|
||||
// StorageConfig storage_config = get_default_storage_config();
|
||||
// auto rcm = std::make_shared<storage::MinioChunkManager>(storage_config);
|
||||
// if (!rcm->BucketExists(storage_config.bucket_name)) {
|
||||
// rcm->CreateBucket(storage_config.bucket_name);
|
||||
// }
|
||||
// milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
|
||||
// milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
|
||||
// auto file_manager =
|
||||
// std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config);
|
||||
// auto index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, file_manager);
|
||||
// StorageConfig storage_config = get_default_storage_config();
|
||||
// auto rcm = std::make_shared<storage::MinioChunkManager>(storage_config);
|
||||
// if (!rcm->BucketExists(storage_config.bucket_name)) {
|
||||
// rcm->CreateBucket(storage_config.bucket_name);
|
||||
// }
|
||||
// milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
|
||||
// milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
|
||||
// auto file_manager =
|
||||
// std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config);
|
||||
// auto index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, file_manager);
|
||||
//
|
||||
// auto build_conf = knowhere::Config{
|
||||
// {knowhere::meta::METRIC_TYPE, metric_type},
|
||||
// {knowhere::meta::DIM, std::to_string(DIM)},
|
||||
// {milvus::index::DISK_ANN_MAX_DEGREE, std::to_string(48)},
|
||||
// {milvus::index::DISK_ANN_SEARCH_LIST_SIZE, std::to_string(128)},
|
||||
// {milvus::index::DISK_ANN_PQ_CODE_BUDGET, std::to_string(0.001)},
|
||||
// {milvus::index::DISK_ANN_BUILD_DRAM_BUDGET, std::to_string(2)},
|
||||
// };
|
||||
// auto build_conf = knowhere::Config{
|
||||
// {knowhere::meta::METRIC_TYPE, metric_type},
|
||||
// {knowhere::meta::DIM, std::to_string(DIM)},
|
||||
// {milvus::index::DISK_ANN_MAX_DEGREE, std::to_string(48)},
|
||||
// {milvus::index::DISK_ANN_SEARCH_LIST_SIZE, std::to_string(128)},
|
||||
// {milvus::index::DISK_ANN_PQ_CODE_BUDGET, std::to_string(0.001)},
|
||||
// {milvus::index::DISK_ANN_BUILD_DRAM_BUDGET, std::to_string(2)},
|
||||
// };
|
||||
//
|
||||
// // build disk ann index
|
||||
// auto dataset = GenDataset(NB, metric_type, false);
|
||||
// std::vector<float> xb_data = dataset.get_col<float>(milvus::FieldId(100));
|
||||
// knowhere::DatasetPtr xb_dataset = knowhere::GenDataset(NB, DIM, xb_data.data());
|
||||
// ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
|
||||
// // build disk ann index
|
||||
// auto dataset = GenDataset(NB, metric_type, false);
|
||||
// std::vector<float> xb_data = dataset.get_col<float>(milvus::FieldId(100));
|
||||
// knowhere::DatasetPtr xb_dataset = knowhere::GenDataset(NB, DIM, xb_data.data());
|
||||
// ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
|
||||
//
|
||||
// // serialize and load disk index, disk index can only be search after loading for now
|
||||
// auto binary_set = index->Serialize(milvus::Config{});
|
||||
// index.reset();
|
||||
// // clean local file dir
|
||||
// file_manager.reset();
|
||||
// // serialize and load disk index, disk index can only be search after loading for now
|
||||
// auto binary_set = index->Serialize(milvus::Config{});
|
||||
// index.reset();
|
||||
// // clean local file dir
|
||||
// file_manager.reset();
|
||||
//
|
||||
// auto new_file_manager =
|
||||
// std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config);
|
||||
// auto new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, new_file_manager);
|
||||
// auto vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
|
||||
// std::vector<std::string> index_files;
|
||||
// for (auto& binary : binary_set.binary_map_) {
|
||||
// index_files.emplace_back(binary.first);
|
||||
// }
|
||||
// auto load_conf = generate_load_conf(index_type, metric_type, NB);
|
||||
// load_conf["index_files"] = index_files;
|
||||
// vec_index->Load(binary_set, load_conf);
|
||||
// EXPECT_EQ(vec_index->Count(), NB);
|
||||
// auto new_file_manager =
|
||||
// std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config);
|
||||
// auto new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, new_file_manager);
|
||||
// auto vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
|
||||
// std::vector<std::string> index_files;
|
||||
// for (auto& binary : binary_set.binary_map_) {
|
||||
// index_files.emplace_back(binary.first);
|
||||
// }
|
||||
// auto load_conf = generate_load_conf(index_type, metric_type, NB);
|
||||
// load_conf["index_files"] = index_files;
|
||||
// vec_index->Load(binary_set, load_conf);
|
||||
// EXPECT_EQ(vec_index->Count(), NB);
|
||||
//
|
||||
// // search disk index with search_list == limit
|
||||
// int query_offset = 100;
|
||||
// knowhere::DatasetPtr xq_dataset = knowhere::GenDataset(NQ, DIM, xb_data.data() + DIM * query_offset);
|
||||
// // search disk index with search_list == limit
|
||||
// int query_offset = 100;
|
||||
// knowhere::DatasetPtr xq_dataset = knowhere::GenDataset(NQ, DIM, xb_data.data() + DIM * query_offset);
|
||||
//
|
||||
// milvus::SearchInfo search_info;
|
||||
// search_info.topk_ = K;
|
||||
// search_info.metric_type_ = metric_type;
|
||||
// search_info.search_params_ = milvus::Config{
|
||||
// {knowhere::meta::METRIC_TYPE, metric_type},
|
||||
// {milvus::index::DISK_ANN_QUERY_LIST, K - 1},
|
||||
// };
|
||||
// EXPECT_THROW(vec_index->Query(xq_dataset, search_info, nullptr), std::runtime_error);
|
||||
// // vec_index->Query(xq_dataset, search_info, nullptr);
|
||||
//}
|
||||
//#endif
|
||||
|
||||
// milvus::SearchInfo search_info;
|
||||
// search_info.topk_ = K;
|
||||
// search_info.metric_type_ = metric_type;
|
||||
// search_info.search_params_ = milvus::Config{
|
||||
// {knowhere::meta::METRIC_TYPE, metric_type},
|
||||
// {milvus::index::DISK_ANN_QUERY_LIST, K - 1},
|
||||
// };
|
||||
// EXPECT_THROW(vec_index->Query(xq_dataset, search_info, nullptr), std::runtime_error);
|
||||
// // vec_index->Query(xq_dataset, search_info, nullptr);
|
||||
// }
|
||||
// #endif
|
||||
|
165
internal/core/unittest/test_range_search_sort.cpp
Normal file
165
internal/core/unittest/test_range_search_sort.cpp
Normal file
@ -0,0 +1,165 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "common/RangeSearchHelper.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Utils.h"
|
||||
#include "common/Schema.h"
|
||||
#include "test_utils/indexbuilder_test_utils.h"
|
||||
|
||||
bool
|
||||
cmp1(std::pair<float, int64_t> a, std::pair<float, int64_t> b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
bool
|
||||
cmp2(std::pair<float, int64_t> a, std::pair<float, int64_t> b) {
|
||||
return a.first < b.first;
|
||||
}
|
||||
|
||||
auto
|
||||
RangeSearchSortResultBF(milvus::DatasetPtr data_set, int64_t topk, size_t nq, std::string metric_type) {
|
||||
auto lims = milvus::GetDatasetLims(data_set);
|
||||
auto id = milvus::GetDatasetIDs(data_set);
|
||||
auto dist = milvus::GetDatasetDistance(data_set);
|
||||
auto p_id = new int64_t[topk * nq];
|
||||
auto p_dist = new float[topk * nq];
|
||||
// cnt means the subscript of p_id and p_dist
|
||||
int cnt = 0;
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto size = lims[i + 1] - lims[i];
|
||||
int capacity = topk > size ? size : topk;
|
||||
// sort each layer
|
||||
std::vector<std::pair<float, int64_t>> list;
|
||||
if (milvus::IsMetricType(metric_type, knowhere::metric::IP)) {
|
||||
for (int j = lims[i]; j < lims[i + 1]; j++) {
|
||||
list.push_back(std::pair<float, int64_t>(dist[j], id[j]));
|
||||
}
|
||||
std::sort(list.begin(), list.end(), cmp1);
|
||||
|
||||
} else {
|
||||
for (int j = lims[i]; j < lims[i + 1]; j++) {
|
||||
list.push_back(std::pair<float, int64_t>(dist[j], id[j]));
|
||||
}
|
||||
std::sort(list.begin(), list.end(), cmp2);
|
||||
}
|
||||
for (int k = 0; k < capacity; k++) {
|
||||
p_dist[cnt] = list[k].first;
|
||||
p_id[cnt] = list[k].second;
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(cnt, p_id, p_dist);
|
||||
}
|
||||
|
||||
milvus::DatasetPtr
|
||||
genResultDataset(const int64_t nq, const int64_t* ids, const float* distance, const size_t* lims) {
|
||||
auto ret_ds = std::make_shared<milvus::Dataset>();
|
||||
ret_ds->SetRows(nq);
|
||||
ret_ds->SetIds(ids);
|
||||
ret_ds->SetDistance(distance);
|
||||
ret_ds->SetLims(lims);
|
||||
ret_ds->SetIsOwner(true);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
void
|
||||
CheckRangeSearchSortResult(int64_t* p_id, float* p_dist, milvus::DatasetPtr dataset, int64_t n) {
|
||||
auto id = milvus::GetDatasetIDs(dataset);
|
||||
auto dist = milvus::GetDatasetDistance(dataset);
|
||||
for (int i = 0; i < n; i++) {
|
||||
AssertInfo(id[i] == p_id[i], "id of range search result are not the same");
|
||||
AssertInfo(dist[i] == p_dist[i], "distance of range search result are not the same");
|
||||
}
|
||||
}
|
||||
|
||||
auto
|
||||
GenRangeSearchResult(int64_t* ids,
|
||||
float* distances,
|
||||
size_t* lims,
|
||||
int64_t N,
|
||||
int64_t id_min,
|
||||
int64_t id_max,
|
||||
float distance_min,
|
||||
float distance_max,
|
||||
int seed = 42) {
|
||||
std::mt19937 e(seed);
|
||||
std::uniform_int_distribution<> uniform_num(0, N);
|
||||
std::uniform_int_distribution<> uniform_ids(id_min, id_max);
|
||||
std::uniform_real_distribution<> uniform_distance(distance_min, distance_max);
|
||||
|
||||
lims = new size_t[N + 1];
|
||||
// alloc max memory
|
||||
distances = new float[N * N];
|
||||
ids = new int64_t[N * N];
|
||||
lims[0] = 0;
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
int64_t num = uniform_num(e);
|
||||
for (int64_t j = 0; j < num; j++) {
|
||||
auto id = uniform_ids(e);
|
||||
auto dis = uniform_distance(e);
|
||||
ids[lims[i] + j] = id;
|
||||
distances[lims[i] + j] = dis;
|
||||
}
|
||||
lims[i + 1] = lims[i] + num;
|
||||
}
|
||||
return genResultDataset(N, ids, distances, lims);
|
||||
}
|
||||
|
||||
class RangeSearchSortTest : public ::testing::TestWithParam<knowhere::MetricType> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
metric_type = GetParam();
|
||||
dataset = GenRangeSearchResult(ids, distances, lims, N, id_min, id_max, dist_min, dist_max);
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
delete[] ids;
|
||||
delete[] distances;
|
||||
delete[] lims;
|
||||
}
|
||||
|
||||
protected:
|
||||
knowhere::MetricType metric_type;
|
||||
milvus::DatasetPtr dataset = nullptr;
|
||||
int64_t N = 100;
|
||||
int64_t TOPK = 10;
|
||||
int64_t DIM = 16;
|
||||
int64_t* ids = nullptr;
|
||||
float* distances = nullptr;
|
||||
size_t* lims = nullptr;
|
||||
int64_t id_min = 0, id_max = 10000;
|
||||
float dist_min = 0.0, dist_max = 100.0;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(RangeSearchSortParameters,
|
||||
RangeSearchSortTest,
|
||||
::testing::Values(knowhere::metric::L2,
|
||||
knowhere::metric::IP,
|
||||
knowhere::metric::JACCARD,
|
||||
knowhere::metric::TANIMOTO,
|
||||
knowhere::metric::HAMMING));
|
||||
|
||||
TEST_P(RangeSearchSortTest, CheckRangeSearchSort) {
|
||||
auto res = milvus::SortRangeSearchResult(dataset, TOPK, N, metric_type);
|
||||
auto [real_num, p_id, p_dist] = RangeSearchSortResultBF(dataset, TOPK, N, metric_type);
|
||||
CheckRangeSearchSortResult(p_id, p_dist, res, real_num);
|
||||
delete[] p_id;
|
||||
delete[] p_dist;
|
||||
}
|
@ -84,8 +84,8 @@ TEST(Sealed, without_predicate) {
|
||||
auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, nullptr);
|
||||
|
||||
auto build_conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
|
||||
{knowhere::meta::DIM, std::to_string(dim)},
|
||||
{knowhere::indexparam::NLIST, "100"}};
|
||||
{knowhere::meta::DIM, std::to_string(dim)},
|
||||
{knowhere::indexparam::NLIST, "100"}};
|
||||
|
||||
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
|
||||
|
||||
@ -190,8 +190,8 @@ TEST(Sealed, with_predicate) {
|
||||
auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, nullptr);
|
||||
|
||||
auto build_conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
|
||||
{knowhere::meta::DIM, std::to_string(dim)},
|
||||
{knowhere::indexparam::NLIST, "100"}};
|
||||
{knowhere::meta::DIM, std::to_string(dim)},
|
||||
{knowhere::indexparam::NLIST, "100"}};
|
||||
|
||||
auto database = knowhere::GenDataSet(N, dim, vec_col.data());
|
||||
indexing->BuildWithDataset(database, build_conf);
|
||||
@ -288,8 +288,8 @@ TEST(Sealed, with_predicate_filter_all) {
|
||||
auto ivf_indexing = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, nullptr);
|
||||
|
||||
auto ivf_build_conf = knowhere::Json{{knowhere::meta::DIM, std::to_string(dim)},
|
||||
{knowhere::indexparam::NLIST, "100"},
|
||||
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}};
|
||||
{knowhere::indexparam::NLIST, "100"},
|
||||
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}};
|
||||
|
||||
auto database = knowhere::GenDataSet(N, dim, vec_col.data());
|
||||
ivf_indexing->BuildWithDataset(database, ivf_build_conf);
|
||||
@ -312,10 +312,10 @@ TEST(Sealed, with_predicate_filter_all) {
|
||||
EXPECT_EQ(sr->get_total_result_count(), 0);
|
||||
|
||||
auto hnsw_conf = knowhere::Json{{knowhere::meta::DIM, std::to_string(dim)},
|
||||
{knowhere::indexparam::HNSW_M, "16"},
|
||||
{knowhere::indexparam::EFCONSTRUCTION, "200"},
|
||||
{knowhere::indexparam::EF, "200"},
|
||||
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}};
|
||||
{knowhere::indexparam::HNSW_M, "16"},
|
||||
{knowhere::indexparam::EFCONSTRUCTION, "200"},
|
||||
{knowhere::indexparam::EF, "200"},
|
||||
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}};
|
||||
|
||||
create_index_info.field_type = DataType::VECTOR_FLOAT;
|
||||
create_index_info.metric_type = knowhere::metric::L2;
|
||||
@ -437,7 +437,7 @@ TEST(Sealed, LoadFieldData) {
|
||||
// ASSERT_EQ(json.dump(-2), json2.dump(-2));
|
||||
// segment->DropFieldData(double_id);
|
||||
// ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
|
||||
//#ifdef __linux__
|
||||
// #ifdef __linux__
|
||||
// auto std_json = Json::parse(R"(
|
||||
//[
|
||||
// [
|
||||
@ -448,7 +448,7 @@ TEST(Sealed, LoadFieldData) {
|
||||
// ["66353->5.696000", "30664->5.881000", "41087->5.917000", "10393->6.633000", "90215->7.202000"]
|
||||
// ]
|
||||
//])");
|
||||
//#else // for mac
|
||||
// #else // for mac
|
||||
// auto std_json = Json::parse(R"(
|
||||
//[
|
||||
// [
|
||||
@ -459,7 +459,7 @@ TEST(Sealed, LoadFieldData) {
|
||||
// ["37759->3.581000", "31292->5.780000", "98124->6.216000", "63535->6.439000", "11707->6.553000"]
|
||||
// ]
|
||||
//])");
|
||||
//#endif
|
||||
// #endif
|
||||
// ASSERT_EQ(std_json.dump(-2), json.dump(-2));
|
||||
}
|
||||
|
||||
|
@ -538,7 +538,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
|
||||
dim, //
|
||||
query_ptr //
|
||||
};
|
||||
auto sub_result = BruteForceSearch(search_dataset, vec_col.data(), N, nullptr);
|
||||
auto sub_result = BruteForceSearch(search_dataset, vec_col.data(), N, knowhere::Json(), nullptr);
|
||||
|
||||
auto sr = segment->Search(plan.get(), ph_group.get(), time);
|
||||
segment->FillPrimaryKeys(plan.get(), *sr);
|
||||
|
@ -292,6 +292,32 @@ generate_search_conf(const milvus::IndexType& index_type, const milvus::MetricTy
|
||||
return conf;
|
||||
}
|
||||
|
||||
auto
|
||||
generate_range_search_conf(const milvus::IndexType& index_type, const milvus::MetricType& metric_type) {
|
||||
auto conf = milvus::Config{
|
||||
{knowhere::meta::METRIC_TYPE, metric_type},
|
||||
};
|
||||
|
||||
if (metric_type == knowhere::metric::IP) {
|
||||
conf[knowhere::meta::RADIUS] = 0.1;
|
||||
conf[knowhere::meta::RANGE_FILTER] = 0.2;
|
||||
} else {
|
||||
conf[knowhere::meta::RADIUS] = 0.2;
|
||||
conf[knowhere::meta::RANGE_FILTER] = 0.1;
|
||||
}
|
||||
|
||||
if (milvus::index::is_in_list<milvus::IndexType>(index_type, search_with_nprobe_list)) {
|
||||
conf[knowhere::indexparam::NPROBE] = 4;
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_HNSW) {
|
||||
conf[knowhere::indexparam::EF] = 200;
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) {
|
||||
conf[knowhere::indexparam::SEARCH_K] = 100;
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
|
||||
conf[milvus::index::DISK_ANN_QUERY_LIST] = K * 2;
|
||||
}
|
||||
return conf;
|
||||
}
|
||||
|
||||
auto
|
||||
generate_params(const knowhere::IndexType& index_type, const knowhere::MetricType& metric_type) {
|
||||
namespace indexcgo = milvus::proto::indexcgo;
|
||||
|
Loading…
Reference in New Issue
Block a user