mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-04 21:09:06 +08:00
Merge branch 'ms544' into 'branch-0.4.0'
MS-544 fix See merge request megasearch/milvus!552 Former-commit-id: 2c99ce1fce5f15e36cfa854d9d688bc52fc68a02
This commit is contained in:
commit
d4ea0151eb
@ -26,17 +26,17 @@ namespace knowhere {
|
|||||||
|
|
||||||
IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
|
IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
|
||||||
auto nlist = config["nlist"].as<size_t>();
|
auto nlist = config["nlist"].as<size_t>();
|
||||||
auto gpu_device = config.get_with_default("gpu_id", gpu_id_);
|
gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
|
||||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||||
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||||
|
|
||||||
GETTENSOR(dataset)
|
GETTENSOR(dataset)
|
||||||
|
|
||||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device);
|
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||||
if (temp_resource != nullptr) {
|
if (temp_resource != nullptr) {
|
||||||
ResScope rs(gpu_device, temp_resource);
|
ResScope rs(gpu_id_, temp_resource);
|
||||||
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
|
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
|
||||||
idx_config.device = gpu_device;
|
idx_config.device = gpu_id_;
|
||||||
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config);
|
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config);
|
||||||
device_index.train(rows, (float *) p_data);
|
device_index.train(rows, (float *) p_data);
|
||||||
|
|
||||||
@ -204,7 +204,7 @@ VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) {
|
|||||||
IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
|
IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
|
||||||
auto nlist = config["nlist"].as<size_t>();
|
auto nlist = config["nlist"].as<size_t>();
|
||||||
auto nbits = config["nbits"].as<size_t>(); // TODO(linxj): gpu only support SQ4 SQ8 SQ16
|
auto nbits = config["nbits"].as<size_t>(); // TODO(linxj): gpu only support SQ4 SQ8 SQ16
|
||||||
auto gpu_num = config.get_with_default("gpu_id", gpu_id_);
|
gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
|
||||||
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
auto metric_type = config["metric_type"].as_string() == "L2" ?
|
||||||
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||||
|
|
||||||
@ -214,10 +214,10 @@ IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
|
|||||||
index_type << "IVF" << nlist << "," << "SQ" << nbits;
|
index_type << "IVF" << nlist << "," << "SQ" << nbits;
|
||||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
|
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
|
||||||
|
|
||||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_num);
|
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||||
if (temp_resource != nullptr) {
|
if (temp_resource != nullptr) {
|
||||||
ResScope rs(gpu_num, temp_resource );
|
ResScope rs(gpu_id_, temp_resource );
|
||||||
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_num, build_index);
|
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
|
||||||
device_index->train(rows, (float *) p_data);
|
device_index->train(rows, (float *) p_data);
|
||||||
|
|
||||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||||
|
Loading…
Reference in New Issue
Block a user