Fix wrong IP distances (#17590)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2022-06-16 16:34:11 +08:00 committed by GitHub
parent 0c2970f916
commit 2f66531fdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 0 deletions

View File

@ -152,6 +152,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if err != nil {
return errors.New(MetricTypeKey + " not found in search_params")
}
t.SearchRequest.MetricType = metricType
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, t.request.SearchParams)
if err != nil {

View File

@ -2521,6 +2521,36 @@ class TestSearchBase(TestcaseBase):
default_search_exp)
assert len(res[0]) <= top_k
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("dim", [2, 8, 128, 768])
@pytest.mark.parametrize("nb", [1, 2, 10, 100])
def test_search_ip_brute_force(self, nb, dim):
"""
target: https://github.com/milvus-io/milvus/issues/17378. Ensure the logic of IP distances won't be changed.
method: search with the given vectors, check the result
expected: The inner product of vector themselves should be positive.
"""
top_k = 1
# 1. initialize with data
collection_w, insert_entities, _, insert_ids, _ = self.init_collection_general(prefix, True, nb,
is_binary=False,
dim=dim)[0:5]
insert_vectors = insert_entities[0][default_search_field].tolist()
# 2. load collection.
collection_w.load()
# 3. search and then check if the distances are expected.
res, _ = collection_w.search(insert_vectors[:nb], default_search_field,
ct.default_search_ip_params, top_k,
default_search_exp)
for i, v in enumerate(insert_vectors):
assert len(res[i]) == 1
ref = ip(v, v)
got = res[i][0].distance
assert abs(got - ref) <= epsilon
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("index, params",
zip(ct.all_index_types[:9],