k-means L2 (#2258)

* k-means L2

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

* fix change log

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

Co-authored-by: Jin Hai <hai.jin@zilliz.com>
This commit is contained in:
shengjun.li 2020-05-08 22:03:50 +08:00 committed by GitHub
parent cf68c9918e
commit 4cea320943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 122 additions and 78 deletions

View File

@ -53,6 +53,7 @@ Please mark all change in change log and use the issue from GitHub
- \#2190 Fix memory usage is twice of index size when using GPU searching
- \#2248 Use hostname and port as instance label of metrics
- \#2252 Upgrade mishards APIs and requirements
- \#2256 k-means clustering algorithm use only Euclidean distance metric
## Task

View File

@ -191,7 +191,7 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
float err = 0;
for (int i = 0; i < niter; i++) {
double t0s = getmillisecs();
index.search (nx, x, 1, dis, assign);
index.assign(nx, x, assign, dis);
InterruptCallback::check();
t_search_tot += getmillisecs() - t0s;

View File

@ -36,11 +36,13 @@ void Index::range_search (idx_t , const float *, float,
FAISS_THROW_MSG ("range search not implemented");
}
void Index::assign (idx_t n, const float * x, idx_t * labels, idx_t k)
void Index::assign (idx_t n, const float *x, idx_t *labels, float *distance)
{
float * distances = new float[n * k];
ScopeDeleter<float> del(distances);
search (n, x, k, distances, labels);
float *dis_inner = (distance == nullptr) ? new float[n] : distance;
search (n, x, 1, dis_inner, labels);
if (distance == nullptr) {
delete[] dis_inner;
}
}
void Index::add_with_ids(idx_t n, const float* x, const idx_t* xids) {

View File

@ -183,9 +183,9 @@ struct Index {
*
* This function is identical as search but only return labels of neighbors.
* @param x input vectors to search, size n * d
* @param labels output labels of the NNs, size n*k
* @param labels output labels of the NNs, size n
*/
void assign (idx_t n, const float * x, idx_t * labels, idx_t k = 1);
virtual void assign (idx_t n, const float *x, idx_t *labels, float *distance = nullptr);
/// removes all elements from the database.
virtual void reset() = 0;

View File

@ -64,6 +64,30 @@ void IndexFlat::search(idx_t n, const float* x, idx_t k, float* distances, idx_t
}
}
void IndexFlat::assign(idx_t n, const float * x, idx_t * labels, float* distances)
{
// usually used in IVF k-means algorithm
float *dis_inner = (distances == nullptr) ? new float[n] : distances;
switch (metric_type) {
case METRIC_INNER_PRODUCT:
case METRIC_L2: {
// ignore the metric_type, both use L2
elkan_L2_sse(x, xb.data(), d, n, ntotal, labels, dis_inner);
break;
}
default: {
// binary metrics
// There may be something wrong, but maintain the original logic now.
Index::assign(n, x, labels, dis_inner);
break;
}
}
if (distances == nullptr) {
delete[] dis_inner;
}
}
void IndexFlat::range_search (idx_t n, const float *x, float radius,
RangeSearchResult *result,
ConcurrentBitsetPtr bitset) const

View File

@ -36,6 +36,12 @@ struct IndexFlat: Index {
idx_t* labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
void assign (
idx_t n,
const float * x,
idx_t * labels,
float* distances = nullptr) override;
void range_search(
idx_t n,
const float* x,

View File

@ -57,9 +57,9 @@ int faiss_Index_range_search(const FaissIndex* index, idx_t n, const float* x, f
} CATCH_AND_HANDLE
}
int faiss_Index_assign(FaissIndex* index, idx_t n, const float * x, idx_t * labels, idx_t k) {
int faiss_Index_assign(FaissIndex* index, idx_t n, const float * x, idx_t * labels) {
try {
reinterpret_cast<faiss::Index*>(index)->assign(n, x, labels, k);
reinterpret_cast<faiss::Index*>(index)->assign(n, x, labels);
} CATCH_AND_HANDLE
}

View File

@ -106,9 +106,9 @@ int faiss_Index_range_search(const FaissIndex* index, idx_t n, const float* x,
* This function is identical as search but only return labels of neighbors.
* @param index opaque pointer to index object
* @param x input vectors to search, size n * d
* @param labels output labels of the NNs, size n*k
* @param labels output labels of the NNs, size n
*/
int faiss_Index_assign(FaissIndex* index, idx_t n, const float * x, idx_t * labels, idx_t k);
int faiss_Index_assign(FaissIndex* index, idx_t n, const float * x, idx_t * labels);
/** removes all elements from the database.
* @param index opaque pointer to index object

View File

@ -352,68 +352,6 @@ static void knn_L2sqr_sse (
*/
}
static void elkan_L2_sse (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res) {
if (nx == 0 || ny == 0) {
return;
}
const size_t bs_y = 1024;
float *data = (float *) malloc((bs_y * (bs_y - 1) / 2) * sizeof (float));
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny) j1 = ny;
auto Y = [&](size_t i, size_t j) -> float& {
assert(i != j);
i -= j0, j -= j0;
return (i > j) ? data[j + i * (i - 1) / 2] : data[i + j * (j - 1) / 2];
};
#pragma omp parallel for
for (size_t i = j0 + 1; i < j1; i++) {
const float *y_i = y + i * d;
for (size_t j = j0; j < i; j++) {
const float *y_j = y + j * d;
Y(i, j) = sqrt(fvec_L2sqr(y_i, y_j, d));
}
}
#pragma omp parallel for
for (size_t i = 0; i < nx; i++) {
const float *x_i = x + i * d;
int64_t ids_i = j0;
float val_i = sqrt(fvec_L2sqr(x_i, y + j0 * d, d));
float val_i_2 = val_i * 2;
for (size_t j = j0 + 1; j < j1; j++) {
if (val_i_2 <= Y(ids_i, j)) {
continue;
}
const float *y_j = y + j * d;
float disij = sqrt(fvec_L2sqr(x_i, y_j, d));
if (disij < val_i) {
ids_i = j;
val_i = disij;
val_i_2 = val_i * 2;
}
}
if (j0 == 0 || res->val[i] > val_i) {
res->val[i] = val_i;
res->ids[i] = ids_i;
}
}
}
free(data);
}
/** Find the nearest neighbors for nx queries in a set of ny vectors */
static void knn_inner_product_blas (
const float * x,
@ -668,11 +606,7 @@ void knn_L2sqr (const float * x,
float_maxheap_array_t * res,
ConcurrentBitsetPtr bitset)
{
if (bitset == nullptr && res->k == 1 && nx >= ny * 2) {
// Note: L2 but not L2sqr
// usually used in IVF::train
elkan_L2_sse(x, y, d, nx, ny, res);
} else if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
knn_L2sqr_sse (x, y, d, nx, ny, res, bitset);
} else {
NopDistanceCorrection nop;
@ -1067,5 +1001,67 @@ void pairwise_L2sqr (int64_t d,
}
void elkan_L2_sse (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
int64_t *ids, float *val) {
if (nx == 0 || ny == 0) {
return;
}
const size_t bs_y = 1024;
float *data = (float *) malloc((bs_y * (bs_y - 1) / 2) * sizeof (float));
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny) j1 = ny;
auto Y = [&](size_t i, size_t j) -> float& {
assert(i != j);
i -= j0, j -= j0;
return (i > j) ? data[j + i * (i - 1) / 2] : data[i + j * (j - 1) / 2];
};
#pragma omp parallel for
for (size_t i = j0 + 1; i < j1; i++) {
const float *y_i = y + i * d;
for (size_t j = j0; j < i; j++) {
const float *y_j = y + j * d;
Y(i, j) = sqrt(fvec_L2sqr(y_i, y_j, d));
}
}
#pragma omp parallel for
for (size_t i = 0; i < nx; i++) {
const float *x_i = x + i * d;
int64_t ids_i = j0;
float val_i = sqrt(fvec_L2sqr(x_i, y + j0 * d, d));
float val_i_2 = val_i * 2;
for (size_t j = j0 + 1; j < j1; j++) {
if (val_i_2 <= Y(ids_i, j)) {
continue;
}
const float *y_j = y + j * d;
float disij = sqrt(fvec_L2sqr(x_i, y_j, d));
if (disij < val_i) {
ids_i = j;
val_i = disij;
val_i_2 = val_i * 2;
}
}
if (j0 == 0 || val[i] > val_i) {
val[i] = val_i;
ids[i] = ids_i;
}
}
}
free(data);
}
} // namespace faiss

View File

@ -247,6 +247,21 @@ void range_search_inner_product (
RangeSearchResult *result);
/***************************************************************************
* elkan
***************************************************************************/
/** Return the nearest neighors of each of the nx vectors x among the ny
*
* @param x query vectors, size nx * d
* @param y database vectors, size ny * d
* @param ids result array ids
* @param val result array value
*/
void elkan_L2_sse (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
int64_t *ids, float *val);
} // namespace faiss