From 5fc875a020a8e76e8403316ff6b8c96cba894a8f Mon Sep 17 00:00:00 2001 From: "shengjun.li" <49774184+shengjun1985@users.noreply.github.com> Date: Tue, 24 Mar 2020 10:10:56 +0800 Subject: [PATCH] #1603 modify substructure/superstructure to perfect match (#1718) * modify substructure/superstructure to perfect match Signed-off-by: shengjun.li * Update cases Signed-off-by: zw * Fix case bug Signed-off-by: zw * set invalid distance infinite Signed-off-by: shengjun.li * Add distance cases Signed-off-by: zw * Fix test cases Signed-off-by: zw * Re-trigger ci Signed-off-by: zw * fix wrong code Signed-off-by: shengjun.li * Fix test case Signed-off-by: zw * Fix case Signed-off-by: zw * Fix cases Signed-off-by: zw Co-authored-by: zw --- .../thirdparty/faiss/IndexBinaryFlat.cpp | 33 ++- .../thirdparty/faiss/utils/BinaryDistance.cpp | 184 +++++++++++-- .../thirdparty/faiss/utils/BinaryDistance.h | 21 ++ .../thirdparty/faiss/utils/substructure-inl.h | 243 ++++++---------- .../faiss/utils/superstructure-inl.h | 260 ++++++------------ .../milvus_python_test/test_search_vectors.py | 63 ++++- tests/milvus_python_test/utils.py | 27 ++ 7 files changed, 444 insertions(+), 387 deletions(-) diff --git a/core/src/index/thirdparty/faiss/IndexBinaryFlat.cpp b/core/src/index/thirdparty/faiss/IndexBinaryFlat.cpp index 50aad366d7..3b1a902b84 100644 --- a/core/src/index/thirdparty/faiss/IndexBinaryFlat.cpp +++ b/core/src/index/thirdparty/faiss/IndexBinaryFlat.cpp @@ -41,8 +41,7 @@ void IndexBinaryFlat::reset() { void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const { const idx_t block_size = query_batch_size; - if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto || - metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) { + if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) { float *D = reinterpret_cast(distances); for (idx_t s = 0; s < n; s += block_size) { idx_t nn = block_size; @@ -50,19 +49,14 @@ void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k, nn = n - s; } - if (use_heap) { - // We see the distances and labels as heaps. + // We see the distances and labels as heaps. + float_maxheap_array_t res = { + size_t(nn), size_t(k), labels + s * k, D + s * k + }; - float_maxheap_array_t res = { - size_t(nn), size_t(k), labels + s * k, D + s * k - }; + binary_distence_knn_hc(metric_type, &res, x + s * code_size, xb.data(), ntotal, code_size, + /* ordered = */ true, bitset); - binary_distence_knn_hc(metric_type, &res, x + s * code_size, xb.data(), ntotal, code_size, - /* ordered = */ true, bitset); - - } else { - FAISS_THROW_MSG("tanimoto_knn_mc not implemented"); - } } if (metric_type == METRIC_Tanimoto) { for (int i = 0; i < k * n; i++) { @@ -70,6 +64,19 @@ void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k, } } + } else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) { + float *D = reinterpret_cast(distances); + for (idx_t s = 0; s < n; s += block_size) { + idx_t nn = block_size; + if (s + block_size > n) { + nn = n - s; + } + + // only match ids will be chosed, not to use heap + binary_distence_knn_mc(metric_type, x + s * code_size, xb.data(), nn, ntotal, k, code_size, + D + s * k, labels + s * k, bitset); + } + } else { for (idx_t s = 0; s < n; s += block_size) { idx_t nn = block_size; diff --git a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp index ff20a21277..7827830452 100644 --- a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp +++ b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp @@ -158,46 +158,176 @@ void binary_distence_knn_hc ( } break; + default: + break; + } +} + +template +static +void binary_distence_knn_mc( + int bytes_per_code, + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + size_t k, + float *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset) +{ + if ((bytes_per_code + sizeof(size_t) + k * sizeof(int64_t)) * n1 < size_1M) { + int thread_max_num = omp_get_max_threads(); + + size_t group_num = n1 * thread_max_num; + size_t *match_num = new size_t[group_num]; + int64_t *match_data = new int64_t[group_num * k]; + for (size_t i = 0; i < group_num; i++) { + match_num[i] = 0; + } + + T *hc = new T[n1]; + for (size_t i = 0; i < n1; i++) { + hc[i].set(bs1 + i * bytes_per_code, bytes_per_code); + } + +#pragma omp parallel for + for (size_t j = 0; j < n2; j++) { + if(!bitset || !bitset->test(j)) { + int thread_no = omp_get_thread_num(); + + const uint8_t * bs2_ = bs2 + j * bytes_per_code; + for (size_t i = 0; i < n1; i++) { + if (hc[i].compute(bs2_)) { + size_t match_index = thread_no * n1 + i; + size_t &index = match_num[match_index]; + if (index < k) { + match_data[match_index * k + index] = j; + index++; + } + } + } + } + } + for (size_t i = 0, ni = 0; i < n1; i++) { + size_t n_i = 0; + float *distances_i = distances + i * k; + int64_t *labels_i = labels + i * k; + + for (size_t t = 0; t < thread_max_num && n_i < k; t++) { + size_t match_index = t * n1 + i; + size_t copy_num = std::min(k - n_i, match_num[match_index]); + memcpy(labels_i + n_i, match_data + match_index * k, copy_num * sizeof(int64_t)); + memset(distances + n_i, 0, copy_num * sizeof(int32_t)); + n_i += copy_num; + } + for (; n_i < k; n_i++) { + distances_i[n_i] = 1.0 / 0.0; + labels_i[n_i] = -1; + } + } + + delete[] hc; + delete[] match_num; + delete[] match_data; + + } else { + size_t *num = new size_t[n1]; + for (size_t i = 0; i < n1; i++) { + num[i] = 0; + } + + const size_t block_size = batch_size; + for (size_t j0 = 0; j0 < n2; j0 += block_size) { + const size_t j1 = std::min(j0 + block_size, n2); +#pragma omp parallel for + for (size_t i = 0; i < n1; i++) { + size_t num_i = num[i]; + if (num_i == k) continue; + float * dis = distances + i * k; + int64_t * lab = labels + i * k; + + T hc (bs1 + i * bytes_per_code, bytes_per_code); + const uint8_t * bs2_ = bs2 + j0 * bytes_per_code; + for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) { + if(!bitset || !bitset->test(j)){ + if (hc.compute (bs2_)) { + dis[num_i] = 0; + lab[num_i] = j; + if (++num_i == k) break; + } + } + } + num[i] = num_i; + } + } + + for (size_t i = 0; i < n1; i++) { + float * dis = distances + i * k; + int64_t * lab = labels + i * k; + for (size_t num_i = num[i]; num_i < k; num_i++) { + dis[num_i] = 1.0 / 0.0; + lab[num_i] = -1; + } + } + + delete[] num; + } +} + +void binary_distence_knn_mc ( + MetricType metric_type, + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + size_t k, + size_t ncodes, + float *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset) { + + switch (metric_type) { case METRIC_Substructure: switch (ncodes) { -#define binary_distence_knn_hc_Substructure(ncodes) \ +#define binary_distence_knn_mc_Substructure(ncodes) \ case ncodes: \ - binary_distence_knn_hc \ - (ncodes, ha, a, b, nb, order, true, bitset); \ + binary_distence_knn_mc \ + (ncodes, a, b, na, nb, k, distances, labels, bitset); \ break; - binary_distence_knn_hc_Substructure(8); - binary_distence_knn_hc_Substructure(16); - binary_distence_knn_hc_Substructure(32); - binary_distence_knn_hc_Substructure(64); - binary_distence_knn_hc_Substructure(128); - binary_distence_knn_hc_Substructure(256); - binary_distence_knn_hc_Substructure(512); -#undef binary_distence_knn_hc_Substructure + binary_distence_knn_mc_Substructure(8); + binary_distence_knn_mc_Substructure(16); + binary_distence_knn_mc_Substructure(32); + binary_distence_knn_mc_Substructure(64); + binary_distence_knn_mc_Substructure(128); + binary_distence_knn_mc_Substructure(256); + binary_distence_knn_mc_Substructure(512); +#undef binary_distence_knn_mc_Substructure default: - binary_distence_knn_hc - (ncodes, ha, a, b, nb, order, true, bitset); + binary_distence_knn_mc + (ncodes, a, b, na, nb, k, distances, labels, bitset); break; } break; case METRIC_Superstructure: switch (ncodes) { -#define binary_distence_knn_hc_Superstructure(ncodes) \ +#define binary_distence_knn_mc_Superstructure(ncodes) \ case ncodes: \ - binary_distence_knn_hc \ - (ncodes, ha, a, b, nb, order, true, bitset); \ + binary_distence_knn_mc \ + (ncodes, a, b, na, nb, k, distances, labels, bitset); \ break; - binary_distence_knn_hc_Superstructure(8); - binary_distence_knn_hc_Superstructure(16); - binary_distence_knn_hc_Superstructure(32); - binary_distence_knn_hc_Superstructure(64); - binary_distence_knn_hc_Superstructure(128); - binary_distence_knn_hc_Superstructure(256); - binary_distence_knn_hc_Superstructure(512); -#undef binary_distence_knn_hc_Superstructure + binary_distence_knn_mc_Superstructure(8); + binary_distence_knn_mc_Superstructure(16); + binary_distence_knn_mc_Superstructure(32); + binary_distence_knn_mc_Superstructure(64); + binary_distence_knn_mc_Superstructure(128); + binary_distence_knn_mc_Superstructure(256); + binary_distence_knn_mc_Superstructure(512); +#undef binary_distence_knn_mc_Superstructure default: - binary_distence_knn_hc - (ncodes, ha, a, b, nb, order, true, bitset); + binary_distence_knn_mc + (ncodes, a, b, na, nb, k, distances, labels, bitset); break; } break; @@ -207,4 +337,4 @@ void binary_distence_knn_hc ( } } -} +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/BinaryDistance.h b/core/src/index/thirdparty/faiss/utils/BinaryDistance.h index 181accaa3a..fccdfd3674 100644 --- a/core/src/index/thirdparty/faiss/utils/BinaryDistance.h +++ b/core/src/index/thirdparty/faiss/utils/BinaryDistance.h @@ -32,6 +32,27 @@ namespace faiss { int ordered, ConcurrentBitsetPtr bitset = nullptr); + /** Return the k matched distances for a set of binary query vectors, + * using a max heap. + * @param a queries, size ha->nh * ncodes + * @param b database, size nb * ncodes + * @param na number of queries vectors + * @param nb number of database vectors + * @param k number of the matched vectors to return + * @param ncodes size of the binary codes (bytes) + */ + void binary_distence_knn_mc ( + MetricType metric_type, + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + size_t k, + size_t ncodes, + float *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset); + } // namespace faiss #include diff --git a/core/src/index/thirdparty/faiss/utils/substructure-inl.h b/core/src/index/thirdparty/faiss/utils/substructure-inl.h index 2b55ce6e8a..aa57a5a646 100644 --- a/core/src/index/thirdparty/faiss/utils/substructure-inl.h +++ b/core/src/index/thirdparty/faiss/utils/substructure-inl.h @@ -15,13 +15,9 @@ namespace faiss { a0 = a[0]; } - inline float compute (const uint8_t *b8) const { + inline bool compute (const uint8_t *b8) const { const uint64_t *b = (uint64_t *)b8; - int accu_num = popcount64 (b[0] & a0); - int accu_den = popcount64 (b[0]); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / (float)(accu_den); + return (a0 & b[0]) == a0; } }; @@ -41,13 +37,9 @@ namespace faiss { a0 = a[0]; a1 = a[1]; } - inline float compute (const uint8_t *b8) const { + inline bool compute (const uint8_t *b8) const { const uint64_t *b = (uint64_t *)b8; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1); - int accu_den = popcount64 (b[0]) + popcount64 (b[1]); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / (float)(accu_den); + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1; } }; @@ -67,15 +59,10 @@ namespace faiss { a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; } - inline float compute (const uint8_t *b8) const { + inline bool compute (const uint8_t *b8) const { const uint64_t *b = (uint64_t *)b8; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3); - int accu_den = popcount64 (b[0]) + popcount64 (b[1]) + - popcount64 (b[2]) + popcount64 (b[3]); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / (float)(accu_den); + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3; } }; @@ -96,19 +83,12 @@ namespace faiss { a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; } - inline float compute (const uint8_t *b8) const { + inline bool compute (const uint8_t *b8) const { const uint64_t *b = (uint64_t *)b8; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + - popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + - popcount64 (b[6] & a6) + popcount64 (b[7] & a7); - int accu_den = popcount64 (b[0]) + popcount64 (b[1]) + - popcount64 (b[2]) + popcount64 (b[3]) + - popcount64 (b[4]) + popcount64 (b[5]) + - popcount64 (b[6]) + popcount64 (b[7]); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / (float)(accu_den); + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3 && + (a4 & b[4]) == a4 && (a5 & b[5]) == a5 && + (a6 & b[6]) == a6 && (a7 & b[7]) == a7; } }; @@ -132,27 +112,16 @@ namespace faiss { a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; } - inline float compute (const uint8_t *b16) const { + inline bool compute (const uint8_t *b16) const { const uint64_t *b = (uint64_t *)b16; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + - popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + - popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + - popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + - popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + - popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + - popcount64 (b[14] & a14) + popcount64 (b[15] & a15); - int accu_den = popcount64 (b[0]) + popcount64 (b[1]) + - popcount64 (b[2]) + popcount64 (b[3]) + - popcount64 (b[4]) + popcount64 (b[5]) + - popcount64 (b[6]) + popcount64 (b[7]) + - popcount64 (b[8]) + popcount64 (b[9]) + - popcount64 (b[10]) + popcount64 (b[11]) + - popcount64 (b[12]) + popcount64 (b[13]) + - popcount64 (b[14]) + popcount64 (b[15]); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / (float)(accu_den); + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3 && + (a4 & b[4]) == a4 && (a5 & b[5]) == a5 && + (a6 & b[6]) == a6 && (a7 & b[7]) == a7 && + (a8 & b[8]) == a8 && (a9 & b[9]) == a9 && + (a10 & b[10]) == a10 && (a11 & b[11]) == a11 && + (a12 & b[12]) == a12 && (a13 & b[13]) == a13 && + (a14 & b[14]) == a14 && (a15 & b[15]) == a15; } }; @@ -182,43 +151,24 @@ namespace faiss { a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31]; } - inline float compute (const uint8_t *b16) const { + inline bool compute (const uint8_t *b16) const { const uint64_t *b = (uint64_t *)b16; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + - popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + - popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + - popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + - popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + - popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + - popcount64 (b[14] & a14) + popcount64 (b[15] & a15) + - popcount64 (b[16] & a16) + popcount64 (b[17] & a17) + - popcount64 (b[18] & a18) + popcount64 (b[19] & a19) + - popcount64 (b[20] & a20) + popcount64 (b[21] & a21) + - popcount64 (b[22] & a22) + popcount64 (b[23] & a23) + - popcount64 (b[24] & a24) + popcount64 (b[25] & a25) + - popcount64 (b[26] & a26) + popcount64 (b[27] & a27) + - popcount64 (b[28] & a28) + popcount64 (b[29] & a29) + - popcount64 (b[30] & a30) + popcount64 (b[31] & a31); - int accu_den = popcount64 (b[0]) + popcount64 (b[1]) + - popcount64 (b[2]) + popcount64 (b[3]) + - popcount64 (b[4]) + popcount64 (b[5]) + - popcount64 (b[6]) + popcount64 (b[7]) + - popcount64 (b[8]) + popcount64 (b[9]) + - popcount64 (b[10]) + popcount64 (b[11]) + - popcount64 (b[12]) + popcount64 (b[13]) + - popcount64 (b[14]) + popcount64 (b[15]) + - popcount64 (b[16]) + popcount64 (b[17]) + - popcount64 (b[18]) + popcount64 (b[19]) + - popcount64 (b[20]) + popcount64 (b[21]) + - popcount64 (b[22]) + popcount64 (b[23]) + - popcount64 (b[24]) + popcount64 (b[25]) + - popcount64 (b[26]) + popcount64 (b[27]) + - popcount64 (b[28]) + popcount64 (b[29]) + - popcount64 (b[30]) + popcount64 (b[31]); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / (float)(accu_den); + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3 && + (a4 & b[4]) == a4 && (a5 & b[5]) == a5 && + (a6 & b[6]) == a6 && (a7 & b[7]) == a7 && + (a8 & b[8]) == a8 && (a9 & b[9]) == a9 && + (a10 & b[10]) == a10 && (a11 & b[11]) == a11 && + (a12 & b[12]) == a12 && (a13 & b[13]) == a13 && + (a14 & b[14]) == a14 && (a15 & b[15]) == a15 && + (a16 & b[16]) == a16 && (a17 & b[17]) == a17 && + (a18 & b[18]) == a18 && (a19 & b[19]) == a19 && + (a20 & b[20]) == a20 && (a21 & b[21]) == a21 && + (a22 & b[22]) == a22 && (a23 & b[23]) == a23 && + (a24 & b[24]) == a24 && (a25 & b[25]) == a25 && + (a26 & b[26]) == a26 && (a27 & b[27]) == a27 && + (a28 & b[28]) == a28 && (a29 & b[29]) == a29 && + (a30 & b[30]) == a30 && (a31 & b[31]) == a31; } }; @@ -260,76 +210,41 @@ namespace faiss { a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63]; } - inline float compute (const uint8_t *b16) const { + inline bool compute (const uint8_t *b16) const { const uint64_t *b = (uint64_t *)b16; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + - popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + - popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + - popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + - popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + - popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + - popcount64 (b[14] & a14) + popcount64 (b[15] & a15) + - popcount64 (b[16] & a16) + popcount64 (b[17] & a17) + - popcount64 (b[18] & a18) + popcount64 (b[19] & a19) + - popcount64 (b[20] & a20) + popcount64 (b[21] & a21) + - popcount64 (b[22] & a22) + popcount64 (b[23] & a23) + - popcount64 (b[24] & a24) + popcount64 (b[25] & a25) + - popcount64 (b[26] & a26) + popcount64 (b[27] & a27) + - popcount64 (b[28] & a28) + popcount64 (b[29] & a29) + - popcount64 (b[30] & a30) + popcount64 (b[31] & a31) + - popcount64 (b[32] & a32) + popcount64 (b[33] & a33) + - popcount64 (b[34] & a34) + popcount64 (b[35] & a35) + - popcount64 (b[36] & a36) + popcount64 (b[37] & a37) + - popcount64 (b[38] & a38) + popcount64 (b[39] & a39) + - popcount64 (b[40] & a40) + popcount64 (b[41] & a41) + - popcount64 (b[42] & a42) + popcount64 (b[43] & a43) + - popcount64 (b[44] & a44) + popcount64 (b[45] & a45) + - popcount64 (b[46] & a46) + popcount64 (b[47] & a47) + - popcount64 (b[48] & a48) + popcount64 (b[49] & a49) + - popcount64 (b[50] & a50) + popcount64 (b[51] & a51) + - popcount64 (b[52] & a52) + popcount64 (b[53] & a53) + - popcount64 (b[54] & a54) + popcount64 (b[55] & a55) + - popcount64 (b[56] & a56) + popcount64 (b[57] & a57) + - popcount64 (b[58] & a58) + popcount64 (b[59] & a59) + - popcount64 (b[60] & a60) + popcount64 (b[61] & a61) + - popcount64 (b[62] & a62) + popcount64 (b[63] & a63); - int accu_den = popcount64 (b[0]) + popcount64 (b[1]) + - popcount64 (b[2]) + popcount64 (b[3]) + - popcount64 (b[4]) + popcount64 (b[5]) + - popcount64 (b[6]) + popcount64 (b[7]) + - popcount64 (b[8]) + popcount64 (b[9]) + - popcount64 (b[10]) + popcount64 (b[11]) + - popcount64 (b[12]) + popcount64 (b[13]) + - popcount64 (b[14]) + popcount64 (b[15]) + - popcount64 (b[16]) + popcount64 (b[17]) + - popcount64 (b[18]) + popcount64 (b[19]) + - popcount64 (b[20]) + popcount64 (b[21]) + - popcount64 (b[22]) + popcount64 (b[23]) + - popcount64 (b[24]) + popcount64 (b[25]) + - popcount64 (b[26]) + popcount64 (b[27]) + - popcount64 (b[28]) + popcount64 (b[29]) + - popcount64 (b[30]) + popcount64 (b[31]) + - popcount64 (b[32]) + popcount64 (b[33]) + - popcount64 (b[34]) + popcount64 (b[35]) + - popcount64 (b[36]) + popcount64 (b[37]) + - popcount64 (b[38]) + popcount64 (b[39]) + - popcount64 (b[40]) + popcount64 (b[41]) + - popcount64 (b[42]) + popcount64 (b[43]) + - popcount64 (b[44]) + popcount64 (b[45]) + - popcount64 (b[46]) + popcount64 (b[47]) + - popcount64 (b[48]) + popcount64 (b[49]) + - popcount64 (b[50]) + popcount64 (b[51]) + - popcount64 (b[52]) + popcount64 (b[53]) + - popcount64 (b[54]) + popcount64 (b[55]) + - popcount64 (b[56]) + popcount64 (b[57]) + - popcount64 (b[58]) + popcount64 (b[59]) + - popcount64 (b[60]) + popcount64 (b[61]) + - popcount64 (b[62]) + popcount64 (b[63]); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / (float)(accu_den); - } + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3 && + (a4 & b[4]) == a4 && (a5 & b[5]) == a5 && + (a6 & b[6]) == a6 && (a7 & b[7]) == a7 && + (a8 & b[8]) == a8 && (a9 & b[9]) == a9 && + (a10 & b[10]) == a10 && (a11 & b[11]) == a11 && + (a12 & b[12]) == a12 && (a13 & b[13]) == a13 && + (a14 & b[14]) == a14 && (a15 & b[15]) == a15 && + (a16 & b[16]) == a16 && (a17 & b[17]) == a17 && + (a18 & b[18]) == a18 && (a19 & b[19]) == a19 && + (a20 & b[20]) == a20 && (a21 & b[21]) == a21 && + (a22 & b[22]) == a22 && (a23 & b[23]) == a23 && + (a24 & b[24]) == a24 && (a25 & b[25]) == a25 && + (a26 & b[26]) == a26 && (a27 & b[27]) == a27 && + (a28 & b[28]) == a28 && (a29 & b[29]) == a29 && + (a30 & b[30]) == a30 && (a31 & b[31]) == a31 && + (a32 & b[32]) == a32 && (a33 & b[33]) == a33 && + (a34 & b[34]) == a34 && (a35 & b[35]) == a35 && + (a36 & b[36]) == a36 && (a37 & b[37]) == a37 && + (a38 & b[38]) == a38 && (a39 & b[39]) == a39 && + (a40 & b[40]) == a40 && (a41 & b[41]) == a41 && + (a42 & b[42]) == a42 && (a43 & b[43]) == a43 && + (a44 & b[44]) == a44 && (a45 & b[45]) == a45 && + (a46 & b[46]) == a46 && (a47 & b[47]) == a47 && + (a48 & b[48]) == a48 && (a49 & b[49]) == a49 && + (a50 & b[50]) == a50 && (a51 & b[51]) == a51 && + (a52 & b[52]) == a52 && (a53 & b[53]) == a53 && + (a54 & b[54]) == a54 && (a55 & b[55]) == a55 && + (a56 & b[56]) == a56 && (a57 & b[57]) == a57 && + (a58 & b[58]) == a58 && (a59 & b[59]) == a59 && + (a60 & b[60]) == a60 && (a61 & b[61]) == a61 && + (a62 & b[62]) == a62 && (a63 & b[63]) == a63; + } }; @@ -348,16 +263,14 @@ namespace faiss { n = code_size; } - float compute (const uint8_t *b8) const { - int accu_num = 0; - int accu_den = 0; + bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; for (int i = 0; i < n; i++) { - accu_num += popcount64(a[i] & b8[i]); - accu_den += popcount64(b8[i]); + if ((a[i] & b[i]) != a[i]) { + return false; + } } - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / (float)(accu_den); + return true; } }; diff --git a/core/src/index/thirdparty/faiss/utils/superstructure-inl.h b/core/src/index/thirdparty/faiss/utils/superstructure-inl.h index 1ebf8946a5..e8b384e75f 100644 --- a/core/src/index/thirdparty/faiss/utils/superstructure-inl.h +++ b/core/src/index/thirdparty/faiss/utils/superstructure-inl.h @@ -2,7 +2,6 @@ namespace faiss { struct SuperstructureComputer8 { uint64_t a0; - float accu_den; SuperstructureComputer8 () {} @@ -14,22 +13,17 @@ namespace faiss { assert (code_size == 8); const uint64_t *a = (uint64_t *)a8; a0 = a[0]; - accu_den = (float)(popcount64 (a0)); } - inline float compute (const uint8_t *b8) const { + inline bool compute (const uint8_t *b8) const { const uint64_t *b = (uint64_t *)b8; - int accu_num = popcount64 (b[0] & a0); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / accu_den; + return (a0 & b[0]) == b[0]; } }; struct SuperstructureComputer16 { uint64_t a0, a1; - float accu_den; SuperstructureComputer16 () {} @@ -41,22 +35,17 @@ namespace faiss { assert (code_size == 16); const uint64_t *a = (uint64_t *)a8; a0 = a[0]; a1 = a[1]; - accu_den = (float)(popcount64 (a0) + popcount64 (a1)); } - inline float compute (const uint8_t *b8) const { + inline bool compute (const uint8_t *b8) const { const uint64_t *b = (uint64_t *)b8; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / accu_den; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1]; } }; struct SuperstructureComputer32 { uint64_t a0, a1, a2, a3; - float accu_den; SuperstructureComputer32 () {} @@ -68,24 +57,18 @@ namespace faiss { assert (code_size == 32); const uint64_t *a = (uint64_t *)a8; a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; - accu_den = (float)(popcount64 (a0) + popcount64 (a1) + - popcount64 (a2) + popcount64 (a3)); } - inline float compute (const uint8_t *b8) const { + inline bool compute (const uint8_t *b8) const { const uint64_t *b = (uint64_t *)b8; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / accu_den; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3]; } }; struct SuperstructureComputer64 { uint64_t a0, a1, a2, a3, a4, a5, a6, a7; - float accu_den; SuperstructureComputer64 () {} @@ -98,21 +81,14 @@ namespace faiss { const uint64_t *a = (uint64_t *)a8; a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; - accu_den = (float)(popcount64 (a0) + popcount64 (a1) + - popcount64 (a2) + popcount64 (a3) + - popcount64 (a4) + popcount64 (a5) + - popcount64 (a6) + popcount64 (a7)); } - inline float compute (const uint8_t *b8) const { + inline bool compute (const uint8_t *b8) const { const uint64_t *b = (uint64_t *)b8; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + - popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + - popcount64 (b[6] & a6) + popcount64 (b[7] & a7); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / accu_den; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] && + (a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] && + (a6 & b[6]) == b[6] && (a7 & b[7]) == b[7]; } }; @@ -120,7 +96,6 @@ namespace faiss { struct SuperstructureComputer128 { uint64_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15; - float accu_den; SuperstructureComputer128 () {} @@ -135,29 +110,18 @@ namespace faiss { a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; - accu_den = (float)(popcount64 (a0) + popcount64 (a1) + - popcount64 (a2) + popcount64 (a3) + - popcount64 (a4) + popcount64 (a5) + - popcount64 (a6) + popcount64 (a7) + - popcount64 (a8) + popcount64 (a9) + - popcount64 (a10) + popcount64 (a11) + - popcount64 (a12) + popcount64 (a13) + - popcount64 (a14) + popcount64 (a15)); } - inline float compute (const uint8_t *b16) const { - const uint64_t *b = (uint64_t *)b16; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + - popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + - popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + - popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + - popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + - popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + - popcount64 (b[14] & a14) + popcount64 (b[15] & a15); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / accu_den; + inline float compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] && + (a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] && + (a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] && + (a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] && + (a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] && + (a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] && + (a14 & b[14]) == b[14] && (a15 & b[15]) == b[15]; } }; @@ -167,7 +131,6 @@ namespace faiss { a8,a9,a10,a11,a12,a13,a14,a15, a16,a17,a18,a19,a20,a21,a22,a23, a24,a25,a26,a27,a28,a29,a30,a31; - float accu_den; SuperstructureComputer256 () {} @@ -186,45 +149,26 @@ namespace faiss { a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23]; a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27]; a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31]; - accu_den = (float)(popcount64 (a0) + popcount64 (a1) + - popcount64 (a2) + popcount64 (a3) + - popcount64 (a4) + popcount64 (a5) + - popcount64 (a6) + popcount64 (a7) + - popcount64 (a8) + popcount64 (a9) + - popcount64 (a10) + popcount64 (a11) + - popcount64 (a12) + popcount64 (a13) + - popcount64 (a14) + popcount64 (a15) + - popcount64 (a16) + popcount64 (a17) + - popcount64 (a18) + popcount64 (a19) + - popcount64 (a20) + popcount64 (a21) + - popcount64 (a22) + popcount64 (a23) + - popcount64 (a24) + popcount64 (a25) + - popcount64 (a26) + popcount64 (a27) + - popcount64 (a28) + popcount64 (a29) + - popcount64 (a30) + popcount64 (a31)); } - inline float compute (const uint8_t *b16) const { - const uint64_t *b = (uint64_t *)b16; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + - popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + - popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + - popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + - popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + - popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + - popcount64 (b[14] & a14) + popcount64 (b[15] & a15) + - popcount64 (b[16] & a16) + popcount64 (b[17] & a17) + - popcount64 (b[18] & a18) + popcount64 (b[19] & a19) + - popcount64 (b[20] & a20) + popcount64 (b[21] & a21) + - popcount64 (b[22] & a22) + popcount64 (b[23] & a23) + - popcount64 (b[24] & a24) + popcount64 (b[25] & a25) + - popcount64 (b[26] & a26) + popcount64 (b[27] & a27) + - popcount64 (b[28] & a28) + popcount64 (b[29] & a29) + - popcount64 (b[30] & a30) + popcount64 (b[31] & a31); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / accu_den; + inline float compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] && + (a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] && + (a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] && + (a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] && + (a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] && + (a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] && + (a14 & b[14]) == b[14] && (a15 & b[15]) == b[15] && + (a16 & b[16]) == b[16] && (a17 & b[17]) == b[17] && + (a18 & b[18]) == b[18] && (a19 & b[19]) == b[19] && + (a20 & b[20]) == b[20] && (a21 & b[21]) == b[21] && + (a22 & b[22]) == b[22] && (a23 & b[23]) == b[23] && + (a24 & b[24]) == b[24] && (a25 & b[25]) == b[25] && + (a26 & b[26]) == b[26] && (a27 & b[27]) == b[27] && + (a28 & b[28]) == b[28] && (a29 & b[29]) == b[29] && + (a30 & b[30]) == b[30] && (a31 & b[31]) == b[31]; } }; @@ -238,7 +182,6 @@ namespace faiss { a40,a41,a42,a43,a44,a45,a46,a47, a48,a49,a50,a51,a52,a53,a54,a55, a56,a57,a58,a59,a60,a61,a62,a63; - float accu_den; SuperstructureComputer512 () {} @@ -265,85 +208,49 @@ namespace faiss { a52 = a[52]; a53 = a[53]; a54 = a[54]; a55 = a[55]; a56 = a[56]; a57 = a[57]; a58 = a[58]; a59 = a[59]; a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63]; - accu_den = (float)(popcount64 (a0) + popcount64 (a1) + - popcount64 (a2) + popcount64 (a3) + - popcount64 (a4) + popcount64 (a5) + - popcount64 (a6) + popcount64 (a7) + - popcount64 (a8) + popcount64 (a9) + - popcount64 (a10) + popcount64 (a11) + - popcount64 (a12) + popcount64 (a13) + - popcount64 (a14) + popcount64 (a15) + - popcount64 (a16) + popcount64 (a17) + - popcount64 (a18) + popcount64 (a19) + - popcount64 (a20) + popcount64 (a21) + - popcount64 (a22) + popcount64 (a23) + - popcount64 (a24) + popcount64 (a25) + - popcount64 (a26) + popcount64 (a27) + - popcount64 (a28) + popcount64 (a29) + - popcount64 (a30) + popcount64 (a31) + - popcount64 (a32) + popcount64 (a33) + - popcount64 (a34) + popcount64 (a35) + - popcount64 (a36) + popcount64 (a37) + - popcount64 (a38) + popcount64 (a39) + - popcount64 (a40) + popcount64 (a41) + - popcount64 (a42) + popcount64 (a43) + - popcount64 (a44) + popcount64 (a45) + - popcount64 (a46) + popcount64 (a47) + - popcount64 (a48) + popcount64 (a49) + - popcount64 (a50) + popcount64 (a51) + - popcount64 (a52) + popcount64 (a53) + - popcount64 (a54) + popcount64 (a55) + - popcount64 (a56) + popcount64 (a57) + - popcount64 (a58) + popcount64 (a59) + - popcount64 (a60) + popcount64 (a61) + - popcount64 (a62) + popcount64 (a63)); } - inline float compute (const uint8_t *b16) const { - const uint64_t *b = (uint64_t *)b16; - int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + - popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + - popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + - popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + - popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + - popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + - popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + - popcount64 (b[14] & a14) + popcount64 (b[15] & a15) + - popcount64 (b[16] & a16) + popcount64 (b[17] & a17) + - popcount64 (b[18] & a18) + popcount64 (b[19] & a19) + - popcount64 (b[20] & a20) + popcount64 (b[21] & a21) + - popcount64 (b[22] & a22) + popcount64 (b[23] & a23) + - popcount64 (b[24] & a24) + popcount64 (b[25] & a25) + - popcount64 (b[26] & a26) + popcount64 (b[27] & a27) + - popcount64 (b[28] & a28) + popcount64 (b[29] & a29) + - popcount64 (b[30] & a30) + popcount64 (b[31] & a31) + - popcount64 (b[32] & a32) + popcount64 (b[33] & a33) + - popcount64 (b[34] & a34) + popcount64 (b[35] & a35) + - popcount64 (b[36] & a36) + popcount64 (b[37] & a37) + - popcount64 (b[38] & a38) + popcount64 (b[39] & a39) + - popcount64 (b[40] & a40) + popcount64 (b[41] & a41) + - popcount64 (b[42] & a42) + popcount64 (b[43] & a43) + - popcount64 (b[44] & a44) + popcount64 (b[45] & a45) + - popcount64 (b[46] & a46) + popcount64 (b[47] & a47) + - popcount64 (b[48] & a48) + popcount64 (b[49] & a49) + - popcount64 (b[50] & a50) + popcount64 (b[51] & a51) + - popcount64 (b[52] & a52) + popcount64 (b[53] & a53) + - popcount64 (b[54] & a54) + popcount64 (b[55] & a55) + - popcount64 (b[56] & a56) + popcount64 (b[57] & a57) + - popcount64 (b[58] & a58) + popcount64 (b[59] & a59) + - popcount64 (b[60] & a60) + popcount64 (b[61] & a61) + - popcount64 (b[62] & a62) + popcount64 (b[63] & a63); - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / accu_den; - } + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] && + (a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] && + (a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] && + (a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] && + (a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] && + (a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] && + (a14 & b[14]) == b[14] && (a15 & b[15]) == b[15] && + (a16 & b[16]) == b[16] && (a17 & b[17]) == b[17] && + (a18 & b[18]) == b[18] && (a19 & b[19]) == b[19] && + (a20 & b[20]) == b[20] && (a21 & b[21]) == b[21] && + (a22 & b[22]) == b[22] && (a23 & b[23]) == b[23] && + (a24 & b[24]) == b[24] && (a25 & b[25]) == b[25] && + (a26 & b[26]) == b[26] && (a27 & b[27]) == b[27] && + (a28 & b[28]) == b[28] && (a29 & b[29]) == b[29] && + (a30 & b[30]) == b[30] && (a31 & b[31]) == b[31] && + (a32 & b[32]) == b[32] && (a33 & b[33]) == b[33] && + (a34 & b[34]) == b[34] && (a35 & b[35]) == b[35] && + (a36 & b[36]) == b[36] && (a37 & b[37]) == b[37] && + (a38 & b[38]) == b[38] && (a39 & b[39]) == b[39] && + (a40 & b[40]) == b[40] && (a41 & b[41]) == b[41] && + (a42 & b[42]) == b[42] && (a43 & b[43]) == b[43] && + (a44 & b[44]) == b[44] && (a45 & b[45]) == b[45] && + (a46 & b[46]) == b[46] && (a47 & b[47]) == b[47] && + (a48 & b[48]) == b[48] && (a49 & b[49]) == b[49] && + (a50 & b[50]) == b[50] && (a51 & b[51]) == b[51] && + (a52 & b[52]) == b[52] && (a53 & b[53]) == b[53] && + (a54 & b[54]) == b[54] && (a55 & b[55]) == b[55] && + (a56 & b[56]) == b[56] && (a57 & b[57]) == b[57] && + (a58 & b[58]) == b[58] && (a59 & b[59]) == b[59] && + (a60 & b[60]) == b[60] && (a61 & b[61]) == b[61] && + (a62 & b[62]) == b[62] && (a63 & b[63]) == b[63]; + } }; struct SuperstructureComputerDefault { const uint8_t *a; int n; - float accu_den; SuperstructureComputerDefault () {} @@ -354,21 +261,16 @@ namespace faiss { void set (const uint8_t *a8, int code_size) { a = a8; n = code_size; - int i_accu_den = 0; - for (int i = 0; i < n; i++) { - i_accu_den += popcount64(a[i]); - } - accu_den = (float)i_accu_den; } - float compute (const uint8_t *b8) const { - int accu_num = 0; + bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; for (int i = 0; i < n; i++) { - accu_num += popcount64(a[i] & b8[i]); + if ((a[i] & b[i]) != b[i]) { + return false; + } } - if (accu_num == 0) - return 1.0; - return 1.0 - (float)(accu_num) / accu_den; + return true; } }; diff --git a/tests/milvus_python_test/test_search_vectors.py b/tests/milvus_python_test/test_search_vectors.py index f057fdcaaf..38e91d1d3d 100644 --- a/tests/milvus_python_test/test_search_vectors.py +++ b/tests/milvus_python_test/test_search_vectors.py @@ -1,5 +1,4 @@ import pdb -import copy import struct from random import sample @@ -675,7 +674,36 @@ class TestSearchBase: status, result = connect.search_vectors(substructure_collection, top_k, query_vecs, params=search_param) logging.getLogger().info(status) logging.getLogger().info(result) - assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon + assert result[0][0].id == -1 + + def test_search_distance_substructure_flat_index_B(self, connect, substructure_collection): + ''' + target: search ip_collection, and check the result: distance + method: compare the return distance value with value computed with SUB + expected: the return distance equals to the computed value + ''' + # from scipy.spatial import distance + top_k = 3 + nprobe = 512 + int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2) + index_type = IndexType.FLAT + index_param = { + "nlist": 16384 + } + connect.create_index(substructure_collection, index_type, index_param) + logging.getLogger().info(connect.describe_collection(substructure_collection)) + logging.getLogger().info(connect.describe_index(substructure_collection)) + query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2) + search_param = get_search_param(index_type) + status, result = connect.search_vectors(substructure_collection, top_k, query_vecs, params=search_param) + logging.getLogger().info(status) + logging.getLogger().info(result) + assert result[0][0].distance <= epsilon + assert result[0][0].id == ids[0] + assert result[1][0].distance <= epsilon + assert result[1][0].id == ids[1] + assert result[0][1].id == -1 + assert result[1][1].id == -1 def test_search_distance_superstructure_flat_index(self, connect, superstructure_collection): ''' @@ -701,7 +729,36 @@ class TestSearchBase: status, result = connect.search_vectors(superstructure_collection, top_k, query_vecs, params=search_param) logging.getLogger().info(status) logging.getLogger().info(result) - assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon + assert result[0][0].id == -1 + + def test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection): + ''' + target: search ip_collection, and check the result: distance + method: compare the return distance value with value computed with SUPER + expected: the return distance equals to the computed value + ''' + # from scipy.spatial import distance + top_k = 3 + nprobe = 512 + int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2) + index_type = IndexType.FLAT + index_param = { + "nlist": 16384 + } + connect.create_index(superstructure_collection, index_type, index_param) + logging.getLogger().info(connect.describe_collection(superstructure_collection)) + logging.getLogger().info(connect.describe_index(superstructure_collection)) + query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2) + search_param = get_search_param(index_type) + status, result = connect.search_vectors(superstructure_collection, top_k, query_vecs, params=search_param) + logging.getLogger().info(status) + logging.getLogger().info(result) + assert result[0][0].id in ids + assert result[0][0].distance <= epsilon + assert result[1][0].id in ids + assert result[1][0].distance <= epsilon + assert result[0][2].id == -1 + assert result[1][2].id == -1 def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection): ''' diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py index ceb72e3a32..31f254670f 100644 --- a/tests/milvus_python_test/utils.py +++ b/tests/milvus_python_test/utils.py @@ -67,6 +67,33 @@ def superstructure(x, y): return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(x) +def gen_binary_sub_vectors(vectors, length): + raw_vectors = [] + binary_vectors = [] + dim = len(vectors[0]) + for i in range(length): + raw_vector = [0 for i in range(dim)] + vector = vectors[i] + for index, j in enumerate(vector): + if j == 1: + raw_vector[index] = 1 + raw_vectors.append(raw_vector) + binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) + return raw_vectors, binary_vectors + + +def gen_binary_super_vectors(vectors, length): + raw_vectors = [] + binary_vectors = [] + dim = len(vectors[0]) + for i in range(length): + cnt_1 = np.count_nonzero(vectors[i]) + raw_vector = [1 for i in range(dim)] + raw_vectors.append(raw_vector) + binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) + return raw_vectors, binary_vectors + + def gen_single_vector(dim): return [[random.random() for _ in range(dim)]]