mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 20:09:57 +08:00
Fix wrong IP distances (#17590)
Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
parent
0c2970f916
commit
2f66531fdf
@ -152,6 +152,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New(MetricTypeKey + " not found in search_params")
|
return errors.New(MetricTypeKey + " not found in search_params")
|
||||||
}
|
}
|
||||||
|
t.SearchRequest.MetricType = metricType
|
||||||
|
|
||||||
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, t.request.SearchParams)
|
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, t.request.SearchParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2521,6 +2521,36 @@ class TestSearchBase(TestcaseBase):
|
|||||||
default_search_exp)
|
default_search_exp)
|
||||||
assert len(res[0]) <= top_k
|
assert len(res[0]) <= top_k
|
||||||
|
|
||||||
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
|
@pytest.mark.parametrize("dim", [2, 8, 128, 768])
|
||||||
|
@pytest.mark.parametrize("nb", [1, 2, 10, 100])
|
||||||
|
def test_search_ip_brute_force(self, nb, dim):
|
||||||
|
"""
|
||||||
|
target: https://github.com/milvus-io/milvus/issues/17378. Ensure the logic of IP distances won't be changed.
|
||||||
|
method: search with the given vectors, check the result
|
||||||
|
expected: The inner product of vector themselves should be positive.
|
||||||
|
"""
|
||||||
|
top_k = 1
|
||||||
|
|
||||||
|
# 1. initialize with data
|
||||||
|
collection_w, insert_entities, _, insert_ids, _ = self.init_collection_general(prefix, True, nb,
|
||||||
|
is_binary=False,
|
||||||
|
dim=dim)[0:5]
|
||||||
|
insert_vectors = insert_entities[0][default_search_field].tolist()
|
||||||
|
|
||||||
|
# 2. load collection.
|
||||||
|
collection_w.load()
|
||||||
|
|
||||||
|
# 3. search and then check if the distances are expected.
|
||||||
|
res, _ = collection_w.search(insert_vectors[:nb], default_search_field,
|
||||||
|
ct.default_search_ip_params, top_k,
|
||||||
|
default_search_exp)
|
||||||
|
for i, v in enumerate(insert_vectors):
|
||||||
|
assert len(res[i]) == 1
|
||||||
|
ref = ip(v, v)
|
||||||
|
got = res[i][0].distance
|
||||||
|
assert abs(got - ref) <= epsilon
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
@pytest.mark.parametrize("index, params",
|
@pytest.mark.parametrize("index, params",
|
||||||
zip(ct.all_index_types[:9],
|
zip(ct.all_index_types[:9],
|
||||||
|
Loading…
Reference in New Issue
Block a user