mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
test: add search test cases with text match (#36399)
/kind improvement --------- Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
4fd9b0a8e3
commit
954d8a5f68
@ -119,8 +119,8 @@ def split_dataframes(df, fields, language="en"):
|
||||
for col in fields:
|
||||
new_texts = []
|
||||
for doc in df[col]:
|
||||
seg_list = jieba.cut(doc)
|
||||
new_texts.append(seg_list)
|
||||
seg_list = jieba.cut(doc, cut_all=True)
|
||||
new_texts.append(list(seg_list))
|
||||
df_copy[col] = new_texts
|
||||
return df_copy
|
||||
for col in fields:
|
||||
|
@ -24,6 +24,11 @@ import numpy
|
||||
import threading
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from faker import Faker
|
||||
|
||||
Faker.seed(19530)
|
||||
fake_en = Faker("en_US")
|
||||
fake_zh = Faker("zh_CN")
|
||||
pd.set_option("expand_frame_repr", False)
|
||||
|
||||
|
||||
@ -13154,3 +13159,149 @@ class TestCollectionSearchNoneAndDefaultData(TestcaseBase):
|
||||
"output_fields": [default_int64_field_name,
|
||||
default_float_field_name]})
|
||||
|
||||
class TestSearchWithTextMatchFilter(TestcaseBase):
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test query text match
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("enable_partition_key", [True, False])
|
||||
@pytest.mark.parametrize("enable_inverted_index", [True, False])
|
||||
@pytest.mark.parametrize("tokenizer", ["jieba", "default"])
|
||||
def test_search_with_text_match_filter_normal(
|
||||
self, tokenizer, enable_inverted_index, enable_partition_key
|
||||
):
|
||||
"""
|
||||
target: test text match normal
|
||||
method: 1. enable text match and insert data with varchar
|
||||
2. get the most common words and query with text match
|
||||
3. verify the result
|
||||
expected: text match successfully and result is correct
|
||||
"""
|
||||
analyzer_params = {
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
dim = 128
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(
|
||||
name="word",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
is_partition_key=enable_partition_key,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="sentence",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="paragraph",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(
|
||||
name="text",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_match=True,
|
||||
analyzer_params=analyzer_params,
|
||||
),
|
||||
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
]
|
||||
schema = CollectionSchema(fields=fields, description="test collection")
|
||||
data_size = 5000
|
||||
collection_w = self.init_collection_wrap(
|
||||
name=cf.gen_unique_str(prefix), schema=schema
|
||||
)
|
||||
fake = fake_en
|
||||
if tokenizer == "jieba":
|
||||
language = "zh"
|
||||
fake = fake_zh
|
||||
else:
|
||||
language = "en"
|
||||
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"word": fake.word().lower(),
|
||||
"sentence": fake.sentence().lower(),
|
||||
"paragraph": fake.paragraph().lower(),
|
||||
"text": fake.text().lower(),
|
||||
"emb": [random.random() for _ in range(dim)],
|
||||
}
|
||||
for i in range(data_size)
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
log.info(f"dataframe\n{df}")
|
||||
batch_size = 5000
|
||||
for i in range(0, len(df), batch_size):
|
||||
collection_w.insert(
|
||||
data[i : i + batch_size]
|
||||
if i + batch_size < len(df)
|
||||
else data[i : len(df)]
|
||||
)
|
||||
collection_w.flush()
|
||||
collection_w.create_index(
|
||||
"emb",
|
||||
{"index_type": "HNSW", "metric_type": "L2", "params": {"M": 16, "efConstruction": 500}},
|
||||
)
|
||||
if enable_inverted_index:
|
||||
collection_w.create_index("word", {"index_type": "INVERTED"})
|
||||
collection_w.load()
|
||||
# analyze the croup
|
||||
text_fields = ["word", "sentence", "paragraph", "text"]
|
||||
wf_map = {}
|
||||
for field in text_fields:
|
||||
wf_map[field] = cf.analyze_documents(df[field].tolist(), language=language)
|
||||
# query single field for one token
|
||||
df_split = cf.split_dataframes(df, text_fields, language=language)
|
||||
log.info(f"df_split\n{df_split}")
|
||||
for field in text_fields:
|
||||
token = wf_map[field].most_common()[0][0]
|
||||
expr = f"TextMatch({field}, '{token}')"
|
||||
manual_result = df_split[
|
||||
df_split.apply(lambda row: token in row[field], axis=1)
|
||||
]
|
||||
log.info(f"expr: {expr}, manual_check_result\n: {manual_result}")
|
||||
res_list, _ = collection_w.search(
|
||||
data=[[random.random() for _ in range(dim)]],
|
||||
anns_field="emb",
|
||||
param={},
|
||||
limit=100,
|
||||
expr=expr, output_fields=["id", field])
|
||||
for res in res_list:
|
||||
assert len(res) > 0
|
||||
log.info(f"res len {len(res)} res {res}")
|
||||
for r in res:
|
||||
r = r.to_dict()
|
||||
assert token in r["entity"][field]
|
||||
|
||||
# query single field for multi-word
|
||||
for field in text_fields:
|
||||
# match top 10 most common words
|
||||
top_10_tokens = []
|
||||
for word, count in wf_map[field].most_common(10):
|
||||
top_10_tokens.append(word)
|
||||
string_of_top_10_words = " ".join(top_10_tokens)
|
||||
expr = f"TextMatch({field}, '{string_of_top_10_words}')"
|
||||
log.info(f"expr {expr}")
|
||||
res_list, _ = collection_w.search(
|
||||
data=[[random.random() for _ in range(dim)]],
|
||||
anns_field="emb",
|
||||
param={},
|
||||
limit=100,
|
||||
expr=expr, output_fields=["id", field])
|
||||
for res in res_list:
|
||||
log.info(f"res len {len(res)} res {res}")
|
||||
for r in res:
|
||||
r = r.to_dict()
|
||||
assert any([token in r["entity"][field] for token in top_10_tokens])
|
||||
|
Loading…
Reference in New Issue
Block a user