mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 20:09:57 +08:00
* optimize sse Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * optimizer BinaryDistance Signed-off-by: shengjun.li <shengjun.li@zilliz.com> * fix superstructure Signed-off-by: sahuang <xiaohai.xu@zilliz.com> Co-authored-by: shengjun.li <shengjun.li@zilliz.com> Co-authored-by: Jin Hai <hai.jin@zilliz.com>
This commit is contained in:
parent
55ecfd5930
commit
1543f8dc37
@ -22,7 +22,7 @@ Please mark all change in change log and use the issue from GitHub
|
||||
- \#1548 Move store/Directory to storage/Operation and add FSHandler
|
||||
- \#1619 Improve compact performance
|
||||
- \#1649 Fix Milvus crash on old CPU
|
||||
- \#1653 IndexFlat performance improvement for NQ less than thread_number
|
||||
- \#1653 IndexFlat (SSE) and IndexBinaryFlat performance improvement for small NQ
|
||||
|
||||
## Task
|
||||
|
||||
|
@ -15,175 +15,196 @@
|
||||
|
||||
namespace faiss {
|
||||
|
||||
size_t batch_size = 65536;
|
||||
static const size_t size_1M = 1 * 1024 * 1024;
|
||||
static const size_t batch_size = 65536;
|
||||
|
||||
template <class T>
|
||||
static
|
||||
void binary_distence_knn_hc(
|
||||
int bytes_per_code,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * bs1,
|
||||
const uint8_t * bs2,
|
||||
size_t n2,
|
||||
bool order = true,
|
||||
bool init_heap = true,
|
||||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
size_t k = ha->k;
|
||||
template <class T>
|
||||
static
|
||||
void binary_distence_knn_hc(
|
||||
int bytes_per_code,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * bs1,
|
||||
const uint8_t * bs2,
|
||||
size_t n2,
|
||||
bool order = true,
|
||||
bool init_heap = true,
|
||||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
size_t k = ha->k;
|
||||
|
||||
if ((bytes_per_code + k * (sizeof(float) + sizeof(int64_t))) * ha->nh < size_1M) {
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
// init hash
|
||||
size_t thread_hash_size = ha->nh * k;
|
||||
size_t all_hash_size = thread_hash_size * thread_max_num;
|
||||
float *value = new float[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 1.0 / 0.0;
|
||||
labels[i] = -1;
|
||||
}
|
||||
|
||||
T *hc = new T[ha->nh];
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hc[i].set(bs1 + i * bytes_per_code, bytes_per_code);
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
int thread_no = omp_get_thread_num();
|
||||
|
||||
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
tadis_t dis = hc[i].compute (bs2_);
|
||||
|
||||
float * val_ = value + thread_no * thread_hash_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
|
||||
if (dis < val_[0]) {
|
||||
faiss::maxheap_pop<tadis_t> (k, val_, ids_);
|
||||
faiss::maxheap_push<tadis_t> (k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge hash
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
float * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
float *value_x_t = value_x + t * thread_hash_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_hash_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] < value_x[0]) {
|
||||
faiss::maxheap_pop<tadis_t> (k, value_x, labels_x);
|
||||
faiss::maxheap_push<tadis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(ha->val, value, thread_hash_size * sizeof(float));
|
||||
memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t));
|
||||
|
||||
delete[] hc;
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
if (init_heap) ha->heapify ();
|
||||
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
if (ha->nh < 4) {
|
||||
// omp for n2
|
||||
int all_hash_size = thread_max_num * k;
|
||||
float *value = new float[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
|
||||
for (int i = 0; i < ha->nh; i++) {
|
||||
T hc (bs1 + i * bytes_per_code, bytes_per_code);
|
||||
// init hash
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 1.0 / 0.0;
|
||||
}
|
||||
const size_t block_size = batch_size;
|
||||
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, n2);
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
|
||||
tadis_t dis = hc.compute (bs2_);
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
T hc (bs1 + i * bytes_per_code, bytes_per_code);
|
||||
|
||||
int thread_no = omp_get_thread_num();
|
||||
float * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
if (dis < val_[0]) {
|
||||
faiss::maxheap_pop<tadis_t> (k, val_, ids_);
|
||||
faiss::maxheap_push<tadis_t> (k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
// merge hash
|
||||
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
|
||||
tadis_t dis;
|
||||
tadis_t * __restrict bh_val_ = ha->val + i * k;
|
||||
int64_t * __restrict bh_ids_ = ha->ids + i * k;
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
if (value[i] < bh_val_[0]) {
|
||||
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
const size_t block_size = batch_size;
|
||||
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, n2);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
T hc (bs1 + i * bytes_per_code, bytes_per_code);
|
||||
|
||||
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
|
||||
tadis_t dis;
|
||||
tadis_t * __restrict bh_val_ = ha->val + i * k;
|
||||
int64_t * __restrict bh_ids_ = ha->ids + i * k;
|
||||
size_t j;
|
||||
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
dis = hc.compute (bs2_);
|
||||
if (dis < bh_val_[0]) {
|
||||
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
}
|
||||
size_t j;
|
||||
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
dis = hc.compute (bs2_);
|
||||
if (dis < bh_val_[0]) {
|
||||
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if (order) ha->reorder ();
|
||||
}
|
||||
|
||||
void binary_distence_knn_hc (
|
||||
MetricType metric_type,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int order,
|
||||
ConcurrentBitsetPtr bitset)
|
||||
{
|
||||
switch (metric_type) {
|
||||
case METRIC_Jaccard:
|
||||
case METRIC_Tanimoto:
|
||||
switch (ncodes) {
|
||||
if (order) ha->reorder ();
|
||||
}
|
||||
|
||||
void binary_distence_knn_hc (
|
||||
MetricType metric_type,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int order,
|
||||
ConcurrentBitsetPtr bitset)
|
||||
{
|
||||
switch (metric_type) {
|
||||
case METRIC_Jaccard:
|
||||
case METRIC_Tanimoto:
|
||||
switch (ncodes) {
|
||||
#define binary_distence_knn_hc_jaccard(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distence_knn_hc<faiss::JaccardComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, order, true, bitset); \
|
||||
break;
|
||||
binary_distence_knn_hc_jaccard(8);
|
||||
binary_distence_knn_hc_jaccard(16);
|
||||
binary_distence_knn_hc_jaccard(32);
|
||||
binary_distence_knn_hc_jaccard(64);
|
||||
binary_distence_knn_hc_jaccard(128);
|
||||
binary_distence_knn_hc_jaccard(256);
|
||||
binary_distence_knn_hc_jaccard(512);
|
||||
case ncodes: \
|
||||
binary_distence_knn_hc<faiss::JaccardComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, order, true, bitset); \
|
||||
break;
|
||||
binary_distence_knn_hc_jaccard(8);
|
||||
binary_distence_knn_hc_jaccard(16);
|
||||
binary_distence_knn_hc_jaccard(32);
|
||||
binary_distence_knn_hc_jaccard(64);
|
||||
binary_distence_knn_hc_jaccard(128);
|
||||
binary_distence_knn_hc_jaccard(256);
|
||||
binary_distence_knn_hc_jaccard(512);
|
||||
#undef binary_distence_knn_hc_jaccard
|
||||
default:
|
||||
binary_distence_knn_hc<faiss::JaccardComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case METRIC_Substructure:
|
||||
switch (ncodes) {
|
||||
#define binary_distence_knn_hc_Substructure(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distence_knn_hc<faiss::SubstructureComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, order, true, bitset); \
|
||||
break;
|
||||
binary_distence_knn_hc_Substructure(8);
|
||||
binary_distence_knn_hc_Substructure(16);
|
||||
binary_distence_knn_hc_Substructure(32);
|
||||
binary_distence_knn_hc_Substructure(64);
|
||||
binary_distence_knn_hc_Substructure(128);
|
||||
binary_distence_knn_hc_Substructure(256);
|
||||
binary_distence_knn_hc_Substructure(512);
|
||||
#undef binary_distence_knn_hc_Substructure
|
||||
default:
|
||||
binary_distence_knn_hc<faiss::SubstructureComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case METRIC_Superstructure:
|
||||
switch (ncodes) {
|
||||
#define binary_distence_knn_hc_Superstructure(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distence_knn_hc<faiss::SuperstructureComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, order, true, bitset); \
|
||||
break;
|
||||
binary_distence_knn_hc_Superstructure(8);
|
||||
binary_distence_knn_hc_Superstructure(16);
|
||||
binary_distence_knn_hc_Superstructure(32);
|
||||
binary_distence_knn_hc_Superstructure(64);
|
||||
binary_distence_knn_hc_Superstructure(128);
|
||||
binary_distence_knn_hc_Superstructure(256);
|
||||
binary_distence_knn_hc_Superstructure(512);
|
||||
#undef binary_distence_knn_hc_Superstructure
|
||||
default:
|
||||
binary_distence_knn_hc<faiss::SuperstructureComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
binary_distence_knn_hc<faiss::JaccardComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case METRIC_Substructure:
|
||||
switch (ncodes) {
|
||||
#define binary_distence_knn_hc_Substructure(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distence_knn_hc<faiss::SubstructureComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, order, true, bitset); \
|
||||
break;
|
||||
binary_distence_knn_hc_Substructure(8);
|
||||
binary_distence_knn_hc_Substructure(16);
|
||||
binary_distence_knn_hc_Substructure(32);
|
||||
binary_distence_knn_hc_Substructure(64);
|
||||
binary_distence_knn_hc_Substructure(128);
|
||||
binary_distence_knn_hc_Substructure(256);
|
||||
binary_distence_knn_hc_Substructure(512);
|
||||
#undef binary_distence_knn_hc_Substructure
|
||||
default:
|
||||
binary_distence_knn_hc<faiss::SubstructureComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case METRIC_Superstructure:
|
||||
switch (ncodes) {
|
||||
#define binary_distence_knn_hc_Superstructure(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distence_knn_hc<faiss::SuperstructureComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, order, true, bitset); \
|
||||
break;
|
||||
binary_distence_knn_hc_Superstructure(8);
|
||||
binary_distence_knn_hc_Superstructure(16);
|
||||
binary_distence_knn_hc_Superstructure(32);
|
||||
binary_distence_knn_hc_Superstructure(64);
|
||||
binary_distence_knn_hc_Superstructure(128);
|
||||
binary_distence_knn_hc_Superstructure(256);
|
||||
binary_distence_knn_hc_Superstructure(512);
|
||||
#undef binary_distence_knn_hc_Superstructure
|
||||
default:
|
||||
binary_distence_knn_hc<faiss::SuperstructureComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
191
core/src/index/thirdparty/faiss/utils/distances.cpp
vendored
191
core/src/index/thirdparty/faiss/utils/distances.cpp
vendored
@ -154,50 +154,68 @@ static void knn_inner_product_sse (const float * x,
|
||||
size_t k = res->k;
|
||||
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
if (nx < 4) {
|
||||
// omp for ny
|
||||
size_t all_hash_size = thread_max_num * k;
|
||||
float *value = new float[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
|
||||
size_t thread_hash_size = nx * k;
|
||||
size_t all_hash_size = thread_hash_size * thread_max_num;
|
||||
float *value = new float[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
|
||||
// init hash
|
||||
for (size_t i = 0; i < all_hash_size; i++) {
|
||||
value[i] = -1.0 / 0.0;
|
||||
labels[i] = -1;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
// init hash
|
||||
for (size_t i = 0; i < all_hash_size; i++) {
|
||||
value[i] = -1.0 / 0.0;
|
||||
}
|
||||
const float *x_i = x + i * d;
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
const float *y_j = y + j * d;
|
||||
float ip = fvec_inner_product (x_i, y_j, d);
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
const float *y_j = y + j * d;
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
float ip = fvec_inner_product (x_i, y_j, d);
|
||||
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
float * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
if (ip > val_[0]) {
|
||||
minheap_pop (k, val_, ids_);
|
||||
minheap_push (k, val_, ids_, ip, j);
|
||||
}
|
||||
float * val_ = value + thread_no * thread_hash_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
|
||||
if (ip > val_[0]) {
|
||||
minheap_pop (k, val_, ids_);
|
||||
minheap_push (k, val_, ids_, ip, j);
|
||||
}
|
||||
}
|
||||
|
||||
// merge hash
|
||||
float * __restrict simi = res->get_val(i);
|
||||
int64_t * __restrict idxi = res->get_ids (i);
|
||||
minheap_heapify (k, simi, idxi);
|
||||
for (size_t i = 0; i < all_hash_size; i++) {
|
||||
if (value[i] > simi[0]) {
|
||||
minheap_pop (k, simi, idxi);
|
||||
minheap_push (k, simi, idxi, value[i], labels[i]);
|
||||
}
|
||||
}
|
||||
minheap_reorder (k, simi, idxi);
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
}
|
||||
|
||||
} else {
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge hash
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
float * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
float *value_x_t = value_x + t * thread_hash_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_hash_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] > value_x[0]) {
|
||||
minheap_pop (k, value_x, labels_x);
|
||||
minheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
float * value_x = value + i * k;
|
||||
int64_t * labels_x = labels + i * k;
|
||||
minheap_reorder (k, value_x, labels_x);
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(res->val, value, thread_hash_size * sizeof(float));
|
||||
memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t));
|
||||
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
/*
|
||||
else {
|
||||
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
||||
check_period *= thread_max_num;
|
||||
|
||||
@ -230,6 +248,7 @@ static void knn_inner_product_sse (const float * x,
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
static void knn_L2sqr_sse (
|
||||
@ -242,55 +261,68 @@ static void knn_L2sqr_sse (
|
||||
size_t k = res->k;
|
||||
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
if (nx < 4) {
|
||||
// omp for ny
|
||||
size_t all_hash_size = thread_max_num * k;
|
||||
float *value = new float[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
// init hash
|
||||
for (size_t i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 1.0 / 0.0;
|
||||
}
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
labels[i] = -1;
|
||||
}
|
||||
const float *x_i = x + i * d;
|
||||
size_t thread_hash_size = nx * k;
|
||||
size_t all_hash_size = thread_hash_size * thread_max_num;
|
||||
float *value = new float[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
|
||||
// init hash
|
||||
for (size_t i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 1.0 / 0.0;
|
||||
labels[i] = -1;
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
const float *y_j = y + j * d;
|
||||
float disij = fvec_L2sqr (x_i, y_j, d);
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
const float *y_j = y + j * d;
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
float disij = fvec_L2sqr (x_i, y_j, d);
|
||||
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
float * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
if (disij < val_[0]) {
|
||||
maxheap_pop (k, val_, ids_);
|
||||
maxheap_push (k, val_, ids_, disij, j);
|
||||
}
|
||||
float * val_ = value + thread_no * thread_hash_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
|
||||
if (disij < val_[0]) {
|
||||
maxheap_pop (k, val_, ids_);
|
||||
maxheap_push (k, val_, ids_, disij, j);
|
||||
}
|
||||
}
|
||||
|
||||
// merge hash
|
||||
float * __restrict simi = res->get_val(i);
|
||||
int64_t * __restrict idxi = res->get_ids (i);
|
||||
memcpy(simi, value, k * sizeof(float));
|
||||
memcpy(idxi, labels, k * sizeof(int64_t));
|
||||
maxheap_heapify (k, simi, idxi, value, labels, k);
|
||||
for (size_t i = k; i < all_hash_size; i++) {
|
||||
if (value[i] < simi[0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
maxheap_push (k, simi, idxi, value[i], labels[i]);
|
||||
}
|
||||
}
|
||||
maxheap_reorder (k, simi, idxi);
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
}
|
||||
|
||||
} else {
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge hash
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
float * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
float *value_x_t = value_x + t * thread_hash_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_hash_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] < value_x[0]) {
|
||||
maxheap_pop (k, value_x, labels_x);
|
||||
maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
float * value_x = value + i * k;
|
||||
int64_t * labels_x = labels + i * k;
|
||||
maxheap_reorder (k, value_x, labels_x);
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(res->val, value, thread_hash_size * sizeof(float));
|
||||
memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t));
|
||||
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
/*
|
||||
else {
|
||||
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
||||
check_period *= thread_max_num;
|
||||
|
||||
@ -322,6 +354,7 @@ static void knn_L2sqr_sse (
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
|
121
core/src/index/thirdparty/faiss/utils/hamming.cpp
vendored
121
core/src/index/thirdparty/faiss/utils/hamming.cpp
vendored
@ -40,7 +40,7 @@
|
||||
#include <faiss/utils/utils.h>
|
||||
|
||||
static const size_t BLOCKSIZE_QUERY = 8192;
|
||||
|
||||
static const size_t size_1M = 1 * 1024 * 1024;
|
||||
|
||||
namespace faiss {
|
||||
|
||||
@ -278,50 +278,69 @@ void hammings_knn_hc (
|
||||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
size_t k = ha->k;
|
||||
if (init_heap) ha->heapify ();
|
||||
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
if (ha->nh < 4) {
|
||||
// omp for n2
|
||||
int all_hash_size = thread_max_num * k;
|
||||
if ((bytes_per_code + k * (sizeof(hamdis_t) + sizeof(int64_t))) * ha->nh < size_1M) {
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
// init hash
|
||||
size_t thread_hash_size = ha->nh * k;
|
||||
size_t all_hash_size = thread_hash_size * thread_max_num;
|
||||
hamdis_t *value = new hamdis_t[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 0x7fffffff;
|
||||
labels[i] = -1;
|
||||
}
|
||||
|
||||
HammingComputer *hc = new HammingComputer[ha->nh];
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hc[i].set(bs1 + i * bytes_per_code, bytes_per_code);
|
||||
}
|
||||
|
||||
for (int i = 0; i < ha->nh; i++) {
|
||||
HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
|
||||
// init hash
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 0x7fffffff;
|
||||
}
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
|
||||
hamdis_t dis = hc.hamming (bs2_);
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
int thread_no = omp_get_thread_num();
|
||||
|
||||
int thread_no = omp_get_thread_num();
|
||||
hamdis_t * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hamdis_t dis = hc[i].hamming (bs2_);
|
||||
|
||||
hamdis_t * val_ = value + thread_no * thread_hash_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
|
||||
if (dis < val_[0]) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
|
||||
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge hash
|
||||
hamdis_t * __restrict bh_val_ = ha->val + i * k;
|
||||
int64_t * __restrict bh_ids_ = ha->ids + i * k;
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
if (value[i] < bh_val_[0]) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hamdis_t * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
hamdis_t *value_x_t = value_x + t * thread_hash_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_hash_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] < value_x[0]) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, value_x, labels_x);
|
||||
faiss::maxheap_push<hamdis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(ha->val, value, thread_hash_size * sizeof(hamdis_t));
|
||||
memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t));
|
||||
|
||||
delete[] hc;
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
if (init_heap) ha->heapify ();
|
||||
const size_t block_size = hamming_batch_size;
|
||||
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, n2);
|
||||
@ -426,48 +445,46 @@ void hammings_knn_hc_1 (
|
||||
const size_t nwords = 1;
|
||||
size_t k = ha->k;
|
||||
|
||||
|
||||
if (init_heap) {
|
||||
ha->heapify ();
|
||||
}
|
||||
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
if (ha->nh < 4) {
|
||||
if (ha->nh == 1) {
|
||||
// omp for n2
|
||||
int all_hash_size = thread_max_num * k;
|
||||
hamdis_t *value = new hamdis_t[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
|
||||
for (int i = 0; i < ha->nh; i++) {
|
||||
// init hash
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 0x7fffffff;
|
||||
}
|
||||
const uint64_t bs1_ = bs1 [i];
|
||||
// init hash
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 0x7fffffff;
|
||||
}
|
||||
const uint64_t bs1_ = bs1[0];
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
hamdis_t dis = popcount64 (bs1_ ^ bs2[j]);
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
hamdis_t dis = popcount64 (bs1_ ^ bs2[j]);
|
||||
|
||||
int thread_no = omp_get_thread_num();
|
||||
hamdis_t * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
if (dis < val_[0]) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
|
||||
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
// merge hash
|
||||
hamdis_t * __restrict bh_val_ = ha->val + i * k;
|
||||
int64_t * __restrict bh_ids_ = ha->ids + i * k;
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
if (value[i] < bh_val_[0]) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
|
||||
int thread_no = omp_get_thread_num();
|
||||
hamdis_t * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
if (dis < val_[0]) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
|
||||
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
// merge hash
|
||||
hamdis_t * __restrict bh_val_ = ha->val;
|
||||
int64_t * __restrict bh_ids_ = ha->ids;
|
||||
for (int i = 0; i < all_hash_size; i++) {
|
||||
if (value[i] < bh_val_[0]) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
|
||||
}
|
||||
}
|
||||
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
|
@ -8,13 +8,13 @@ namespace faiss {
|
||||
|
||||
SuperstructureComputer8 (const uint8_t *a8, int code_size) {
|
||||
set (a8, code_size);
|
||||
accu_den = (float)(popcount64 (a0));
|
||||
}
|
||||
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
assert (code_size == 8);
|
||||
const uint64_t *a = (uint64_t *)a8;
|
||||
a0 = a[0];
|
||||
accu_den = (float)(popcount64 (a0));
|
||||
}
|
||||
|
||||
inline float compute (const uint8_t *b8) const {
|
||||
@ -35,13 +35,13 @@ namespace faiss {
|
||||
|
||||
SuperstructureComputer16 (const uint8_t *a8, int code_size) {
|
||||
set (a8, code_size);
|
||||
accu_den = (float)(popcount64 (a0) + popcount64 (a1));
|
||||
}
|
||||
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
assert (code_size == 16);
|
||||
const uint64_t *a = (uint64_t *)a8;
|
||||
a0 = a[0]; a1 = a[1];
|
||||
accu_den = (float)(popcount64 (a0) + popcount64 (a1));
|
||||
}
|
||||
|
||||
inline float compute (const uint8_t *b8) const {
|
||||
|
Loading…
Reference in New Issue
Block a user