mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-03 12:29:36 +08:00
Caiyd 1965 unify metric cal (#2008)
* #1965 add metric algorithm benchmark for FAISS Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * #1965 add metric algorithm benchmark for NSG Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * #1965 add metric algorithm benchmark for ANNOY Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * #1965 add metric algorithm benchmark for HNSW Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * code opt Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * calculate average time Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * #1965 annoy/nsg/hnsw all use faiss distance algorithm Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * #1965 support AVX512/AVX2/SSE42 Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * #1965 update changelog Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * #1965 fix hnsw ip calculation error Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
a5eec9d7b4
commit
f039032c8c
@ -9,6 +9,7 @@ Please mark all change in change log and use the issue from GitHub
|
||||
- \#1929 Skip MySQL meta schema field width check
|
||||
|
||||
## Feature
|
||||
- \#1965 FAISS/NSG/HNSW/ANNOY use unified distance calculation algorithm
|
||||
|
||||
## Improvement
|
||||
- \#221 Refactor LOG macro
|
||||
|
@ -128,7 +128,7 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
distances.reserve(k);
|
||||
index_->get_nns_by_vector((const float*)p_data + i * dim, k, search_k, &result, &distances, blacklist);
|
||||
|
||||
size_t result_num = result.size();
|
||||
int64_t result_num = result.size();
|
||||
auto local_p_id = p_id + k * i;
|
||||
auto local_p_dist = p_dist + k * i;
|
||||
memcpy(local_p_id, result.data(), result_num * sizeof(int64_t));
|
||||
|
@ -9,6 +9,7 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/FaissHook.h>
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "knowhere/index/vector_index/impl/nsg/Distance.h"
|
||||
@ -17,6 +18,8 @@ namespace milvus {
|
||||
namespace knowhere {
|
||||
namespace impl {
|
||||
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
|
||||
float
|
||||
DistanceL2::Compare(const float* a, const float* b, unsigned size) const {
|
||||
float result = 0;
|
||||
@ -225,16 +228,19 @@ DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
//#include <faiss/utils/distances.h>
|
||||
// float
|
||||
// DistanceL2::Compare(const float* a, const float* b, unsigned size) const {
|
||||
// return faiss::fvec_L2sqr(a,b,size);
|
||||
//}
|
||||
//
|
||||
// float
|
||||
// DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
|
||||
// return faiss::fvec_inner_product(a,b,size);
|
||||
//}
|
||||
#else
|
||||
|
||||
float
|
||||
DistanceL2::Compare(const float* a, const float* b, unsigned size) const {
|
||||
return faiss::fvec_L2sqr(a, b, (size_t)size);
|
||||
}
|
||||
|
||||
float
|
||||
DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
|
||||
return faiss::fvec_inner_product(a, b, (size_t)size);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace impl
|
||||
} // namespace knowhere
|
||||
|
19
core/src/index/thirdparty/annoy/src/annoylib.h
vendored
19
core/src/index/thirdparty/annoy/src/annoylib.h
vendored
@ -125,6 +125,7 @@ inline void set_error_from_string(char **error, const char* msg) {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <faiss/FaissHook.h>
|
||||
|
||||
using std::vector;
|
||||
using std::pair;
|
||||
@ -184,7 +185,7 @@ inline T euclidean_distance(const T* x, const T* y, int f) {
|
||||
return d;
|
||||
}
|
||||
|
||||
#ifdef USE_AVX
|
||||
//#ifdef USE_AVX
|
||||
// Horizontal single sum of 256bit vector.
|
||||
inline float hsum256_ps_avx(__m256 v) {
|
||||
const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v));
|
||||
@ -195,6 +196,7 @@ inline float hsum256_ps_avx(__m256 v) {
|
||||
|
||||
template<>
|
||||
inline float dot<float>(const float* x, const float *y, int f) {
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
float result = 0;
|
||||
if (f > 7) {
|
||||
__m256 d = _mm256_setzero_ps();
|
||||
@ -213,10 +215,14 @@ inline float dot<float>(const float* x, const float *y, int f) {
|
||||
y++;
|
||||
}
|
||||
return result;
|
||||
#else
|
||||
return faiss::fvec_inner_product(x, y, (size_t)f);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
inline float manhattan_distance<float>(const float* x, const float* y, int f) {
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
float result = 0;
|
||||
int i = f;
|
||||
if (f > 7) {
|
||||
@ -239,10 +245,14 @@ inline float manhattan_distance<float>(const float* x, const float* y, int f) {
|
||||
y++;
|
||||
}
|
||||
return result;
|
||||
#else
|
||||
return faiss::fvec_L1(x, y, (size_t)f);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
inline float euclidean_distance<float>(const float* x, const float* y, int f) {
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
float result=0;
|
||||
if (f > 7) {
|
||||
__m256 d = _mm256_setzero_ps();
|
||||
@ -263,10 +273,14 @@ inline float euclidean_distance<float>(const float* x, const float* y, int f) {
|
||||
y++;
|
||||
}
|
||||
return result;
|
||||
#else
|
||||
return faiss::fvec_L2sqr(x, y, (size_t)f);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
//#endif
|
||||
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
#ifdef USE_AVX512
|
||||
template<>
|
||||
inline float dot<float>(const float* x, const float *y, int f) {
|
||||
@ -340,6 +354,7 @@ inline float euclidean_distance<float>(const float* x, const float* y, int f) {
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
template<typename T>
|
||||
|
14
core/src/index/thirdparty/faiss/FaissHook.cpp
vendored
14
core/src/index/thirdparty/faiss/FaissHook.cpp
vendored
@ -38,14 +38,14 @@ bool support_avx512() {
|
||||
instruction_set_inst.AVX512BW());
|
||||
}
|
||||
|
||||
bool support_avx() {
|
||||
bool support_avx2() {
|
||||
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
|
||||
return (instruction_set_inst.AVX2());
|
||||
}
|
||||
|
||||
bool support_sse() {
|
||||
bool support_sse42() {
|
||||
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
|
||||
return (instruction_set_inst.SSE());
|
||||
return (instruction_set_inst.SSE42());
|
||||
}
|
||||
|
||||
bool hook_init(std::string& cpu_flag) {
|
||||
@ -65,7 +65,7 @@ bool hook_init(std::string& cpu_flag) {
|
||||
sq_sel_quantizer = sq_select_quantizer_avx512;
|
||||
|
||||
cpu_flag = "AVX512";
|
||||
} else if (support_avx()) {
|
||||
} else if (support_avx2()) {
|
||||
/* for IVFFLAT */
|
||||
fvec_inner_product = fvec_inner_product_avx;
|
||||
fvec_L2sqr = fvec_L2sqr_avx;
|
||||
@ -77,8 +77,8 @@ bool hook_init(std::string& cpu_flag) {
|
||||
sq_get_distance_computer_IP = sq_get_distance_computer_IP_avx;
|
||||
sq_sel_quantizer = sq_select_quantizer_avx;
|
||||
|
||||
cpu_flag = "AVX";
|
||||
} else if (support_sse()) {
|
||||
cpu_flag = "AVX2";
|
||||
} else if (support_sse42()) {
|
||||
/* for IVFFLAT */
|
||||
fvec_inner_product = fvec_inner_product_sse;
|
||||
fvec_L2sqr = fvec_L2sqr_sse;
|
||||
@ -90,7 +90,7 @@ bool hook_init(std::string& cpu_flag) {
|
||||
sq_get_distance_computer_IP = sq_get_distance_computer_IP_sse;
|
||||
sq_sel_quantizer = sq_select_quantizer_sse;
|
||||
|
||||
cpu_flag = "SSE";
|
||||
cpu_flag = "SSE42";
|
||||
} else {
|
||||
cpu_flag = "UNSUPPORTED";
|
||||
return false;
|
||||
|
10
core/src/index/thirdparty/hnswlib/space_ip.h
vendored
10
core/src/index/thirdparty/hnswlib/space_ip.h
vendored
@ -1,19 +1,24 @@
|
||||
#pragma once
|
||||
#include "hnswlib.h"
|
||||
#include <faiss/FaissHook.h>
|
||||
|
||||
namespace hnswlib {
|
||||
|
||||
static float
|
||||
InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
float res = 0;
|
||||
for (unsigned i = 0; i < qty; i++) {
|
||||
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
|
||||
}
|
||||
return (1.0f - res);
|
||||
|
||||
#else
|
||||
return (1.0f - faiss::fvec_inner_product((const float*)pVect1, (const float*)pVect2, *((size_t*)qty_ptr)));
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
#if defined(USE_AVX)
|
||||
|
||||
// Favor using AVX if available.
|
||||
@ -209,6 +214,7 @@ InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_
|
||||
return 1.0f - sum;
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
class InnerProductSpace : public SpaceInterface<float> {
|
||||
@ -218,11 +224,13 @@ class InnerProductSpace : public SpaceInterface<float> {
|
||||
public:
|
||||
InnerProductSpace(size_t dim) {
|
||||
fstdistfunc_ = InnerProduct;
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
#if defined(USE_AVX) || defined(USE_SSE)
|
||||
if (dim % 4 == 0)
|
||||
fstdistfunc_ = InnerProductSIMD4Ext;
|
||||
if (dim % 16 == 0)
|
||||
fstdistfunc_ = InnerProductSIMD16Ext;
|
||||
#endif
|
||||
#endif
|
||||
dim_ = dim;
|
||||
data_size_ = dim * sizeof(float);
|
||||
|
9
core/src/index/thirdparty/hnswlib/space_l2.h
vendored
9
core/src/index/thirdparty/hnswlib/space_l2.h
vendored
@ -1,10 +1,12 @@
|
||||
#pragma once
|
||||
#include "hnswlib.h"
|
||||
#include <faiss/FaissHook.h>
|
||||
|
||||
namespace hnswlib {
|
||||
|
||||
static float
|
||||
L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) {
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
//return *((float *)pVect2);
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
float res = 0;
|
||||
@ -13,8 +15,12 @@ L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) {
|
||||
res += t * t;
|
||||
}
|
||||
return (res);
|
||||
#else
|
||||
return faiss::fvec_L2sqr((const float*)pVect1, (const float*)pVect2, *((size_t*)qty_ptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
#if defined(USE_AVX)
|
||||
|
||||
// Favor using AVX if available.
|
||||
@ -140,6 +146,7 @@ L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
return (res);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
class L2Space : public SpaceInterface<float> {
|
||||
DISTFUNC<float> fstdistfunc_;
|
||||
@ -148,6 +155,7 @@ class L2Space : public SpaceInterface<float> {
|
||||
public:
|
||||
L2Space(size_t dim) {
|
||||
fstdistfunc_ = L2Sqr;
|
||||
#if 0 /* use FAISS distance calculation algorithm instead */
|
||||
#if defined(USE_SSE) || defined(USE_AVX)
|
||||
if (dim % 4 == 0)
|
||||
fstdistfunc_ = L2SqrSIMD4Ext;
|
||||
@ -156,6 +164,7 @@ class L2Space : public SpaceInterface<float> {
|
||||
/*else{
|
||||
throw runtime_error("Data type not supported!");
|
||||
}*/
|
||||
#endif
|
||||
#endif
|
||||
dim_ = dim;
|
||||
data_size_ = dim * sizeof(float);
|
||||
|
@ -198,4 +198,4 @@ install(TARGETS test_annoy DESTINATION unittest)
|
||||
|
||||
#add_subdirectory(faiss_ori)
|
||||
#add_subdirectory(faiss_benchmark)
|
||||
|
||||
#add_subdirectory(metric_alg_benchmark)
|
||||
|
17
core/src/index/unittest/metric_alg_benchmark/CMakeLists.txt
Normal file
17
core/src/index/unittest/metric_alg_benchmark/CMakeLists.txt
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
set(unittest_libs
|
||||
gtest gmock gtest_main gmock_main)
|
||||
|
||||
add_executable(test_metric_benchmark metric_benchmark_test.cpp)
|
||||
target_link_libraries(test_metric_benchmark ${unittest_libs})
|
||||
install(TARGETS test_metric_benchmark DESTINATION unittest)
|
@ -0,0 +1,415 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <immintrin.h>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
typedef float (*metric_func_ptr)(const float*, const float*, size_t);
|
||||
|
||||
constexpr int64_t DIM = 512;
|
||||
constexpr int64_t NB = 10000;
|
||||
constexpr int64_t NQ = 5;
|
||||
constexpr int64_t LOOP = 5;
|
||||
|
||||
void
|
||||
GenerateData(const int64_t dim, const int64_t n, float* x) {
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
for (int64_t j = 0; j < dim; ++j) {
|
||||
x[i * dim + j] = drand48();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TestMetricAlg(std::unordered_map<std::string, metric_func_ptr>& func_map, const std::string& key, int64_t loop,
|
||||
float* distance, const int64_t nb, const float* xb, const int64_t nq, const float* xq,
|
||||
const int64_t dim) {
|
||||
int64_t diff = 0;
|
||||
for (int64_t i = 0; i < loop; i++) {
|
||||
auto t0 = std::chrono::system_clock::now();
|
||||
for (int64_t i = 0; i < nb; i++) {
|
||||
for (int64_t j = 0; j < nq; j++) {
|
||||
distance[i * NQ + j] = func_map[key](xb + i * dim, xq + j * dim, dim);
|
||||
}
|
||||
}
|
||||
auto t1 = std::chrono::system_clock::now();
|
||||
diff += std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count();
|
||||
}
|
||||
std::cout << key << " takes average " << diff / loop << "ms" << std::endl;
|
||||
}
|
||||
|
||||
void
|
||||
CheckResult(const float* result1, const float* result2, const size_t size) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
ASSERT_FLOAT_EQ(result1[i], result2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/* from faiss/utils/distances_simd.cpp */
|
||||
namespace FAISS {
|
||||
// reads 0 <= d < 4 floats as __m128
|
||||
static inline __m128
|
||||
masked_read(int d, const float* x) {
|
||||
assert(0 <= d && d < 4);
|
||||
__attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
|
||||
switch (d) {
|
||||
case 3:
|
||||
buf[2] = x[2];
|
||||
case 2:
|
||||
buf[1] = x[1];
|
||||
case 1:
|
||||
buf[0] = x[0];
|
||||
}
|
||||
return _mm_load_ps(buf);
|
||||
// cannot use AVX2 _mm_mask_set1_epi32
|
||||
}
|
||||
|
||||
static inline __m256
|
||||
masked_read_8(int d, const float* x) {
|
||||
assert(0 <= d && d < 8);
|
||||
if (d < 4) {
|
||||
__m256 res = _mm256_setzero_ps();
|
||||
res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
|
||||
return res;
|
||||
} else {
|
||||
__m256 res = _mm256_setzero_ps();
|
||||
res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
|
||||
res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
float
|
||||
fvec_inner_product_avx(const float* x, const float* y, size_t d) {
|
||||
__m256 msum1 = _mm256_setzero_ps();
|
||||
|
||||
while (d >= 8) {
|
||||
__m256 mx = _mm256_loadu_ps(x);
|
||||
x += 8;
|
||||
__m256 my = _mm256_loadu_ps(y);
|
||||
y += 8;
|
||||
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
|
||||
d -= 8;
|
||||
}
|
||||
|
||||
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
||||
msum2 += _mm256_extractf128_ps(msum1, 0);
|
||||
|
||||
if (d >= 4) {
|
||||
__m128 mx = _mm_loadu_ps(x);
|
||||
x += 4;
|
||||
__m128 my = _mm_loadu_ps(y);
|
||||
y += 4;
|
||||
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
||||
d -= 4;
|
||||
}
|
||||
|
||||
if (d > 0) {
|
||||
__m128 mx = masked_read(d, x);
|
||||
__m128 my = masked_read(d, y);
|
||||
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
||||
}
|
||||
|
||||
msum2 = _mm_hadd_ps(msum2, msum2);
|
||||
msum2 = _mm_hadd_ps(msum2, msum2);
|
||||
return _mm_cvtss_f32(msum2);
|
||||
}
|
||||
|
||||
float
|
||||
fvec_L2sqr_avx(const float* x, const float* y, size_t d) {
|
||||
__m256 msum1 = _mm256_setzero_ps();
|
||||
|
||||
while (d >= 8) {
|
||||
__m256 mx = _mm256_loadu_ps(x);
|
||||
x += 8;
|
||||
__m256 my = _mm256_loadu_ps(y);
|
||||
y += 8;
|
||||
const __m256 a_m_b1 = mx - my;
|
||||
msum1 += a_m_b1 * a_m_b1;
|
||||
d -= 8;
|
||||
}
|
||||
|
||||
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
||||
msum2 += _mm256_extractf128_ps(msum1, 0);
|
||||
|
||||
if (d >= 4) {
|
||||
__m128 mx = _mm_loadu_ps(x);
|
||||
x += 4;
|
||||
__m128 my = _mm_loadu_ps(y);
|
||||
y += 4;
|
||||
const __m128 a_m_b1 = mx - my;
|
||||
msum2 += a_m_b1 * a_m_b1;
|
||||
d -= 4;
|
||||
}
|
||||
|
||||
if (d > 0) {
|
||||
__m128 mx = masked_read(d, x);
|
||||
__m128 my = masked_read(d, y);
|
||||
__m128 a_m_b1 = mx - my;
|
||||
msum2 += a_m_b1 * a_m_b1;
|
||||
}
|
||||
|
||||
msum2 = _mm_hadd_ps(msum2, msum2);
|
||||
msum2 = _mm_hadd_ps(msum2, msum2);
|
||||
return _mm_cvtss_f32(msum2);
|
||||
}
|
||||
} // namespace FAISS
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/* from knowhere/index/vector_index/impl/nsg/Distance.cpp */
|
||||
namespace NSG {
|
||||
float
|
||||
DistanceL2_Compare(const float* a, const float* b, size_t size) {
|
||||
float result = 0;
|
||||
|
||||
#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
|
||||
tmp1 = _mm256_loadu_ps(addr1); \
|
||||
tmp2 = _mm256_loadu_ps(addr2); \
|
||||
tmp1 = _mm256_sub_ps(tmp1, tmp2); \
|
||||
tmp1 = _mm256_mul_ps(tmp1, tmp1); \
|
||||
dest = _mm256_add_ps(dest, tmp1);
|
||||
|
||||
__m256 sum;
|
||||
__m256 l0, l1;
|
||||
__m256 r0, r1;
|
||||
unsigned D = (size + 7) & ~7U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float* l = a;
|
||||
const float* r = b;
|
||||
const float* e_l = l + DD;
|
||||
const float* e_r = r + DD;
|
||||
float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
sum = _mm256_loadu_ps(unpack);
|
||||
if (DR) {
|
||||
AVX_L2SQR(e_l, e_r, sum, l0, r0);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
|
||||
AVX_L2SQR(l, r, sum, l0, r0);
|
||||
AVX_L2SQR(l + 8, r + 8, sum, l1, r1);
|
||||
}
|
||||
_mm256_storeu_ps(unpack, sum);
|
||||
result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
float
|
||||
DistanceIP_Compare(const float* a, const float* b, size_t size) {
|
||||
float result = 0;
|
||||
|
||||
#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \
|
||||
tmp1 = _mm256_loadu_ps(addr1); \
|
||||
tmp2 = _mm256_loadu_ps(addr2); \
|
||||
tmp1 = _mm256_mul_ps(tmp1, tmp2); \
|
||||
dest = _mm256_add_ps(dest, tmp1);
|
||||
|
||||
__m256 sum;
|
||||
__m256 l0, l1;
|
||||
__m256 r0, r1;
|
||||
unsigned D = (size + 7) & ~7U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float* l = a;
|
||||
const float* r = b;
|
||||
const float* e_l = l + DD;
|
||||
const float* e_r = r + DD;
|
||||
float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
sum = _mm256_loadu_ps(unpack);
|
||||
if (DR) {
|
||||
AVX_DOT(e_l, e_r, sum, l0, r0);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
|
||||
AVX_DOT(l, r, sum, l0, r0);
|
||||
AVX_DOT(l + 8, r + 8, sum, l1, r1);
|
||||
}
|
||||
_mm256_storeu_ps(unpack, sum);
|
||||
result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace NSG
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/* from index/thirdparty/annoy/src/annoylib.h */
|
||||
namespace ANNOY {
|
||||
inline float
|
||||
hsum256_ps_avx(__m256 v) {
|
||||
const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v));
|
||||
const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
|
||||
const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
|
||||
return _mm_cvtss_f32(x32);
|
||||
}
|
||||
|
||||
inline float
|
||||
euclidean_distance(const float* x, const float* y, size_t f) {
|
||||
float result = 0;
|
||||
if (f > 7) {
|
||||
__m256 d = _mm256_setzero_ps();
|
||||
for (; f > 7; f -= 8) {
|
||||
const __m256 diff = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y));
|
||||
d = _mm256_add_ps(d, _mm256_mul_ps(diff, diff)); // no support for fmadd in AVX...
|
||||
x += 8;
|
||||
y += 8;
|
||||
}
|
||||
// Sum all floats in dot register.
|
||||
result = hsum256_ps_avx(d);
|
||||
}
|
||||
// Don't forget the remaining values.
|
||||
for (; f > 0; f--) {
|
||||
float tmp = *x - *y;
|
||||
result += tmp * tmp;
|
||||
x++;
|
||||
y++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline float
|
||||
dot(const float* x, const float* y, size_t f) {
|
||||
float result = 0;
|
||||
if (f > 7) {
|
||||
__m256 d = _mm256_setzero_ps();
|
||||
for (; f > 7; f -= 8) {
|
||||
d = _mm256_add_ps(d, _mm256_mul_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y)));
|
||||
x += 8;
|
||||
y += 8;
|
||||
}
|
||||
// Sum all floats in dot register.
|
||||
result += hsum256_ps_avx(d);
|
||||
}
|
||||
// Don't forget the remaining values.
|
||||
for (; f > 0; f--) {
|
||||
result += *x * *y;
|
||||
x++;
|
||||
y++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace ANNOY
|
||||
|
||||
namespace HNSW {
|
||||
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
|
||||
|
||||
static float
|
||||
L2SqrSIMD16Ext(const float* pVect1v, const float* pVect2v, size_t qty) {
|
||||
float* pVect1 = (float*)pVect1v;
|
||||
float* pVect2 = (float*)pVect2v;
|
||||
// size_t qty = *((size_t *) qty_ptr);
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
size_t qty16 = qty >> 4;
|
||||
|
||||
const float* pEnd1 = pVect1 + (qty16 << 4);
|
||||
|
||||
__m256 diff, v1, v2;
|
||||
__m256 sum = _mm256_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
diff = _mm256_sub_ps(v1, v2);
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
|
||||
|
||||
v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
diff = _mm256_sub_ps(v1, v2);
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
|
||||
}
|
||||
|
||||
_mm256_store_ps(TmpRes, sum);
|
||||
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
|
||||
|
||||
return (res);
|
||||
}
|
||||
|
||||
static float
|
||||
InnerProductSIMD16Ext(const float* pVect1v, const float* pVect2v, size_t qty) {
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
float* pVect1 = (float*)pVect1v;
|
||||
float* pVect2 = (float*)pVect2v;
|
||||
// size_t qty = *((size_t *) qty_ptr);
|
||||
|
||||
size_t qty16 = qty / 16;
|
||||
|
||||
const float* pEnd1 = pVect1 + 16 * qty16;
|
||||
|
||||
__m256 sum256 = _mm256_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
|
||||
|
||||
__m256 v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
__m256 v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
|
||||
}
|
||||
|
||||
_mm256_store_ps(TmpRes, sum256);
|
||||
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
|
||||
|
||||
return sum;
|
||||
}
|
||||
} // namespace HNSW
|
||||
|
||||
TEST(METRICTEST, BENCHMARK) {
|
||||
std::unordered_map<std::string, metric_func_ptr> func_map;
|
||||
func_map["FAISS::L2"] = FAISS::fvec_L2sqr_avx;
|
||||
func_map["NSG::L2"] = NSG::DistanceL2_Compare;
|
||||
func_map["HNSW::L2"] = HNSW::L2SqrSIMD16Ext;
|
||||
func_map["ANNOY::L2"] = ANNOY::euclidean_distance;
|
||||
|
||||
func_map["FAISS::IP"] = FAISS::fvec_inner_product_avx;
|
||||
func_map["NSG::IP"] = NSG::DistanceIP_Compare;
|
||||
func_map["HNSW::IP"] = HNSW::InnerProductSIMD16Ext;
|
||||
func_map["ANNOY::IP"] = ANNOY::dot;
|
||||
|
||||
std::vector<float> xb(NB * DIM);
|
||||
std::vector<float> xq(NQ * DIM);
|
||||
GenerateData(DIM, NB, xb.data());
|
||||
GenerateData(DIM, NQ, xq.data());
|
||||
|
||||
std::vector<float> distance_faiss(NB * NQ);
|
||||
std::vector<float> distance_nsg(NB * NQ);
|
||||
std::vector<float> distance_annoy(NB * NQ);
|
||||
std::vector<float> distance_hnsw(NB * NQ);
|
||||
|
||||
std::cout << "==========" << std::endl;
|
||||
TestMetricAlg(func_map, "FAISS::L2", LOOP, distance_faiss.data(), NB, xb.data(), NQ, xq.data(), DIM);
|
||||
|
||||
TestMetricAlg(func_map, "ANNOY::L2", LOOP, distance_annoy.data(), NB, xb.data(), NQ, xq.data(), DIM);
|
||||
CheckResult(distance_faiss.data(), distance_annoy.data(), NB * NQ);
|
||||
|
||||
std::cout << "==========" << std::endl;
|
||||
TestMetricAlg(func_map, "FAISS::IP", LOOP, distance_faiss.data(), NB, xb.data(), NQ, xq.data(), DIM);
|
||||
|
||||
TestMetricAlg(func_map, "ANNOY::IP", LOOP, distance_annoy.data(), NB, xb.data(), NQ, xq.data(), DIM);
|
||||
CheckResult(distance_faiss.data(), distance_annoy.data(), NB * NQ);
|
||||
}
|
Loading…
Reference in New Issue
Block a user