Fix search successfully with invalid metric type (#17977)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2022-07-01 22:28:23 +08:00 committed by GitHub
parent ede529ac30
commit 6954a5ba3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 179 additions and 1 deletions

View File

@ -101,10 +101,13 @@ FloatSearchBruteForce(const dataset::SearchDataset& dataset,
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(),
sub_qr.get_distances()};
faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, nullptr, bitset);
} else {
} else if (metric_type == knowhere::metric::IP) {
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(),
sub_qr.get_distances()};
faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
} else {
std::string msg = "search not support metric type: " + metric_type;
PanicInfo(msg);
}
sub_qr.round_values();
return sub_qr;

View File

@ -17,6 +17,7 @@ add_definitions(-DMILVUS_TEST_SEGCORE_YAML_PATH="${CMAKE_SOURCE_DIR}/unittest/te
# TODO: better to use ls/find pattern
set(MILVUS_TEST_FILES
init_gtest.cpp
test_bf.cpp
test_binary.cpp
test_bitmap.cpp
test_bool_index.cpp

View File

@ -0,0 +1,139 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <random>
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
#include "query/SearchBruteForce.h"
#include "test_utils/Distance.h"
#include "test_utils/DataGen.h"
using namespace milvus;
using namespace milvus::segcore;
using namespace milvus::query;
namespace {
auto
GenFloatVecs(int dim, int n, const knowhere::MetricType& metric, int seed = 42) {
auto schema = std::make_shared<Schema>();
auto fvec = schema->AddDebugField("fvec", DataType::VECTOR_FLOAT, dim, metric);
auto dataset = DataGen(schema, n, seed);
return dataset.get_col<float>(fvec);
}
// (offset, distance)
std::vector<std::tuple<int, float>>
Distances(const float* base,
const float* query, // one query.
int nb,
int dim,
const knowhere::MetricType& metric) {
if (metric == knowhere::metric::L2) {
std::vector<std::tuple<int, float>> res;
for (int i = 0; i < nb; i++) {
res.emplace_back(i, L2(base + i * dim, query, dim));
}
return res;
} else if (metric == knowhere::metric::IP) {
std::vector<std::tuple<int, float>> res;
for (int i = 0; i < nb; i++) {
res.emplace_back(i, IP(base + i * dim, query, dim));
}
return res;
} else {
PanicInfo("invalid metric type");
}
}
std::vector<int>
GetOffsets(const std::vector<std::tuple<int, float>>& tuples, int k) {
std::vector<int> offsets;
for (int i = 0; i < k; i++) {
auto [offset, distance] = tuples[i];
offsets.push_back(offset);
}
return offsets;
}
// offsets
std::vector<int>
Ref(const float* base,
const float* query, // one query.
int nb,
int dim,
int topk,
const knowhere::MetricType& metric) {
auto res = Distances(base, query, nb, dim, metric);
std::sort(res.begin(), res.end());
if (metric == knowhere::metric::L2) {
} else if (metric == knowhere::metric::IP) {
std::reverse(res.begin(), res.end());
} else {
PanicInfo("invalid metric type");
}
return GetOffsets(res, topk);
}
bool
AssertMatch(const std::vector<int>& ref, const int64_t* ans) {
for (int i = 0; i < ref.size(); i++) {
if (ref[i] != ans[i]) {
return false;
}
}
return true;
}
bool
is_supported_float_metric(const knowhere::MetricType& metric) {
return metric == knowhere::metric::L2 || metric == knowhere::metric::IP;
}
} // namespace
class TestFloatSearchBruteForce : public ::testing::Test {
public:
void
Run(int nb, int nq, int topk, int dim, const knowhere::MetricType& metric_type) {
auto bitset = std::make_shared<BitsetType>();
bitset->resize(nb);
auto bitset_view = BitsetView(*bitset);
auto base = GenFloatVecs(dim, nb, metric_type);
auto query = GenFloatVecs(dim, nq, metric_type);
dataset::SearchDataset dataset{metric_type, nq, topk, -1, dim, query.data()};
if (!is_supported_float_metric(metric_type)) {
ASSERT_ANY_THROW(FloatSearchBruteForce(dataset, base.data(), nb, bitset_view));
return;
}
auto result = FloatSearchBruteForce(dataset, base.data(), nb, bitset_view);
for (int i = 0; i < nq; i++) {
auto ref = Ref(base.data(), query.data() + i * dim, nb, dim, topk, metric_type);
auto ans = result.get_seg_offsets() + i * topk;
AssertMatch(ref, ans);
}
}
};
TEST_F(TestFloatSearchBruteForce, L2) {
Run(100, 10, 5, 128, knowhere::metric::L2);
}
TEST_F(TestFloatSearchBruteForce, IP) {
Run(100, 10, 5, 128, knowhere::metric::IP);
}
TEST_F(TestFloatSearchBruteForce, NotSupported) {
Run(100, 10, 5, 128, "aaaaaaaaaaaa");
}

View File

@ -0,0 +1,35 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
namespace {
float
L2(const float* point_a, const float* point_b, int dim) {
float dis = 0;
for (auto i = 0; i < dim; i++) {
auto c_a = point_a[i];
auto c_b = point_b[i];
dis += pow(c_b - c_a, 2);
}
return dis;
}
float
IP(const float* point_a, const float* point_b, int dim) {
float dis = 0;
for (auto i = 0; i < dim; i++) {
auto c_a = point_a[i];
auto c_b = point_b[i];
dis += c_a * c_b;
}
return dis;
}
} // namespace