Merge branch 'ms544' into 'branch-0.4.0'

MS-544 fix

See merge request megasearch/milvus!552

Former-commit-id: 2c99ce1fce5f15e36cfa854d9d688bc52fc68a02
This commit is contained in:
jinhai 2019-09-11 20:18:02 +08:00
commit d4ea0151eb

View File

@ -26,17 +26,17 @@ namespace knowhere {
IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) { IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
auto nlist = config["nlist"].as<size_t>(); auto nlist = config["nlist"].as<size_t>();
auto gpu_device = config.get_with_default("gpu_id", gpu_id_); gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
auto metric_type = config["metric_type"].as_string() == "L2" ? auto metric_type = config["metric_type"].as_string() == "L2" ?
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
GETTENSOR(dataset) GETTENSOR(dataset)
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device); auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) { if (temp_resource != nullptr) {
ResScope rs(gpu_device, temp_resource); ResScope rs(gpu_id_, temp_resource);
faiss::gpu::GpuIndexIVFFlatConfig idx_config; faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_device; idx_config.device = gpu_id_;
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config); faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index.train(rows, (float *) p_data); device_index.train(rows, (float *) p_data);
@ -204,7 +204,7 @@ VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) {
IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) { IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
auto nlist = config["nlist"].as<size_t>(); auto nlist = config["nlist"].as<size_t>();
auto nbits = config["nbits"].as<size_t>(); // TODO(linxj): gpu only support SQ4 SQ8 SQ16 auto nbits = config["nbits"].as<size_t>(); // TODO(linxj): gpu only support SQ4 SQ8 SQ16
auto gpu_num = config.get_with_default("gpu_id", gpu_id_); gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
auto metric_type = config["metric_type"].as_string() == "L2" ? auto metric_type = config["metric_type"].as_string() == "L2" ?
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
@ -214,10 +214,10 @@ IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
index_type << "IVF" << nlist << "," << "SQ" << nbits; index_type << "IVF" << nlist << "," << "SQ" << nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type); auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_num); auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) { if (temp_resource != nullptr) {
ResScope rs(gpu_num, temp_resource ); ResScope rs(gpu_id_, temp_resource );
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_num, build_index); auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float *) p_data); device_index->train(rows, (float *) p_data);
std::shared_ptr<faiss::Index> host_index = nullptr; std::shared_ptr<faiss::Index> host_index = nullptr;