test: add search test cases with text match (#36399)

/kind improvement

---------

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2024-09-26 10:15:14 +08:00 committed by GitHub
parent 4fd9b0a8e3
commit 954d8a5f68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 153 additions and 2 deletions

View File

@ -119,8 +119,8 @@ def split_dataframes(df, fields, language="en"):
for col in fields:
new_texts = []
for doc in df[col]:
seg_list = jieba.cut(doc)
new_texts.append(seg_list)
seg_list = jieba.cut(doc, cut_all=True)
new_texts.append(list(seg_list))
df_copy[col] = new_texts
return df_copy
for col in fields:

View File

@ -24,6 +24,11 @@ import numpy
import threading
import pytest
import pandas as pd
from faker import Faker
Faker.seed(19530)
fake_en = Faker("en_US")
fake_zh = Faker("zh_CN")
pd.set_option("expand_frame_repr", False)
@ -13154,3 +13159,149 @@ class TestCollectionSearchNoneAndDefaultData(TestcaseBase):
"output_fields": [default_int64_field_name,
default_float_field_name]})
class TestSearchWithTextMatchFilter(TestcaseBase):
"""
******************************************************************
The following cases are used to test query text match
******************************************************************
"""
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.parametrize("enable_partition_key", [True, False])
@pytest.mark.parametrize("enable_inverted_index", [True, False])
@pytest.mark.parametrize("tokenizer", ["jieba", "default"])
def test_search_with_text_match_filter_normal(
self, tokenizer, enable_inverted_index, enable_partition_key
):
"""
target: test text match normal
method: 1. enable text match and insert data with varchar
2. get the most common words and query with text match
3. verify the result
expected: text match successfully and result is correct
"""
analyzer_params = {
"tokenizer": tokenizer,
}
dim = 128
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(
name="word",
dtype=DataType.VARCHAR,
max_length=65535,
enable_match=True,
is_partition_key=enable_partition_key,
analyzer_params=analyzer_params,
),
FieldSchema(
name="sentence",
dtype=DataType.VARCHAR,
max_length=65535,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="paragraph",
dtype=DataType.VARCHAR,
max_length=65535,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="text",
dtype=DataType.VARCHAR,
max_length=65535,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
]
schema = CollectionSchema(fields=fields, description="test collection")
data_size = 5000
collection_w = self.init_collection_wrap(
name=cf.gen_unique_str(prefix), schema=schema
)
fake = fake_en
if tokenizer == "jieba":
language = "zh"
fake = fake_zh
else:
language = "en"
data = [
{
"id": i,
"word": fake.word().lower(),
"sentence": fake.sentence().lower(),
"paragraph": fake.paragraph().lower(),
"text": fake.text().lower(),
"emb": [random.random() for _ in range(dim)],
}
for i in range(data_size)
]
df = pd.DataFrame(data)
log.info(f"dataframe\n{df}")
batch_size = 5000
for i in range(0, len(df), batch_size):
collection_w.insert(
data[i : i + batch_size]
if i + batch_size < len(df)
else data[i : len(df)]
)
collection_w.flush()
collection_w.create_index(
"emb",
{"index_type": "HNSW", "metric_type": "L2", "params": {"M": 16, "efConstruction": 500}},
)
if enable_inverted_index:
collection_w.create_index("word", {"index_type": "INVERTED"})
collection_w.load()
# analyze the croup
text_fields = ["word", "sentence", "paragraph", "text"]
wf_map = {}
for field in text_fields:
wf_map[field] = cf.analyze_documents(df[field].tolist(), language=language)
# query single field for one token
df_split = cf.split_dataframes(df, text_fields, language=language)
log.info(f"df_split\n{df_split}")
for field in text_fields:
token = wf_map[field].most_common()[0][0]
expr = f"TextMatch({field}, '{token}')"
manual_result = df_split[
df_split.apply(lambda row: token in row[field], axis=1)
]
log.info(f"expr: {expr}, manual_check_result\n: {manual_result}")
res_list, _ = collection_w.search(
data=[[random.random() for _ in range(dim)]],
anns_field="emb",
param={},
limit=100,
expr=expr, output_fields=["id", field])
for res in res_list:
assert len(res) > 0
log.info(f"res len {len(res)} res {res}")
for r in res:
r = r.to_dict()
assert token in r["entity"][field]
# query single field for multi-word
for field in text_fields:
# match top 10 most common words
top_10_tokens = []
for word, count in wf_map[field].most_common(10):
top_10_tokens.append(word)
string_of_top_10_words = " ".join(top_10_tokens)
expr = f"TextMatch({field}, '{string_of_top_10_words}')"
log.info(f"expr {expr}")
res_list, _ = collection_w.search(
data=[[random.random() for _ in range(dim)]],
anns_field="emb",
param={},
limit=100,
expr=expr, output_fields=["id", field])
for res in res_list:
log.info(f"res len {len(res)} res {res}")
for r in res:
r = r.to_dict()
assert any([token in r["entity"][field] for token in top_10_tokens])