#1653 IndexFlat performance improvement for NQ < thread_number (#1690)

* optimize sse

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* optimizer BinaryDistance

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

* fix superstructure

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

Co-authored-by: shengjun.li <shengjun.li@zilliz.com>
Co-authored-by: Jin Hai <hai.jin@zilliz.com>
This commit is contained in:
Xiaohai Xu 2020-03-18 18:27:57 +08:00 committed by GitHub
parent 55ecfd5930
commit 1543f8dc37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 355 additions and 284 deletions

View File

@ -22,7 +22,7 @@ Please mark all change in change log and use the issue from GitHub
- \#1548 Move store/Directory to storage/Operation and add FSHandler
- \#1619 Improve compact performance
- \#1649 Fix Milvus crash on old CPU
- \#1653 IndexFlat performance improvement for NQ less than thread_number
- \#1653 IndexFlat (SSE) and IndexBinaryFlat performance improvement for small NQ
## Task

View File

@ -15,175 +15,196 @@
namespace faiss {
size_t batch_size = 65536;
static const size_t size_1M = 1 * 1024 * 1024;
static const size_t batch_size = 65536;
template <class T>
static
void binary_distence_knn_hc(
int bytes_per_code,
float_maxheap_array_t * ha,
const uint8_t * bs1,
const uint8_t * bs2,
size_t n2,
bool order = true,
bool init_heap = true,
ConcurrentBitsetPtr bitset = nullptr)
{
size_t k = ha->k;
template <class T>
static
void binary_distence_knn_hc(
int bytes_per_code,
float_maxheap_array_t * ha,
const uint8_t * bs1,
const uint8_t * bs2,
size_t n2,
bool order = true,
bool init_heap = true,
ConcurrentBitsetPtr bitset = nullptr)
{
size_t k = ha->k;
if ((bytes_per_code + k * (sizeof(float) + sizeof(int64_t))) * ha->nh < size_1M) {
int thread_max_num = omp_get_max_threads();
// init hash
size_t thread_hash_size = ha->nh * k;
size_t all_hash_size = thread_hash_size * thread_max_num;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < all_hash_size; i++) {
value[i] = 1.0 / 0.0;
labels[i] = -1;
}
T *hc = new T[ha->nh];
for (size_t i = 0; i < ha->nh; i++) {
hc[i].set(bs1 + i * bytes_per_code, bytes_per_code);
}
#pragma omp parallel for
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
int thread_no = omp_get_thread_num();
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
for (size_t i = 0; i < ha->nh; i++) {
tadis_t dis = hc[i].compute (bs2_);
float * val_ = value + thread_no * thread_hash_size + i * k;
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
if (dis < val_[0]) {
faiss::maxheap_pop<tadis_t> (k, val_, ids_);
faiss::maxheap_push<tadis_t> (k, val_, ids_, dis, j);
}
}
}
}
for (size_t t = 1; t < thread_max_num; t++) {
// merge hash
for (size_t i = 0; i < ha->nh; i++) {
float * __restrict value_x = value + i * k;
int64_t * __restrict labels_x = labels + i * k;
float *value_x_t = value_x + t * thread_hash_size;
int64_t *labels_x_t = labels_x + t * thread_hash_size;
for (size_t j = 0; j < k; j++) {
if (value_x_t[j] < value_x[0]) {
faiss::maxheap_pop<tadis_t> (k, value_x, labels_x);
faiss::maxheap_push<tadis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
}
}
}
}
// copy result
memcpy(ha->val, value, thread_hash_size * sizeof(float));
memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t));
delete[] hc;
delete[] value;
delete[] labels;
} else {
if (init_heap) ha->heapify ();
int thread_max_num = omp_get_max_threads();
if (ha->nh < 4) {
// omp for n2
int all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < ha->nh; i++) {
T hc (bs1 + i * bytes_per_code, bytes_per_code);
// init hash
for (int i = 0; i < all_hash_size; i++) {
value[i] = 1.0 / 0.0;
}
const size_t block_size = batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
#pragma omp parallel for
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
tadis_t dis = hc.compute (bs2_);
for (size_t i = 0; i < ha->nh; i++) {
T hc (bs1 + i * bytes_per_code, bytes_per_code);
int thread_no = omp_get_thread_num();
float * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (dis < val_[0]) {
faiss::maxheap_pop<tadis_t> (k, val_, ids_);
faiss::maxheap_push<tadis_t> (k, val_, ids_, dis, j);
}
}
}
// merge hash
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
tadis_t dis;
tadis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
for (int i = 0; i < all_hash_size; i++) {
if (value[i] < bh_val_[0]) {
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
}
}
}
delete[] value;
delete[] labels;
} else {
const size_t block_size = batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
#pragma omp parallel for
for (size_t i = 0; i < ha->nh; i++) {
T hc (bs1 + i * bytes_per_code, bytes_per_code);
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
tadis_t dis;
tadis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
size_t j;
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
if(!bitset || !bitset->test(j)){
dis = hc.compute (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
}
size_t j;
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
if(!bitset || !bitset->test(j)){
dis = hc.compute (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
}
}
}
}
}
if (order) ha->reorder ();
}
void binary_distence_knn_hc (
MetricType metric_type,
float_maxheap_array_t * ha,
const uint8_t * a,
const uint8_t * b,
size_t nb,
size_t ncodes,
int order,
ConcurrentBitsetPtr bitset)
{
switch (metric_type) {
case METRIC_Jaccard:
case METRIC_Tanimoto:
switch (ncodes) {
if (order) ha->reorder ();
}
void binary_distence_knn_hc (
MetricType metric_type,
float_maxheap_array_t * ha,
const uint8_t * a,
const uint8_t * b,
size_t nb,
size_t ncodes,
int order,
ConcurrentBitsetPtr bitset)
{
switch (metric_type) {
case METRIC_Jaccard:
case METRIC_Tanimoto:
switch (ncodes) {
#define binary_distence_knn_hc_jaccard(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::JaccardComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_jaccard(8);
binary_distence_knn_hc_jaccard(16);
binary_distence_knn_hc_jaccard(32);
binary_distence_knn_hc_jaccard(64);
binary_distence_knn_hc_jaccard(128);
binary_distence_knn_hc_jaccard(256);
binary_distence_knn_hc_jaccard(512);
case ncodes: \
binary_distence_knn_hc<faiss::JaccardComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_jaccard(8);
binary_distence_knn_hc_jaccard(16);
binary_distence_knn_hc_jaccard(32);
binary_distence_knn_hc_jaccard(64);
binary_distence_knn_hc_jaccard(128);
binary_distence_knn_hc_jaccard(256);
binary_distence_knn_hc_jaccard(512);
#undef binary_distence_knn_hc_jaccard
default:
binary_distence_knn_hc<faiss::JaccardComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
case METRIC_Substructure:
switch (ncodes) {
#define binary_distence_knn_hc_Substructure(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::SubstructureComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_Substructure(8);
binary_distence_knn_hc_Substructure(16);
binary_distence_knn_hc_Substructure(32);
binary_distence_knn_hc_Substructure(64);
binary_distence_knn_hc_Substructure(128);
binary_distence_knn_hc_Substructure(256);
binary_distence_knn_hc_Substructure(512);
#undef binary_distence_knn_hc_Substructure
default:
binary_distence_knn_hc<faiss::SubstructureComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
case METRIC_Superstructure:
switch (ncodes) {
#define binary_distence_knn_hc_Superstructure(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::SuperstructureComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_Superstructure(8);
binary_distence_knn_hc_Superstructure(16);
binary_distence_knn_hc_Superstructure(32);
binary_distence_knn_hc_Superstructure(64);
binary_distence_knn_hc_Superstructure(128);
binary_distence_knn_hc_Superstructure(256);
binary_distence_knn_hc_Superstructure(512);
#undef binary_distence_knn_hc_Superstructure
default:
binary_distence_knn_hc<faiss::SuperstructureComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
default:
binary_distence_knn_hc<faiss::JaccardComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
case METRIC_Substructure:
switch (ncodes) {
#define binary_distence_knn_hc_Substructure(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::SubstructureComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_Substructure(8);
binary_distence_knn_hc_Substructure(16);
binary_distence_knn_hc_Substructure(32);
binary_distence_knn_hc_Substructure(64);
binary_distence_knn_hc_Substructure(128);
binary_distence_knn_hc_Substructure(256);
binary_distence_knn_hc_Substructure(512);
#undef binary_distence_knn_hc_Substructure
default:
binary_distence_knn_hc<faiss::SubstructureComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
case METRIC_Superstructure:
switch (ncodes) {
#define binary_distence_knn_hc_Superstructure(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::SuperstructureComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_Superstructure(8);
binary_distence_knn_hc_Superstructure(16);
binary_distence_knn_hc_Superstructure(32);
binary_distence_knn_hc_Superstructure(64);
binary_distence_knn_hc_Superstructure(128);
binary_distence_knn_hc_Superstructure(256);
binary_distence_knn_hc_Superstructure(512);
#undef binary_distence_knn_hc_Superstructure
default:
binary_distence_knn_hc<faiss::SuperstructureComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
default:
break;
}
}
}

View File

@ -154,50 +154,68 @@ static void knn_inner_product_sse (const float * x,
size_t k = res->k;
size_t thread_max_num = omp_get_max_threads();
if (nx < 4) {
// omp for ny
size_t all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
size_t thread_hash_size = nx * k;
size_t all_hash_size = thread_hash_size * thread_max_num;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
value[i] = -1.0 / 0.0;
labels[i] = -1;
}
for (size_t i = 0; i < nx; i++) {
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
value[i] = -1.0 / 0.0;
}
const float *x_i = x + i * d;
#pragma omp parallel for
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)) {
const float *y_j = y + j * d;
float ip = fvec_inner_product (x_i, y_j, d);
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)) {
size_t thread_no = omp_get_thread_num();
const float *y_j = y + j * d;
for (size_t i = 0; i < nx; i++) {
const float *x_i = x + i * d;
float ip = fvec_inner_product (x_i, y_j, d);
size_t thread_no = omp_get_thread_num();
float * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (ip > val_[0]) {
minheap_pop (k, val_, ids_);
minheap_push (k, val_, ids_, ip, j);
}
float * val_ = value + thread_no * thread_hash_size + i * k;
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
if (ip > val_[0]) {
minheap_pop (k, val_, ids_);
minheap_push (k, val_, ids_, ip, j);
}
}
// merge hash
float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);
minheap_heapify (k, simi, idxi);
for (size_t i = 0; i < all_hash_size; i++) {
if (value[i] > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, value[i], labels[i]);
}
}
minheap_reorder (k, simi, idxi);
}
delete[] value;
delete[] labels;
}
} else {
for (size_t t = 1; t < thread_max_num; t++) {
// merge hash
for (size_t i = 0; i < nx; i++) {
float * __restrict value_x = value + i * k;
int64_t * __restrict labels_x = labels + i * k;
float *value_x_t = value_x + t * thread_hash_size;
int64_t *labels_x_t = labels_x + t * thread_hash_size;
for (size_t j = 0; j < k; j++) {
if (value_x_t[j] > value_x[0]) {
minheap_pop (k, value_x, labels_x);
minheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
}
}
}
}
for (size_t i = 0; i < nx; i++) {
float * value_x = value + i * k;
int64_t * labels_x = labels + i * k;
minheap_reorder (k, value_x, labels_x);
}
// copy result
memcpy(res->val, value, thread_hash_size * sizeof(float));
memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t));
delete[] value;
delete[] labels;
/*
else {
size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= thread_max_num;
@ -230,6 +248,7 @@ static void knn_inner_product_sse (const float * x,
InterruptCallback::check ();
}
}
*/
}
static void knn_L2sqr_sse (
@ -242,55 +261,68 @@ static void knn_L2sqr_sse (
size_t k = res->k;
size_t thread_max_num = omp_get_max_threads();
if (nx < 4) {
// omp for ny
size_t all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (size_t i = 0; i < nx; i++) {
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
value[i] = 1.0 / 0.0;
}
for (size_t i = 0; i < k; i++) {
labels[i] = -1;
}
const float *x_i = x + i * d;
size_t thread_hash_size = nx * k;
size_t all_hash_size = thread_hash_size * thread_max_num;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
value[i] = 1.0 / 0.0;
labels[i] = -1;
}
#pragma omp parallel for
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)) {
const float *y_j = y + j * d;
float disij = fvec_L2sqr (x_i, y_j, d);
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)) {
size_t thread_no = omp_get_thread_num();
const float *y_j = y + j * d;
for (size_t i = 0; i < nx; i++) {
const float *x_i = x + i * d;
float disij = fvec_L2sqr (x_i, y_j, d);
size_t thread_no = omp_get_thread_num();
float * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (disij < val_[0]) {
maxheap_pop (k, val_, ids_);
maxheap_push (k, val_, ids_, disij, j);
}
float * val_ = value + thread_no * thread_hash_size + i * k;
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
if (disij < val_[0]) {
maxheap_pop (k, val_, ids_);
maxheap_push (k, val_, ids_, disij, j);
}
}
// merge hash
float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);
memcpy(simi, value, k * sizeof(float));
memcpy(idxi, labels, k * sizeof(int64_t));
maxheap_heapify (k, simi, idxi, value, labels, k);
for (size_t i = k; i < all_hash_size; i++) {
if (value[i] < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, value[i], labels[i]);
}
}
maxheap_reorder (k, simi, idxi);
}
delete[] value;
delete[] labels;
}
} else {
for (size_t t = 1; t < thread_max_num; t++) {
// merge hash
for (size_t i = 0; i < nx; i++) {
float * __restrict value_x = value + i * k;
int64_t * __restrict labels_x = labels + i * k;
float *value_x_t = value_x + t * thread_hash_size;
int64_t *labels_x_t = labels_x + t * thread_hash_size;
for (size_t j = 0; j < k; j++) {
if (value_x_t[j] < value_x[0]) {
maxheap_pop (k, value_x, labels_x);
maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
}
}
}
}
for (size_t i = 0; i < nx; i++) {
float * value_x = value + i * k;
int64_t * labels_x = labels + i * k;
maxheap_reorder (k, value_x, labels_x);
}
// copy result
memcpy(res->val, value, thread_hash_size * sizeof(float));
memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t));
delete[] value;
delete[] labels;
/*
else {
size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= thread_max_num;
@ -322,6 +354,7 @@ static void knn_L2sqr_sse (
InterruptCallback::check ();
}
}
*/
}

View File

@ -40,7 +40,7 @@
#include <faiss/utils/utils.h>
static const size_t BLOCKSIZE_QUERY = 8192;
static const size_t size_1M = 1 * 1024 * 1024;
namespace faiss {
@ -278,50 +278,69 @@ void hammings_knn_hc (
ConcurrentBitsetPtr bitset = nullptr)
{
size_t k = ha->k;
if (init_heap) ha->heapify ();
int thread_max_num = omp_get_max_threads();
if (ha->nh < 4) {
// omp for n2
int all_hash_size = thread_max_num * k;
if ((bytes_per_code + k * (sizeof(hamdis_t) + sizeof(int64_t))) * ha->nh < size_1M) {
int thread_max_num = omp_get_max_threads();
// init hash
size_t thread_hash_size = ha->nh * k;
size_t all_hash_size = thread_hash_size * thread_max_num;
hamdis_t *value = new hamdis_t[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < all_hash_size; i++) {
value[i] = 0x7fffffff;
labels[i] = -1;
}
HammingComputer *hc = new HammingComputer[ha->nh];
for (size_t i = 0; i < ha->nh; i++) {
hc[i].set(bs1 + i * bytes_per_code, bytes_per_code);
}
for (int i = 0; i < ha->nh; i++) {
HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
// init hash
for (int i = 0; i < all_hash_size; i++) {
value[i] = 0x7fffffff;
}
#pragma omp parallel for
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
hamdis_t dis = hc.hamming (bs2_);
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
int thread_no = omp_get_thread_num();
int thread_no = omp_get_thread_num();
hamdis_t * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
for (size_t i = 0; i < ha->nh; i++) {
hamdis_t dis = hc[i].hamming (bs2_);
hamdis_t * val_ = value + thread_no * thread_hash_size + i * k;
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
if (dis < val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
}
}
}
}
for (size_t t = 1; t < thread_max_num; t++) {
// merge hash
hamdis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
for (int i = 0; i < all_hash_size; i++) {
if (value[i] < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
for (size_t i = 0; i < ha->nh; i++) {
hamdis_t * __restrict value_x = value + i * k;
int64_t * __restrict labels_x = labels + i * k;
hamdis_t *value_x_t = value_x + t * thread_hash_size;
int64_t *labels_x_t = labels_x + t * thread_hash_size;
for (size_t j = 0; j < k; j++) {
if (value_x_t[j] < value_x[0]) {
faiss::maxheap_pop<hamdis_t> (k, value_x, labels_x);
faiss::maxheap_push<hamdis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
}
}
}
}
// copy result
memcpy(ha->val, value, thread_hash_size * sizeof(hamdis_t));
memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t));
delete[] hc;
delete[] value;
delete[] labels;
} else {
if (init_heap) ha->heapify ();
const size_t block_size = hamming_batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
@ -426,48 +445,46 @@ void hammings_knn_hc_1 (
const size_t nwords = 1;
size_t k = ha->k;
if (init_heap) {
ha->heapify ();
}
int thread_max_num = omp_get_max_threads();
if (ha->nh < 4) {
if (ha->nh == 1) {
// omp for n2
int all_hash_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < ha->nh; i++) {
// init hash
for (int i = 0; i < all_hash_size; i++) {
value[i] = 0x7fffffff;
}
const uint64_t bs1_ = bs1 [i];
// init hash
for (int i = 0; i < all_hash_size; i++) {
value[i] = 0x7fffffff;
}
const uint64_t bs1_ = bs1[0];
#pragma omp parallel for
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
hamdis_t dis = popcount64 (bs1_ ^ bs2[j]);
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
hamdis_t dis = popcount64 (bs1_ ^ bs2[j]);
int thread_no = omp_get_thread_num();
hamdis_t * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (dis < val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
}
}
}
// merge hash
hamdis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
for (int i = 0; i < all_hash_size; i++) {
if (value[i] < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
int thread_no = omp_get_thread_num();
hamdis_t * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (dis < val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
}
}
}
// merge hash
hamdis_t * __restrict bh_val_ = ha->val;
int64_t * __restrict bh_ids_ = ha->ids;
for (int i = 0; i < all_hash_size; i++) {
if (value[i] < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
}
}
delete[] value;
delete[] labels;

View File

@ -8,13 +8,13 @@ namespace faiss {
SuperstructureComputer8 (const uint8_t *a8, int code_size) {
set (a8, code_size);
accu_den = (float)(popcount64 (a0));
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 8);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0];
accu_den = (float)(popcount64 (a0));
}
inline float compute (const uint8_t *b8) const {
@ -35,13 +35,13 @@ namespace faiss {
SuperstructureComputer16 (const uint8_t *a8, int code_size) {
set (a8, code_size);
accu_den = (float)(popcount64 (a0) + popcount64 (a1));
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 16);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1];
accu_den = (float)(popcount64 (a0) + popcount64 (a1));
}
inline float compute (const uint8_t *b8) const {