mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
Add unittest IVFFlatNM for Indexing (#13044)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
342200ce13
commit
15b932c63c
@ -17,6 +17,8 @@
|
||||
#include <vector>
|
||||
|
||||
#include "faiss/utils/distances.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_offset_index/IndexIVF_NM.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "segcore/Reduce.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
@ -177,20 +179,22 @@ TEST(Indexing, Naive) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Indexing, IVFFlatNM) {
|
||||
constexpr auto DIM = 16;
|
||||
constexpr auto K = 10;
|
||||
TEST(Indexing, IVFFlat) {
|
||||
constexpr int N = 100000;
|
||||
constexpr int NQ = 10;
|
||||
constexpr int DIM = 16;
|
||||
constexpr int TOPK = 5;
|
||||
constexpr int NLIST = 128;
|
||||
constexpr int NPROBE = 16;
|
||||
|
||||
auto N = 1024 * 1024;
|
||||
auto num_query = 100;
|
||||
Timer timer;
|
||||
auto [raw_data, timestamps, uids] = generate_data<DIM>(N);
|
||||
std::cout << "generate data: " << timer.get_step_seconds() << " seconds" << std::endl;
|
||||
auto indexing = std::make_shared<knowhere::IVF>();
|
||||
auto conf = knowhere::Config{{knowhere::meta::DIM, DIM},
|
||||
{knowhere::meta::TOPK, K},
|
||||
{knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::meta::TOPK, TOPK},
|
||||
{knowhere::IndexParams::nlist, NLIST},
|
||||
{knowhere::IndexParams::nprobe, NPROBE},
|
||||
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, 0}};
|
||||
|
||||
@ -203,14 +207,63 @@ TEST(Indexing, IVFFlatNM) {
|
||||
|
||||
EXPECT_EQ(indexing->Count(), N);
|
||||
EXPECT_EQ(indexing->Dim(), DIM);
|
||||
auto dataset = knowhere::GenDataset(num_query, DIM, raw_data.data() + DIM * 4200);
|
||||
auto dataset = knowhere::GenDataset(NQ, DIM, raw_data.data() + DIM * 4200);
|
||||
|
||||
auto result = indexing->Query(dataset, conf, nullptr);
|
||||
std::cout << "query ivf " << timer.get_step_seconds() << " seconds" << std::endl;
|
||||
|
||||
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto dis = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
for (int i = 0; i < std::min(num_query * K, 100); ++i) {
|
||||
for (int i = 0; i < std::min(NQ * TOPK, 100); ++i) {
|
||||
std::cout << ids[i] << "->" << dis[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Indexing, IVFFlatNM) {
|
||||
constexpr int N = 100000;
|
||||
constexpr int NQ = 10;
|
||||
constexpr int DIM = 16;
|
||||
constexpr int TOPK = 5;
|
||||
constexpr int NLIST = 128;
|
||||
constexpr int NPROBE = 16;
|
||||
|
||||
Timer timer;
|
||||
auto [raw_data, timestamps, uids] = generate_data<DIM>(N);
|
||||
std::cout << "generate data: " << timer.get_step_seconds() << " seconds" << std::endl;
|
||||
auto indexing = std::make_shared<knowhere::IVF_NM>();
|
||||
auto conf = knowhere::Config{{knowhere::meta::DIM, DIM},
|
||||
{knowhere::meta::TOPK, TOPK},
|
||||
{knowhere::IndexParams::nlist, NLIST},
|
||||
{knowhere::IndexParams::nprobe, NPROBE},
|
||||
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, 0}};
|
||||
|
||||
auto database = knowhere::GenDataset(N, DIM, raw_data.data());
|
||||
std::cout << "init ivf " << timer.get_step_seconds() << " seconds" << std::endl;
|
||||
indexing->Train(database, conf);
|
||||
std::cout << "train ivf " << timer.get_step_seconds() << " seconds" << std::endl;
|
||||
indexing->AddWithoutIds(database, conf);
|
||||
std::cout << "insert ivf " << timer.get_step_seconds() << " seconds" << std::endl;
|
||||
|
||||
indexing->SetIndexSize(NQ * DIM * sizeof(float));
|
||||
milvus::knowhere::BinarySet bs = indexing->Serialize(conf);
|
||||
|
||||
milvus::knowhere::BinaryPtr bptr = std::make_shared<milvus::knowhere::Binary>();
|
||||
bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)raw_data.data(), [&](uint8_t*) {});
|
||||
bptr->size = DIM * N * sizeof(float);
|
||||
bs.Append(RAW_DATA, bptr);
|
||||
indexing->Load(bs);
|
||||
|
||||
EXPECT_EQ(indexing->Count(), N);
|
||||
EXPECT_EQ(indexing->Dim(), DIM);
|
||||
auto dataset = knowhere::GenDataset(NQ, DIM, raw_data.data() + DIM * 4200);
|
||||
|
||||
auto result = indexing->Query(dataset, conf, nullptr);
|
||||
std::cout << "query ivf " << timer.get_step_seconds() << " seconds" << std::endl;
|
||||
|
||||
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto dis = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
for (int i = 0; i < std::min(NQ * TOPK, 100); ++i) {
|
||||
std::cout << ids[i] << "->" << dis[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user